mirror of
https://github.com/github/codeql.git
synced 2026-06-06 14:07:06 +02:00
Python: Copy Python extractor to codeql repo
This commit is contained in:
153
python/extractor/semmle/python/parser/__init__.py
Normal file
153
python/extractor/semmle/python/parser/__init__.py
Normal file
@@ -0,0 +1,153 @@
|
||||
|
||||
# Black's version of lib2to3 (modified)
|
||||
from blib2to3.pytree import type_repr
|
||||
from blib2to3 import pygram
|
||||
from blib2to3.pgen2 import driver, token
|
||||
from blib2to3.pgen2.parse import ParseError, Parser
|
||||
from . import ast
|
||||
from blib2to3.pgen2 import tokenize, grammar
|
||||
from blib2to3.pgen2.token import tok_name
|
||||
from semmle.profiling import timers
|
||||
|
||||
pygram.initialize()
|
||||
syms = pygram.python_symbols
|
||||
|
||||
|
||||
GRAMMARS = [
|
||||
("Python 3", pygram.python3_grammar),
|
||||
("Python 3 without async", pygram.python3_grammar_no_async),
|
||||
("Python 2 with print as function", pygram.python2_grammar_no_print_statement),
|
||||
("Python 2", pygram.python2_grammar),
|
||||
]
|
||||
|
||||
|
||||
SKIP_IF_SINGLE_CHILD_NAMES = {
|
||||
'atom',
|
||||
'power',
|
||||
'test',
|
||||
'not_test',
|
||||
'and_test',
|
||||
'or_test',
|
||||
'suite',
|
||||
'testlist',
|
||||
'expr',
|
||||
'xor_expr',
|
||||
'and_expr',
|
||||
'shift_expr',
|
||||
'arith_expr',
|
||||
'term',
|
||||
'factor',
|
||||
'testlist_gexp',
|
||||
'exprlist',
|
||||
'testlist_safe',
|
||||
'old_test',
|
||||
'comparison',
|
||||
}
|
||||
|
||||
SKIP_IF_SINGLE_CHILD = {
|
||||
val for name, val in
|
||||
syms.__dict__.items()
|
||||
if name in SKIP_IF_SINGLE_CHILD_NAMES
|
||||
}
|
||||
|
||||
|
||||
class Leaf(object):
|
||||
|
||||
__slots__ = "type", "value", "start", "end"
|
||||
|
||||
def __init__(self, type, value, start, end):
|
||||
self.type = type
|
||||
self.value = value
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a canonical string representation."""
|
||||
return "%s(%s, %r)" % (self.__class__.__name__,
|
||||
self.name,
|
||||
self.value)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return tok_name.get(self.type, self.type)
|
||||
|
||||
class Node(object):
|
||||
|
||||
__slots__ = "type", "children", "used_names"
|
||||
|
||||
def __init__(self, type, children):
|
||||
self.type = type
|
||||
self.children = children
|
||||
|
||||
@property
|
||||
def start(self):
|
||||
node = self
|
||||
while isinstance(node, Node):
|
||||
node = node.children[0]
|
||||
return node.start
|
||||
|
||||
@property
|
||||
def end(self):
|
||||
node = self
|
||||
while isinstance(node, Node):
|
||||
node = node.children[-1]
|
||||
return node.end
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a canonical string representation."""
|
||||
return "%s(%s, %r)" % (self.__class__.__name__,
|
||||
self.name,
|
||||
self.children)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return type_repr(self.type)
|
||||
|
||||
def convert(gr, raw_node):
|
||||
type, value, context, children = raw_node
|
||||
if children or type in gr.number2symbol:
|
||||
# If there's exactly one child, return that child instead of
|
||||
# creating a new node.
|
||||
if len(children) == 1 and type in SKIP_IF_SINGLE_CHILD:
|
||||
return children[0]
|
||||
return Node(type, children)
|
||||
else:
|
||||
start, end = context
|
||||
return Leaf(type, value, start, end)
|
||||
|
||||
def parse_tokens(gr, tokens):
|
||||
"""Parse a series of tokens and return the syntax tree."""
|
||||
p = Parser(gr, convert)
|
||||
p.setup()
|
||||
for tkn in tokens:
|
||||
type, value, start, end = tkn
|
||||
if type in (tokenize.COMMENT, tokenize.NL):
|
||||
continue
|
||||
if type == token.OP:
|
||||
type = grammar.opmap[value]
|
||||
if type == token.INDENT:
|
||||
value = ""
|
||||
if p.addtoken(type, value, (start, end)):
|
||||
break
|
||||
else:
|
||||
# We never broke out -- EOF is too soon (how can this happen???)
|
||||
raise parse.ParseError("incomplete input",
|
||||
type, value, ("", start))
|
||||
return p.rootnode
|
||||
|
||||
|
||||
def parse(tokens, logger):
|
||||
"""Given a string with source, return the lib2to3 Node."""
|
||||
for name, grammar in GRAMMARS:
|
||||
try:
|
||||
with timers["parse"]:
|
||||
cpt = parse_tokens(grammar, tokens)
|
||||
with timers["rewrite"]:
|
||||
return ast.convert(logger, cpt)
|
||||
except ParseError as pe:
|
||||
lineno, column = pe.context[1]
|
||||
logger.debug("%s at line %d, column %d using grammar for %s", pe, lineno, column, name)
|
||||
exc = SyntaxError("Syntax Error")
|
||||
exc.lineno = lineno
|
||||
exc.offset = column
|
||||
raise exc
|
||||
1491
python/extractor/semmle/python/parser/ast.py
Normal file
1491
python/extractor/semmle/python/parser/ast.py
Normal file
File diff suppressed because it is too large
Load Diff
151
python/extractor/semmle/python/parser/dump_ast.py
Normal file
151
python/extractor/semmle/python/parser/dump_ast.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# dump_ast.py
|
||||
|
||||
# Functions for dumping the internal Python AST in a human-readable format.
|
||||
|
||||
import sys
|
||||
import semmle.python.parser.tokenizer
|
||||
import semmle.python.parser.tsg_parser
|
||||
from semmle.python.parser.tsg_parser import ast_fields
|
||||
from semmle.python import ast
|
||||
from semmle import logging
|
||||
from semmle.python.modules import PythonSourceModule
|
||||
|
||||
|
||||
|
||||
def get_fields(cls):
|
||||
"""Gets the fields of the given class, followed by the fields of its (single-inheritance)
|
||||
superclasses, if any.
|
||||
Only includes fields for classes in `ast_fields`."""
|
||||
if cls not in ast_fields:
|
||||
return ()
|
||||
s = cls.__bases__[0]
|
||||
return ast_fields[cls] + get_fields(s)
|
||||
|
||||
def missing_fields(known, node):
|
||||
"""Returns a list of fields in `node` that are not in `known`."""
|
||||
return [field
|
||||
for field in dir(node)
|
||||
if field not in known
|
||||
and not field.startswith("_")
|
||||
and not field in ("lineno", "col_offset")
|
||||
and not (isinstance(node, ast.Name) and field == "id")
|
||||
]
|
||||
|
||||
class AstDumper(object):
|
||||
def __init__(self, output=sys.stdout, no_locations=False):
|
||||
self.output = output
|
||||
self.show_locations = not no_locations
|
||||
|
||||
def visit(self, node, level=0, visited=None):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if node in visited:
|
||||
output.write("{} CYCLE DETECTED!\n".format(indent))
|
||||
return
|
||||
visited = visited.union({node})
|
||||
output = self.output
|
||||
cls = node.__class__
|
||||
name = cls.__name__
|
||||
indent = ' ' * level
|
||||
if node is None: # Special case for `None` to avoid printing `NoneType`.
|
||||
name = 'None'
|
||||
if cls == str: # Special case for bare strings
|
||||
output.write("{}{}\n".format(indent, repr(node)))
|
||||
return
|
||||
# In some places, we have non-AST nodes in lists, and since these don't have a location, we
|
||||
# simply print their name instead.
|
||||
# `ast.arguments` is special -- it has fields but no location
|
||||
if hasattr(node, 'lineno') and not isinstance(node, ast.arguments) and self.show_locations:
|
||||
position = (node.lineno, node.col_offset, node._end[0], node._end[1])
|
||||
output.write("{}{}: [{}, {}] - [{}, {}]\n".format(indent, name, *position))
|
||||
else:
|
||||
output.write("{}{}\n".format(indent, name))
|
||||
|
||||
|
||||
fields = get_fields(cls)
|
||||
unknown = missing_fields(fields, node)
|
||||
if unknown:
|
||||
output.write("{}UNKNOWN FIELDS: {}\n".format(indent, unknown))
|
||||
for field in fields:
|
||||
value = getattr(node, field, None)
|
||||
# By default, the `parenthesised` field on expressions has no value, so it's easier to
|
||||
# just not print it in that case.
|
||||
if field == "parenthesised" and value is None:
|
||||
continue
|
||||
# Likewise, the default value for `is_async` is `False`, so we don't need to print it.
|
||||
if field == "is_async" and value is False:
|
||||
continue
|
||||
output.write("{} {}:".format(indent,field))
|
||||
if isinstance(value, list):
|
||||
output.write(" [")
|
||||
if len(value) == 0:
|
||||
output.write("]\n")
|
||||
continue
|
||||
output.write("\n")
|
||||
for n in value:
|
||||
self.visit(n, level+2, visited)
|
||||
output.write("{} ]\n".format(indent))
|
||||
# Some AST classes are special in that the identity of the object is the only thing
|
||||
# that matters (and they have no location info). For this reason we simply print the name.
|
||||
elif isinstance(value, (ast.expr_context, ast.boolop, ast.cmpop, ast.operator, ast.unaryop)):
|
||||
output.write(' {}\n'.format(value.__class__.__name__))
|
||||
elif isinstance(value, ast.AstBase):
|
||||
output.write("\n")
|
||||
self.visit(value, level+2, visited)
|
||||
else:
|
||||
output.write(' {}\n'.format(repr(value)))
|
||||
|
||||
|
||||
class StdoutLogger(logging.Logger):
|
||||
def log(self, level, fmt, *args):
|
||||
sys.stdout.write(fmt % args + "\n")
|
||||
|
||||
def old_parser(inputfile, logger):
|
||||
mod = PythonSourceModule(None, inputfile, logger)
|
||||
logger.close()
|
||||
return mod.old_py_ast
|
||||
|
||||
def args_parser():
|
||||
'Parse command_line, returning options, arguments'
|
||||
from optparse import OptionParser
|
||||
usage = "usage: %prog [options] python-file"
|
||||
parser = OptionParser(usage=usage)
|
||||
parser.add_option("-o", "--old", help="Dump old AST.", action="store_true")
|
||||
parser.add_option("-n", "--new", help="Dump new AST.", action="store_true")
|
||||
parser.add_option("-l", "--no-locations", help="Don't include location info in dump", action="store_true")
|
||||
parser.add_option("-d", "--debug", help="Print debug information.", action="store_true")
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = args_parser()
|
||||
options, args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
if options.debug:
|
||||
global DEBUG
|
||||
DEBUG = True
|
||||
|
||||
if len(args) != 1:
|
||||
sys.stderr.write("Error: wrong number of arguments.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
inputfile = args[0]
|
||||
|
||||
if options.old and options.new:
|
||||
sys.stderr.write("Error: options --old and --new are mutually exclusive.\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not (options.old or options.new):
|
||||
sys.stderr.write("Error: Must specify either --old or --new.\n")
|
||||
sys.exit(1)
|
||||
|
||||
with StdoutLogger() as logger:
|
||||
|
||||
if options.old:
|
||||
ast = old_parser(inputfile, logger)
|
||||
else:
|
||||
ast = semmle.python.parser.tsg_parser.parse(inputfile, logger)
|
||||
AstDumper(no_locations=options.no_locations).visit(ast)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
1146
python/extractor/semmle/python/parser/tokenizer.py
Normal file
1146
python/extractor/semmle/python/parser/tokenizer.py
Normal file
File diff suppressed because it is too large
Load Diff
495
python/extractor/semmle/python/parser/tsg_parser.py
Normal file
495
python/extractor/semmle/python/parser/tsg_parser.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# tsg_parser.py
|
||||
|
||||
# Functions and classes used for parsing Python files using `tree-sitter-graph`
|
||||
|
||||
from ast import literal_eval
|
||||
import sys
|
||||
import os
|
||||
import semmle.python.parser
|
||||
from semmle.python.parser.ast import copy_location, decode_str, split_string
|
||||
from semmle.python import ast
|
||||
import subprocess
|
||||
from itertools import groupby
|
||||
|
||||
DEBUG = False
|
||||
def debug_print(*args, **kwargs):
|
||||
if DEBUG:
|
||||
print(*args, **kwargs)
|
||||
|
||||
# Node ids are integers, and so to distinguish them from actual integers we wrap them in this class.
|
||||
class Node(object):
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
def __repr__(self):
|
||||
return "Node({})".format(self.id)
|
||||
|
||||
# A wrapper for nodes containing comments. The old parser does not create such nodes (and therefore
|
||||
# there is no `ast.Comment` class) since it accesses the comments via the tokens for the given file.
|
||||
class Comment(object):
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
def __repr__(self):
|
||||
return "Comment({})".format(self.text)
|
||||
|
||||
class SyntaxErrorNode(object):
|
||||
def __init__(self, source):
|
||||
self.source = source
|
||||
def __repr__(self):
|
||||
return "SyntaxErrorNode({})".format(self.source)
|
||||
|
||||
# Mapping from tree-sitter CPT node kinds to their corresponding AST node classes.
|
||||
tsg_to_ast = {name: cls
|
||||
for name, cls in semmle.python.ast.__dict__.items()
|
||||
if isinstance(cls, type) and ast.AstBase in cls.__mro__
|
||||
}
|
||||
tsg_to_ast["Comment"] = Comment
|
||||
tsg_to_ast["SyntaxErrorNode"] = SyntaxErrorNode
|
||||
|
||||
# Mapping from AST node class to the fields of the node. The order of the fields is the order in
|
||||
# which they will be output in the AST dump.
|
||||
#
|
||||
# These fields cannot be extracted automatically, so we set them manually.
|
||||
ast_fields = {
|
||||
ast.Module: ("body",), # Note: has no `__slots__` to inspect
|
||||
Comment: ("text",), # Note: not an `ast` class
|
||||
SyntaxErrorNode: ("source",), # Note: not an `ast` class
|
||||
ast.Continue: (),
|
||||
ast.Break: (),
|
||||
ast.Pass: (),
|
||||
ast.Ellipsis: (),
|
||||
ast.MatchWildcardPattern: (),
|
||||
}
|
||||
|
||||
# Fields that we don't want to dump on every single AST node. These are just the slots of the AST
|
||||
# base class, consisting of all of the location information (which we print in a different way).
|
||||
ignored_fields = semmle.python.ast.AstBase.__slots__
|
||||
|
||||
# Extract fields for the remaining AST classes
|
||||
for name, cls in semmle.python.ast.__dict__.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if not hasattr(cls, "__slots__"):
|
||||
continue
|
||||
slots = tuple(field for field in cls.__slots__ if field not in ignored_fields)
|
||||
if not slots:
|
||||
continue
|
||||
ast_fields[cls] = slots
|
||||
|
||||
# A mapping from strings to the AST node classes that represent things like operators.
|
||||
# These have to be handled specially, because they have no location information.
|
||||
locationless = {
|
||||
"and": ast.And,
|
||||
"or": ast.Or,
|
||||
"not": ast.Not,
|
||||
"uadd": ast.UAdd,
|
||||
"usub": ast.USub,
|
||||
"+": ast.Add,
|
||||
"-": ast.Sub,
|
||||
"~": ast.Invert,
|
||||
"**": ast.Pow,
|
||||
"<<": ast.LShift,
|
||||
">>": ast.RShift,
|
||||
"&": ast.BitAnd,
|
||||
"|": ast.BitOr,
|
||||
"^": ast.BitXor,
|
||||
"load": ast.Load,
|
||||
"store": ast.Store,
|
||||
"del" : ast.Del,
|
||||
"param" : ast.Param,
|
||||
}
|
||||
locationless.update(semmle.python.parser.ast.TERM_OP_CLASSES)
|
||||
locationless.update(semmle.python.parser.ast.COMP_OP_CLASSES)
|
||||
locationless.update(semmle.python.parser.ast.AUG_ASSIGN_OPS)
|
||||
|
||||
if 'CODEQL_EXTRACTOR_PYTHON_ROOT' in os.environ:
|
||||
platform = os.environ['CODEQL_PLATFORM']
|
||||
ext = ".exe" if platform == "win64" else ""
|
||||
tools = os.path.join(os.environ['CODEQL_EXTRACTOR_PYTHON_ROOT'], "tools", platform)
|
||||
tsg_command = [os.path.join(tools, "tsg-python" + ext )]
|
||||
else:
|
||||
# Get the path to the current script
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
tsg_python_path = os.path.join(script_path, "../../../tsg-python")
|
||||
cargo_file = os.path.join(tsg_python_path, "Cargo.toml")
|
||||
tsg_command = ["cargo", "run", "--quiet", "--release", "--manifest-path="+cargo_file]
|
||||
|
||||
def read_tsg_python_output(path, logger):
|
||||
# Mapping from node id (an integer) to a dictionary containing attribute data.
|
||||
node_attr = {}
|
||||
# Mapping a start node to a map from attribute names to lists of (value, end_node) pairs.
|
||||
edge_attr = {}
|
||||
|
||||
command_args = tsg_command + [path]
|
||||
p = subprocess.Popen(command_args, stdout=subprocess.PIPE)
|
||||
for line in p.stdout:
|
||||
line = line.decode(sys.getfilesystemencoding())
|
||||
line = line.rstrip()
|
||||
if line.startswith("node"): # e.g. `node 5`
|
||||
current_node = int(line.split(" ")[1])
|
||||
d = {}
|
||||
node_attr[current_node] = d
|
||||
in_node = True
|
||||
elif line.startswith("edge"): # e.g. `edge 5 -> 6`
|
||||
current_start, current_end = tuple(map(int, line[4:].split("->")))
|
||||
d = edge_attr.setdefault(current_start, {})
|
||||
in_node = False
|
||||
else: # attribute, e.g. `_kind: "Class"`
|
||||
key, value = line[2:].split(": ", 1)
|
||||
if value.startswith("[graph node"): # e.g. `_skip_to: [graph node 5]`
|
||||
value = Node(int(value.split(" ")[2][:-1]))
|
||||
elif value == "#true": # e.g. `_is_parenthesised: #true`
|
||||
value = True
|
||||
elif value == "#false": # e.g. `top: #false`
|
||||
value = False
|
||||
elif value == "#null": # e.g. `exc: #null`
|
||||
value = None
|
||||
else: # literal values, e.g. `name: "k1.k2"` or `level: 5`
|
||||
try:
|
||||
if key =="s" and value[0] == '"': # e.g. `s: "k1.k2"`
|
||||
value = evaluate_string(value)
|
||||
else:
|
||||
value = literal_eval(value)
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
value = value.decode(sys.getfilesystemencoding())
|
||||
except UnicodeDecodeError:
|
||||
# just include the bytes as-is
|
||||
pass
|
||||
except Exception as ex:
|
||||
# We may not know the location at this point -- for instance if we forgot to set
|
||||
# it -- but `get_location_info` will degrade gracefully in this case.
|
||||
loc = ":".join(str(i) for i in get_location_info(d))
|
||||
error = ex.args[0] if ex.args else "unknown"
|
||||
logger.warning("Error '{}' while parsing value {} at {}:{}\n".format(error, repr(value), path, loc))
|
||||
if in_node:
|
||||
d[key] = value
|
||||
else:
|
||||
d.setdefault(key, []).append((value, current_end))
|
||||
p.stdout.close()
|
||||
p.terminate()
|
||||
p.wait()
|
||||
logger.info("Read {} nodes and {} edges from TSG output".format(len(node_attr), len(edge_attr)))
|
||||
return node_attr, edge_attr
|
||||
|
||||
def evaluate_string(s):
|
||||
s = literal_eval(s)
|
||||
prefix, quotes, content = split_string(s, None)
|
||||
ends_with_illegal_character = False
|
||||
# If the string ends with the same quote character as the outer quotes (and/or backslashes)
|
||||
# (e.g. the first string part of `f"""hello"{0}"""`), we must take care to not accidently create
|
||||
# the ending quotes at the wrong place. To do this, we insert an extra space at the end (that we
|
||||
# then must remember to remove later on.)
|
||||
if content.endswith(quotes[0]) or content.endswith('\\'):
|
||||
ends_with_illegal_character = True
|
||||
content = content + " "
|
||||
s = prefix.strip("fF") + quotes + content + quotes
|
||||
s = literal_eval(s)
|
||||
if isinstance(s, bytes):
|
||||
s = decode_str(s)
|
||||
if ends_with_illegal_character:
|
||||
s = s[:-1]
|
||||
return s
|
||||
|
||||
def resolve_node_id(id, node_attr):
|
||||
"""Finds the end of a sequence of nodes linked by `_skip_to` fields, starting at `id`."""
|
||||
while "_skip_to" in node_attr[id]:
|
||||
id = node_attr[id]["_skip_to"].id
|
||||
return id
|
||||
|
||||
def get_context(id, node_attr, logger):
|
||||
"""Gets the context of the node with the given `id`. This is either whatever is stored in the
|
||||
`ctx` attribute of the node, or the result of dereferencing a sequence of `_inherited_ctx` attributes."""
|
||||
|
||||
while "ctx" not in node_attr[id]:
|
||||
if "_inherited_ctx" not in node_attr[id]:
|
||||
logger.error("No context for node {} with attributes {}\n".format(id, node_attr[id]))
|
||||
# A missing context is most likely to be a "load", so return that.
|
||||
return ast.Load()
|
||||
id = node_attr[id]["_inherited_ctx"].id
|
||||
return locationless[node_attr[id]["ctx"]]()
|
||||
|
||||
def get_location_info(attrs):
|
||||
"""Returns the location information for a node, depending on which fields are set.
|
||||
|
||||
In particular, more specific fields take precedence over (and overwrite) less specific fields.
|
||||
So, `_start_line` and `_start_column` take precedence over `location_start`, which takes
|
||||
precedence over `_location`. Likewise when `end` replaces `start` above.
|
||||
|
||||
If part of the location information is missing, the string `"???"` is substituted for the
|
||||
missing bits.
|
||||
"""
|
||||
start_line = "???"
|
||||
start_column = "???"
|
||||
end_line = "???"
|
||||
end_column = "???"
|
||||
if "_location" in attrs:
|
||||
(start_line, start_column, end_line, end_column) = attrs["_location"]
|
||||
if "_location_start" in attrs:
|
||||
(start_line, start_column) = attrs["_location_start"]
|
||||
if "_location_end" in attrs:
|
||||
(end_line, end_column) = attrs["_location_end"]
|
||||
if "_start_line" in attrs:
|
||||
start_line = attrs["_start_line"]
|
||||
if "_start_column" in attrs:
|
||||
start_column = attrs["_start_column"]
|
||||
if "_end_line" in attrs:
|
||||
end_line = attrs["_end_line"]
|
||||
if "_end_column" in attrs:
|
||||
end_column = attrs["_end_column"]
|
||||
# Lines in the `tsg-python` output is 0-indexed, but the AST expects them to be 1-indexed.
|
||||
if start_line != "???":
|
||||
start_line += 1
|
||||
if end_line != "???":
|
||||
end_line += 1
|
||||
return (start_line, start_column, end_line, end_column)
|
||||
|
||||
list_fields = {
|
||||
ast.arguments: ("annotations", "defaults", "kw_defaults", "kw_annotations"),
|
||||
ast.Assign: ("targets",),
|
||||
ast.BoolOp: ("values",),
|
||||
ast.Bytes: ("implicitly_concatenated_parts",),
|
||||
ast.Call: ("positional_args", "named_args"),
|
||||
ast.Case: ("body",),
|
||||
ast.Class: ("body",),
|
||||
ast.ClassExpr: ("type_parameters", "bases", "keywords"),
|
||||
ast.Compare: ("ops", "comparators",),
|
||||
ast.comprehension: ("ifs",),
|
||||
ast.Delete: ("targets",),
|
||||
ast.Dict: ("items",),
|
||||
ast.ExceptStmt: ("body",),
|
||||
ast.For: ("body",),
|
||||
ast.Function: ("type_parameters", "args", "kwonlyargs", "body"),
|
||||
ast.Global: ("names",),
|
||||
ast.If: ("body",),
|
||||
ast.Import: ("names",),
|
||||
ast.List: ("elts",),
|
||||
ast.Match: ("cases",),
|
||||
ast.MatchClassPattern: ("positional", "keyword"),
|
||||
ast.MatchMappingPattern: ("mappings",),
|
||||
ast.MatchOrPattern: ("patterns",),
|
||||
ast.MatchSequencePattern: ("patterns",),
|
||||
ast.Module: ("body",),
|
||||
ast.Nonlocal: ("names",),
|
||||
ast.Print: ("values",),
|
||||
ast.Set: ("elts",),
|
||||
ast.Str: ("implicitly_concatenated_parts",),
|
||||
ast.TypeAlias: ("type_parameters",),
|
||||
ast.Try: ("body", "handlers", "orelse", "finalbody"),
|
||||
ast.Tuple: ("elts",),
|
||||
ast.While: ("body",),
|
||||
# ast.FormattedStringLiteral: ("arguments",),
|
||||
}
|
||||
|
||||
def create_placeholder_args(cls):
|
||||
""" Returns a dictionary containing the placeholder arguments necessary to create an AST node.
|
||||
|
||||
In most cases these arguments will be assigned the value `None`, however for a few classes we
|
||||
must substitute the empty list, as this is enforced by asserts in the constructor.
|
||||
"""
|
||||
if cls in (ast.Raise, ast.Ellipsis):
|
||||
return {}
|
||||
fields = ast_fields[cls]
|
||||
args = {field: None for field in fields if field != "is_async"}
|
||||
for field in list_fields.get(cls, ()):
|
||||
args[field] = []
|
||||
if cls in (ast.GeneratorExp, ast.ListComp, ast.SetComp, ast.DictComp):
|
||||
del args["function"]
|
||||
del args["iterable"]
|
||||
return args
|
||||
|
||||
def parse(path, logger):
|
||||
node_attr, edge_attr = read_tsg_python_output(path, logger)
|
||||
debug_print("node_attr:", node_attr)
|
||||
debug_print("edge_attr:", edge_attr)
|
||||
nodes = {}
|
||||
# Nodes that need to be fixed up after building the graph
|
||||
fixups = {}
|
||||
# Reverse index from node object to node id.
|
||||
node_id = {}
|
||||
# Create all the node objects
|
||||
for id, attrs in node_attr.items():
|
||||
if "_is_literal" in attrs:
|
||||
nodes[id] = attrs["_is_literal"]
|
||||
continue
|
||||
if "_kind" not in attrs:
|
||||
logger.error("Error: Graph node {} with attributes {} has no `_kind`!\n".format(id, attrs))
|
||||
continue
|
||||
# This is not the node we are looking for (so don't bother creating it).
|
||||
if "_skip_to" in attrs:
|
||||
continue
|
||||
cls = tsg_to_ast[attrs["_kind"]]
|
||||
args = ast_fields[cls]
|
||||
obj = cls(**create_placeholder_args(cls))
|
||||
nodes[id] = obj
|
||||
node_id[obj] = id
|
||||
# If this node needs fixing up afterwards, add it to the fixups map.
|
||||
if "_fixup" in attrs:
|
||||
fixups[id] = obj
|
||||
# Set all of the node attributes
|
||||
for id, node in nodes.items():
|
||||
attrs = node_attr[id]
|
||||
if "_is_literal" in attrs:
|
||||
continue
|
||||
expected_fields = ast_fields[type(node)]
|
||||
|
||||
# Set up location information.
|
||||
node.lineno, node.col_offset, end_line, end_column = get_location_info(attrs)
|
||||
node._end = (end_line, end_column)
|
||||
|
||||
if isinstance(node, SyntaxErrorNode):
|
||||
exc = SyntaxError("Syntax Error")
|
||||
exc.lineno = node.lineno
|
||||
exc.offset = node.col_offset
|
||||
raise exc
|
||||
|
||||
# Set up context information, if any
|
||||
if "ctx" in expected_fields:
|
||||
node.ctx = get_context(id, node_attr, logger)
|
||||
# Set the fields.
|
||||
for field, val in attrs.items():
|
||||
if field.startswith("_"): continue
|
||||
if field == "ctx": continue
|
||||
if field != "parenthesised" and field not in expected_fields:
|
||||
logger.warning("Unknown field {} found among {} in node {}\n".format(field, attrs, id))
|
||||
|
||||
# For fields that point to other AST nodes.
|
||||
if isinstance(val, Node):
|
||||
val = resolve_node_id(val.id, node_attr)
|
||||
setattr(node, field, nodes[val])
|
||||
# Special case for `Num.n`, which should be coerced to an int.
|
||||
elif isinstance(node, ast.Num) and field == "n":
|
||||
node.n = literal_eval(val.rstrip("lL"))
|
||||
# Special case for `Name.variable`, for which we must create a new `Variable` object
|
||||
elif isinstance(node, ast.Name) and field == "variable":
|
||||
node.variable = ast.Variable(val)
|
||||
# Special case for location-less leaf-node subclasses of `ast.Node`, such as `ast.Add`.
|
||||
elif field == "op" and val in locationless.keys():
|
||||
setattr(node, field, locationless[val]())
|
||||
else: # Any other value, usually literals of various kinds.
|
||||
setattr(node, field, val)
|
||||
|
||||
# Create all fields pointing to lists of values.
|
||||
for start, field_map in edge_attr.items():
|
||||
start = resolve_node_id(start, node_attr)
|
||||
parent = nodes[start]
|
||||
extra_fields = {}
|
||||
for field_name, value_end in field_map.items():
|
||||
# Sort children by index (in case they were visited out of order)
|
||||
children = [nodes[resolve_node_id(end, node_attr)] for _index, end in sorted(value_end)]
|
||||
# Skip any comments.
|
||||
children = [child for child in children if not isinstance(child, Comment)]
|
||||
# Special case for `Compare.ops`, a list of comparison operators
|
||||
if isinstance(parent, ast.Compare) and field_name == "ops":
|
||||
parent.ops = [locationless[v]() for v in children]
|
||||
elif field_name.startswith("_"):
|
||||
# We can only set the attributes given in `__slots__` on the `start` node, and so we
|
||||
# must handle fields starting with `_` specially. In this case, we simply record the
|
||||
# values and then subsequently update `edge_attr` to refer to these values. This
|
||||
# makes it act as a pseudo-field, that we can access as long as we know the `id`
|
||||
# corresponding to a given node (for which we have the `node_id` map).
|
||||
extra_fields[field_name] = children
|
||||
else:
|
||||
setattr(parent, field_name, children)
|
||||
if extra_fields:
|
||||
# Extend the existing map in `node_attr` with the extra fields.
|
||||
node_attr[start].update(extra_fields)
|
||||
|
||||
# Fixup any nodes that need it.
|
||||
for id, node in fixups.items():
|
||||
if isinstance(node, (ast.JoinedStr, ast.Str)):
|
||||
fix_strings(id, node, node_attr, node_id, logger)
|
||||
|
||||
debug_print("nodes:", nodes)
|
||||
if not nodes:
|
||||
# if the file referenced by path is empty, return an empty module:
|
||||
if os.path.getsize(path) == 0:
|
||||
module = ast.Module([])
|
||||
module.lineno = 1
|
||||
module.col_offset = 0
|
||||
module._end = (1, 0)
|
||||
return module
|
||||
else:
|
||||
raise SyntaxError("Syntax Error")
|
||||
# Fix up start location of outer `Module`.
|
||||
module = nodes[0]
|
||||
if module.body:
|
||||
# Get the location of the first non-comment node.
|
||||
module.lineno = module.body[0].lineno
|
||||
else:
|
||||
# No children! File must contain only comments! Pick the end location as the start location.
|
||||
module.lineno = module._end[0]
|
||||
return module
|
||||
|
||||
|
||||
def get_JoinedStr_children(children):
|
||||
"""
|
||||
Folds the `Str` and `expr` parts of a `JoinedStr` into a single list, and does this for each
|
||||
`JoinedStr` in `children`. Top-level `StringPart`s are included in the output directly.
|
||||
"""
|
||||
for child in children:
|
||||
if isinstance(child, ast.JoinedStr):
|
||||
for value in child.values:
|
||||
yield value
|
||||
elif isinstance(child, ast.StringPart):
|
||||
yield child
|
||||
else:
|
||||
raise ValueError("Unexpected node type: {}".format(type(child)))
|
||||
|
||||
def concatenate_stringparts(stringparts, logger):
|
||||
"""Concatenates the strings contained in the list of `stringparts`."""
|
||||
try:
|
||||
return "".join(decode_str(stringpart.s) for stringpart in stringparts)
|
||||
except Exception as ex:
|
||||
logger.error("Unable to concatenate string %s getting error %s", stringparts, ex)
|
||||
return stringparts[0].s
|
||||
|
||||
|
||||
def fix_strings(id, node, node_attr, node_id, logger):
|
||||
"""
|
||||
Reassociates the `StringPart` children of an implicitly concatenated f-string (`JoinedStr`)
|
||||
"""
|
||||
# Tests whether something is a string child
|
||||
is_string = lambda node: isinstance(node, ast.StringPart)
|
||||
|
||||
# We have two cases to consider. Either we're given something that came from a
|
||||
# `concatenated_string`, or something that came from an `formatted_string`. The latter case can
|
||||
# be seen as a special case of the former where the list of children we consider is just the
|
||||
# single f-string.
|
||||
children = node_attr[id].get("_children", [node])
|
||||
if isinstance(node, ast.Str):
|
||||
# If the outer node is a `Str`, then we don't have to reassociate, since there are no
|
||||
# f-strings.
|
||||
# In this case we simply have to create the concatenation of its constituent parts.
|
||||
node.implicitly_concatenated_parts = children
|
||||
node.s = concatenate_stringparts(children, logger)
|
||||
node.prefix = children[0].prefix
|
||||
else:
|
||||
# Otherwise, we first have to get the flattened list of all of the strings and/or
|
||||
# expressions.
|
||||
flattened_children = get_JoinedStr_children(children)
|
||||
groups = [list(n) for _, n in groupby(flattened_children, key=is_string)]
|
||||
# At this point, `values` is a list of lists, where each sublist is either:
|
||||
# - a list of `StringPart`s, or
|
||||
# - a singleton list containing an `expr`.
|
||||
# Crucially, `StringPart` is _not_ an `expr`.
|
||||
combined_values = []
|
||||
for group in groups:
|
||||
first = group[0]
|
||||
if isinstance(first, ast.expr):
|
||||
# If we have a list of expressions (which may happen if an interpolation contains
|
||||
# multiple distinct expressions, such as f"{foo:{bar}}", which uses interpolation to
|
||||
# also specify the padding dynamically), we simply append it.
|
||||
combined_values.extend(group)
|
||||
else:
|
||||
# Otherwise, we have a list of `StringPart`s, and we need to create a `Str` node to
|
||||
# it.
|
||||
|
||||
combined_string = concatenate_stringparts(group, logger)
|
||||
str_node = ast.Str(combined_string, first.prefix, None)
|
||||
copy_location(first, str_node)
|
||||
# The end location should be the end of the last part (even if there is only one part).
|
||||
str_node._end = group[-1]._end
|
||||
if len(group) > 1:
|
||||
str_node.implicitly_concatenated_parts = group
|
||||
combined_values.append(str_node)
|
||||
node.values = combined_values
|
||||
Reference in New Issue
Block a user