Files
codeql/python/extractor/semmle/python/parser/tsg_parser.py
Taus 2ef3ae9860 Python: Improve parser logging/timing/customisability
Does a bunch of things, unfortunately all in the same place, so my
apologies in advance for a slightly complicated commit.

As for the changes themselves, this commit

- Adds timers for the old and new parsers. This means we get the overall
time spent on these parts of the extractor if the extractor is run with
`DEBUG` output shown.
- Adds logging information (at the `DEBUG` level) to show which
invocations of the parsers happen when, and whether they succeed or not.
- Adds support for using an environment variable named
`CODEQL_PYTHON_DISABLE_OLD_PARSER` to disable using the old parser
entirely. This makes it easier to test the new parser in isolation.
- Fixes a bug where we did not check whether a parse with the new parser
had already succeeded, and so would do a superfluous second parse.
2024-10-30 13:58:46 +00:00

496 lines
21 KiB
Python

# tsg_parser.py
# Functions and classes used for parsing Python files using `tree-sitter-graph`
from ast import literal_eval
import sys
import os
import semmle.python.parser
from semmle.python.parser.ast import copy_location, decode_str, split_string
from semmle.python import ast
import subprocess
from itertools import groupby
DEBUG = False
def debug_print(*args, **kwargs):
if DEBUG:
print(*args, **kwargs)
# Node ids are integers, and so to distinguish them from actual integers we wrap them in this class.
class Node(object):
def __init__(self, id):
self.id = id
def __repr__(self):
return "Node({})".format(self.id)
# A wrapper for nodes containing comments. The old parser does not create such nodes (and therefore
# there is no `ast.Comment` class) since it accesses the comments via the tokens for the given file.
class Comment(object):
def __init__(self, text):
self.text = text
def __repr__(self):
return "Comment({})".format(self.text)
class SyntaxErrorNode(object):
def __init__(self, source):
self.source = source
def __repr__(self):
return "SyntaxErrorNode({})".format(self.source)
# Mapping from tree-sitter CPT node kinds to their corresponding AST node classes.
tsg_to_ast = {name: cls
for name, cls in semmle.python.ast.__dict__.items()
if isinstance(cls, type) and ast.AstBase in cls.__mro__
}
tsg_to_ast["Comment"] = Comment
tsg_to_ast["SyntaxErrorNode"] = SyntaxErrorNode
# Mapping from AST node class to the fields of the node. The order of the fields is the order in
# which they will be output in the AST dump.
#
# These fields cannot be extracted automatically, so we set them manually.
ast_fields = {
ast.Module: ("body",), # Note: has no `__slots__` to inspect
Comment: ("text",), # Note: not an `ast` class
SyntaxErrorNode: ("source",), # Note: not an `ast` class
ast.Continue: (),
ast.Break: (),
ast.Pass: (),
ast.Ellipsis: (),
ast.MatchWildcardPattern: (),
}
# Fields that we don't want to dump on every single AST node. These are just the slots of the AST
# base class, consisting of all of the location information (which we print in a different way).
ignored_fields = semmle.python.ast.AstBase.__slots__
# Extract fields for the remaining AST classes
for name, cls in semmle.python.ast.__dict__.items():
if name.startswith("_"):
continue
if not hasattr(cls, "__slots__"):
continue
slots = tuple(field for field in cls.__slots__ if field not in ignored_fields)
if not slots:
continue
ast_fields[cls] = slots
# A mapping from strings to the AST node classes that represent things like operators.
# These have to be handled specially, because they have no location information.
locationless = {
"and": ast.And,
"or": ast.Or,
"not": ast.Not,
"uadd": ast.UAdd,
"usub": ast.USub,
"+": ast.Add,
"-": ast.Sub,
"~": ast.Invert,
"**": ast.Pow,
"<<": ast.LShift,
">>": ast.RShift,
"&": ast.BitAnd,
"|": ast.BitOr,
"^": ast.BitXor,
"load": ast.Load,
"store": ast.Store,
"del" : ast.Del,
"param" : ast.Param,
}
locationless.update(semmle.python.parser.ast.TERM_OP_CLASSES)
locationless.update(semmle.python.parser.ast.COMP_OP_CLASSES)
locationless.update(semmle.python.parser.ast.AUG_ASSIGN_OPS)
if 'CODEQL_EXTRACTOR_PYTHON_ROOT' in os.environ:
platform = os.environ['CODEQL_PLATFORM']
ext = ".exe" if platform == "win64" else ""
tools = os.path.join(os.environ['CODEQL_EXTRACTOR_PYTHON_ROOT'], "tools", platform)
tsg_command = [os.path.join(tools, "tsg-python" + ext )]
else:
# Get the path to the current script
script_path = os.path.dirname(os.path.realpath(__file__))
tsg_python_path = os.path.join(script_path, "../../../tsg-python")
cargo_file = os.path.join(tsg_python_path, "Cargo.toml")
tsg_command = ["cargo", "run", "--quiet", "--release", "--manifest-path="+cargo_file]
def read_tsg_python_output(path, logger):
# Mapping from node id (an integer) to a dictionary containing attribute data.
node_attr = {}
# Mapping a start node to a map from attribute names to lists of (value, end_node) pairs.
edge_attr = {}
command_args = tsg_command + [path]
p = subprocess.Popen(command_args, stdout=subprocess.PIPE)
for line in p.stdout:
line = line.decode(sys.getfilesystemencoding())
line = line.rstrip()
if line.startswith("node"): # e.g. `node 5`
current_node = int(line.split(" ")[1])
d = {}
node_attr[current_node] = d
in_node = True
elif line.startswith("edge"): # e.g. `edge 5 -> 6`
current_start, current_end = tuple(map(int, line[4:].split("->")))
d = edge_attr.setdefault(current_start, {})
in_node = False
else: # attribute, e.g. `_kind: "Class"`
key, value = line[2:].split(": ", 1)
if value.startswith("[graph node"): # e.g. `_skip_to: [graph node 5]`
value = Node(int(value.split(" ")[2][:-1]))
elif value == "#true": # e.g. `_is_parenthesised: #true`
value = True
elif value == "#false": # e.g. `top: #false`
value = False
elif value == "#null": # e.g. `exc: #null`
value = None
else: # literal values, e.g. `name: "k1.k2"` or `level: 5`
try:
if key =="s" and value[0] == '"': # e.g. `s: "k1.k2"`
value = evaluate_string(value)
else:
value = literal_eval(value)
if isinstance(value, bytes):
try:
value = value.decode(sys.getfilesystemencoding())
except UnicodeDecodeError:
# just include the bytes as-is
pass
except Exception as ex:
# We may not know the location at this point -- for instance if we forgot to set
# it -- but `get_location_info` will degrade gracefully in this case.
loc = ":".join(str(i) for i in get_location_info(d))
error = ex.args[0] if ex.args else "unknown"
logger.warning("Error '{}' while parsing value {} at {}:{}\n".format(error, repr(value), path, loc))
if in_node:
d[key] = value
else:
d.setdefault(key, []).append((value, current_end))
p.stdout.close()
p.terminate()
p.wait()
logger.debug("Read {} nodes and {} edges from TSG output".format(len(node_attr), len(edge_attr)))
return node_attr, edge_attr
def evaluate_string(s):
s = literal_eval(s)
prefix, quotes, content = split_string(s, None)
ends_with_illegal_character = False
# If the string ends with the same quote character as the outer quotes (and/or backslashes)
# (e.g. the first string part of `f"""hello"{0}"""`), we must take care to not accidently create
# the ending quotes at the wrong place. To do this, we insert an extra space at the end (that we
# then must remember to remove later on.)
if content.endswith(quotes[0]) or content.endswith('\\'):
ends_with_illegal_character = True
content = content + " "
s = prefix.strip("fF") + quotes + content + quotes
s = literal_eval(s)
if isinstance(s, bytes):
s = decode_str(s)
if ends_with_illegal_character:
s = s[:-1]
return s
def resolve_node_id(id, node_attr):
"""Finds the end of a sequence of nodes linked by `_skip_to` fields, starting at `id`."""
while "_skip_to" in node_attr[id]:
id = node_attr[id]["_skip_to"].id
return id
def get_context(id, node_attr, logger):
"""Gets the context of the node with the given `id`. This is either whatever is stored in the
`ctx` attribute of the node, or the result of dereferencing a sequence of `_inherited_ctx` attributes."""
while "ctx" not in node_attr[id]:
if "_inherited_ctx" not in node_attr[id]:
logger.error("No context for node {} with attributes {}\n".format(id, node_attr[id]))
# A missing context is most likely to be a "load", so return that.
return ast.Load()
id = node_attr[id]["_inherited_ctx"].id
return locationless[node_attr[id]["ctx"]]()
def get_location_info(attrs):
"""Returns the location information for a node, depending on which fields are set.
In particular, more specific fields take precedence over (and overwrite) less specific fields.
So, `_start_line` and `_start_column` take precedence over `location_start`, which takes
precedence over `_location`. Likewise when `end` replaces `start` above.
If part of the location information is missing, the string `"???"` is substituted for the
missing bits.
"""
start_line = "???"
start_column = "???"
end_line = "???"
end_column = "???"
if "_location" in attrs:
(start_line, start_column, end_line, end_column) = attrs["_location"]
if "_location_start" in attrs:
(start_line, start_column) = attrs["_location_start"]
if "_location_end" in attrs:
(end_line, end_column) = attrs["_location_end"]
if "_start_line" in attrs:
start_line = attrs["_start_line"]
if "_start_column" in attrs:
start_column = attrs["_start_column"]
if "_end_line" in attrs:
end_line = attrs["_end_line"]
if "_end_column" in attrs:
end_column = attrs["_end_column"]
# Lines in the `tsg-python` output is 0-indexed, but the AST expects them to be 1-indexed.
if start_line != "???":
start_line += 1
if end_line != "???":
end_line += 1
return (start_line, start_column, end_line, end_column)
list_fields = {
ast.arguments: ("annotations", "defaults", "kw_defaults", "kw_annotations"),
ast.Assign: ("targets",),
ast.BoolOp: ("values",),
ast.Bytes: ("implicitly_concatenated_parts",),
ast.Call: ("positional_args", "named_args"),
ast.Case: ("body",),
ast.Class: ("body",),
ast.ClassExpr: ("type_parameters", "bases", "keywords"),
ast.Compare: ("ops", "comparators",),
ast.comprehension: ("ifs",),
ast.Delete: ("targets",),
ast.Dict: ("items",),
ast.ExceptStmt: ("body",),
ast.For: ("body",),
ast.Function: ("type_parameters", "args", "kwonlyargs", "body"),
ast.Global: ("names",),
ast.If: ("body",),
ast.Import: ("names",),
ast.List: ("elts",),
ast.Match: ("cases",),
ast.MatchClassPattern: ("positional", "keyword"),
ast.MatchMappingPattern: ("mappings",),
ast.MatchOrPattern: ("patterns",),
ast.MatchSequencePattern: ("patterns",),
ast.Module: ("body",),
ast.Nonlocal: ("names",),
ast.Print: ("values",),
ast.Set: ("elts",),
ast.Str: ("implicitly_concatenated_parts",),
ast.TypeAlias: ("type_parameters",),
ast.Try: ("body", "handlers", "orelse", "finalbody"),
ast.Tuple: ("elts",),
ast.While: ("body",),
# ast.FormattedStringLiteral: ("arguments",),
}
def create_placeholder_args(cls):
""" Returns a dictionary containing the placeholder arguments necessary to create an AST node.
In most cases these arguments will be assigned the value `None`, however for a few classes we
must substitute the empty list, as this is enforced by asserts in the constructor.
"""
if cls in (ast.Raise, ast.Ellipsis):
return {}
fields = ast_fields[cls]
args = {field: None for field in fields if field != "is_async"}
for field in list_fields.get(cls, ()):
args[field] = []
if cls in (ast.GeneratorExp, ast.ListComp, ast.SetComp, ast.DictComp):
del args["function"]
del args["iterable"]
return args
def parse(path, logger):
node_attr, edge_attr = read_tsg_python_output(path, logger)
debug_print("node_attr:", node_attr)
debug_print("edge_attr:", edge_attr)
nodes = {}
# Nodes that need to be fixed up after building the graph
fixups = {}
# Reverse index from node object to node id.
node_id = {}
# Create all the node objects
for id, attrs in node_attr.items():
if "_is_literal" in attrs:
nodes[id] = attrs["_is_literal"]
continue
if "_kind" not in attrs:
logger.error("Error: Graph node {} with attributes {} has no `_kind`!\n".format(id, attrs))
continue
# This is not the node we are looking for (so don't bother creating it).
if "_skip_to" in attrs:
continue
cls = tsg_to_ast[attrs["_kind"]]
args = ast_fields[cls]
obj = cls(**create_placeholder_args(cls))
nodes[id] = obj
node_id[obj] = id
# If this node needs fixing up afterwards, add it to the fixups map.
if "_fixup" in attrs:
fixups[id] = obj
# Set all of the node attributes
for id, node in nodes.items():
attrs = node_attr[id]
if "_is_literal" in attrs:
continue
expected_fields = ast_fields[type(node)]
# Set up location information.
node.lineno, node.col_offset, end_line, end_column = get_location_info(attrs)
node._end = (end_line, end_column)
if isinstance(node, SyntaxErrorNode):
exc = SyntaxError("Syntax Error")
exc.lineno = node.lineno
exc.offset = node.col_offset
raise exc
# Set up context information, if any
if "ctx" in expected_fields:
node.ctx = get_context(id, node_attr, logger)
# Set the fields.
for field, val in attrs.items():
if field.startswith("_"): continue
if field == "ctx": continue
if field != "parenthesised" and field not in expected_fields:
logger.warning("Unknown field {} found among {} in node {}\n".format(field, attrs, id))
# For fields that point to other AST nodes.
if isinstance(val, Node):
val = resolve_node_id(val.id, node_attr)
setattr(node, field, nodes[val])
# Special case for `Num.n`, which should be coerced to an int.
elif isinstance(node, ast.Num) and field == "n":
node.n = literal_eval(val.rstrip("lL"))
# Special case for `Name.variable`, for which we must create a new `Variable` object
elif isinstance(node, ast.Name) and field == "variable":
node.variable = ast.Variable(val)
# Special case for location-less leaf-node subclasses of `ast.Node`, such as `ast.Add`.
elif field == "op" and val in locationless.keys():
setattr(node, field, locationless[val]())
else: # Any other value, usually literals of various kinds.
setattr(node, field, val)
# Create all fields pointing to lists of values.
for start, field_map in edge_attr.items():
start = resolve_node_id(start, node_attr)
parent = nodes[start]
extra_fields = {}
for field_name, value_end in field_map.items():
# Sort children by index (in case they were visited out of order)
children = [nodes[resolve_node_id(end, node_attr)] for _index, end in sorted(value_end)]
# Skip any comments.
children = [child for child in children if not isinstance(child, Comment)]
# Special case for `Compare.ops`, a list of comparison operators
if isinstance(parent, ast.Compare) and field_name == "ops":
parent.ops = [locationless[v]() for v in children]
elif field_name.startswith("_"):
# We can only set the attributes given in `__slots__` on the `start` node, and so we
# must handle fields starting with `_` specially. In this case, we simply record the
# values and then subsequently update `edge_attr` to refer to these values. This
# makes it act as a pseudo-field, that we can access as long as we know the `id`
# corresponding to a given node (for which we have the `node_id` map).
extra_fields[field_name] = children
else:
setattr(parent, field_name, children)
if extra_fields:
# Extend the existing map in `node_attr` with the extra fields.
node_attr[start].update(extra_fields)
# Fixup any nodes that need it.
for id, node in fixups.items():
if isinstance(node, (ast.JoinedStr, ast.Str)):
fix_strings(id, node, node_attr, node_id, logger)
debug_print("nodes:", nodes)
if not nodes:
# if the file referenced by path is empty, return an empty module:
if os.path.getsize(path) == 0:
module = ast.Module([])
module.lineno = 1
module.col_offset = 0
module._end = (1, 0)
return module
else:
raise SyntaxError("Syntax Error")
# Fix up start location of outer `Module`.
module = nodes[0]
if module.body:
# Get the location of the first non-comment node.
module.lineno = module.body[0].lineno
else:
# No children! File must contain only comments! Pick the end location as the start location.
module.lineno = module._end[0]
return module
def get_JoinedStr_children(children):
"""
Folds the `Str` and `expr` parts of a `JoinedStr` into a single list, and does this for each
`JoinedStr` in `children`. Top-level `StringPart`s are included in the output directly.
"""
for child in children:
if isinstance(child, ast.JoinedStr):
for value in child.values:
yield value
elif isinstance(child, ast.StringPart):
yield child
else:
raise ValueError("Unexpected node type: {}".format(type(child)))
def concatenate_stringparts(stringparts, logger):
"""Concatenates the strings contained in the list of `stringparts`."""
try:
return "".join(decode_str(stringpart.s) for stringpart in stringparts)
except Exception as ex:
logger.error("Unable to concatenate string {} getting error {}".format(stringparts, ex))
return stringparts[0].s
def fix_strings(id, node, node_attr, node_id, logger):
"""
Reassociates the `StringPart` children of an implicitly concatenated f-string (`JoinedStr`)
"""
# Tests whether something is a string child
is_string = lambda node: isinstance(node, ast.StringPart)
# We have two cases to consider. Either we're given something that came from a
# `concatenated_string`, or something that came from an `formatted_string`. The latter case can
# be seen as a special case of the former where the list of children we consider is just the
# single f-string.
children = node_attr[id].get("_children", [node])
if isinstance(node, ast.Str):
# If the outer node is a `Str`, then we don't have to reassociate, since there are no
# f-strings.
# In this case we simply have to create the concatenation of its constituent parts.
node.implicitly_concatenated_parts = children
node.s = concatenate_stringparts(children, logger)
node.prefix = children[0].prefix
else:
# Otherwise, we first have to get the flattened list of all of the strings and/or
# expressions.
flattened_children = get_JoinedStr_children(children)
groups = [list(n) for _, n in groupby(flattened_children, key=is_string)]
# At this point, `values` is a list of lists, where each sublist is either:
# - a list of `StringPart`s, or
# - a singleton list containing an `expr`.
# Crucially, `StringPart` is _not_ an `expr`.
combined_values = []
for group in groups:
first = group[0]
if isinstance(first, ast.expr):
# If we have a list of expressions (which may happen if an interpolation contains
# multiple distinct expressions, such as f"{foo:{bar}}", which uses interpolation to
# also specify the padding dynamically), we simply append it.
combined_values.extend(group)
else:
# Otherwise, we have a list of `StringPart`s, and we need to create a `Str` node to
# it.
combined_string = concatenate_stringparts(group, logger)
str_node = ast.Str(combined_string, first.prefix, None)
copy_location(first, str_node)
# The end location should be the end of the last part (even if there is only one part).
str_node._end = group[-1]._end
if len(group) > 1:
str_node.implicitly_concatenated_parts = group
combined_values.append(str_node)
node.values = combined_values