Python: Copy Python extractor to codeql repo

This commit is contained in:
Taus
2024-02-28 15:15:21 +00:00
parent 297a17975d
commit 6dec323cfc
369 changed files with 165346 additions and 0 deletions

View File

@@ -0,0 +1,560 @@
'''Meta nodes for defining database relations'''
from abc import abstractmethod
from semmle.util import fprintf
PREFIX = 'py_'
__all__ = [ 'order' ]
parent_nodes = {}
class Node(object):
'Node in the attribute tree, describing relations'
next_id = 0
def __init__(self):
Node.next_id += 1
self._index = Node.next_id
self._unique_parent = None
@property
def parents(self):
return parent_of(self)
def add_child(self, child):
child.add_parent(self)
def db_key(self, name):
return 'unique int ' + name + ' : ' + self.db_name()
def is_sub_type(self):
return False
@staticmethod
def is_union_type():
return False
def is_case_type(self):
return False
@staticmethod
def is_list():
return False
@staticmethod
def is_primitive():
return False
def prune(self, node_set):
return self
@abstractmethod
def child_offsets(self, n):
pass
@abstractmethod
def write_fields(self, out):
pass
@abstractmethod
def ql_name(self):
pass
@property
def unique_parent(self):
if self._unique_parent is None:
parents = self.parents
if len(parents.child_offsets(self)) < 2:
self._unique_parent = True
elif parents.is_union_type():
self._unique_parent = False
for t in parents.types:
if len(t.child_offsets(self)) > 1:
break
else:
self._unique_parent = True
return self._unique_parent
class PrimitiveNode(Node):
'A primitive node: int, str, etc'
def __init__(self, name, db_name, key, descriptive_name = None):
Node.__init__(self)
assert isinstance(name, str)
self.name = name
self.super_type = None
self.layout = []
self.fields = []
self.subclasses = set()
self._key = key
self._db_name = db_name
if descriptive_name is None:
self.descriptive_name = self.name
else:
self.descriptive_name = descriptive_name
def db_key(self, name):
return self._key + ' ' + name + ' : ' + self._db_name + ' ref'
@property
def __name__(self):
return self.name
def ql_name(self):
'Return Java style name if a schema type, otherwise the specified name'
if self._db_name[0] == '@':
return capitalize(self.name)
else:
return self._db_name
def relation_name(self):
return pluralize(PREFIX + self.name)
def db_name(self):
return self._db_name
def add_parent(self, p):
parent_nodes[self] = UnionNode.join(parent_of(self), p)
def fixup(self):
pass
@staticmethod
def is_primitive():
return True
def child_offsets(self, n):
return set()
def write_init(self, out):
fprintf(out, "%s = PrimitiveNode(%s, %s, %s)\n", self.name,
self.name, self._db_name, self._key)
def write_fields(self, out):
pass
def parent_of(node):
if node in parent_nodes:
return parent_nodes[node]
else:
return None
class ClassNode(Node):
'A node corresponding to a single AST type'
def __init__(self, name, super_type = None, descriptive_name = None):
Node.__init__(self)
assert isinstance(name, str)
self.name = name
self._db_name = name
self.super_type = super_type
self.layout = []
if super_type:
self.fields = list(super_type.fields)
else:
self.fields = []
self.subclasses = set()
if super_type:
super_type.subclasses.add(self)
if descriptive_name is None:
self.descriptive_name = self.name.lower()
else:
self.descriptive_name = descriptive_name
if self.descriptive_name[0] == '$':
self.descriptive_name = self.descriptive_name[1:]
elif super_type and ' ' not in self.descriptive_name:
self.descriptive_name += ' ' + super_type.descriptive_name
def field(self, name, field_type, descriptive_name = None, artificial=False, parser_type = None):
if descriptive_name is None:
self.fields.append((name, field_type, name, artificial, parser_type))
else:
self.fields.append((name, field_type, descriptive_name, artificial, parser_type))
def is_stmt_or_expr_subclass(self):
if self.super_type is None:
return False
return self.super_type.name in ('expr', 'stmt')
def is_sub_type(self):
if self.super_type is None:
return False
return self.super_type.is_case_type()
def is_case_type(self):
return (self.subclasses
and parent_of(self))
def fixup(self):
self.add_children()
self.compute_layout()
def add_parent(self, p):
parent_nodes[self] = UnionNode.join(parent_of(self), p)
if self.super_type:
self.super_type.add_parent(p)
def add_children(self):
for f, f_node, _, _, _ in self.fields:
self.add_child(f_node)
def compute_layout(self):
fields = self.fields
lists = 0
for f, f_node, _, _, _ in fields:
if (isinstance(f_node, ListNode) and
f_node.item_type.__name__ != 'stmt'):
lists += 1
index = 0
inc = 1
for f, f_node, docname, artificial, pt in fields:
self.layout.append((f, f_node, index, docname, artificial, pt))
index += inc
def relation_name(self):
return pluralize(PREFIX + self._db_name)
def set_name(self, name):
self._db_name = name
@property
def __name__(self):
return self.name
def ql_name(self):
if self._db_name == 'str':
return 'string'
elif self._db_name in ('int', 'float'):
return self.db_name
name = self._db_name
return ''.join(capitalize(part) for part in name.split('_'))
def db_name(self):
return '@' + PREFIX + self._db_name
def dump(self, out):
def yes_no(b):
return "yes" if b else "no"
fprintf(out, "'%s' :\n", self.name)
fprintf(out, " QL name: %s\n", self.ql_name())
fprintf(out, " Relation name: %s\n", self.relation_name())
fprintf(out, " Is case_type %s\n", yes_no(self.is_case_type()))
fprintf(out, " Super type: %s\n", self.super_type)
fprintf(out, " Layout:\n")
for l in self.layout:
fprintf(out, " %s, %s, %s, '%s, %s'\n" % l)
fprintf(out, " Parents: %s\n\n", parent_of(self))
def write_init(self, out):
if self.super_type:
fprintf(out, "%s = ClassNode('%s', %s)\n", self.name,
self.name, self.super_type.name)
else:
fprintf(out, "%s = ClassNode('%s')\n", self.name, self.name)
def write_fields(self, out):
for name, field_type, docname, _, _ in self.fields:
fprintf(out, "%s.field('%s', %s, '%s')\n", self.name,
name, field_type.__name__, docname)
if self.layout:
fprintf(out, "\n")
def __repr__(self):
return "Node('%s')" % self.name
def child_offsets(self, n):
#Only used by db-scheme generator, so can be slow
found = set()
for name, node, offset, _, artificial, _ in self.layout:
if node is n:
found.add(offset)
if self.subclasses:
for s in self.subclasses:
found.update(s.child_offsets(n))
return found
class ListNode(Node):
"Node corresponding to a list, parameterized by its member's type"
def __init__(self, item_node, name=None):
Node.__init__(self)
self.list_type = None
self.layout = ()
self.super_type = None
self.item_type = item_node
self.subclasses = ()
self.add_child(item_node)
self.name = name
def relation_name(self):
return pluralize(PREFIX + self.__name__)
def dump(self, out):
fprintf(out, "List of %s\n", self.name)
fprintf(out, " Parents: %s\n\n", parent_of(self))
def write_init(self, out):
fprintf(out, "%s = ListNode(%s)\n",
self.__name__, self.item_type.__name__)
def write_fields(self, out):
pass
@staticmethod
def is_list():
return True
@property
def __name__(self):
if self.name is None:
assert isinstance(self.item_type.__name__, str)
return self.item_type.__name__ + '_list'
else:
return self.name
@property
def descriptive_name(self):
return self.item_type.descriptive_name + ' list'
def db_name(self):
return '@' + PREFIX + self.__name__
def ql_name(self):
if self.name is not None:
return capitalize(self.name)
if self.item_type is str:
return 'StringList'
elif self.item_type is int:
return 'IntList'
elif self.item_type is float:
return 'FloatList'
return capitalize(self.item_type.ql_name()) + 'List'
def __repr__(self):
return "ListNode(%s)" % self.__name__
def fixup(self):
pass
def add_parent(self, p):
parent_nodes[self] = UnionNode.join(parent_of(self), p)
def child_offsets(self, n):
return set((0,1,2,3))
_all_unions = {}
class UnionNode(Node):
'Node representing a set of AST types'
def __init__(self, *types):
Node.__init__(self)
assert len(types) > 1
self.types = frozenset(types)
self.name = None
self.super_type = None
self.layout = []
self.subclasses = ()
#Whether this node should be visited in auto-generated extractor.
self.visit = False
@staticmethod
def join(t1, t2):
if t1 is None:
return t2
if t2 is None:
return t1
if isinstance(t1, UnionNode):
all_types = set(t1.types)
else:
all_types = set([t1])
if isinstance(t2, UnionNode):
all_types = all_types.union(t2.types)
else:
all_types.add(t2)
done = False
while not done:
for n in all_types:
if n.super_type in all_types:
all_types.remove(n)
break
else:
done = True
return UnionNode._make_union(all_types)
@staticmethod
def _make_union(all_types):
if len(all_types) == 1:
return next(iter(all_types))
else:
key = frozenset(all_types)
if key in _all_unions:
u = _all_unions[key]
else:
u = UnionNode(*all_types)
_all_unions[key] = u
return u
def set_name(self, name):
self.name = name
@staticmethod
def is_union_type():
return True
def write_init(self, out):
fprintf(out, "%s = UnionNode(%s)\n", self.__name__,
', '.join(t.__name__ for t in self.types))
if self.name:
fprintf(out, "%s.setname('%s')\n", self.name, self.name)
def write_fields(self, out):
pass
def fixup(self):
pass
def __hash__(self):
return hash(self.types)
def __eq__(self, other):
assert len(self.types) > 1
if isinstance(other, UnionNode):
return self.types == other.types
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
@property
def __name__(self):
if self.name is None:
names = [ n.__name__ for n in self.types ]
return '_or_'.join(sorted(names))
else:
return self.name
@property
def descriptive_name(self):
if self.name is None:
names = [ n.descriptive_name for n in self.types ]
return '_or_'.join(sorted(names))
else:
return self.name
def db_name(self):
return '@' + PREFIX + self.__name__
def relation_name(self):
return pluralize(PREFIX + self.__name__)
def ql_name(self):
if self.name is None:
assert len(self.types) > 1
names = [ n.ql_name() for n in self.types ]
return 'Or'.join(sorted(names))
else:
return ''.join(capitalize(part) for part in self.name.split('_'))
def add_parent(self, p):
for n in self.types:
n.add_parent(p)
def child_offsets(self, n):
res = set()
for t in self.types:
res = res.union(t.child_offsets(n))
return res
def prune(self, node_set):
new_set = self.types.intersection(node_set)
if len(new_set) == len(self.types):
return self
if not new_set:
return None
return UnionNode._make_union(new_set)
def shorten_name(node):
p = parent_of(node)
if (isinstance(p, UnionNode) and len(p.__name__) > 16
and len(p.__name__) > len(node.__name__) + 4):
p.set_name(node.__name__ + '_parent')
def build_node_relations(nodes):
nodes = set(nodes)
for node in nodes:
node.fixup()
for node in sorted(nodes, key=lambda n : n.__name__):
shorten_name(node)
node_set = set(nodes)
for node in (str, int, float, bytes):
p = parent_of(node)
if p is not None:
node_set.add(p)
for node in nodes:
p = parent_of(node)
if p is not None:
node_set.add(p)
for n in nodes:
sub_types = sorted(n.subclasses, key = lambda x : x._index)
if n.is_case_type():
for index, item in enumerate(sub_types):
item.index = index
for n in list(nodes):
if not n.parents and n.is_list() and n.name is None:
#Discard lists with no parents and no name as unreachable
node_set.remove(n)
#Prune unused nodes from unions.
node_set = set(node.prune(node_set) for node in node_set)
for node in node_set:
if node in parent_nodes:
parent_nodes[node] = parent_nodes[node].prune(node_set)
for node in node_set:
shorten_name(node)
result_nodes = {}
for n in node_set:
if n:
result_nodes[n.__name__] = n
return result_nodes
def pluralize(name):
if name[-1] == 's':
if name[-2] in 'aiuos':
return name + 'es'
else:
#Already plural
return name
elif name.endswith('ex'):
return name[:-2] + 'ices'
elif name.endswith('y'):
return name[:-1] + 'ies'
else:
return name + 's'
def capitalize(name):
'Unlike the str method capitalize(), leave upper case letters alone'
return name[0].upper() + name[1:]
def order(node):
if node.is_primitive():
return 0
if isinstance(node, ClassNode):
res = 1
while node.super_type:
node = node.super_type
res += 1
return res
if isinstance(node, ListNode):
return order(node.item_type) + 1
else:
assert isinstance(node, UnionNode)
return max(order(t) for t in node.types)+1

View File

@@ -0,0 +1,949 @@
'''
Abstract syntax tree classes.
This is designed to replace the stdlib ast module.
Unlike the stdlib module, it is version independent.
The classes in this file are based on the corresponding types in the cpython interpreter, copyright PSF.
'''
class AstBase(object):
__slots__ = "lineno", "col_offset", "_end",
def __repr__(self):
args = ",".join(repr(getattr(self, field, None)) for field in self.__slots__)
return "%s(%s)" % (self.__class__.__name__, args)
class Class(AstBase):
'AST node representing a class definition'
__slots__ = "name", "body",
def __init__(self, name, body):
self.name = name
self.body = body
class Function(AstBase):
'AST node representing a function definition'
__slots__ = "is_async", "name", "type_parameters", "args", "vararg", "kwonlyargs", "kwarg", "body",
def __init__(self, name, type_parameters, args, vararg, kwonlyargs, kwarg, body, is_async=False):
self.name = name
self.type_parameters = type_parameters
self.args = args
self.vararg = vararg
self.kwonlyargs = kwonlyargs
self.kwarg = kwarg
self.body = body
self.is_async = is_async
class Module(AstBase):
def __init__(self, body):
self.body = body
class StringPart(AstBase):
'''Implicitly concatenated part of string literal'''
__slots__ = "prefix", "text", "s",
def __init__(self, prefix, text, s):
self.prefix = prefix
self.text = text
self.s = s
class alias(AstBase):
__slots__ = "value", "asname",
def __init__(self, value, asname):
self.value = value
self.asname = asname
class arguments(AstBase):
__slots__ = "defaults", "kw_defaults", "annotations", "varargannotation", "kwargannotation", "kw_annotations",
def __init__(self, defaults, kw_defaults, annotations, varargannotation, kwargannotation, kw_annotations):
if len(defaults) != len(annotations):
raise AssertionError('len(defaults) != len(annotations)')
if len(kw_defaults) != len(kw_annotations):
raise AssertionError('len(kw_defaults) != len(kw_annotations)')
self.kw_defaults = kw_defaults
self.defaults = defaults
self.annotations = annotations
self.varargannotation = varargannotation
self.kwargannotation = kwargannotation
self.kw_annotations = kw_annotations
class boolop(AstBase):
pass
class cmpop(AstBase):
pass
class comprehension(AstBase):
__slots__ = "is_async", "target", "iter", "ifs",
def __init__(self, target, iter, ifs, is_async=False):
self.target = target
self.iter = iter
self.ifs = ifs
self.is_async = is_async
class dict_item(AstBase):
pass
class type_parameter(AstBase):
pass
class expr(AstBase):
__slots__ = "parenthesised",
class expr_context(AstBase):
pass
class operator(AstBase):
pass
class stmt(AstBase):
pass
class unaryop(AstBase):
pass
class pattern(AstBase):
__slots__ = "parenthesised",
class And(boolop):
pass
class Or(boolop):
pass
class Eq(cmpop):
pass
class Gt(cmpop):
pass
class GtE(cmpop):
pass
class In(cmpop):
pass
class Is(cmpop):
pass
class IsNot(cmpop):
pass
class Lt(cmpop):
pass
class LtE(cmpop):
pass
class NotEq(cmpop):
pass
class NotIn(cmpop):
pass
class DictUnpacking(dict_item):
__slots__ = "value",
def __init__(self, value):
self.value = value
class KeyValuePair(dict_item):
__slots__ = "key", "value",
def __init__(self, key, value):
self.key = key
self.value = value
class keyword(dict_item):
__slots__ = "arg", "value",
def __init__(self, arg, value):
self.arg = arg
self.value = value
class AssignExpr(expr):
__slots__ = "target", "value",
def __init__(self, value, target):
self.value = value
self.target = target
class Attribute(expr):
__slots__ = "value", "attr", "ctx",
def __init__(self, value, attr, ctx):
self.value = value
self.attr = attr
self.ctx = ctx
class Await(expr):
__slots__ = "value",
def __init__(self, value):
self.value = value
class BinOp(expr):
__slots__ = "left", "op", "right",
def __init__(self, left, op, right):
self.left = left
self.op = op
self.right = right
class BoolOp(expr):
__slots__ = "op", "values",
def __init__(self, op, values):
self.op = op
self.values = values
class Bytes(expr):
__slots__ = "s", "prefix", "implicitly_concatenated_parts",
def __init__(self, s, prefix, implicitly_concatenated_parts):
self.s = s
self.prefix = prefix
self.implicitly_concatenated_parts = implicitly_concatenated_parts
class Call(expr):
__slots__ = "func", "positional_args", "named_args",
def __init__(self, func, positional_args, named_args):
self.func = func
self.positional_args = positional_args
self.named_args = named_args
class ClassExpr(expr):
'AST node representing class creation'
__slots__ = "name", "type_parameters", "bases", "keywords", "inner_scope",
def __init__(self, name, type_parameters, bases, keywords, inner_scope):
self.name = name
self.type_parameters = type_parameters
self.bases = bases
self.keywords = keywords
self.inner_scope = inner_scope
class Compare(expr):
__slots__ = "left", "ops", "comparators",
def __init__(self, left, ops, comparators):
self.left = left
self.ops = ops
self.comparators = comparators
class Dict(expr):
__slots__ = "items",
def __init__(self, items):
self.items = items
class DictComp(expr):
__slots__ = "key", "value", "generators", "function", "iterable",
def __init__(self, key, value, generators):
self.key = key
self.value = value
self.generators = generators
class Ellipsis(expr):
pass
class Filter(expr):
'''Filtered expression in a template'''
__slots__ = "value", "filter",
def __init__(self, value, filter):
self.value = value
self.filter = filter
class FormattedValue(expr):
__slots__ = "value", "conversion", "format_spec",
def __init__(self, value, conversion, format_spec):
self.value = value
self.conversion = conversion
self.format_spec = format_spec
class FunctionExpr(expr):
'AST node representing function creation'
__slots__ = "name", "args", "returns", "inner_scope",
def __init__(self, name, args, returns, inner_scope):
self.name = name
self.args = args
self.returns = returns
self.inner_scope = inner_scope
class GeneratorExp(expr):
__slots__ = "elt", "generators", "function", "iterable",
def __init__(self, elt, generators):
self.elt = elt
self.generators = generators
class IfExp(expr):
__slots__ = "test", "body", "orelse",
def __init__(self, test, body, orelse):
self.test = test
self.body = body
self.orelse = orelse
class ImportExpr(expr):
'''AST node representing module import
(roughly equivalent to the runtime call to __import__)'''
__slots__ = "level", "name", "top",
def __init__(self, level, name, top):
self.level = level
self.name = name
self.top = top
class ImportMember(expr):
'''AST node representing 'from import'. Similar to Attribute access,
but during import'''
__slots__ = "module", "name",
def __init__(self, module, name):
self.module = module
self.name = name
class JoinedStr(expr):
__slots__ = "values",
def __init__(self, values):
self.values = values
class Lambda(expr):
__slots__ = "args", "inner_scope",
def __init__(self, args, inner_scope):
self.args = args
self.inner_scope = inner_scope
class List(expr):
__slots__ = "elts", "ctx",
def __init__(self, elts, ctx):
self.elts = elts
self.ctx = ctx
class ListComp(expr):
__slots__ = "elt", "generators", "function", "iterable",
def __init__(self, elt, generators):
self.elt = elt
self.generators = generators
class Match(stmt):
__slots__ = "subject", "cases",
def __init__(self, subject, cases):
self.subject = subject
self.cases = cases
class Case(stmt):
__slots__ = "pattern", "guard", "body",
def __init__(self, pattern, guard, body):
self.pattern = pattern
self.guard = guard
self.body = body
class Guard(expr):
__slots__ = "test",
def __init__(self, test):
self.test = test
class MatchAsPattern(pattern):
__slots__ = "pattern", "alias",
def __init__(self, pattern, alias):
self.pattern = pattern
self.alias = alias
class MatchOrPattern(pattern):
__slots__ = "patterns",
def __init__(self, patterns):
self.patterns = patterns
class MatchLiteralPattern(pattern):
__slots__ = "literal",
def __init__(self, literal):
self.literal = literal
class MatchCapturePattern(pattern):
__slots__ = "variable",
def __init__(self, variable):
self.variable = variable
class MatchWildcardPattern(pattern):
__slots__ = []
class MatchValuePattern(pattern):
__slots__ = "value",
def __init__(self, value):
self.value = value
class MatchSequencePattern(pattern):
__slots__ = "patterns",
def __init__(self, patterns):
self.patterns = patterns
class MatchStarPattern(pattern):
__slots__ = "target",
def __init__(self, target):
self.target = target
class MatchMappingPattern(pattern):
__slots__ = "mappings",
def __init__(self, mappings):
self.mappings = mappings
class MatchDoubleStarPattern(pattern):
__slots__ = "target",
def __init__(self, target):
self.target = target
class MatchKeyValuePattern(pattern):
__slots__ = "key", "value",
def __init__(self, key, value):
self.key = key
self.value = value
class MatchClassPattern(pattern):
__slots__ = "class_name", "positional", "keyword",
def __init__(self, class_name, positional, keyword):
self.class_name = class_name
self.positional = positional
self.keyword = keyword
class MatchKeywordPattern(pattern):
__slots__ = "attribute", "value",
def __init__(self, attribute, value):
self.attribute = attribute
self.value = value
class Name(expr):
__slots__ = "variable", "ctx",
def __init__(self, variable, ctx):
self.variable = variable
self.ctx = ctx
@property
def id(self):
return self.variable.id
class Num(expr):
__slots__ = "n", "text",
def __init__(self, n, text):
self.n = n
self.text = text
class ParamSpec(type_parameter):
__slots__ = "name",
def __init__(self, name):
self.name = name
class PlaceHolder(expr):
'''PlaceHolder variable in template ($name)'''
__slots__ = "variable", "ctx",
def __init__(self, variable, ctx):
self.variable = variable
self.ctx = ctx
@property
def id(self):
return self.variable.id
class Repr(expr):
__slots__ = "value",
def __init__(self, value):
self.value = value
class Set(expr):
__slots__ = "elts",
def __init__(self, elts):
self.elts = elts
class SetComp(expr):
__slots__ = "elt", "generators", "function", "iterable",
def __init__(self, elt, generators):
self.elt = elt
self.generators = generators
class Slice(expr):
'''AST node for a slice as a subclass of expr to simplify Subscripts'''
__slots__ = "start", "stop", "step",
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
self.step = step
class Starred(expr):
__slots__ = "value", "ctx",
def __init__(self, value, ctx):
self.value = value
self.ctx = ctx
class Str(expr):
__slots__ = "s", "prefix", "implicitly_concatenated_parts",
def __init__(self, s, prefix, implicitly_concatenated_parts):
self.s = s
self.prefix = prefix
self.implicitly_concatenated_parts = implicitly_concatenated_parts
class Subscript(expr):
__slots__ = "value", "index", "ctx",
def __init__(self, value, index, ctx):
self.value = value
self.index = index
self.ctx = ctx
class TemplateDottedNotation(expr):
'''Unified dot notation expression in a template'''
__slots__ = "value", "attr", "ctx",
def __init__(self, value, attr, ctx):
self.value = value
self.attr = attr
self.ctx = ctx
class Tuple(expr):
__slots__ = "elts", "ctx",
def __init__(self, elts, ctx):
self.elts = elts
self.ctx = ctx
class TypeAlias(stmt):
__slots__ = "name", "type_parameters", "value",
def __init__(self, name, type_parameters, value):
self.name = name
self.type_parameters = type_parameters
self.value = value
class TypeVar(type_parameter):
__slots__ = "name", "bound",
def __init__(self, name, bound):
self.name = name
self.bound = bound
class TypeVarTuple(type_parameter):
__slots__ = "name",
def __init__(self, name):
self.name = name
class UnaryOp(expr):
__slots__ = "op", "operand",
def __init__(self, op, operand):
self.op = op
self.operand = operand
class Yield(expr):
__slots__ = "value",
def __init__(self, value):
self.value = value
class YieldFrom(expr):
__slots__ = "value",
def __init__(self, value):
self.value = value
class SpecialOperation(expr):
__slots__ = "name", "arguments"
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments
class AugLoad(expr_context):
pass
class AugStore(expr_context):
pass
class Del(expr_context):
pass
class Load(expr_context):
pass
class Param(expr_context):
pass
class Store(expr_context):
pass
class Add(operator):
pass
class BitAnd(operator):
pass
class BitOr(operator):
pass
class BitXor(operator):
pass
class Div(operator):
pass
class FloorDiv(operator):
pass
class LShift(operator):
pass
class MatMult(operator):
pass
class Mod(operator):
pass
class Mult(operator):
pass
class Pow(operator):
pass
class RShift(operator):
pass
class Sub(operator):
pass
class AnnAssign(stmt):
__slots__ = "value", "annotation", "target",
def __init__(self, value, annotation, target):
self.value = value
self.annotation = annotation
self.target = target
class Assert(stmt):
__slots__ = "test", "msg",
def __init__(self, test, msg):
self.test = test
self.msg = msg
class Assign(stmt):
__slots__ = "targets", "value",
def __init__(self, value, targets):
self.value = value
assert isinstance(targets, list)
self.targets = targets
class AugAssign(stmt):
__slots__ = "operation",
def __init__(self, operation):
self.operation = operation
class Break(stmt):
pass
class Continue(stmt):
pass
class Delete(stmt):
__slots__ = "targets",
def __init__(self, targets):
self.targets = targets
class ExceptStmt(stmt):
'''AST node for except handler, as a subclass of stmt in order
to better support location and flow control'''
__slots__ = "type", "name", "body",
def __init__(self, type, name, body):
self.type = type
self.name = name
self.body = body
class ExceptGroupStmt(stmt):
'''AST node for except* handler, as a subclass of stmt in order
to better support location and flow control'''
__slots__ = "type", "name", "body",
def __init__(self, type, name, body):
self.type = type
self.name = name
self.body = body
class Exec(stmt):
__slots__ = "body", "globals", "locals",
def __init__(self, body, globals, locals):
self.body = body
self.globals = globals
self.locals = locals
class Expr(stmt):
__slots__ = "value",
def __init__(self, value):
self.value = value
class For(stmt):
__slots__ = "is_async", "target", "iter", "body", "orelse",
def __init__(self, target, iter, body, orelse, is_async=False):
self.target = target
self.iter = iter
self.body = body
self.orelse = orelse
self.is_async = is_async
class Global(stmt):
__slots__ = "names",
def __init__(self, names):
self.names = names
class If(stmt):
__slots__ = "test", "body", "orelse",
def __init__(self, test, body, orelse):
self.test = test
self.body = body
self.orelse = orelse
class Import(stmt):
__slots__ = "names",
def __init__(self, names):
self.names = names
class ImportFrom(stmt):
__slots__ = "module",
def __init__(self, module):
self.module = module
class Nonlocal(stmt):
__slots__ = "names",
def __init__(self, names):
self.names = names
class Pass(stmt):
pass
class Print(stmt):
__slots__ = "dest", "values", "nl",
def __init__(self, dest, values, nl):
self.dest = dest
self.values = values
self.nl = nl
class Raise(stmt):
__slots__ = "exc", "cause", "type", "inst", "tback",
class Return(stmt):
__slots__ = "value",
def __init__(self, value):
self.value = value
class TemplateWrite(stmt):
'''Template text'''
__slots__ = "value",
def __init__(self, value):
self.value = value
class Try(stmt):
__slots__ = "body", "orelse", "handlers", "finalbody",
def __init__(self, body, orelse, handlers, finalbody):
self.body = body
self.orelse = orelse
self.handlers = handlers
self.finalbody = finalbody
class While(stmt):
__slots__ = "test", "body", "orelse",
def __init__(self, test, body, orelse):
self.test = test
self.body = body
self.orelse = orelse
class With(stmt):
__slots__ = "is_async", "context_expr", "optional_vars", "body",
def __init__(self, context_expr, optional_vars, body, is_async=False):
self.context_expr = context_expr
self.optional_vars = optional_vars
self.body = body
self.is_async = is_async
class Invert(unaryop):
pass
class Not(unaryop):
pass
class UAdd(unaryop):
pass
class USub(unaryop):
pass
class Variable(object):
'A variable'
def __init__(self, var_id, scope = None):
assert isinstance(var_id, str), type(var_id)
self.id = var_id
self.scope = scope
def __repr__(self):
return 'Variable(%r, %r)' % (self.id, self.scope)
def __eq__(self, other):
if type(other) is not Variable:
return False
if self.scope is None or other.scope is None:
raise TypeError("Scope not set")
return self.scope == other.scope and self.id == other.id
def __ne__(self, other):
return not self == other
def __hash__(self):
if self.scope is None:
raise TypeError("Scope not set")
return 391246 ^ hash(self.id) ^ hash(self.scope)
def is_global(self):
return isinstance(self.scope, Module)
def iter_fields(node):
for name in node.__slots__:
if hasattr(node, name):
yield name, getattr(node, name)

View File

@@ -0,0 +1,284 @@
import sys
import os
import inspect
import pkgutil
from semmle.python import ast
from semmle.python.passes.exports import ExportsPass
from semmle.python.passes.lexical import LexicalPass
from semmle.python.passes.flow import FlowPass
from semmle.python.passes.ast_pass import ASTPass
from semmle.python.passes.objects import ObjectPass
from semmle.util import VERSION, uuid, get_analysis_version, get_analysis_major_version
from semmle.util import makedirs, get_source_file_tag, TrapWriter, base64digest
from semmle.cache import Cache
from semmle.logging import WARN, syntax_error_message, Logger
from semmle.profiling import timers
UTRAP_KEY = 'utrap%s' % VERSION
__all__ = [ 'Extractor', 'CachingExtractor' ]
FLAG_SAVE_TYPES = float, complex, bool, int, bytes, str
class Extractor(object):
'''The extractor controls the execution of the all the
specialised passes'''
def __init__(self, trap_folder, src_archive, options, logger: Logger, diagnostics_writer):
assert trap_folder
self.trap_folder = trap_folder
self.src_archive = src_archive
self.object_pass = ObjectPass()
self.passes = [
ASTPass(),
ExportsPass(),
FlowPass(options.split, options.prune, options.unroll, logger)
]
self.lexical = LexicalPass()
self.files = {}
self.options = options
self.handle_syntax_errors = not options.no_syntax_errors
self.logger = logger
self.diagnostics_writer = diagnostics_writer
def _handle_syntax_error(self, module, ex):
# Write out diagnostics for the syntax error.
error = syntax_error_message(ex, module)
self.diagnostics_writer.write(error)
# Emit trap for the syntax error
self.logger.debug("Emitting trap for syntax error in %s", module.path)
writer = TrapWriter()
module_id = writer.get_node_id(module)
# Report syntax error as an alert.
# Ensure line and col are ints (not None).
line = ex.lineno if ex.lineno else 0
if line > len(module.lines):
line = len(module.lines)
col = len(module.lines[-1])-1
else:
col = ex.offset if ex.offset else 0
loc_id = writer.get_unique_id()
writer.write_tuple(u'locations_ast', 'rrdddd',
loc_id, module_id, 0, 0, 0, 0)
syntax_id = u'syntax%d:%d' % (line, col)
writer.write_tuple(u'locations_ast', 'nrdddd',
syntax_id, module_id, line, col+1, line, col+1)
writer.write_tuple(u'py_syntax_error_versioned', 'nss', syntax_id, ex.msg, get_analysis_major_version())
trap = writer.get_compressed()
self.trap_folder.write_trap("syntax-error", module.path, trap)
#Create an AST equivalent to an empty file, so that the other passes produce consistent output.
return ast.Module([])
def _extract_trap_file(self, ast, comments, path):
writer = TrapWriter()
file_tag = get_source_file_tag(self.src_archive.get_virtual_path(path))
writer.write_tuple(u'py_Modules', 'g', ast.trap_name)
writer.write_tuple(u'py_module_path', 'gg', ast.trap_name, file_tag)
try:
for ex in self.passes:
with timers[ex.name]:
if isinstance(ex, FlowPass):
ex.set_filename(path)
ex.extract(ast, writer)
with timers['lexical']:
self.lexical.extract(ast, comments, writer)
with timers['object']:
self.object_pass.extract(ast, path, writer)
except Exception as ex:
self.logger.error("Exception extracting module %s: %s", path, ex)
self.logger.traceback(WARN)
return None
return writer.get_compressed()
def process_source_module(self, module):
'''Process a Python source module. Checks that module has valid syntax,
then passes passes ast, source, etc to `process_module`
'''
try:
#Ensure that module does not have invalid syntax before extracting it.
ast = module.ast
except SyntaxError as ex:
self.logger.debug("handle syntax errors is %s", self.handle_syntax_errors)
if self.handle_syntax_errors:
ast = self._handle_syntax_error(module, ex)
else:
return None
ast.name = module.name
ast.kind = module.kind
ast.trap_name = module.trap_name
return self.process_module(ast, module.trap_name, module.bytes_source,
module.path, module.comments)
def process_module(self, ast, module_tag, bytes_source, path, comments):
'Process a module, generating the trap file for that module'
self.logger.debug(u"Populating trap file for %s", path)
ast.trap_name = module_tag
trap = self._extract_trap_file(ast, comments, path)
if trap is None:
return None
with timers['trap']:
self.trap_folder.write_trap("python", path, trap)
try:
with timers['archive']:
self.copy_source(bytes_source, module_tag, path)
except Exception:
import traceback
traceback.print_exc()
return trap
def copy_source(self, bytes_source, module_tag, path):
if bytes_source is None:
return
self.files[module_tag] = self.src_archive.get_virtual_path(path)
self.src_archive.write(path, bytes_source)
def write_interpreter_data(self, options):
'''Write interpreter data, such as version numbers and flags.'''
def write_flag(name, value):
writer.write_tuple(u'py_flags_versioned', 'uus', name, value, get_analysis_major_version())
def write_flags(obj, prefix):
pre = prefix + u"."
for name, value in inspect.getmembers(obj):
if name[0] == "_":
continue
if type(value) in FLAG_SAVE_TYPES:
write_flag(pre + name, str(value))
writer = TrapWriter()
for index, name in enumerate((u'major', u'minor', u'micro', u'releaselevel', u'serial')):
writer.write_tuple(u'py_flags_versioned', 'sss', u'extractor_python_version.' + name, str(sys.version_info[index]), get_analysis_major_version())
write_flags(sys.flags, u'flags')
write_flags(sys.float_info, u'float')
write_flags(self.options, u'options')
write_flag(u'sys.prefix', sys.prefix)
path = os.pathsep.join(os.path.abspath(p) for p in options.sys_path)
write_flag(u'sys.path', path)
if options.path is None:
path = ''
else:
path = os.pathsep.join(self.src_archive.get_virtual_path(p) for p in options.path)
if options.language_version:
write_flag(u'language.version', options.language_version[-1])
else:
write_flag(u'language.version', get_analysis_version())
write_flag(u'extractor.path', path)
write_flag(u'sys.platform', sys.platform)
write_flag(u'os.sep', os.sep)
write_flag(u'os.pathsep', os.pathsep)
write_flag(u'extractor.version', VERSION)
if options.context_cost is not None:
write_flag(u'context.cost', options.context_cost)
self.trap_folder.write_trap("flags", "$flags", writer.get_compressed())
if get_analysis_major_version() == 2:
# Copy the pre-extracted builtins trap
builtins_trap_data = pkgutil.get_data('semmle.data', 'interpreter2.trap')
self.trap_folder.write_trap("interpreter", '$interpreter2', builtins_trap_data, extension=".trap")
else:
writer = TrapWriter()
self.object_pass.write_special_objects(writer)
self.trap_folder.write_trap("interpreter", '$interpreter3', writer.get_compressed())
# Copy stdlib trap
if get_analysis_major_version() == 2:
stdlib_trap_name = '$stdlib_27.trap'
else:
stdlib_trap_name = '$stdlib_33.trap'
stdlib_trap_data = pkgutil.get_data('semmle.data', stdlib_trap_name)
self.trap_folder.write_trap("stdlib", stdlib_trap_name[:-5], stdlib_trap_data, extension=".trap")
@staticmethod
def from_options(options, trap_dir, archive, logger: Logger, diagnostics_writer):
'''Convenience method to create extractor from options'''
try:
trap_copy_dir = options.trap_cache
caching_extractor = CachingExtractor(trap_copy_dir, options, logger)
except Exception as ex:
if options.verbose and trap_copy_dir is not None:
print ("Failed to create caching extractor: " + str(ex))
caching_extractor = None
worker = Extractor(trap_dir, archive, options, logger, diagnostics_writer)
if caching_extractor:
caching_extractor.set_worker(worker)
return caching_extractor
else:
return worker
def stop(self):
pass
def close(self):
'close() must be called, or some information will be not be written'
#Add name tag to file name, so that multiple extractors do not overwrite each other
if self.files:
trapwriter = TrapWriter()
for _, filepath in self.files.items():
trapwriter.write_file(filepath)
self.trap_folder.write_trap('folders', uuid('python') + '/$files', trapwriter.get_compressed())
self.files = set()
for name, timer in sorted(timers.items()):
self.logger.debug("Total time for pass '%s': %0.0fms", name, timer.elapsed)
def hash_combine(x, y):
return base64digest(x + u":" + y)
class CachingExtractor(object):
'''The caching extractor has a two stage initialization process.
After creating the extractor (which will check that the cachedir is valid)
set_worker(worker) must be called before the CachingExtractor is valid'''
def __init__(self, cachedir, options, logger: Logger):
if cachedir is None:
raise IOError("No cache directory")
makedirs(cachedir)
self.worker = None
self.cache = Cache.for_directory(cachedir, options.verbose)
self.logger = logger
self.split = options.split
def set_worker(self, worker):
self.worker = worker
def get_cache_key(self, module):
key = hash_combine(module.path, module.source)
if not self.split:
#Use different key, as not splitting will modify the trap file.
key = hash_combine(UTRAP_KEY, key)
return hash_combine(key, module.source)
def process_source_module(self, module):
'''Process a Python source module. First look up trap file in cache.
In no cached trap file is found, then delegate to normal extractor.
'''
if self.worker is None:
raise Exception("worker is not set")
key = self.get_cache_key(module)
trap = self.cache.get(key)
if trap is None:
trap = self.worker.process_source_module(module)
if trap is not None:
self.cache.set(key, trap)
else:
self.logger.debug(u"Found cached trap file for %s", module.path)
self.worker.trap_folder.write_trap("python", module.path, trap)
try:
self.worker.copy_source(module.bytes_source, module.trap_name, module.path)
except Exception:
self.logger.traceback(WARN)
return trap
def process_module(self, ast, module_tag, source_code, path, comments):
self.worker.process_module(ast, module_tag, source_code, path, comments)
def close(self):
self.worker.close()
def write_interpreter_data(self, sys_path):
self.worker.write_interpreter_data(sys_path)
def stop(self):
self.worker.stop()

View File

@@ -0,0 +1,377 @@
'''
Classes and functions for converting module names into paths and Extractables.
Implements standard Python import semantics, and is designed to be extensible
to handle additional features like stub and template files.
'''
import sys
import imp
import os.path
from semmle.util import FileExtractable, FolderExtractable, BuiltinModuleExtractable, PY_EXTENSIONS, get_analysis_major_version
from semmle.python.modules import PythonSourceModule, is_script
class Module(object):
'''A module. Modules are approximations
to Python module objects and are used for
analyzing imports.'''
IS_PACKAGE = False
path = None
respect_init = True
def __init__(self, name, package):
self.name = name
self.package = package
def get_sub_module(self, name):
'''gets the (immediate) sub-module with the given name'''
raise NotImplementedError()
def all_sub_modules(self):
'''returns an iterable of all the sub-modules of this module'''
raise NotImplementedError()
def get_extractable(self):
'''gets the Extractable for this module'''
raise NotImplementedError()
def find(self, name):
'''Returns the named sub-module of this module if this module
is a package, otherwise returns `None`'''
if '.' in name:
top, rest = name.split(".", 1)
pkg = self.get_sub_module(top)
return pkg.find(rest) if pkg else None
else:
return self.get_sub_module(name)
def is_package(self):
return self.IS_PACKAGE
class PyModule(Module):
' A Python source code module'
def __init__(self, name, package, path):
Module.__init__(self, name, package)
assert isinstance(path, str)
self.path = path
def get_sub_module(self, name):
return None
def all_sub_modules(self):
return ()
def get_extractable(self):
return FileExtractable(self.path)
def load(self, logger=None):
return PythonSourceModule(self.name, self.path, logger=logger)
def __str__(self):
return "Python module at %s" % self.path
class BuiltinModule(Module):
' A built-in module'
def __init__(self, name, package):
Module.__init__(self, name, package)
def get_sub_module(self, name):
return None
def all_sub_modules(self):
return ()
def get_extractable(self):
return BuiltinModuleExtractable(self.name)
def __str__(self):
return "Builtin module %s" % self.name
class FilePackage(Module):
' A normal package. That is a folder with an __init__.py'
IS_PACKAGE = True
def __init__(self, name, package, path, respect_init=True):
Module.__init__(self, name, package)
assert isinstance(path, str), type(path)
self.path = path
self.respect_init = respect_init
def get_sub_module(self, name):
modname = self.name + "." + name if self.name else None
basepath = os.path.join(self.path, name)
return _from_base(modname, basepath, self, self.respect_init)
def all_sub_modules(self):
return _from_folder(self.name, self.path, self, self.respect_init)
def load(self):
return None
def get_extractable(self):
return FolderExtractable(self.path)
def __str__(self):
return "Package at %s" % self.path
class PthPackage(Module):
"A built-in package object generated from a '.pth' file"
IS_PACKAGE = True
def __init__(self, name, package, search_path):
Module.__init__(self, name, package)
self.search_path = search_path
def get_sub_module(self, name):
mname = self.name + "." + name
for path in self.search_path:
mod = _from_base(mname, os.path.join(path, name), self)
if mod is not None:
return mod
return None
def all_sub_modules(self):
for path in self.search_path:
for mod in _from_folder(self.name, path, self):
yield mod
def load(self):
return None
def __str__(self):
return "Builtin package (.pth) %s %s" % (self.name, self.search_path)
def get_extractable(self):
return None
#Helper functions
def _from_base(name, basepath, pkg, respect_init=True):
if os.path.isdir(basepath):
if os.path.exists(os.path.join(basepath, "__init__.py")) or not respect_init:
return FilePackage(name, pkg, basepath, respect_init)
else:
return None
for ext in PY_EXTENSIONS:
filepath = basepath + ext
if os.path.isfile(filepath):
return PyModule(name, pkg, filepath)
return None
def _from_folder(name, path, pkg, respect_init=True):
for file in os.listdir(path):
fullpath = os.path.join(path, file)
if os.path.isdir(fullpath):
if os.path.exists(os.path.join(fullpath, "__init__.py")) or not respect_init:
yield FilePackage(name + "." + file if name else None, pkg, fullpath, respect_init)
base, ext = os.path.splitext(file)
if ext not in PY_EXTENSIONS:
continue
if os.path.isfile(fullpath):
yield PyModule(name + "." + base if name else None, pkg, fullpath)
class AbstractFinder(object):
def find(self, mod_name):
'''Find an extractable object given a module name'''
if '.' in mod_name:
top, rest = mod_name.split(".", 1)
pkg = self.find_top(top)
return pkg.find(rest) if pkg else None
else:
return self.find_top(mod_name)
def find_top(self, name):
'''Find module or package object given a simple (dot-less) name'''
raise NotImplementedError()
def name_from_path(self, path, extensions):
'''Find module or package object given a path'''
raise NotImplementedError()
class PyFinder(AbstractFinder):
__slots__ = [ 'path', 'respect_init', 'logger' ]
def __init__(self, path, respect_init, logger):
assert isinstance(path, str), path
self.path = os.path.abspath(path)
self.respect_init = respect_init
self.logger = logger
def find_top(self, mod_name):
basepath = os.path.join(self.path, mod_name)
return _from_base(mod_name, basepath, None, self.respect_init)
def name_from_path(self, path, extensions):
rel_path = _relative_subpath(path, self.path)
if rel_path is None:
return None
base, ext = os.path.splitext(rel_path)
if ext and ext not in extensions:
return None
return ".".join(base.split(os.path.sep))
def _relative_subpath(subpath, root):
'Returns the relative path if `subpath` is within `root` or `None` otherwise'
try:
relpath = os.path.relpath(subpath, root)
except ValueError:
#No relative path possible
return None
if relpath.startswith(os.pardir):
#Not in root:
return None
return relpath
class BuiltinFinder(AbstractFinder):
'''Finder for builtin modules that are already present in the VM
or can be guaranteed to load successfully'''
def __init__(self, logger):
self.modules = {}
for name, module in sys.modules.items():
self.modules[name] = module
try:
self.dynload_path = os.path.dirname(imp.find_module("_json")[1])
except Exception:
if os.name != "nt":
logger.warning("Failed to find dynload path")
self.dynload_path = None
def builtin_module(self, name):
if "." in name:
pname, name = name.rsplit(".", 1)
return BuiltinModule(name, self.builtin_module(pname))
return BuiltinModule(name, None)
def find(self, mod_name):
mod = super(BuiltinFinder, self).find(mod_name)
if mod is not None:
return mod
#Use `imp` module to find module
try:
_, filepath, mod_t = imp.find_module(mod_name)
except ImportError:
return None
#Accept builtin dynamically loaded modules like _ctypes or _json
if filepath and os.path.dirname(filepath) == self.dynload_path:
return BuiltinModule(mod_name, None)
return None
def find_top(self, mod_name):
if mod_name in self.modules:
mod = self.modules[mod_name]
if hasattr(mod, "__file__"):
return None
if hasattr(mod, "__path__"):
return PthPackage(mod_name, None, mod.__path__)
return BuiltinModule(mod_name, None)
if mod_name in sys.builtin_module_names:
return BuiltinModule(mod_name, None)
return None
def name_from_path(self, path, extensions):
return None
#Stub file handling
class StubFinder(PyFinder):
def __init__(self, logger):
try:
tools = os.environ['ODASA_TOOLS']
except KeyError:
tools = sys.path[1]
logger.debug("StubFinder: can't find ODASA_TOOLS, using '%s' instead", tools)
path = os.path.join(tools, "data", "python", "stubs")
super(StubFinder, self).__init__(path, True, logger)
def _finders_for_path(path, respect_init, logger):
finders = [ StubFinder(logger) ]
for p in path:
if p:
finders.append(PyFinder(p, respect_init, logger))
finders.append(BuiltinFinder(logger))
return finders
def finders_from_options_and_env(options, logger):
'''Return a list of finders from the given command line options'''
if options.path:
path = options.path + options.sys_path
else:
path = options.sys_path
path = [os.path.abspath(p) for p in path]
if options.exclude:
exclude = set(options.exclude)
trimmed_path = []
for p in path:
for x in exclude:
if p.startswith(x):
break
else:
trimmed_path.append(p)
path = trimmed_path
logger.debug("Finder path: %s", path)
logger.debug("sys path: %s", sys.path)
return _finders_for_path(path, options.respect_init, logger)
class Finder(object):
def __init__(self, finders, options, logger):
self.finders = finders
self.path_map = {}
self.logger = logger
self.respect_init = options.respect_init
def find(self, mod_name):
for finder in self.finders:
mod = finder.find(mod_name)
if mod is not None:
return mod
self.logger.debug("Cannot find module '%s'", mod_name)
return None
@staticmethod
def from_options_and_env(options, logger):
return Finder(finders_from_options_and_env(options, logger), options, logger)
def from_extractable(self, unit):
if isinstance(unit, FolderExtractable) or isinstance(unit, FileExtractable):
return self.from_path(unit.path)
return None
def from_path(self, path, extensions=PY_EXTENSIONS):
if path in self.path_map:
return self.path_map[path]
if not path or path == "/":
return None
is_python_2 = (get_analysis_major_version() == 2)
if os.path.isdir(path) and not os.path.exists(os.path.join(path, "__init__.py")) and (self.respect_init or not is_python_2):
return None
pkg = self.from_path(os.path.dirname(path))
mod = None
if os.path.isdir(path):
mod = FilePackage(None, pkg, path)
if os.path.isfile(path):
base, ext = os.path.splitext(path)
if ext in extensions:
mod = PyModule(None, pkg, path)
if is_script(path):
mod = PyModule(None, None, path)
self.path_map[path] = mod
return mod
def name_from_path(self, path, extensions=PY_EXTENSIONS):
for finder in self.finders:
name = finder.name_from_path(path, extensions)
if name is not None:
return name
return None

View File

@@ -0,0 +1,256 @@
import sys
from semmle.python import ast
from collections import namedtuple
from semmle.util import VERSION, get_analysis_major_version
from semmle.cache import Cache
from semmle.logging import INFO
#Maintain distinct version strings for distinct versions of Python
IMPORTS_KEY = 'import%s_%x%x' % (VERSION, sys.version_info[0], sys.version_info[1])
import pickle
__all__ = [ 'CachingModuleImporter', 'ModuleImporter', 'importer_from_options' ]
ImportStar = namedtuple('ImportStar', 'level module')
ImportExpr = namedtuple('ImportExpr', 'level module')
ImportMember = namedtuple('ImportMember', 'level module name')
def safe_string(txt):
try:
if isinstance(txt, bytes):
try:
return txt.decode(sys.getfilesystemencoding(), errors="replace")
except Exception:
return txt.decode("latin-1")
else:
return str(txt)
except Exception:
return u"?"
class SemmleImportError(Exception):
def __init__(self, module_name, *reasons):
reason_txt = u"".join(safe_string(reason) for reason in reasons)
module_name = safe_string(module_name)
if reason_txt:
message = u"Import of %s failed: %s.\n" % (module_name, reason_txt)
else:
message = u"Import of %s failed.\n" % module_name
Exception.__init__(self, message)
def write(self, out=sys.stdout):
out.write(self.args[0])
class CachingModuleImporter(object):
def __init__(self, cachedir, finder, logger):
self.worker = ModuleImporter(finder, logger)
if cachedir is None:
raise IOError("No cache directory")
self.cache = Cache.for_directory(cachedir, logger)
self.logger = logger
def get_imports(self, module, loaded_module):
import_nodes = self.get_import_nodes(loaded_module)
return self.worker.parse_imports(module, import_nodes)
def get_import_nodes(self, loaded_module):
key = loaded_module.get_hash_key(IMPORTS_KEY)
if key is None:
return self.worker.get_import_nodes(loaded_module)
imports = self.cache.get(key)
#Unpickle the data
if imports is not None:
try:
imports = pickle.loads(imports)
except Exception:
self.logger.debug("Failed to unpickle imports for %s", loaded_module.path)
imports = None
if imports is None:
imports = self.worker.get_import_nodes(loaded_module)
try:
data = pickle.dumps(imports)
self.cache.set(key, data)
except Exception as ex:
# Shouldn't really fail, but carry on anyway
self.logger.debug("Failed to save pickled imports to cache for %s: %s", loaded_module.path, ex)
else:
self.logger.debug("Cached imports file found for %s", loaded_module.path)
return imports
class ModuleImporter(object):
'Discovers and records which modules import which other modules'
def __init__(self, finder, logger):
self.finder = finder
self.logger = logger
self.failures = {}
def get_imports(self, module, loaded_module):
import_nodes = self.get_import_nodes(loaded_module)
return self.parse_imports(module, import_nodes)
def get_import_nodes(self, loaded_module):
'Return list of AST nodes representing imports'
try:
return imports_from_ast(loaded_module.py_ast)
except Exception as ex:
if isinstance(ex, SyntaxError):
# Example: `Syntax Error (line 123) in /home/.../file.py`
self.logger.warning("%s in %s", ex, loaded_module.path)
# no need to show traceback, it's not an internal bug
else:
self.logger.warning("Failed to analyse imports of %s : %s", loaded_module.path, ex)
self.logger.traceback(INFO)
return []
def _relative_import(self, module, level, mod_name, report_failure = True):
for i in range(level):
parent = module.package
if parent is None:
relative_name = level * u'.' + mod_name
if relative_name not in self.failures:
if report_failure:
self.logger.warning("Failed to find %s, no parent package of %s", relative_name, module)
self.failures[relative_name] = str(module)
return None
module = parent
res = module
if mod_name:
res = res.get_sub_module(mod_name)
if res is None and report_failure:
relative_name = level * '.' + mod_name
if relative_name not in self.failures:
self.logger.warning("Failed to find %s, %s has no module %s", relative_name, module, mod_name)
self.failures[relative_name] = str(module)
return res
def _absolute_import(self, module, mod_name):
try:
mod = self.finder.find(mod_name)
except SemmleImportError as ex:
if mod_name not in self.failures:
self.logger.warning("%s", ex)
self.failures[mod_name] = str(module)
return None
return mod
def parse_imports(self, module, import_nodes):
imports = set()
#If an imported module is a package, then yield its __init__ module as well
for imported in self._parse_imports_no_init(module, import_nodes):
if imported not in imports:
imports.add(imported)
assert imported is not None
yield imported
if not imported.is_package():
continue
init = imported.get_sub_module(u"__init__")
if init is not None and init not in imports:
yield init
def _parse_imports_no_init(self, module, import_nodes):
assert not module.is_package()
for node in import_nodes:
if node.module is None:
top = ''
parts = []
else:
parts = node.module.split('.')
top, parts = parts[0], parts[1:]
if node.level <= 0:
if get_analysis_major_version() < 3:
#Attempt relative import with level 1
imported = self._relative_import(module, 1, top, False)
if imported is None:
imported = self._absolute_import(module, top)
else:
imported = self._absolute_import(module, top)
else:
imported = self._relative_import(module, node.level, top)
if imported is None:
self.logger.debug("Unable to resolve import: %s", top)
continue
yield imported
for p in parts:
inner = imported.get_sub_module(p)
if inner is None:
self.logger.debug("Unable to resolve import: %s", p)
break
imported = inner
yield imported
if isinstance(node, ImportStar):
self.logger.debug("Importing all sub modules of %s", imported)
#If import module is a package then yield all sub_modules.
for mod in imported.all_sub_modules():
yield mod
elif isinstance(node, ImportMember):
mod = imported.get_sub_module(node.name)
if mod is not None:
self.logger.debug("Unable to resolve import: %s", node.name)
yield mod
def imports_from_ast(the_ast):
def walk(node, in_function, in_name_main):
if isinstance(node, ast.Module):
for import_node in walk(node.body, in_function, in_name_main):
yield import_node
elif isinstance(node, ast.ImportFrom):
yield ImportStar(node.module.level, node.module.name)
elif isinstance(node, ast.Import):
for alias in node.names:
imp = alias.value
if isinstance(imp, ast.ImportExpr):
yield ImportExpr(imp.level, imp.name)
else:
assert isinstance(imp, ast.ImportMember)
yield ImportMember(imp.module.level, imp.module.name, imp.name)
elif isinstance(node, ast.FunctionExpr):
for _, child in ast.iter_fields(node.inner_scope):
for import_node in walk(child, True, in_name_main):
yield import_node
elif isinstance(node, ast.Call):
# Might be a decorator
for import_node in walk(node.positional_args, in_function, in_name_main):
yield import_node
elif isinstance(node, list):
for n in node:
for import_node in walk(n, in_function, in_name_main):
yield import_node
elif isinstance(node, ast.stmt):
name_eq_main = is_name_eq_main(node)
for _, child in ast.iter_fields(node):
for import_node in walk(child, in_function, name_eq_main or in_name_main):
yield import_node
return list(walk(the_ast, False, False))
def name_from_expr(expr):
if isinstance(expr, ast.Name):
return expr.id
if isinstance(expr, ast.Attribute):
return name_from_expr(expr.value) + "." + expr.attr
raise ValueError("%s is not a name" % expr)
def is_name_eq_main(node):
if not isinstance(node, ast.If):
return False
try:
lhs = node.test.left
rhs = node.test.comparators[0]
return rhs.s == "__main__" and lhs.id == "__name__"
except Exception:
return False
def importer_from_options(options, finder, logger):
try:
importer = CachingModuleImporter(options.trap_cache, finder, logger)
except Exception as ex:
if options.trap_cache is not None:
logger.warn("Failed to create caching importer: %s", ex)
importer = ModuleImporter(finder, logger)
return importer

View File

@@ -0,0 +1,504 @@
#Much of the information in this file is hardcoded into parser.
#Modify with care and test well.
#It should be relatively safe to add fields.
from semmle.python.AstMeta import Node, PrimitiveNode, ClassNode, UnionNode, ListNode
from semmle.python.AstMeta import build_node_relations as _build_node_relations
string = PrimitiveNode('str', 'string', 'varchar(1)', 'string')
bytes_ = PrimitiveNode('bytes', 'string', 'varchar(1)')
location = PrimitiveNode('location', '@location', 'unique int')
variable = PrimitiveNode('variable', '@py_variable', 'int')
int_ = PrimitiveNode('int', 'int', 'int')
bool_ = PrimitiveNode('bool', 'boolean', 'boolean')
number = PrimitiveNode('number', 'string', 'varchar(1)')
Module = ClassNode('Module')
Class = ClassNode('Class')
Function = ClassNode('Function')
alias = ClassNode('alias')
arguments = ClassNode('arguments', None, 'parameters definition')
boolop = ClassNode('boolop', None, 'boolean operator')
cmpop = ClassNode('cmpop', None, 'comparison operator')
comprehension = ClassNode('comprehension')
comprehension.field('location', location)
expr = ClassNode('expr', None, 'expression')
expr.field('location', location)
expr.field('parenthesised', bool_, 'parenthesised')
expr_context = ClassNode('expr_context', None, 'expression context')
operator = ClassNode('operator')
stmt = ClassNode('stmt', None, 'statement')
stmt.field('location', location)
unaryop = ClassNode('unaryop', None, 'unary operation')
pattern = ClassNode('pattern')
pattern.field('location', location)
pattern.field('parenthesised', bool_, 'parenthesised')
Add = ClassNode('Add', operator, '+')
And = ClassNode('And', boolop, 'and')
Assert = ClassNode('Assert', stmt)
Assign = ClassNode('Assign', stmt, 'assignment')
Attribute = ClassNode('Attribute', expr)
AugAssign = ClassNode('AugAssign', stmt, 'augmented assignment statement')
AugLoad = ClassNode('AugLoad', expr_context, 'augmented-load')
AugStore = ClassNode('AugStore', expr_context, 'augmented-store')
BinOp = ClassNode('BinOp', expr, 'binary')
#Choose a name more consistent with other Exprs.
BinOp.set_name("BinaryExpr")
BitAnd = ClassNode('BitAnd', operator, '&')
BitOr = ClassNode('BitOr', operator, '|')
BitXor = ClassNode('BitXor', operator, '^')
BoolOp = ClassNode('BoolOp', expr, 'boolean')
#Avoid name clash with boolop
BoolOp.set_name('BoolExpr')
Break = ClassNode('Break', stmt)
Bytes = ClassNode('Bytes', expr)
Call = ClassNode('Call', expr)
ClassExpr = ClassNode('ClassExpr', expr, 'class definition')
Compare = ClassNode('Compare', expr)
Continue = ClassNode('Continue', stmt)
Del = ClassNode('Del', expr_context, 'deletion')
Delete = ClassNode('Delete', stmt)
Dict = ClassNode('Dict', expr, 'dictionary')
DictComp = ClassNode('DictComp', expr, 'dictionary comprehension')
Div = ClassNode('Div', operator, '/')
Ellipsis = ClassNode('Ellipsis', expr)
Eq = ClassNode('Eq', cmpop, '==')
ExceptStmt = ClassNode('ExceptStmt', stmt, 'except block')
ExceptGroupStmt = ClassNode('ExceptGroupStmt', stmt, 'except group block')
Exec = ClassNode('Exec', stmt)
Expr_stmt = ClassNode('Expr', stmt)
Expr_stmt.set_name('Expr_stmt')
FloorDiv = ClassNode('FloorDiv', operator, '//')
For = ClassNode('For', stmt)
FunctionExpr = ClassNode('FunctionExpr', expr, 'function definition')
GeneratorExp = ClassNode('GeneratorExp', expr, 'generator')
Global = ClassNode('Global', stmt)
Gt = ClassNode('Gt', cmpop, '>')
GtE = ClassNode('GtE', cmpop, '>=')
If = ClassNode('If', stmt)
IfExp = ClassNode('IfExp', expr, 'if')
Import = ClassNode('Import', stmt)
ImportExpr = ClassNode('ImportExpr', expr, 'import')
ImportMember = ClassNode('ImportMember', expr, 'from import')
ImportFrom = ClassNode('ImportFrom', stmt, 'import * statement')
In = ClassNode('In', cmpop)
Invert = ClassNode('Invert', unaryop, '~')
Is = ClassNode('Is', cmpop)
IsNot = ClassNode('IsNot', cmpop, 'is not')
LShift = ClassNode('LShift', operator, '<<')
Lambda = ClassNode('Lambda', expr)
List = ClassNode('List', expr)
ListComp = ClassNode('ListComp', expr, 'list comprehension')
Load = ClassNode('Load', expr_context)
Lt = ClassNode('Lt', cmpop, '<')
LtE = ClassNode('LtE', cmpop, '<=')
Match = ClassNode('Match', stmt)
#Avoid name clash with regex match
Match.set_name('MatchStmt')
Case = ClassNode('Case', stmt)
Guard = ClassNode('Guard', expr)
MatchAsPattern = ClassNode('MatchAsPattern', pattern)
MatchOrPattern = ClassNode('MatchOrPattern', pattern)
MatchLiteralPattern = ClassNode('MatchLiteralPattern', pattern)
MatchCapturePattern = ClassNode('MatchCapturePattern', pattern)
MatchWildcardPattern = ClassNode('MatchWildcardPattern', pattern)
MatchValuePattern = ClassNode('MatchValuePattern', pattern)
MatchSequencePattern = ClassNode('MatchSequencePattern', pattern)
MatchStarPattern = ClassNode('MatchStarPattern', pattern)
MatchMappingPattern = ClassNode('MatchMappingPattern', pattern)
MatchDoubleStarPattern = ClassNode('MatchDoubleStarPattern', pattern)
MatchKeyValuePattern = ClassNode('MatchKeyValuePattern', pattern)
MatchClassPattern = ClassNode('MatchClassPattern', pattern)
MatchKeywordPattern = ClassNode('MatchKeywordPattern', pattern)
Mod = ClassNode('Mod', operator, '%')
Mult = ClassNode('Mult', operator, '*')
Name = ClassNode('Name', expr)
Nonlocal = ClassNode('Nonlocal', stmt)
Not = ClassNode('Not', unaryop)
NotEq = ClassNode('NotEq', cmpop, '!=')
NotIn = ClassNode('NotIn', cmpop, 'not in')
Num = ClassNode('Num', expr, 'numeric literal')
Or = ClassNode('Or', boolop)
Param = ClassNode('Param', expr_context, 'parameter')
Pass = ClassNode('Pass', stmt)
Pow = ClassNode('Pow', operator, '**')
Print = ClassNode('Print', stmt)
RShift = ClassNode('RShift', operator, '>>')
Raise = ClassNode('Raise', stmt)
Repr = ClassNode('Repr', expr, 'backtick')
Return = ClassNode('Return', stmt)
Set = ClassNode('Set', expr)
SetComp = ClassNode('SetComp', expr, 'set comprehension')
#Add $ to name to prevent doc-gen adding sub type name
Slice = ClassNode('Slice', expr, '$slice')
Starred = ClassNode('Starred', expr)
Store = ClassNode('Store', expr_context)
Str = ClassNode('Str', expr, 'string literal')
Sub = ClassNode('Sub', operator, '-')
Subscript = ClassNode('Subscript', expr)
Try = ClassNode('Try', stmt)
Tuple = ClassNode('Tuple', expr)
UAdd = ClassNode('UAdd', unaryop, '+')
USub = ClassNode('USub', unaryop, '-')
UnaryOp = ClassNode('UnaryOp', expr, 'unary')
#Avoid name clash with 'unaryop'
UnaryOp.set_name('UnaryExpr')
While = ClassNode('While', stmt)
With = ClassNode('With', stmt)
Yield = ClassNode('Yield', expr)
YieldFrom = ClassNode('YieldFrom', expr, 'yield-from')
alias_list = ListNode(alias)
cmpop_list = ListNode(cmpop)
comprehension_list = ListNode(comprehension)
expr_list = ListNode(expr)
stmt_list = ListNode(stmt)
string_list = ListNode(string)
StringPart = ClassNode('StringPart', None, "implicitly concatenated part")
string_parts_list = ListNode(StringPart)
pattern_list = ListNode(pattern)
#Template AST Nodes
TemplateWrite = ClassNode('TemplateWrite', stmt, "template write statement")
TemplateDottedNotation = ClassNode('TemplateDottedNotation', expr, "template dotted notation expression")
Filter = ClassNode("Filter", expr, "template filter expression")
PlaceHolder = ClassNode('PlaceHolder', expr, "template place-holder expression")
Await = ClassNode('Await', expr)
MatMult = ClassNode('MatMult', operator, '@')
scope = UnionNode(Module, Class, Function)
scope.set_name('scope')
dict_item = ClassNode('dict_item')
#DoubleStar in calls fn(**{'a': 1, 'c': 3}, **{'b': 2, 'd': 4}) or dict displays {'a': 1, **{'b': 2, 'd': 4}}
DictUnpacking = ClassNode('DictUnpacking', dict_item, descriptive_name='dictionary unpacking')
KeyValuePair = ClassNode('KeyValuePair', dict_item, descriptive_name='key-value pair')
keyword = ClassNode('keyword', dict_item, descriptive_name='keyword argument')
#Initial name must match that in ast module.
FormattedStringLiteral = ClassNode("JoinedStr", expr, descriptive_name='formatted string literal')
FormattedStringLiteral.set_name("Fstring")
FormattedValue = ClassNode("FormattedValue", expr, descriptive_name='formatted value')
AnnAssign = ClassNode("AnnAssign", stmt, descriptive_name='annotated assignment')
AssignExpr = ClassNode('AssignExpr', expr, "assignment expression")
SpecialOperation = ClassNode('SpecialOperation', expr, "special operation")
type_parameter = ClassNode('type_parameter', descriptive_name='type parameter')
type_parameter.field('location', location)
type_parameter_list = ListNode(type_parameter)
TypeAlias = ClassNode('TypeAlias', stmt, 'type alias')
ParamSpec = ClassNode('ParamSpec', type_parameter, 'parameter spec')
TypeVar = ClassNode('TypeVar', type_parameter, 'type variable')
TypeVarTuple = ClassNode('TypeVarTuple', type_parameter, 'type variable tuple')
expr_or_stmt = UnionNode(expr, stmt)
dict_item_list = ListNode(dict_item)
ast_node = UnionNode(expr, stmt, pattern, Module, Class, Function, comprehension, StringPart, dict_item, type_parameter)
ast_node.set_name('ast_node')
parameter = UnionNode(Name, Tuple)
parameter.set_name('parameter')
parameter_list = ListNode(parameter)
alias.field('value', expr)
alias.field('asname', expr, 'name')
arguments.field('kw_defaults', expr_list, 'keyword-only default values')
arguments.field('defaults', expr_list, 'default values')
arguments.field('annotations', expr_list)
arguments.field('varargannotation', expr, '*arg annotation')
arguments.field('kwargannotation', expr, '**kwarg annotation')
arguments.field('kw_annotations', expr_list, 'keyword-only annotations')
Assert.field('test', expr, 'value being tested')
Assert.field('msg', expr, 'failure message')
Assign.field('value', expr)
Assign.field('targets', expr_list, 'targets')
Attribute.field('value', expr, 'object')
Attribute.field('attr', string, 'attribute name')
Attribute.field('ctx', expr_context, 'context')
AugAssign.field('operation', BinOp)
BinOp.field('left', expr, 'left sub-expression')
BinOp.field('op', operator, 'operator')
BinOp.field('right', expr, 'right sub-expression')
BoolOp.field('op', boolop, 'operator')
BoolOp.field('values', expr_list, 'sub-expressions')
Bytes.field('s', bytes_, 'value')
Bytes.field('prefix', bytes_, 'prefix')
Bytes.field('implicitly_concatenated_parts', string_parts_list)
Call.field('func', expr, 'callable')
Call.field('positional_args', expr_list, 'positional arguments')
Call.field('named_args', dict_item_list, 'named arguments')
Class.field('name', string)
Class.field('body', stmt_list)
ClassExpr.field('name', string)
ClassExpr.field('bases', expr_list)
ClassExpr.field('keywords', dict_item_list, 'keyword arguments')
ClassExpr.field('inner_scope', Class, 'class scope')
ClassExpr.field('type_parameters', type_parameter_list, 'type parameters')
Compare.field('left', expr, 'left sub-expression')
Compare.field('ops', cmpop_list, 'comparison operators')
Compare.field('comparators', expr_list, 'right sub-expressions')
comprehension.field('iter', expr, 'iterable')
comprehension.field('target', expr)
comprehension.field('ifs', expr_list, 'conditions')
Delete.field('targets', expr_list)
Dict.field('items', dict_item_list)
DictUnpacking.field('location', location)
DictUnpacking.field('value', expr)
DictComp.field('function', Function, 'implementation')
DictComp.field('iterable', expr)
ExceptStmt.field('type', expr)
ExceptStmt.field('name', expr)
ExceptStmt.field('body', stmt_list)
ExceptGroupStmt.field('type', expr)
ExceptGroupStmt.field('name', expr)
ExceptGroupStmt.field('body', stmt_list)
Exec.field('body', expr)
Exec.field('globals', expr)
Exec.field('locals', expr)
Expr_stmt.field('value', expr)
For.field('target', expr)
For.field('iter', expr, 'iterable')
For.field('body', stmt_list)
For.field('orelse', stmt_list, 'else block')
For.field('is_async', bool_, 'async')
Function.field('name', string)
Function.field('args', parameter_list, 'positional parameter list')
Function.field('vararg', expr, 'tuple (*) parameter')
Function.field('kwonlyargs', expr_list, 'keyword-only parameter list')
Function.field('kwarg', expr, 'dictionary (**) parameter')
Function.field('body', stmt_list)
Function.field('is_async', bool_, 'async')
Function.field('type_parameters', type_parameter_list, 'type parameters')
FunctionExpr.field('name', string)
FunctionExpr.field('args', arguments, 'parameters')
FunctionExpr.field('returns', expr, 'return annotation')
FunctionExpr.field('inner_scope', Function, 'function scope')
GeneratorExp.field('function', Function, 'implementation')
GeneratorExp.field('iterable', expr)
Global.field('names', string_list)
If.field('test', expr)
If.field('body', stmt_list, 'if-true block')
If.field('orelse', stmt_list, 'if-false block')
IfExp.field('test', expr)
IfExp.field('body', expr, 'if-true expression')
IfExp.field('orelse', expr, 'if-false expression')
Import.field('names', alias_list, 'alias list')
ImportFrom.set_name('ImportStar')
ImportFrom.field('module', expr)
ImportMember.field('module', expr)
ImportMember.field('name', string)
keyword.field('location', location)
keyword.field('value', expr)
keyword.field('arg', string)
KeyValuePair.field('location', location)
KeyValuePair.field('value', expr)
KeyValuePair.field('key', expr)
Lambda.field('args', arguments, 'arguments')
Lambda.field('inner_scope', Function, 'function scope')
List.field('elts', expr_list, 'element list')
List.field('ctx', expr_context, 'context')
#For Python 3 a new scope is created and these fields are populated:
ListComp.field('function', Function, 'implementation')
ListComp.field('iterable', expr)
#For Python 2 no new scope is created and these are populated:
ListComp.field('generators', comprehension_list)
ListComp.field('elt', expr, 'elements')
Match.field('subject', expr)
Match.field('cases', stmt_list)
Case.field('pattern', pattern)
Case.field('guard', expr)
Case.field('body', stmt_list)
Guard.field('test', expr)
MatchStarPattern.field('target', pattern)
MatchDoubleStarPattern.field('target', pattern)
MatchKeyValuePattern.field('key', pattern)
MatchKeyValuePattern.field('value', pattern)
MatchClassPattern.field('class', expr)
MatchKeywordPattern.field('attribute', expr)
MatchKeywordPattern.field('value', pattern)
MatchAsPattern.field('pattern', pattern)
MatchAsPattern.field('alias', expr)
MatchOrPattern.field('patterns', pattern_list)
MatchLiteralPattern.field('literal', expr)
MatchCapturePattern.field('variable', expr)
MatchValuePattern.field('value', expr)
MatchSequencePattern.field('patterns', pattern_list)
MatchMappingPattern.field('mappings', pattern_list)
MatchClassPattern.field('class_name', expr)
MatchClassPattern.field('positional', pattern_list)
MatchClassPattern.field('keyword', pattern_list)
Module.field('name', string)
Module.field('hash', string , 'hash (not populated)')
Module.field('body', stmt_list)
Module.field('kind', string)
ImportExpr.field('level', int_)
ImportExpr.field('name', string)
ImportExpr.field('top', bool_, 'top level')
Name.field('variable', variable)
Name.field('ctx', expr_context, 'context')
Nonlocal.field('names', string_list)
Num.field('n', number, 'value')
Num.field('text', number)
ParamSpec.field('name', expr)
Print.field('dest', expr, 'destination')
Print.field('values', expr_list)
Print.field('nl', bool_, 'new line')
#Python3 has exc & cause
Raise.field('exc', expr, 'exception')
Raise.field('cause', expr)
#Python2 has type, inst, tback
Raise.field('type', expr)
Raise.field('inst', expr, 'instance')
Raise.field('tback', expr, 'traceback')
Repr.field('value', expr)
Return.field('value', expr)
Set.field('elts', expr_list, 'elements')
SetComp.field('function', Function, 'implementation')
SetComp.field('iterable', expr)
Slice.field('start', expr)
Slice.field('stop', expr)
Slice.field('step', expr)
Starred.field('value', expr)
Starred.field('ctx', expr_context, 'context')
Str.field('s', string, 'text')
Str.field('prefix', string, 'prefix')
Str.field('implicitly_concatenated_parts', string_parts_list)
Subscript.field('value', expr)
Subscript.field('index', expr)
Subscript.field('ctx', expr_context, 'context')
Try.field('body', stmt_list)
Try.field('orelse', stmt_list, 'else block')
Try.field('handlers', stmt_list, 'exception handlers')
Try.field('finalbody', stmt_list, 'finally block')
Tuple.field('elts', expr_list, 'elements')
Tuple.field('ctx', expr_context, 'context')
TypeAlias.field('name', expr)
TypeAlias.field('type_parameters', type_parameter_list)
TypeAlias.field('value', expr)
TypeVar.field('name', expr)
TypeVar.field('bound', expr)
TypeVarTuple.field('name', expr)
UnaryOp.field('op', unaryop, 'operator')
UnaryOp.field('operand', expr)
While.field('test', expr)
While.field('body', stmt_list)
While.field('orelse', stmt_list, 'else block')
With.field('context_expr', expr, 'context manager')
With.field('optional_vars', expr, 'optional variable')
With.field('body', stmt_list)
With.field('is_async', bool_, 'async')
Yield.field('value', expr)
YieldFrom.field('value', expr)
#Template AST Nodes
TemplateWrite.field('value', expr)
TemplateDottedNotation.field('value', expr, 'object')
TemplateDottedNotation.field('attr', string, 'attribute name')
TemplateDottedNotation.field('ctx', expr_context, 'context')
Filter.field('value', expr, 'filtered value')
Filter.field('filter', expr, 'filter')
PlaceHolder.field('variable', variable)
PlaceHolder.field('ctx', expr_context, 'context')
StringPart.field('text', string)
StringPart.field('location', location)
Await.field('value', expr, 'expression waited upon')
FormattedStringLiteral.field('values', expr_list)
FormattedValue.field('value', expr, "expression to be formatted")
FormattedValue.field('conversion', string, 'type conversion')
FormattedValue.field('format_spec', FormattedStringLiteral, 'format specifier')
AnnAssign.field('value', expr)
AnnAssign.field('annotation', expr)
AnnAssign.field('target', expr)
SpecialOperation.field('name', string)
SpecialOperation.field('arguments', expr_list)
AssignExpr.field('value', expr)
AssignExpr.field('target', expr)
def all_nodes():
nodes = [ val for val in globals().values() if isinstance(val, Node) ]
return _build_node_relations(nodes)

View File

@@ -0,0 +1,214 @@
'''MODULE_TYPES: mapping from type-code returned by
imp.find_module to Module subclass'''
import semmle.python.parser.tokenizer
import semmle.python.parser.tsg_parser
import re
import os
from blib2to3.pgen2 import tokenize
import codecs
from semmle.python.passes.labeller import Labeller
from semmle.util import base64digest
from semmle.profiling import timers
__all__ = [ 'PythonSourceModule' ]
class PythonSourceModule(object):
kind = None
def __init__(self, name, path, logger, bytes_source = None):
assert isinstance(path, str), path
self.name = name # May be None
self.path = path
if bytes_source is None:
with timers["load"]:
with open(self.path, 'rb') as src:
bytes_source = src.read()
if BIN_PYTHON.match(bytes_source):
self.kind = "Script"
self._ast = None
self._py_ast = None
self._lines = None
self._line_types = None
self._comments = None
self._tokens = None
self.logger = logger
with timers["decode"]:
self.encoding, self.bytes_source = semmle.python.parser.tokenizer.encoding_from_source(bytes_source)
if self.encoding != 'utf-8':
logger.debug("File '%s' has encoding %s.", path, self.encoding)
try:
self._source = self.bytes_source.decode(self.encoding)
self._illegal_encoding = False
except Exception as ex:
self.logger.warning("%s has encoding '%s'", path, self.encoding)
#Set source to a latin-1 decoding of source string (which cannot fail).
#Attempting to get the AST will raise a syntax error as expected.
self._source = self.bytes_source.decode("latin-1")
self._illegal_encoding = str(ex)
self._source = normalize_line_endings(self._source)
#Strip BOM
if self._source.startswith(u'\ufeff'):
self._source = self._source[1:]
self._secure_hash = base64digest(self._source)
assert isinstance(self._source, str)
@property
def source(self):
return self._source
@property
def lines(self):
if self._lines is None:
def genline():
src = self._source
#Handle non-linux line endings
src = src.replace("\r\n", "\n").replace("\r", "\n")
length = len(src)
start = 0
while True:
end = src.find(u'\n', start)
if end < 0:
if start < length:
yield src[start:]
return
yield src[start:end+1]
start = end+1
self._lines = list(genline())
return self._lines
@property
def tokens(self):
if self._tokens is None:
with timers["tokenize"]:
tokenizer = semmle.python.parser.tokenizer.Tokenizer(self._source)
self._tokens = list(tokenizer.tokens())
return self._tokens
@property
def ast(self):
# The ast will be modified by the labeller, so we cannot share it with the py_ast property.
# However, we expect py_ast to be accessed and used before ast, so we avoid reparsing in that case.
if self._ast is None:
if self._illegal_encoding:
message = self._illegal_encoding
error = SyntaxError(message)
error.filename = self.path
error.lineno, error.offset = offending_byte_position(message, self.bytes_source)
raise error
self._ast = self.py_ast
self._ast.trap_name = self.trap_name
self._py_ast = None
with timers["label"]:
Labeller().apply(self)
return self._ast
@property
def old_py_ast(self):
# The py_ast is the raw ast from the Python parser.
if self._py_ast is None:
self._py_ast = semmle.python.parser.parse(self.tokens, self.logger)
return self._py_ast
@property
def py_ast(self):
try:
# First, try to parse the source with the old Python parser.
return self.old_py_ast
except Exception as ex:
# If that fails, try to parse the source with the new Python parser (unless it has been
# explicitly disabled).
#
# Like PYTHONUNBUFFERED for Python, we treat any non-empty string as meaning the
# flag is enabled.
# https://docs.python.org/3/using/cmdline.html#envvar-PYTHONUNBUFFERED
if os.environ.get("CODEQL_PYTHON_DISABLE_TSG_PARSER"):
if isinstance(ex, SyntaxError):
raise ex
else:
raise SyntaxError("Exception %s while parsing %s" % (ex, self.path))
else:
try:
self._py_ast = semmle.python.parser.tsg_parser.parse(self.path, self.logger)
return self._py_ast
except SyntaxError as ex:
raise ex
except Exception as ex:
raise SyntaxError("Exception %s in tsg-python while parsing %s" % (ex, self.path))
@property
def trap_name(self):
return type(self).__name__ + ':' + self.path + ":" + self._secure_hash
def get_hash_key(self, token):
return base64digest(self.path + u":" + self._secure_hash + token)
def get_encoding(self):
'Returns encoding of source'
return self.encoding
@property
def comments(self):
''' Returns an iterable of comments in the form:
test, start, end where start and end are line. column
pairs'''
if self._comments is None:
self._lexical()
return self._comments
def close(self):
self.bytes_source = None
self._source = None
self._ast = None
self._line_types = None
self._comments = None
self._lines = None
def _lexical(self):
self._comments = []
for kind, text, start, end in self.tokens:
if kind == tokenize.COMMENT:
self._comments.append((text, start, end))
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
NEWLINE = b'\n'
OFFENDING_BYTE_RE = re.compile(r"decode byte \w+ in position (\d+):")
def offending_byte_position(message, string):
m = OFFENDING_BYTE_RE.search(message)
if m is None:
return (0,0)
badposition = int(m.group(1))
prefix = string[:badposition]
line = prefix.count(NEWLINE) + 1
column = badposition - prefix.rfind(NEWLINE) - 1
return (line, column)
BIN_PYTHON = re.compile(b'#! *(/usr|/bin|/local)*/?(env)? *python')
def is_script(path):
'''Is the file at `path` a script? (does it start with #!... python)'''
try:
with open(path, "rb") as contents:
start = contents.read(100)
return bool(BIN_PYTHON.match(start))
except Exception:
return False
def normalize_line_endings(src):
#Our tokenizer expects single character `\n`, `\r` or `\f` as line endings.
src = src.replace(u'\r\n', u'\n')
#Our parser expects that there are no unterminated lines.
if src and src[-1] != u'\n':
return src + u'\n'
return src

View File

@@ -0,0 +1,153 @@
# Black's version of lib2to3 (modified)
from blib2to3.pytree import type_repr
from blib2to3 import pygram
from blib2to3.pgen2 import driver, token
from blib2to3.pgen2.parse import ParseError, Parser
from . import ast
from blib2to3.pgen2 import tokenize, grammar
from blib2to3.pgen2.token import tok_name
from semmle.profiling import timers
pygram.initialize()
syms = pygram.python_symbols
GRAMMARS = [
("Python 3", pygram.python3_grammar),
("Python 3 without async", pygram.python3_grammar_no_async),
("Python 2 with print as function", pygram.python2_grammar_no_print_statement),
("Python 2", pygram.python2_grammar),
]
SKIP_IF_SINGLE_CHILD_NAMES = {
'atom',
'power',
'test',
'not_test',
'and_test',
'or_test',
'suite',
'testlist',
'expr',
'xor_expr',
'and_expr',
'shift_expr',
'arith_expr',
'term',
'factor',
'testlist_gexp',
'exprlist',
'testlist_safe',
'old_test',
'comparison',
}
SKIP_IF_SINGLE_CHILD = {
val for name, val in
syms.__dict__.items()
if name in SKIP_IF_SINGLE_CHILD_NAMES
}
class Leaf(object):
__slots__ = "type", "value", "start", "end"
def __init__(self, type, value, start, end):
self.type = type
self.value = value
self.start = start
self.end = end
def __repr__(self):
"""Return a canonical string representation."""
return "%s(%s, %r)" % (self.__class__.__name__,
self.name,
self.value)
@property
def name(self):
return tok_name.get(self.type, self.type)
class Node(object):
__slots__ = "type", "children", "used_names"
def __init__(self, type, children):
self.type = type
self.children = children
@property
def start(self):
node = self
while isinstance(node, Node):
node = node.children[0]
return node.start
@property
def end(self):
node = self
while isinstance(node, Node):
node = node.children[-1]
return node.end
def __repr__(self):
"""Return a canonical string representation."""
return "%s(%s, %r)" % (self.__class__.__name__,
self.name,
self.children)
@property
def name(self):
return type_repr(self.type)
def convert(gr, raw_node):
type, value, context, children = raw_node
if children or type in gr.number2symbol:
# If there's exactly one child, return that child instead of
# creating a new node.
if len(children) == 1 and type in SKIP_IF_SINGLE_CHILD:
return children[0]
return Node(type, children)
else:
start, end = context
return Leaf(type, value, start, end)
def parse_tokens(gr, tokens):
"""Parse a series of tokens and return the syntax tree."""
p = Parser(gr, convert)
p.setup()
for tkn in tokens:
type, value, start, end = tkn
if type in (tokenize.COMMENT, tokenize.NL):
continue
if type == token.OP:
type = grammar.opmap[value]
if type == token.INDENT:
value = ""
if p.addtoken(type, value, (start, end)):
break
else:
# We never broke out -- EOF is too soon (how can this happen???)
raise parse.ParseError("incomplete input",
type, value, ("", start))
return p.rootnode
def parse(tokens, logger):
"""Given a string with source, return the lib2to3 Node."""
for name, grammar in GRAMMARS:
try:
with timers["parse"]:
cpt = parse_tokens(grammar, tokens)
with timers["rewrite"]:
return ast.convert(logger, cpt)
except ParseError as pe:
lineno, column = pe.context[1]
logger.debug("%s at line %d, column %d using grammar for %s", pe, lineno, column, name)
exc = SyntaxError("Syntax Error")
exc.lineno = lineno
exc.offset = column
raise exc

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,151 @@
# dump_ast.py
# Functions for dumping the internal Python AST in a human-readable format.
import sys
import semmle.python.parser.tokenizer
import semmle.python.parser.tsg_parser
from semmle.python.parser.tsg_parser import ast_fields
from semmle.python import ast
from semmle import logging
from semmle.python.modules import PythonSourceModule
def get_fields(cls):
"""Gets the fields of the given class, followed by the fields of its (single-inheritance)
superclasses, if any.
Only includes fields for classes in `ast_fields`."""
if cls not in ast_fields:
return ()
s = cls.__bases__[0]
return ast_fields[cls] + get_fields(s)
def missing_fields(known, node):
"""Returns a list of fields in `node` that are not in `known`."""
return [field
for field in dir(node)
if field not in known
and not field.startswith("_")
and not field in ("lineno", "col_offset")
and not (isinstance(node, ast.Name) and field == "id")
]
class AstDumper(object):
def __init__(self, output=sys.stdout, no_locations=False):
self.output = output
self.show_locations = not no_locations
def visit(self, node, level=0, visited=None):
if visited is None:
visited = set()
if node in visited:
output.write("{} CYCLE DETECTED!\n".format(indent))
return
visited = visited.union({node})
output = self.output
cls = node.__class__
name = cls.__name__
indent = ' ' * level
if node is None: # Special case for `None` to avoid printing `NoneType`.
name = 'None'
if cls == str: # Special case for bare strings
output.write("{}{}\n".format(indent, repr(node)))
return
# In some places, we have non-AST nodes in lists, and since these don't have a location, we
# simply print their name instead.
# `ast.arguments` is special -- it has fields but no location
if hasattr(node, 'lineno') and not isinstance(node, ast.arguments) and self.show_locations:
position = (node.lineno, node.col_offset, node._end[0], node._end[1])
output.write("{}{}: [{}, {}] - [{}, {}]\n".format(indent, name, *position))
else:
output.write("{}{}\n".format(indent, name))
fields = get_fields(cls)
unknown = missing_fields(fields, node)
if unknown:
output.write("{}UNKNOWN FIELDS: {}\n".format(indent, unknown))
for field in fields:
value = getattr(node, field, None)
# By default, the `parenthesised` field on expressions has no value, so it's easier to
# just not print it in that case.
if field == "parenthesised" and value is None:
continue
# Likewise, the default value for `is_async` is `False`, so we don't need to print it.
if field == "is_async" and value is False:
continue
output.write("{} {}:".format(indent,field))
if isinstance(value, list):
output.write(" [")
if len(value) == 0:
output.write("]\n")
continue
output.write("\n")
for n in value:
self.visit(n, level+2, visited)
output.write("{} ]\n".format(indent))
# Some AST classes are special in that the identity of the object is the only thing
# that matters (and they have no location info). For this reason we simply print the name.
elif isinstance(value, (ast.expr_context, ast.boolop, ast.cmpop, ast.operator, ast.unaryop)):
output.write(' {}\n'.format(value.__class__.__name__))
elif isinstance(value, ast.AstBase):
output.write("\n")
self.visit(value, level+2, visited)
else:
output.write(' {}\n'.format(repr(value)))
class StdoutLogger(logging.Logger):
def log(self, level, fmt, *args):
sys.stdout.write(fmt % args + "\n")
def old_parser(inputfile, logger):
mod = PythonSourceModule(None, inputfile, logger)
logger.close()
return mod.old_py_ast
def args_parser():
'Parse command_line, returning options, arguments'
from optparse import OptionParser
usage = "usage: %prog [options] python-file"
parser = OptionParser(usage=usage)
parser.add_option("-o", "--old", help="Dump old AST.", action="store_true")
parser.add_option("-n", "--new", help="Dump new AST.", action="store_true")
parser.add_option("-l", "--no-locations", help="Don't include location info in dump", action="store_true")
parser.add_option("-d", "--debug", help="Print debug information.", action="store_true")
return parser
def main():
parser = args_parser()
options, args = parser.parse_args(sys.argv[1:])
if options.debug:
global DEBUG
DEBUG = True
if len(args) != 1:
sys.stderr.write("Error: wrong number of arguments.\n")
parser.print_help()
sys.exit(1)
inputfile = args[0]
if options.old and options.new:
sys.stderr.write("Error: options --old and --new are mutually exclusive.\n")
sys.exit(1)
if not (options.old or options.new):
sys.stderr.write("Error: Must specify either --old or --new.\n")
sys.exit(1)
with StdoutLogger() as logger:
if options.old:
ast = old_parser(inputfile, logger)
else:
ast = semmle.python.parser.tsg_parser.parse(inputfile, logger)
AstDumper(no_locations=options.no_locations).visit(ast)
if __name__ == '__main__':
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,495 @@
# tsg_parser.py
# Functions and classes used for parsing Python files using `tree-sitter-graph`
from ast import literal_eval
import sys
import os
import semmle.python.parser
from semmle.python.parser.ast import copy_location, decode_str, split_string
from semmle.python import ast
import subprocess
from itertools import groupby
DEBUG = False
def debug_print(*args, **kwargs):
if DEBUG:
print(*args, **kwargs)
# Node ids are integers, and so to distinguish them from actual integers we wrap them in this class.
class Node(object):
def __init__(self, id):
self.id = id
def __repr__(self):
return "Node({})".format(self.id)
# A wrapper for nodes containing comments. The old parser does not create such nodes (and therefore
# there is no `ast.Comment` class) since it accesses the comments via the tokens for the given file.
class Comment(object):
def __init__(self, text):
self.text = text
def __repr__(self):
return "Comment({})".format(self.text)
class SyntaxErrorNode(object):
def __init__(self, source):
self.source = source
def __repr__(self):
return "SyntaxErrorNode({})".format(self.source)
# Mapping from tree-sitter CPT node kinds to their corresponding AST node classes.
tsg_to_ast = {name: cls
for name, cls in semmle.python.ast.__dict__.items()
if isinstance(cls, type) and ast.AstBase in cls.__mro__
}
tsg_to_ast["Comment"] = Comment
tsg_to_ast["SyntaxErrorNode"] = SyntaxErrorNode
# Mapping from AST node class to the fields of the node. The order of the fields is the order in
# which they will be output in the AST dump.
#
# These fields cannot be extracted automatically, so we set them manually.
ast_fields = {
ast.Module: ("body",), # Note: has no `__slots__` to inspect
Comment: ("text",), # Note: not an `ast` class
SyntaxErrorNode: ("source",), # Note: not an `ast` class
ast.Continue: (),
ast.Break: (),
ast.Pass: (),
ast.Ellipsis: (),
ast.MatchWildcardPattern: (),
}
# Fields that we don't want to dump on every single AST node. These are just the slots of the AST
# base class, consisting of all of the location information (which we print in a different way).
ignored_fields = semmle.python.ast.AstBase.__slots__
# Extract fields for the remaining AST classes
for name, cls in semmle.python.ast.__dict__.items():
if name.startswith("_"):
continue
if not hasattr(cls, "__slots__"):
continue
slots = tuple(field for field in cls.__slots__ if field not in ignored_fields)
if not slots:
continue
ast_fields[cls] = slots
# A mapping from strings to the AST node classes that represent things like operators.
# These have to be handled specially, because they have no location information.
locationless = {
"and": ast.And,
"or": ast.Or,
"not": ast.Not,
"uadd": ast.UAdd,
"usub": ast.USub,
"+": ast.Add,
"-": ast.Sub,
"~": ast.Invert,
"**": ast.Pow,
"<<": ast.LShift,
">>": ast.RShift,
"&": ast.BitAnd,
"|": ast.BitOr,
"^": ast.BitXor,
"load": ast.Load,
"store": ast.Store,
"del" : ast.Del,
"param" : ast.Param,
}
locationless.update(semmle.python.parser.ast.TERM_OP_CLASSES)
locationless.update(semmle.python.parser.ast.COMP_OP_CLASSES)
locationless.update(semmle.python.parser.ast.AUG_ASSIGN_OPS)
if 'CODEQL_EXTRACTOR_PYTHON_ROOT' in os.environ:
platform = os.environ['CODEQL_PLATFORM']
ext = ".exe" if platform == "win64" else ""
tools = os.path.join(os.environ['CODEQL_EXTRACTOR_PYTHON_ROOT'], "tools", platform)
tsg_command = [os.path.join(tools, "tsg-python" + ext )]
else:
# Get the path to the current script
script_path = os.path.dirname(os.path.realpath(__file__))
tsg_python_path = os.path.join(script_path, "../../../tsg-python")
cargo_file = os.path.join(tsg_python_path, "Cargo.toml")
tsg_command = ["cargo", "run", "--quiet", "--release", "--manifest-path="+cargo_file]
def read_tsg_python_output(path, logger):
# Mapping from node id (an integer) to a dictionary containing attribute data.
node_attr = {}
# Mapping a start node to a map from attribute names to lists of (value, end_node) pairs.
edge_attr = {}
command_args = tsg_command + [path]
p = subprocess.Popen(command_args, stdout=subprocess.PIPE)
for line in p.stdout:
line = line.decode(sys.getfilesystemencoding())
line = line.rstrip()
if line.startswith("node"): # e.g. `node 5`
current_node = int(line.split(" ")[1])
d = {}
node_attr[current_node] = d
in_node = True
elif line.startswith("edge"): # e.g. `edge 5 -> 6`
current_start, current_end = tuple(map(int, line[4:].split("->")))
d = edge_attr.setdefault(current_start, {})
in_node = False
else: # attribute, e.g. `_kind: "Class"`
key, value = line[2:].split(": ", 1)
if value.startswith("[graph node"): # e.g. `_skip_to: [graph node 5]`
value = Node(int(value.split(" ")[2][:-1]))
elif value == "#true": # e.g. `_is_parenthesised: #true`
value = True
elif value == "#false": # e.g. `top: #false`
value = False
elif value == "#null": # e.g. `exc: #null`
value = None
else: # literal values, e.g. `name: "k1.k2"` or `level: 5`
try:
if key =="s" and value[0] == '"': # e.g. `s: "k1.k2"`
value = evaluate_string(value)
else:
value = literal_eval(value)
if isinstance(value, bytes):
try:
value = value.decode(sys.getfilesystemencoding())
except UnicodeDecodeError:
# just include the bytes as-is
pass
except Exception as ex:
# We may not know the location at this point -- for instance if we forgot to set
# it -- but `get_location_info` will degrade gracefully in this case.
loc = ":".join(str(i) for i in get_location_info(d))
error = ex.args[0] if ex.args else "unknown"
logger.warning("Error '{}' while parsing value {} at {}:{}\n".format(error, repr(value), path, loc))
if in_node:
d[key] = value
else:
d.setdefault(key, []).append((value, current_end))
p.stdout.close()
p.terminate()
p.wait()
logger.info("Read {} nodes and {} edges from TSG output".format(len(node_attr), len(edge_attr)))
return node_attr, edge_attr
def evaluate_string(s):
s = literal_eval(s)
prefix, quotes, content = split_string(s, None)
ends_with_illegal_character = False
# If the string ends with the same quote character as the outer quotes (and/or backslashes)
# (e.g. the first string part of `f"""hello"{0}"""`), we must take care to not accidently create
# the ending quotes at the wrong place. To do this, we insert an extra space at the end (that we
# then must remember to remove later on.)
if content.endswith(quotes[0]) or content.endswith('\\'):
ends_with_illegal_character = True
content = content + " "
s = prefix.strip("fF") + quotes + content + quotes
s = literal_eval(s)
if isinstance(s, bytes):
s = decode_str(s)
if ends_with_illegal_character:
s = s[:-1]
return s
def resolve_node_id(id, node_attr):
"""Finds the end of a sequence of nodes linked by `_skip_to` fields, starting at `id`."""
while "_skip_to" in node_attr[id]:
id = node_attr[id]["_skip_to"].id
return id
def get_context(id, node_attr, logger):
"""Gets the context of the node with the given `id`. This is either whatever is stored in the
`ctx` attribute of the node, or the result of dereferencing a sequence of `_inherited_ctx` attributes."""
while "ctx" not in node_attr[id]:
if "_inherited_ctx" not in node_attr[id]:
logger.error("No context for node {} with attributes {}\n".format(id, node_attr[id]))
# A missing context is most likely to be a "load", so return that.
return ast.Load()
id = node_attr[id]["_inherited_ctx"].id
return locationless[node_attr[id]["ctx"]]()
def get_location_info(attrs):
"""Returns the location information for a node, depending on which fields are set.
In particular, more specific fields take precedence over (and overwrite) less specific fields.
So, `_start_line` and `_start_column` take precedence over `location_start`, which takes
precedence over `_location`. Likewise when `end` replaces `start` above.
If part of the location information is missing, the string `"???"` is substituted for the
missing bits.
"""
start_line = "???"
start_column = "???"
end_line = "???"
end_column = "???"
if "_location" in attrs:
(start_line, start_column, end_line, end_column) = attrs["_location"]
if "_location_start" in attrs:
(start_line, start_column) = attrs["_location_start"]
if "_location_end" in attrs:
(end_line, end_column) = attrs["_location_end"]
if "_start_line" in attrs:
start_line = attrs["_start_line"]
if "_start_column" in attrs:
start_column = attrs["_start_column"]
if "_end_line" in attrs:
end_line = attrs["_end_line"]
if "_end_column" in attrs:
end_column = attrs["_end_column"]
# Lines in the `tsg-python` output is 0-indexed, but the AST expects them to be 1-indexed.
if start_line != "???":
start_line += 1
if end_line != "???":
end_line += 1
return (start_line, start_column, end_line, end_column)
list_fields = {
ast.arguments: ("annotations", "defaults", "kw_defaults", "kw_annotations"),
ast.Assign: ("targets",),
ast.BoolOp: ("values",),
ast.Bytes: ("implicitly_concatenated_parts",),
ast.Call: ("positional_args", "named_args"),
ast.Case: ("body",),
ast.Class: ("body",),
ast.ClassExpr: ("type_parameters", "bases", "keywords"),
ast.Compare: ("ops", "comparators",),
ast.comprehension: ("ifs",),
ast.Delete: ("targets",),
ast.Dict: ("items",),
ast.ExceptStmt: ("body",),
ast.For: ("body",),
ast.Function: ("type_parameters", "args", "kwonlyargs", "body"),
ast.Global: ("names",),
ast.If: ("body",),
ast.Import: ("names",),
ast.List: ("elts",),
ast.Match: ("cases",),
ast.MatchClassPattern: ("positional", "keyword"),
ast.MatchMappingPattern: ("mappings",),
ast.MatchOrPattern: ("patterns",),
ast.MatchSequencePattern: ("patterns",),
ast.Module: ("body",),
ast.Nonlocal: ("names",),
ast.Print: ("values",),
ast.Set: ("elts",),
ast.Str: ("implicitly_concatenated_parts",),
ast.TypeAlias: ("type_parameters",),
ast.Try: ("body", "handlers", "orelse", "finalbody"),
ast.Tuple: ("elts",),
ast.While: ("body",),
# ast.FormattedStringLiteral: ("arguments",),
}
def create_placeholder_args(cls):
""" Returns a dictionary containing the placeholder arguments necessary to create an AST node.
In most cases these arguments will be assigned the value `None`, however for a few classes we
must substitute the empty list, as this is enforced by asserts in the constructor.
"""
if cls in (ast.Raise, ast.Ellipsis):
return {}
fields = ast_fields[cls]
args = {field: None for field in fields if field != "is_async"}
for field in list_fields.get(cls, ()):
args[field] = []
if cls in (ast.GeneratorExp, ast.ListComp, ast.SetComp, ast.DictComp):
del args["function"]
del args["iterable"]
return args
def parse(path, logger):
node_attr, edge_attr = read_tsg_python_output(path, logger)
debug_print("node_attr:", node_attr)
debug_print("edge_attr:", edge_attr)
nodes = {}
# Nodes that need to be fixed up after building the graph
fixups = {}
# Reverse index from node object to node id.
node_id = {}
# Create all the node objects
for id, attrs in node_attr.items():
if "_is_literal" in attrs:
nodes[id] = attrs["_is_literal"]
continue
if "_kind" not in attrs:
logger.error("Error: Graph node {} with attributes {} has no `_kind`!\n".format(id, attrs))
continue
# This is not the node we are looking for (so don't bother creating it).
if "_skip_to" in attrs:
continue
cls = tsg_to_ast[attrs["_kind"]]
args = ast_fields[cls]
obj = cls(**create_placeholder_args(cls))
nodes[id] = obj
node_id[obj] = id
# If this node needs fixing up afterwards, add it to the fixups map.
if "_fixup" in attrs:
fixups[id] = obj
# Set all of the node attributes
for id, node in nodes.items():
attrs = node_attr[id]
if "_is_literal" in attrs:
continue
expected_fields = ast_fields[type(node)]
# Set up location information.
node.lineno, node.col_offset, end_line, end_column = get_location_info(attrs)
node._end = (end_line, end_column)
if isinstance(node, SyntaxErrorNode):
exc = SyntaxError("Syntax Error")
exc.lineno = node.lineno
exc.offset = node.col_offset
raise exc
# Set up context information, if any
if "ctx" in expected_fields:
node.ctx = get_context(id, node_attr, logger)
# Set the fields.
for field, val in attrs.items():
if field.startswith("_"): continue
if field == "ctx": continue
if field != "parenthesised" and field not in expected_fields:
logger.warning("Unknown field {} found among {} in node {}\n".format(field, attrs, id))
# For fields that point to other AST nodes.
if isinstance(val, Node):
val = resolve_node_id(val.id, node_attr)
setattr(node, field, nodes[val])
# Special case for `Num.n`, which should be coerced to an int.
elif isinstance(node, ast.Num) and field == "n":
node.n = literal_eval(val.rstrip("lL"))
# Special case for `Name.variable`, for which we must create a new `Variable` object
elif isinstance(node, ast.Name) and field == "variable":
node.variable = ast.Variable(val)
# Special case for location-less leaf-node subclasses of `ast.Node`, such as `ast.Add`.
elif field == "op" and val in locationless.keys():
setattr(node, field, locationless[val]())
else: # Any other value, usually literals of various kinds.
setattr(node, field, val)
# Create all fields pointing to lists of values.
for start, field_map in edge_attr.items():
start = resolve_node_id(start, node_attr)
parent = nodes[start]
extra_fields = {}
for field_name, value_end in field_map.items():
# Sort children by index (in case they were visited out of order)
children = [nodes[resolve_node_id(end, node_attr)] for _index, end in sorted(value_end)]
# Skip any comments.
children = [child for child in children if not isinstance(child, Comment)]
# Special case for `Compare.ops`, a list of comparison operators
if isinstance(parent, ast.Compare) and field_name == "ops":
parent.ops = [locationless[v]() for v in children]
elif field_name.startswith("_"):
# We can only set the attributes given in `__slots__` on the `start` node, and so we
# must handle fields starting with `_` specially. In this case, we simply record the
# values and then subsequently update `edge_attr` to refer to these values. This
# makes it act as a pseudo-field, that we can access as long as we know the `id`
# corresponding to a given node (for which we have the `node_id` map).
extra_fields[field_name] = children
else:
setattr(parent, field_name, children)
if extra_fields:
# Extend the existing map in `node_attr` with the extra fields.
node_attr[start].update(extra_fields)
# Fixup any nodes that need it.
for id, node in fixups.items():
if isinstance(node, (ast.JoinedStr, ast.Str)):
fix_strings(id, node, node_attr, node_id, logger)
debug_print("nodes:", nodes)
if not nodes:
# if the file referenced by path is empty, return an empty module:
if os.path.getsize(path) == 0:
module = ast.Module([])
module.lineno = 1
module.col_offset = 0
module._end = (1, 0)
return module
else:
raise SyntaxError("Syntax Error")
# Fix up start location of outer `Module`.
module = nodes[0]
if module.body:
# Get the location of the first non-comment node.
module.lineno = module.body[0].lineno
else:
# No children! File must contain only comments! Pick the end location as the start location.
module.lineno = module._end[0]
return module
def get_JoinedStr_children(children):
"""
Folds the `Str` and `expr` parts of a `JoinedStr` into a single list, and does this for each
`JoinedStr` in `children`. Top-level `StringPart`s are included in the output directly.
"""
for child in children:
if isinstance(child, ast.JoinedStr):
for value in child.values:
yield value
elif isinstance(child, ast.StringPart):
yield child
else:
raise ValueError("Unexpected node type: {}".format(type(child)))
def concatenate_stringparts(stringparts, logger):
"""Concatenates the strings contained in the list of `stringparts`."""
try:
return "".join(decode_str(stringpart.s) for stringpart in stringparts)
except Exception as ex:
logger.error("Unable to concatenate string %s getting error %s", stringparts, ex)
return stringparts[0].s
def fix_strings(id, node, node_attr, node_id, logger):
"""
Reassociates the `StringPart` children of an implicitly concatenated f-string (`JoinedStr`)
"""
# Tests whether something is a string child
is_string = lambda node: isinstance(node, ast.StringPart)
# We have two cases to consider. Either we're given something that came from a
# `concatenated_string`, or something that came from an `formatted_string`. The latter case can
# be seen as a special case of the former where the list of children we consider is just the
# single f-string.
children = node_attr[id].get("_children", [node])
if isinstance(node, ast.Str):
# If the outer node is a `Str`, then we don't have to reassociate, since there are no
# f-strings.
# In this case we simply have to create the concatenation of its constituent parts.
node.implicitly_concatenated_parts = children
node.s = concatenate_stringparts(children, logger)
node.prefix = children[0].prefix
else:
# Otherwise, we first have to get the flattened list of all of the strings and/or
# expressions.
flattened_children = get_JoinedStr_children(children)
groups = [list(n) for _, n in groupby(flattened_children, key=is_string)]
# At this point, `values` is a list of lists, where each sublist is either:
# - a list of `StringPart`s, or
# - a singleton list containing an `expr`.
# Crucially, `StringPart` is _not_ an `expr`.
combined_values = []
for group in groups:
first = group[0]
if isinstance(first, ast.expr):
# If we have a list of expressions (which may happen if an interpolation contains
# multiple distinct expressions, such as f"{foo:{bar}}", which uses interpolation to
# also specify the padding dynamically), we simply append it.
combined_values.extend(group)
else:
# Otherwise, we have a list of `StringPart`s, and we need to create a `Str` node to
# it.
combined_string = concatenate_stringparts(group, logger)
str_node = ast.Str(combined_string, first.prefix, None)
copy_location(first, str_node)
# The end location should be the end of the last part (even if there is only one part).
str_node._end = group[-1]._end
if len(group) > 1:
str_node.implicitly_concatenated_parts = group
combined_values.append(str_node)
node.values = combined_values

View File

@@ -0,0 +1,11 @@
from abc import abstractmethod
class Pass(object):
'''The base class for all extractor passes.
Defines a single method 'extract' for all extractors to override'''
@abstractmethod
def extract(self, module, writer):
'''Extract trap file data from 'module', writing it to the writer.'''
pass

View File

@@ -0,0 +1,232 @@
from semmle.python import ast
import semmle.python.master
import sys
from semmle.python.passes._pass import Pass
from semmle.util import get_analysis_major_version
__all__ = [ 'ASTPass' ]
class ASTPass(Pass):
'''Extract relations from AST.
Use AST.Node objects to guide _walking of AST'''
name = "ast"
def __init__(self):
self.offsets = get_offset_table()
#Entry point
def extract(self, root, writer):
try:
self.writer = writer
if root is None:
return
self._emit_variable(ast.Variable("__name__", root))
self._emit_variable(ast.Variable("__package__", root))
# Introduce special variable "$" for use by the points-to library.
self._emit_variable(ast.Variable("$", root))
writer.write_tuple(u'py_extracted_version', 'gs', root.trap_name, get_analysis_major_version())
self._walk(root, None, 0, root, None)
finally:
self.writer = None
#Tree _walkers
def _get_walker(self, node):
if isinstance(node, list):
return self._walk_list
elif isinstance(node, ast.AstBase):
return self._walk_node
else:
return self._emit_primitive
def _walk(self, node, parent, index, scope, description):
self._get_walker(node)(node, parent, index, scope, description)
def _walk_node(self, node, parent, index, scope, _unused):
self._emit_node(node, parent, index, scope)
if type(node) is ast.Name:
assert (hasattr(node, 'variable') and
type(node.variable) is ast.Variable), (node, parent, index, scope)
if type(node) in (ast.Class, ast.Function):
scope = node
# For scopes with a `from ... import *` statement introduce special variable "*" for use by the points-to library.
if isinstance(node, ast.ImportFrom):
self._emit_variable(ast.Variable("*", scope))
for field_name, desc, child_node in iter_fields(node):
try:
index = self.offsets[(type(node).__name__, field_name)]
self._walk(child_node, node, index, scope, desc)
except ConsistencyError:
ex = sys.exc_info()[1]
ex.message += ' in ' + type(node).__name__
if hasattr(node, 'rewritten') and node.rewritten:
ex.message += '(rewritten)'
ex.message += '.' + field_name
raise
def _walk_list(self, node, parent, index, scope, description):
assert description.is_list(), description
if len(node) == 0:
return
else:
self._emit_list(node, parent, index, description)
for i, child in enumerate(node):
self._get_walker(child)(child, node, i, scope, description.item_type)
#Emitters
def _emit_node(self, ast_node, parent, index, scope):
t = type(ast_node)
node = _ast_nodes[t.__name__]
#Ensure all stmts have a list as a parent.
if isinstance(ast_node, ast.stmt):
assert isinstance(parent, list), (ast_node, parent)
if node.is_sub_type():
rel_name = node.super_type.relation_name()
shared_parent = not node.super_type.unique_parent
else:
rel_name = node.relation_name()
shared_parent = node.parents is None or not node.unique_parent
if rel_name[-1] != 's':
rel_name += 's'
if t.__mro__[1] in (ast.cmpop, ast.operator, ast.expr_context, ast.unaryop, ast.boolop):
#These nodes may be used more than once, but must have a
#unique id for each occurrence in the AST
fields = [ self.writer.get_unique_id() ]
fmt = 'r'
else:
fields = [ ast_node ]
fmt = 'n'
if node.is_sub_type():
fields.append(node.index)
fmt += 'd'
if parent:
fields.append(parent)
fmt += 'n'
if shared_parent:
fields.append(index)
fmt += 'd'
self.writer.write_tuple(rel_name, fmt, *fields)
if t.__mro__[1] in (ast.expr, ast.stmt):
self.writer.write_tuple(u'py_scopes', 'nn', ast_node, scope)
def _emit_variable(self, ast_node):
self.writer.write_tuple(u'variable', 'nns', ast_node, ast_node.scope, ast_node.id)
def _emit_name(self, ast_node, parent):
self._emit_variable(ast_node)
self.writer.write_tuple(u'py_variables', 'nn', ast_node, parent)
def _emit_primitive(self, val, parent, index, scope, description):
if val is None or val is False:
return
if isinstance(val, ast.Variable):
self._emit_name(val, parent)
return
assert not isinstance(val, ast.AstBase)
rel = description.relation_name()
if val is True:
if description.unique_parent:
self.writer.write_tuple(rel, 'n', parent)
else:
self.writer.write_tuple(rel, 'nd', parent, index)
else:
f = format_for_primitive(val, description)
if description.unique_parent:
self.writer.write_tuple(rel, f + 'n', val, parent)
else:
self.writer.write_tuple(rel, f + 'nd', val, parent, index)
def _emit_list(self, node, parent, index, description):
rel_name = description.relation_name()
if description.unique_parent:
self.writer.write_tuple(rel_name, 'nn', node, parent)
else:
self.writer.write_tuple(rel_name, 'nnd', node, parent, index)
_ast_nodes = semmle.python.master.all_nodes()
if get_analysis_major_version() < 3:
_ast_nodes['TryExcept'] = _ast_nodes['Try']
_ast_nodes['TryFinally'] = _ast_nodes['Try']
class ConsistencyError(Exception):
def __str__(self):
return self.message
def iter_fields(node):
desc = _ast_nodes[type(node).__name__]
for name, description, _, _, _ in desc.fields:
if hasattr(node, name):
yield name, description, getattr(node, name)
NUMBER_TYPES = (int, float)
def check_matches(node, node_type, owner, field):
if node_type is list:
if node.is_list():
return
else:
for t in node_type.__mro__:
if t.__name__ == node.__name__:
return
if node_type in NUMBER_TYPES and node.__name__ == 'number':
return
raise ConsistencyError("Found %s expected %s for field %s of %s" %
(node_type.__name__, node.__name__, field, owner.__name__))
def get_offset_table():
'''Returns mapping of (class_name, field_name)
pairs to offsets (in relation)'''
table = {}
nodes = _ast_nodes.values()
for node in nodes:
for field, _, offset, _, _, _ in node.layout:
table[(node.__name__, field)] = offset
try_node = _ast_nodes['Try']
for field, _, offset, _, _, _ in try_node.layout:
table[('TryFinally', field)] = offset
table[('TryExcept', field)] = offset
return table
def format_for_primitive(val, description):
if isinstance(val, str):
return 'u'
elif isinstance(val, bytes):
return 'b'
elif description.__name__ == 'int':
return 'd'
else:
return 'q'
class ASTVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
visitor function for every node found. This function may return a value
which is forwarded by the `visit` method.
This class is meant to be subclassed, with the subclass adding visitor
methods.
The visitor functions for the nodes are ``'visit_'`` + class name of the node.
"""
def _get_visit_method(self, node):
method = 'visit_' + node.__class__.__name__
return getattr(self, method, self.generic_visit)
def visit(self, node):
"""Visit a node."""
self._get_visit_method(node)(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
if isinstance(node, ast.AstBase):
for _, _, child in iter_fields(node):
self.visit(child)
elif isinstance(node, list):
for item in node:
self._get_visit_method(item)(item)

View File

@@ -0,0 +1,113 @@
from semmle.python import ast
from semmle.python.passes._pass import Pass
def write_exports(module, exports, writer):
for sym in exports:
writer.write_tuple(u'py_exports', 'ns', module, sym)
def list_of_symbols_from_expr(expr):
#This should be a list of constant strings
if isinstance(expr, (ast.List, ast.Tuple)):
exports = []
for item in expr.elts:
if isinstance(item, ast.Str):
exports.append(item.s)
return exports
return []
def is___all__(node):
try:
return isinstance(node, ast.Name) and node.variable.id == '__all__'
except Exception:
return False
def __all___from_stmt(stmt):
'''Returns None if __all__ is not defined.
If __all__ may be defined then return a conservative approximation'''
assert isinstance(stmt, ast.stmt)
if isinstance(stmt, ast.If):
body_exports = __all___from_stmt_list(stmt.body)
if stmt.orelse:
orelse_exports = __all___from_stmt_list(stmt.orelse)
else:
orelse_exports = None
# If __all__ = ... on one branch but not other then return []
# If defined on neither branch return None
if body_exports is None:
if orelse_exports is None:
return None
else:
return []
else:
if orelse_exports is None:
return []
else:
return set(body_exports).intersection(set(orelse_exports))
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if is___all__(target):
return list_of_symbols_from_expr(stmt.value)
return None
def __all___from_stmt_list(stmts):
assert isinstance(stmts, list)
exports = None
for stmt in stmts:
ex = __all___from_stmt(stmt)
if ex is not None:
exports = ex
return exports
def is_private_symbol(sym):
if sym[0] != '_':
return False
if len(sym) >= 4 and sym[:2] == '__' and sym[-2:] == '__':
return False
return True
def globals_from_tree(node, names):
'Add all globals defined in the tree to names'
if isinstance(node, list):
for subnode in node:
globals_from_tree(subnode, names)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name):
names.add(target.variable.id)
elif isinstance(node, ast.If):
if node.orelse:
left = set()
right = set()
globals_from_tree(node.body, left)
globals_from_tree(node.orelse, right)
names.update(left.intersection(right))
# Don't decent into other nodes.
def exports_from_ast(node):
'Get a list of symbols exported by the module from its ast.'
#Look for assignments to __all__
#If not available at top-level, then check if-statements,
#but ignore try-except and loops
assert type(node) is ast.Module
exports = __all___from_stmt_list(node.body)
if exports is not None:
return exports
# No explicit __all__ assignment so gather global assignments
exports = set()
globals_from_tree(node.body, exports)
return [ ex for ex in exports if not is_private_symbol(ex) ]
class ExportsPass(Pass):
'''Finds all 'exports' of a module. An export is a symbol that is defined
in the __all__ list or, if __all__ is undefined, is defined at top-level
and is not private'''
name = "exports"
def __init__(self):
pass
def extract(self, ast, writer):
exported = exports_from_ast(ast)
write_exports(ast, exported, writer)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
# Label an AST with symbol-tables.
# Follow ordering specified in Python/symtable.c
from semmle.python import ast
from semmle.python.passes.ast_pass import iter_fields, ASTVisitor
__all__ = [ 'Labeller' ]
class SymbolTable(ASTVisitor):
'''A symbol table for a Python scope.
Records uses and definitions, `global` and `nonlocal` statements for names in that scope'''
def __init__(self, scope):
self.definitions = set()
self.uses = set()
self.declared_as_global = set()
self.declared_as_nonlocal = set()
for _, _, child in iter_fields(scope):
self.visit(child)
def visit_Class(self, node):
pass
def visit_Function(self, node):
pass
def visit_Name(self, node):
name = node.variable.id
if isinstance(node.ctx, ast.Load):
self.uses.add(name)
elif isinstance(node.ctx, (ast.Store, ast.Param, ast.Del)):
self.definitions.add(name)
else:
raise Exception("Unknown context for name: %s" % node.ctx)
def visit_Global(self, node):
self.declared_as_global.update(node.names)
def visit_Nonlocal(self, node):
self.declared_as_nonlocal.update(node.names)
def is_bound(self, name):
declared_free = name in self.declared_as_global or name in self.declared_as_nonlocal
return name in self.definitions and not declared_free
class _LabellingContext(ASTVisitor):
def __init__(self, scope, module = None, outer = None):
'''Create a labelling context for `scope`. `module` is the module containing the scope,
and outer is the enclosing context, if any'''
self.symbols = SymbolTable(scope)
self.scope = scope
self.outer = outer
if module is None:
module = scope
self.module = module
def label(self):
'Label the node with this context'
self.visit(self.module)
def visit_Function(self, node):
sub_context = _LabellingContext(node, self.module, self)
for _, _, child in iter_fields(node):
sub_context.visit(child)
visit_Class = visit_Function
def visit_Variable(self, node):
if node.scope is not None:
return
name = node.id
if name in self.symbols.declared_as_global:
node.scope = self.module
elif self.symbols.is_bound(name):
node.scope = self.scope
else: # Free variable, either implicitly or explicitly via nonlocal.
outer = self.outer
while outer is not None:
if isinstance(outer.scope, ast.Class):
# in the code example below, the use of `baz` inside `func` is NOT a reference to the
# function defined on the class, but is a reference to a global variable.
#
# The use of `baz` on class scope -- `bazzed = baz("class-scope")`
# -- is a reference to the function defined on the
#
# ```py
# class Foo
# def baz(arg):
# return arg + "-baz"
# def func(self):
# return baz("global-scope")
# bazzed = baz("class-scope")
# ```
#
# So we skip over class scopes.
#
# See ql/python/ql/test/library-tests/variables/scopes/in_class.py
# added in https://github.com/github/codeql/pull/10171
pass
elif outer.symbols.is_bound(name):
node.scope = outer.scope
break
outer = outer.outer
else:
node.scope = self.module
class Labeller(object):
'''Labels the ast using symbols generated by the symtable module'''
def apply(self, module):
'Apply this Labeller to the module'
#Ensure that AST root nodes have a globally consistent identifier
if module.ast is None:
return
_LabellingContext(module.ast).label()

View File

@@ -0,0 +1,153 @@
import ast
import sys
import math
from semmle.python.passes.ast_pass import iter_fields
from semmle.python import ast
from semmle.python.passes._pass import Pass
__all__ = [ 'LexicalPass' ]
STMT_OR_EXPR = ast.expr, ast.stmt
LOCATABLE = STMT_OR_EXPR + (ast.pattern, ast.comprehension, ast.StringPart, ast.keyword, ast.KeyValuePair, ast.DictUnpacking, ast.type_parameter)
CLASS_OR_FUNCTION = ast.Class, ast.Function
SCOPES = ast.Class, ast.Function, ast.Module
class LexicalPass(Pass):
def extract(self, ast, comments, writer):
'The entry point'
LexicalModule(ast, comments, writer).extract()
class LexicalModule(object):
'Object for extracting lexical information for the given module.'
def __init__(self, ast, comments, writer):
assert ast is not None and comments is not None
self.ast = ast
self.comments = comments
self.writer = writer
self.module_id = writer.get_node_id(ast)
def extract(self):
loc_id = self.get_location(0, 0, 0, 0)
self.writer.write_tuple(u'py_scope_location', 'rr', loc_id, self.module_id)
self.emit_line_info()
self.emit_locations(self.ast)
def emit_line_info(self):
for text, start, end in self.comments:
#Generate a unique string for comment based on location
comment_id = str(start + end)
loc_id = self.get_location(start[0], start[1]+1,
end[0], end[1])
try:
self.writer.write_tuple(u'py_comments', 'nsr',
comment_id, text, loc_id)
except UnicodeDecodeError:
# Handle non-ascii comments. Should only happen in Py2
assert sys.hexversion < 0x03000000
text = text.decode("latin8")
self.writer.write_tuple(u'py_comments', 'nsr',
comment_id, text, loc_id)
comment_bits = get_comment_bits(self.comments)
self.emit_line_counts(self.ast, set(), comment_bits)
def emit_line_counts(self, node, code_lines, comment_bits):
if isinstance(node, SCOPES) and node.body:
doc_line_count = 0
stmt0 = node.body[0]
if type(stmt0) == ast.Expr:
docstring = stmt0.value
if isinstance(docstring, ast.Str):
doc_line_count = docstring._end[0] - docstring.lineno + 1
inner_code_lines = set()
inner_code_lines.add(node.lineno)
for _, _, child_node in iter_fields(node):
self.emit_line_counts(child_node, inner_code_lines, comment_bits)
assert inner_code_lines
startline = min(inner_code_lines)
endline = max(inner_code_lines)
if isinstance(node, ast.Module):
endline = max(endline, last_line(comment_bits))
comment_line_count = get_lines_in_range(comment_bits, startline, endline)
code_line_count = len(inner_code_lines) - doc_line_count
code_lines.update(inner_code_lines)
self.print_lines(u'code', node, code_line_count)
self.print_lines(u'comment', node, comment_line_count)
self.print_lines(u'docstring', node, doc_line_count)
self.print_lines(u'all', node, endline - startline + 1)
if isinstance(node, ast.Module):
total_lines = code_line_count + comment_line_count + doc_line_count
self.writer.write_tuple(u'numlines', 'rddd', self.module_id, total_lines, code_line_count, comment_line_count + doc_line_count)
elif isinstance(node, list):
for n in node:
self.emit_line_counts(n, code_lines, comment_bits)
elif isinstance(node, STMT_OR_EXPR):
for _, _, child_node in iter_fields(node):
self.emit_line_counts(child_node, code_lines, comment_bits)
assert hasattr(node, "lineno"), node
line = node.lineno
endline, _ = node._end
while line <= endline:
code_lines.add(line)
line += 1
def print_lines(self, name, node, count):
self.writer.write_tuple(u'py_%slines' % name, 'nd', node, count)
def get_location(self, bl, bc, el, ec):
loc_id = self.writer.get_unique_id()
self.writer.write_tuple(u'locations_ast', 'rrdddd',
loc_id, self.module_id, bl, bc, el, ec)
return loc_id
def emit_locations(self, node):
if isinstance(node, ast.AstBase):
if isinstance(node, LOCATABLE):
self._write_location(node)
elif isinstance(node, CLASS_OR_FUNCTION):
bl, bc = node.lineno, node.col_offset+1
el, ec = node._end
loc_id = self.get_location(bl, bc, el, ec)
self.writer.write_tuple(u'py_scope_location', 'rn', loc_id, node)
for _, _, child_node in iter_fields(node):
self.emit_locations(child_node)
elif isinstance(node, list):
for n in node:
self.emit_locations(n)
def _write_location(self, node):
bl, bc = node.lineno, node.col_offset+1
assert len(node._end) == 2, node
el, ec = node._end
loc_id = self.get_location(bl, bc, el, ec)
self.writer.write_tuple(u'py_locations', 'rn', loc_id, node)
def get_comment_bits(comments):
comment_bits = 0
for _, start, end in comments:
line, _ = start
end_line, _ = end
while line <= end_line:
comment_bits |= (1<<line)
line += 1
return comment_bits
def get_lines_in_range(bits, start, end):
if end >= 0:
length = end - start + 1
if length < 0:
return 0
section = bits >> start
section &= (1 << length) - 1
else:
section = bits >> start
return bin(section).count('1')
def last_line(n):
if n <= 0:
return 0
return int(math.log(n, 2))

View File

@@ -0,0 +1,380 @@
import ast
import sys
from types import ModuleType, GetSetDescriptorType
import hashlib
import os
from semmle.python import ast
from semmle.python.passes._pass import Pass
from semmle.util import get_analysis_major_version
from semmle.python.passes.ast_pass import iter_fields
from semmle.cmdline import is_legal_module_name
'''
The QL library depends on a reasonable one-to-one correspondence
between DB entities and Python objects. However, since QL has only
one notion of equality, but Python has two (`__eq__` and `is`) we need to be careful.
What we want to do is to treat objects like builtin functions and classes as using
reference equality and numbers and strings as using value equality.
In practice this is impossible as we want to distinguish `True` from `1` from `1.0`
even though all these values are equal. However, we want to get as close as possible.
'''
__all__ = [ 'ObjectPass' ]
OBJECT_TYPES = set([ ast.ClassExpr, ast.Call,
ast.FunctionExpr, ast.Tuple,
ast.Str, ast.Num, ast.List, ast.ListComp, ast.Module,
ast.Dict, ast.Ellipsis, ast.Lambda])
# Types from Python 2.7 onwards
OBJECT_TYPES.add(ast.DictComp)
OBJECT_TYPES.add(ast.SetComp)
OBJECT_TYPES.add(ast.Set)
NUMERIC_TYPES = set([int, float, bool])
BUILTINS_NAME = 'builtins'
LITERALS = (ast.Num, ast.Str)
class _CObject(object):
'''Utility class to wrap arbitrary C objects.
Treat all objects as unique. Rely on naming in the
trap files to merge the objects that we want merged.
'''
__slots__ = ['obj']
def __init__(self, obj):
self.obj = obj
def __eq__(self, other):
if isinstance(other, _CObject):
return self.obj is other.obj
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return id(self.obj)
class ObjectPass(Pass):
'''Generates relations for objects. This includes information about
builtin objects, including their types and members.
It also generates objects for all literal values present in the Python source.'''
def extract(self, ast, path, writer):
self.writer = writer
try:
self._extract_py(ast)
self._extract_possible_module_names(path)
finally:
self.writer = None
def _extract_possible_module_names(self, path):
maybe_name, _ = os.path.splitext(path)
maybe_name = maybe_name.replace(os.sep, ".")
while maybe_name.count(".") > 3:
_, maybe_name = maybe_name.split(".", 1)
while True:
if is_legal_module_name(maybe_name):
self._write_module_and_package_names(maybe_name)
if "." not in maybe_name:
return
_, maybe_name = maybe_name.split(".", 1)
def _write_module_and_package_names(self, module_name):
self._write_c_object(module_name, None, False)
while "." in module_name:
module_name, _ = module_name.rsplit(".", 1)
self._write_c_object(module_name, None, False)
def extract_builtin(self, module, writer):
self.writer = writer
try:
self._extract_c(module)
finally:
self.writer = None
def _extract_c(self, mod):
self.next_address_label = 0
self.address_labels = {}
self._write_c_object(mod, None, False)
self.address_labels = None
def _write_str(self, s):
assert type(s) is str
self._write_c_object(s, None, False)
def _write_c_object(self, obj, label, write_special, string_prefix=""):
ANALYSIS_MAJOR_VERSION = get_analysis_major_version()
# If we're extracting Python 2 code using Python 3, we want to treat `str` as `bytes` for
# the purposes of determining the type, but we still want to treat the _value_ as if it's a `str`.
obj_type = type(obj)
if obj_type == str and ANALYSIS_MAJOR_VERSION == 2 and 'u' not in string_prefix:
obj_type = bytes
cobj = _CObject(obj)
if self.writer.has_written(cobj):
return self.writer.get_node_id(cobj)
obj_label = self.get_label_for_object(obj, label, obj_type)
obj_id = self.writer.get_labelled_id(cobj, obj_label)
#Avoid writing out all the basic types for every C module.
if not write_special and cobj in SPECIAL_OBJECTS:
return obj_id
type_id = self._write_c_object(obj_type, None, write_special)
self.writer.write_tuple(u'py_cobjects', 'r', obj_id)
self.writer.write_tuple(u'py_cobjecttypes', 'rr', obj_id, type_id)
self.writer.write_tuple(u'py_cobject_sources', 'rd', obj_id, 0)
if isinstance(obj, ModuleType) or isinstance(obj, type):
for name, value in sorted(obj.__dict__.items()):
if (obj, name) in SKIPLIST:
continue
val_id = self._write_c_object(value, obj_label + u'$%d' % ANALYSIS_MAJOR_VERSION + name, write_special)
self.writer.write_tuple(u'py_cmembers_versioned', 'rsrs',
obj_id, name, val_id, ANALYSIS_MAJOR_VERSION)
if isinstance(obj, type) and obj is not object:
super_id = self._write_c_object(obj.__mro__[1], None, write_special)
self.writer.write_tuple(u'py_cmembers_versioned', 'rsrs',
obj_id, u".super.", super_id, ANALYSIS_MAJOR_VERSION)
if isinstance(obj, (list, tuple)):
for index, item in enumerate(obj):
item_id = self._write_c_object(item, obj_label + u'$' + str(index), write_special)
self.writer.write_tuple(u'py_citems', 'rdr',
obj_id, index, item_id)
if type(obj) is GetSetDescriptorType:
for name in type(obj).__dict__:
if name == '__name__' or not hasattr(obj, name):
continue
val_id = self._write_c_object(getattr(obj, name), obj_label + u'$%d' % ANALYSIS_MAJOR_VERSION + name, write_special)
self.writer.write_tuple(u'py_cmembers_versioned', 'rsrs',
obj_id, name, val_id, ANALYSIS_MAJOR_VERSION)
if hasattr(obj, '__name__'):
#Use qualified names for classes.
if isinstance(obj, type):
name = qualified_type_name(obj)
# https://bugs.python.org/issue18602
elif isinstance(obj, ModuleType) and obj.__name__ == "io":
name = "_io"
elif obj is EXEC:
name = "exec"
else:
name = obj.__name__
self.writer.write_tuple(u'py_cobjectnames', 'rs',
obj_id, name)
elif type(obj) in NUMERIC_TYPES:
self.writer.write_tuple(u'py_cobjectnames', 'rq',
obj_id, obj)
elif type(obj) is str:
if 'b' in string_prefix:
prefix = u"b"
elif 'u' in string_prefix:
prefix = u"u"
else:
if ANALYSIS_MAJOR_VERSION == 2:
prefix = u"b"
else:
prefix = u"u"
self.writer.write_tuple(u'py_cobjectnames', 'rs',
obj_id, prefix + u"'" + obj + u"'")
elif type(obj) is bytes:
#Convert bytes to a unicode characters one-to-one.
obj_string = u"b'" + obj.decode("latin-1") + u"'"
self.writer.write_tuple(u'py_cobjectnames', 'rs',
obj_id, obj_string)
elif type(obj) is type(None):
self.writer.write_tuple(u'py_cobjectnames', 'rs',
obj_id, u'None')
else:
self.writer.write_tuple(u'py_cobjectnames', 'rs',
obj_id, u'object')
return obj_id
def write_special_objects(self, writer):
'''Write important builtin objects to the trap file'''
self.writer = writer
self.next_address_label = 0
self.address_labels = {}
def write(obj, name, label=None):
obj_id = self._write_c_object(obj, label, True)
self.writer.write_tuple(u'py_special_objects', 'rs', obj_id, name)
for obj, name in SPECIAL_OBJECTS.items():
write(obj.obj, name)
###Artificial objects for use by the type-inferencer - Make sure that they are unique.
write(object(), u"_semmle_unknown_type", u"$_semmle_unknown_type")
write(object(), u"_semmle_undefined_value", u"$_semmle_undefined_value")
self.writer = None
self.address_labels = None
def get_label_for_object(self, obj, default_label, obj_type):
"""Gets a label for an object. Attempt to make this as universal as possible.
The object graph in the database should reflect the real object graph,
only rarely diverging. This should be true even in highly parallel environments
including cases where trap files may be overwritten.
Proviso: Distinct immutable primitive objects may be merged (which should be benign)
For objects without a unambiguous global name, 'default_label' is used.
"""
#This code must be robust against (possibly intentionally) incorrect implementations
#of the object model.
if obj is None:
return u"C_None"
t = type(obj)
t_name = t.__name__
if t is tuple and len(obj) == 0:
return u"C_EmptyTuple"
if obj_type is str:
prefix = u"C_unicode$"
else:
prefix = u"C_bytes$"
if t is str:
obj = obj.encode("utf8", errors='replace')
return prefix + hashlib.sha1(obj).hexdigest()
if t is bytes:
return prefix + hashlib.sha1(obj).hexdigest()
if t in NUMERIC_TYPES:
return u"C_" + t_name + u"$" + repr(obj)
try:
if isinstance(obj, type):
return u"C_" + t_name + u"$" + qualified_type_name(obj)
except Exception:
#Misbehaved object.
return default_label
if t is ModuleType:
return u"C_" + t_name + u"$" + obj.__name__
if t is type(len):
mod_name = obj.__module__
if isinstance(mod_name, str):
if mod_name == BUILTINS_NAME:
mod_name = "builtins"
return u"C_" + t_name + u"$" + mod_name + "." + obj.__name__
return default_label
# Python files -- Extract objects for all numeric and string values.
def _extract_py(self, ast):
self._walk_py(ast)
def _write_literal(self, node):
if isinstance(node, ast.Num):
self._write_c_object(node.n, None, False)
else:
prefix = getattr(node, "prefix", "")
# Output both byte and unicode objects if the relevant objects could exist
# Non-prefixed strings can be either bytes or unicode.
if 'u' not in prefix:
try:
self._write_c_object(node.s.encode("latin-1"), None, False, string_prefix=prefix)
except UnicodeEncodeError:
#If not encodeable as latin-1 then it cannot be bytes
pass
if 'b' not in prefix:
self._write_c_object(node.s, None, False, string_prefix=prefix)
def _walk_py(self, node):
if isinstance(node, ast.AstBase):
if isinstance(node, LITERALS):
self._write_literal(node)
else:
for _, _, child_node in iter_fields(node):
self._walk_py(child_node)
elif isinstance(node, list):
for n in node:
self._walk_py(n)
def a_function():
pass
def a_generator_function():
yield None
class C(object):
def meth(self):
pass
#Create an object for 'exec', as parser no longer treats it as statement.
# Use `[].append` as it has the same type as `exec`.
EXEC = [].append
SPECIAL_OBJECTS = {
type(a_function): u"FunctionType",
type(len): u"BuiltinFunctionType",
classmethod: u"ClassMethod",
staticmethod: u"StaticMethod",
type(sys): u"ModuleType",
type(a_generator_function()): u"generator",
None: u"None",
type(None): u"NoneType",
True: u"True",
False: u"False",
bool: u"bool",
sys: u"sys",
Exception: u"Exception",
BaseException: u"BaseException",
TypeError: u"TypeError",
AttributeError: u"AttributeError",
KeyError: u"KeyError",
int: u"int",
float: u"float",
object: u"object",
type: u"type",
tuple: u"tuple",
dict: u"dict",
list: u"list",
set: u"set",
locals: u"locals",
globals: u"globals",
property: u"property",
type(list.append): u"MethodDescriptorType",
super: u"super",
type(C().meth): u"MethodType",
#For future enhancements
object(): u"_1",
object(): u"_2",
#Make sure we have all version numbers as single character strings.
b'2': u'b2',
b'3': u'b3',
u'2': u'u2',
u'3': u'u3',
}
SPECIAL_OBJECTS[__import__(BUILTINS_NAME)] = u"builtin_module"
SPECIAL_OBJECTS[str] = u"unicode"
SPECIAL_OBJECTS[bytes] = u"bytes"
#Store wrapped versions of special objects, so that they compare correctly.
tmp = {}
for obj, name in SPECIAL_OBJECTS.items():
tmp[_CObject(obj)] = name
SPECIAL_OBJECTS = tmp
del tmp
#List of various attributes VM implementation details we want to skip.
SKIPLIST = set([
(sys, "exc_value"),
(sys, "exc_type"),
(sys, "exc_traceback"),
(__import__(BUILTINS_NAME), "_"),
])
def qualified_type_name(cls):
#Special case bytes/str/unicode to make sure they share names across versions
if cls is bytes:
return u"bytes"
if cls is str:
return u"unicode"
if cls.__module__ == BUILTINS_NAME or cls.__module__ == "exceptions":
return cls.__name__
else:
return cls.__module__ + "." + cls.__name__

View File

@@ -0,0 +1,450 @@
'''
Prune the flow-graph, eliminating edges with impossible constraints.
For example:
1. if x:
2. if x == 0:
3. pass
The edge from `x == 0` to pass (line 2 to line 3) is impossible as `x` cannot be zero to
reach line 2.
While code like the above is unlikely in source code, it is quite common after splitting.
'''
from semmle.python import ast
import cmath
from collections import defaultdict
from semmle.python.passes.ast_pass import ASTVisitor
import semmle.util as util
from semmle.python.ast import Lt, LtE, Eq, NotEq, Gt, GtE, Is, IsNot
__all__ = [ 'do_pruning' ]
INT_TYPES = int
# Classes representing constraint on branches, for pruning.
# For example, the constraint `x` allows pruning and edge with the constraint `x == 0`
# since if `x` is True it cannot be zero.
class Truthy(object):
'''A test of the form `x` or `not x`'''
def __init__(self, sense):
self.sense = sense
def invert(self):
return (VAR_IS_TRUE, VAR_IS_FALSE)[self.sense]
def contradicts(self, other):
'''Holds if self and other are contradictory.'''
if self.sense:
return other.constrainsVariableToBeFalse()
else:
return other.constrainsVariableToBeTrue()
def constrainsVariableToBeTrue(self):
'''Holds if this constrains the variable such that `bool(var) is True`'''
return self.sense
def constrainsVariableToBeFalse(self):
'''Holds if this constrains the variable such that `bool(var) is False`'''
return not self.sense
def __repr__(self):
return "True" if self.sense else "False"
class IsNone(object):
'''A test of the form `x is None` or `x is not None`'''
def __init__(self, sense):
self.sense = sense
def contradicts(self, other):
if self is VAR_IS_NONE:
return other is VAR_IS_NOT_NONE or other is VAR_IS_TRUE
else:
return other is VAR_IS_NONE
def invert(self):
return (VAR_IS_NONE, VAR_IS_NOT_NONE)[self.sense]
def constrainsVariableToBeTrue(self):
return False
def constrainsVariableToBeFalse(self):
return self is VAR_IS_NONE
def __repr__(self):
return "Is None" if self.sense else "Is Not None"
class ComparedToConst(object):
'''A test of the form `x == k`, `x < k`, etc.'''
def __init__(self, op, k):
#We can treat is/is not as ==/!= as we only
#compare with simple literals which are always interned.
if op is Is:
op = Eq
elif op is IsNot:
op = NotEq
self.op = op
self.k = k
def invert(self):
return ComparedToConst(INVERT_OP[self.op], self.k)
def constrainsVariableToBeTrue(self):
if self.op == Eq:
return self.k != 0
if self.op == NotEq:
return self.k == 0
if self.op == GtE:
return self.k > 0
if self.op == Gt:
return self.k >= 0
if self.op == LtE:
return self.k < 0
if self.op == Lt:
return self.k <= 0
return False
def constrainsVariableToBeFalse(self):
return self.op == Eq and self.k == 0
def contradicts(self, other):
if self.constrainsVariableToBeTrue() and other is VAR_IS_FALSE:
return True
if self.constrainsVariableToBeFalse() and other is VAR_IS_TRUE:
return True
if self.op == Eq and other is VAR_IS_NONE:
return True
if not isinstance(other, ComparedToConst):
return False
if self.op == Eq:
if other.op == NotEq:
return self.k == other.k
if other.op == Eq:
return self.k != other.k
if other.op == Lt:
return self.k >= other.k
if other.op == LtE:
return self.k > other.k
if other.op == Gt:
return self.k <= other.k
if other.op == GtE:
return self.k < other.k
return False
if self.op == Lt:
if other.op == Eq or other.op == Gt or other.op == GtE:
return self.k <= other.k
return False
if self.op == LtE:
if other.op == Eq or other.op == GtE:
return self.k < other.k
if other.op == Gt:
return self.k <= other.k
return False
if other.op in (NotEq, Gt, GtE):
return False
return other.contradicts(self)
def __repr__(self):
return "%s %d" % (OP_NAME[self.op], self.k)
INVERT_OP = {
Eq: NotEq,
NotEq: Eq,
Lt: GtE,
LtE: Gt,
Gt: LtE,
GtE: Lt
}
OP_NAME = {
Eq: "==",
NotEq: "!=",
Lt: "<",
LtE: "<=",
Gt: ">",
GtE: ">=",
}
VAR_IS_TRUE = Truthy(True)
VAR_IS_FALSE = Truthy(False)
VAR_IS_NONE = IsNone(True)
VAR_IS_NOT_NONE = IsNone(False)
NAME_CONSTS = {
"True" : VAR_IS_TRUE,
"False": VAR_IS_FALSE,
"None": VAR_IS_NONE,
}
class SkippedVisitor(ASTVisitor):
def __init__(self):
self.nodes = set()
def visit_Subscript(self, node):
if isinstance(node.value, ast.Name):
self.nodes.add(node.value)
def visit_Attribute(self, node):
if isinstance(node.value, ast.Name):
self.nodes.add(node.value)
class NonlocalVisitor(ASTVisitor):
def __init__(self):
self.names = set()
def visit_Nonlocal(self, node):
for name in node.names:
self.names.add(name)
class GlobalVisitor(ASTVisitor):
def __init__(self):
self.names = set()
def visit_Global(self, node):
for name in node.names:
self.names.add(name)
class KeptVisitor(ASTVisitor):
def __init__(self):
self.nodes = set()
#Keep imports
def visit_alias(self, node):
bool_const = const_value(node.value)
if bool_const is None:
return
defn = node.asname
if hasattr(defn, 'variable'):
self.nodes.add(defn)
def skipped_variables(tree, graph, use_map):
'''Returns a collection of SsaVariables that
are skipped as possibly mutated.
Variables are skipped if their values may be mutated
in such a way that it might alter their boolean value.
This means that they have an attribute accessed, or are subscripted.
However, modules are always true, so are never skipped.
'''
variables = use_map.values()
skiplist = set()
v = SkippedVisitor()
v.visit(tree)
ast_skiplist = v.nodes
for node, var in use_map.items():
if node.node in ast_skiplist:
skiplist.add(var)
v = KeptVisitor()
v.visit(tree)
ast_keeplist = v.nodes
keeplist = set()
for var in variables:
defn = graph.get_ssa_definition(var)
if defn and defn.node in ast_keeplist:
keeplist.add(var)
return skiplist - keeplist
def get_branching_edges(tree, graph, use_map):
''''Returns an iterator of pred, succ, var, bool tuples
representing edges and the boolean value or None-ness that
the ssa variable holds on that edge.
'''
for pred, succ, ann in graph.edges():
if ann not in (util.TRUE_EDGE, util.FALSE_EDGE):
continue
#Handle 'not' expressions.
invert = ann == util.FALSE_EDGE
test = pred
while isinstance(test.node, ast.UnaryOp):
if not isinstance(test.node.op, ast.Not):
break
preds = graph.pred[test]
if len(preds) != 1:
break
test = preds[0]
invert = not invert
t = comparison_kind(graph, test)
if t is None:
continue
val, use = t
if invert:
val = val.invert()
if use in use_map:
yield pred, succ, use_map[use], val
def effective_constants_definitions(bool_const_defns, graph, branching_edges):
'''Returns a mapping of var -> list of (node, effective-constant)
representing the effective boolean constant definitions.
A constant definition is an assignment to a
SSA variable 'var' such that bool(var) is a constant
for all uses of that variable dominated by the definition.
A (SSA) variable is effectively constant if it assigned
a constant, or it is guarded by a test.
'''
consts = defaultdict(list)
for var in graph.ssa_variables():
defn = graph.get_ssa_definition(var)
if not defn or defn.node not in bool_const_defns:
continue
consts[var].append((defn, bool_const_defns[defn.node]))
for pred, succ, var, bval in branching_edges:
if len(graph.pred[succ]) != 1:
continue
consts[var].append((succ, bval))
return consts
def do_pruning(tree, graph):
v = BoolConstVisitor()
v.visit(tree)
nonlocals = NonlocalVisitor()
nonlocals.visit(tree)
global_vars = GlobalVisitor()
global_vars.visit(tree)
bool_const_defns = v.const_defns
#Need to repeatedly do this until we reach a fixed point
while True:
use_map = {}
for node, var in graph.ssa_uses():
if isinstance(node.node, ast.Name):
use_map[node] = var
skiplist = skipped_variables(tree, graph, use_map)
edges = list(get_branching_edges(tree, graph, use_map))
consts = effective_constants_definitions(bool_const_defns, graph, edges)
dominated_by = {}
#Look for effectively constant definitions that dominate edges on
#which the relative variable has the inverse sense.
#Put edges to be removed in a set, as an edge could be removed for
#multiple reasons.
to_be_removed = set()
for pred, succ, var, bval in edges:
if var not in consts:
continue
if var in skiplist and bval in (VAR_IS_TRUE, VAR_IS_FALSE):
continue
if var.variable.id in nonlocals.names:
continue
if var.variable.id in global_vars.names:
continue
for defn, const_kind in consts[var]:
if not const_kind.contradicts(bval):
continue
if defn not in dominated_by:
dominated_by[defn] = graph.dominated_by(defn)
if pred in dominated_by[defn]:
to_be_removed.add((pred, succ))
#Delete simply dead edges (like `if False: ...` )
for pred, succ, ann in graph.edges():
if ann == util.TRUE_EDGE:
val = VAR_IS_TRUE
elif ann == util.FALSE_EDGE:
val = VAR_IS_FALSE
else:
continue
b = const_value(pred.node)
if b is None:
continue
if b.contradicts(val):
to_be_removed.add((pred, succ))
if not to_be_removed:
break
for pred, succ in to_be_removed:
graph.remove_edge(pred, succ)
graph.clear_computed()
class BoolConstVisitor(ASTVisitor):
'''Look for assignments of a boolean constant to a variable.
self.const_defns holds a mapping from the AST node for the definition
to the True/False value for the constant.'''
def __init__(self):
self.const_defns = {}
def visit_alias(self, node):
bool_const = const_value(node.value)
if bool_const is None:
return
defn = node.asname
if hasattr(defn, 'variable'):
self.const_defns[defn] = bool_const
def visit_Assign(self, node):
bool_const = const_value(node.value)
if bool_const is None:
return
for defn in node.targets:
if hasattr(defn, 'variable'):
self.const_defns[defn] = bool_const
def _comparison(test):
# Comparisons to None or ints
if isinstance(test, ast.Compare) and len(test.ops) == 1:
left = test.left
right = test.comparators[0]
if not hasattr(left, "variable"):
return None
if isinstance(right, ast.Name) and right.id == "None":
if isinstance(test.ops[0], ast.Is):
return VAR_IS_NONE
if isinstance(test.ops[0], ast.IsNot):
return VAR_IS_NOT_NONE
if isinstance(right, ast.Num) and isinstance(right.n, INT_TYPES):
return ComparedToConst(type(test.ops[0]), right.n)
return None
def comparison_kind(graph, test):
# Comparisons to None or ints
val = _comparison(test.node)
if val is None:
if hasattr(test.node, "variable"):
return VAR_IS_TRUE, test
return None
use_set = graph.pred[graph.pred[test][0]]
if len(use_set) != 1:
return None
use = use_set[0]
return val, use
def const_value(ast_node):
'''Returns the boolean value of a boolean or numeric constant AST node or None if not a constant.
NaN is not a constant.'''
if isinstance(ast_node, ast.Name):
if ast_node.id in ("True", "False", "None"):
return NAME_CONSTS[ast_node.id]
else:
return None
if isinstance(ast_node, ast.ImportExpr):
#Modules always evaluate True
return VAR_IS_TRUE
if isinstance(ast_node, ast.Num):
n = ast_node.n
elif isinstance(ast_node, ast.UnaryOp):
if isinstance(ast_node.op, ast.USub) and isinstance(ast_node.operand, ast.Num):
n = ast_node.operand.n
elif isinstance(ast_node.op, ast.Not):
not_value = const_value(ast_node.operand)
if not_value is None:
return None
return not_value.invert()
else:
return None
else:
return None
#Check for NaN, but be careful not to overflow
#Handle integers first as they may overflow cmath.isnan()
if not isinstance(n, INT_TYPES) and cmath.isnan(n):
return None
#Now have an int or a normal float or complex
return ComparedToConst(Eq, n)

View File

@@ -0,0 +1,384 @@
'''
Split the flow-graph to allow tests to dominate all parts of the code that depends on them.
We split on `if`s and `try`s. Either because of several tests on the same condition or
subsequent tests on a constant determined by the first condition.
E.g.
if a:
A
B
if a:
C
becomes
if a:
A
B
C
else:
B
ensuring that A dominates C.
or...
try:
import foo
except:
foo = None
X
if foo:
Y
becomes
try:
import foo
X
Y
except:
foo = None
X
To split on CFG node N we require that there exists nodes H1..Hn and N2 such that:
N and N2 are tests or conditional assignments to the same variable.
N dominates H1 .. Hn and N2
There is no assignment to the variable between N and N2
H1..Hn are the "split heads" of N, that is:
if N is a test, H1 and H2 are its true and false successors (there is no H3).
if N is a `try` then H1 .. Hn-1 are exists from the try body and Hn is the CFG node for the first (and only) `except` statement.
Within the region strictly dominated by N, N2 must reachable from all of H1..Hn
For simplicity we limit n (as in Hn) to 2, but that is not required for correctness.
'''
from collections import defaultdict
from semmle.python import ast
from semmle.python.passes.ast_pass import iter_fields
from operator import itemgetter
from semmle.graph import FlowGraph
MAX_SPLITS = 2
def do_split(ast_root, graph: FlowGraph):
'''Split the flow graph, using the AST to determine split points.'''
ast_labels = label_ast(ast_root)
cfg_labels = label_cfg(graph, ast_labels)
split_points = choose_split_points(graph, cfg_labels)
graph.split(split_points)
class ScopedAstLabellingVisitor(object):
'''Visitor for labelling AST nodes in scope.
Does not visit nodes belonging to inner scopes (methods, etc)
'''
def __init__(self, labels):
self.labels = labels
self.priority = 0
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
getattr(self, method, self.generic_visit)(node)
def generic_visit(self, node):
if isinstance(node, ast.AstBase):
for _, _, value in iter_fields(node):
self.visit(value)
def visit_Class(self, node):
#Do not visit sub-scopes
return
visit_Function = visit_Class
def visit_list(self, the_list):
for item in the_list:
method = 'visit_' + item.__class__.__name__
getattr(self, method, self.generic_visit)(item)
#Helper methods
@staticmethod
def get_variable(expr):
'''Returns the variable of this expr. Returns None if no variable.'''
if hasattr(expr, "variable"):
return expr.variable
else:
return None
@staticmethod
def is_const(expr):
if isinstance(expr, ast.Name):
return expr.variable.id in ("None", "True", "False")
elif isinstance(expr, ast.UnaryOp):
return ScopedAstLabellingVisitor.is_const(expr.operand)
return isinstance(expr, (ast.Num, ast.Str))
class AstLabeller(ScopedAstLabellingVisitor):
'''Visitor to label tests and assignments
for later scanning to determine split points.
'''
def __init__(self, *args):
ScopedAstLabellingVisitor.__init__(self, *args)
self.in_test = 0
def _label_for_compare(self, cmp):
if len(cmp.ops) != 1:
return None
var = self.get_variable(cmp.left)
if var is None:
var = self.get_variable(cmp.comparators[0])
k = cmp.left
else:
k = cmp.comparators[0]
if var is not None and self.is_const(k):
self.priority += 1
return (var, k, self.priority)
return None
def visit_Compare(self, cmp):
label = self._label_for_compare(cmp)
if label:
self.labels[cmp].append(label)
def visit_Name(self, name):
self.priority += 1
if isinstance(name.ctx, ast.Store):
self.labels[name].append((name.variable, "assign", self.priority))
elif self.in_test:
self.labels[name].append((name.variable, None, self.priority))
def _label_for_unary_operand(self, op):
if not isinstance(op.op, ast.Not):
return None
if isinstance(op.operand, ast.UnaryOp):
return self._label_for_unary_operand(op.operand)
elif isinstance(op.operand, ast.Name):
self.priority += 1
return (op.operand.variable, None, self.priority)
elif isinstance(op.operand, ast.Compare):
return self._label_for_compare(op.operand)
return None
def visit_UnaryOp(self, op):
if not self.in_test:
return
label = self._label_for_unary_operand(op)
if label:
self.labels[op].append(label)
else:
self.visit(op.operand)
def visit_If(self, ifstmt):
# Looking for the pattern:
# if x: k = K0 else: k = K1
# the test is the split point, but the variable is `k`
self.in_test += 1
self.visit(ifstmt.test)
self.in_test -= 1
self.visit(ifstmt.body)
self.visit(ifstmt.orelse)
k1 = {}
ConstantAssignmentVisitor(k1).visit(ifstmt.body)
k2 = {}
ConstantAssignmentVisitor(k2).visit(ifstmt.orelse)
k = set(k1.keys()).union(k2.keys())
self.priority += 1
for var in k:
val = k1[var] if var in k1 else k2[var]
self.labels[ifstmt.test].append((var, val, self.priority))
def visit_Try(self, stmt):
# Looking for the pattern:
# if try: k = K0 except: k = K1
# the try is the split point, and the variable is `k`
self.generic_visit(stmt)
if not stmt.handlers or len(stmt.handlers) > 1:
return
k1 = {}
ConstantAssignmentVisitor(k1).visit(stmt.body)
k2 = {}
ConstantAssignmentVisitor(k2).visit(stmt.handlers[0])
k = set(k1.keys()).union(k2.keys())
self.priority += 1
for var in k:
val = k1[var] if var in k1 else k2[var]
self.labels[stmt].append((var, val, self.priority))
def visit_ClassExpr(self, node):
# Don't split over class definitions,
# as the presence of multiple ClassObjects for a
# single class can be confusing.
# The same applies to function definitions.
self.priority += 1
self.labels[node].append((None, "define", self.priority))
visit_FunctionExpr = visit_ClassExpr
class TryBodyAndHandlerVisitor(ScopedAstLabellingVisitor):
'''Visitor to gather all AST nodes under visited node
including, but not under `ExceptStmt`s.'''
def generic_visit(self, node):
if isinstance(node, ast.AstBase):
self.labels.add(node)
for _, _, value in iter_fields(node):
self.visit(value)
def visit_ExceptStmt(self, node):
#Do not visit node below this.
self.labels.add(node)
return
class ConstantAssignmentVisitor(ScopedAstLabellingVisitor):
'''Visitor to label assignments where RHS is a constant'''
def visit_Assign(self, asgn):
if not self.is_const(asgn.value):
return
for target in asgn.targets:
if hasattr(target, "variable"):
self.labels[target.variable] = asgn.value
def label_ast(ast_root):
'''Visits the AST, returning the labels'''
labels = defaultdict(list)
labeller = AstLabeller(labels)
labeller.generic_visit(ast_root)
return labels
def _is_branch(node, graph: FlowGraph):
'''Holds if `node` (in `graph`) is a branch point.'''
if len(graph.succ[node]) == 2 or isinstance(node.node, ast.Try):
return True
if len(graph.succ[node]) != 1:
return False
succ = graph.succ[node][0]
if not isinstance(succ.node, ast.UnaryOp):
return False
return _is_branch(succ, graph)
def label_cfg(graph: FlowGraph, ast_labels):
'''Copies labels from AST to CFG for branches and assignments.'''
cfg_labels = {}
for node, _ in graph.nodes():
if node.node not in ast_labels:
continue
labels = ast_labels[node.node]
if not labels:
continue
if _is_branch(node, graph) or labels[0][1] in ("assign", "define", "loop"):
cfg_labels[node] = labels
return cfg_labels
def usefully_comparable_types(o1, o2):
'''Holds if a test against object o1 can provide any
meaningful information w.r.t. to a test against o2.
'''
if o1 is None or o2 is None:
return True
return type(o1) is type(o2)
def exits_from_subtree(head, subtree, graph: FlowGraph):
'''Returns all nodes in `subtree`, that exit
the subtree and are reachable from `head`
'''
exits = set()
seen = set()
todo = set([head])
while todo:
node = todo.pop()
if node in seen:
continue
seen.add(node)
if not graph.succ[node]:
continue
is_exit = True
for succ in graph.succ[node]:
if succ.node in subtree:
todo.add(succ)
is_exit = False
if is_exit:
exits.add(node)
return exits
def get_split_heads(head, graph: FlowGraph):
'''Compute the split tails for the node `head`
That is, the set of nodes from which splitting should commence.
'''
if isinstance(head.node, ast.Try):
try_body = set()
TryBodyAndHandlerVisitor(try_body).visit(head.node)
if head.node.handlers:
try_body.add(head.node.handlers[0])
try_split_tails = exits_from_subtree(head, try_body, graph)
return try_split_tails
else:
return graph.succ[head]
def choose_split_points(graph: FlowGraph, cfg_labels):
'''Select the set of nodes to be the split heads for the graph,
from the given labels. A maximum of two points are chosen to avoid
excessive blow up.
'''
candidates = []
#Find pairs -- N1, N2 where N1 and N2 are tests on the same variable and the tests are similar.
labels = []
for node, label_list in cfg_labels.items():
for label in label_list:
labels.append((node, label[0], label[1], label[2]))
labels.sort(key=itemgetter(3))
for first_node, first_var, first_type, first_priority in labels:
if first_type in ("assign", "define"):
continue
#Avoid splitting if any class or function is defined later in scope.
if 'define' in [type for (_, _, type, priority) in labels if priority > first_priority]:
break
for second_node, second_var, second_type, second_priority in labels:
if second_var != first_var:
continue
# First node must dominate second node to be a viable splitting candidate.
# Quick check to avoid doing pointless dominance checks.
if first_priority >= second_priority:
continue
#Avoid splitting if variable is reassigned
if second_type == "assign":
break
if not graph.strictly_dominates(first_node, second_node):
continue
if not usefully_comparable_types(first_type, second_type):
continue
split_heads = get_split_heads(first_node, graph)
if len(split_heads) != 2:
continue
# Unless both of the split heads reach the second node,
# then there is no benefit to splitting.
for head in split_heads:
if not graph.strictly_dominates(first_node, head):
break
if not graph.reaches_while_dominated(head, second_node, first_node):
break
else:
candidates.append((first_node, split_heads, first_var, first_priority))
#Candidates is a list of (node, split-heads, variable, priority) tuples.
#Remove any duplicate nodes
candidates = deduplicate(candidates, 0, 3)
#Remove repeated splits on the same variable if more than MAX_SPLITS split and more than one variable.
if len(candidates) > MAX_SPLITS and len({c[2] for c in candidates}) > 1:
candidates = deduplicate(candidates, 2, 3)
# Return best two results, but must return in reverse priority order,
# so that splitting on one node does not remove a later one.
return [c[:2] for c in candidates[MAX_SPLITS-1::-1]]
def deduplicate(lst, col, sort_col):
'''De-duplicate list `lst` of tuples removing all but the first tuple containing
duplicates of `col`. Sort the result on `sort_col'''
dedupped = {}
for t in reversed(lst):
dedupped[t[col]] = t
return sorted(dedupped.values(), key=itemgetter(sort_col))

View File

@@ -0,0 +1,181 @@
'''
Unroll loops in the flow-graph once if we know that the iterator is not empty.
E.g.
if seq:
for x in seq:
y = x
y # y is defined here
or
if not seq:
raise
for x in seq:
y = x
y # y is defined here
This is broadly analagous to splitting.
If the edge leaving the test that signifies a non-empty container dominates the loop, then we want to unroll the loop once.
Loop unrolling will transform
A (loop header), B (loop body) --> A(first loop header), B(first loop body), C(second loop header), D(second loop body)
and is done as follows:
Make a copy of A as C and make a copy of B as D.
Convert all edges from B to A into edges from B to C.
Convert edge from C to B to an edge from C to D.
Convert all edges from D to A into edges from D to C.
Subsequent pruning will then remove any dead edges for iterables known to be empty or non-empty.
'''
from collections import defaultdict, namedtuple
from operator import itemgetter
from semmle.python import ast
from semmle.python.passes.splitter import ScopedAstLabellingVisitor, label_cfg
from semmle.util import EXHAUSTED_EDGE
class HasDefinitionInLoop(ScopedAstLabellingVisitor):
'''Check to see if a class or function definition occurs
in a loop. Note that this will prevent unrolling of a loop
if a definition occurs in any loop in scope, not just the one
to be unrolled.
'''
def __init__(self):
ScopedAstLabellingVisitor.__init__(self, None)
self.has_definition = False
self.in_loopbody = False
def visit_For(self, loop):
self.visit(loop.iter)
self.in_loopbody = True
self.visit(loop.body)
self.in_loopbody = False
def visit_ClassExpr(self, node):
# Don't split over class definitions,
# as the presence of multiple ClassObjects for a
# single class can be confusing.
# The same applies to function definitions.
if self.in_loopbody:
self.has_definition = True
visit_FunctionExpr = visit_ClassExpr
def __bool__(self):
return self.has_definition
AstLabel = namedtuple("AstLabel", "variable type priority")
CfgLabel = namedtuple("CfgLabel", "node variable type priority")
class Labeller(ScopedAstLabellingVisitor):
def __init__(self, *args):
ScopedAstLabellingVisitor.__init__(self, *args)
self.in_test = 0
self.in_loop = False
def visit_If(self, ifstmt):
# Looking for tests for empty sequences.
self.in_test += 1
self.visit(ifstmt.test)
self.in_test -= 1
self.visit(ifstmt.body)
self.visit(ifstmt.orelse)
def visit_UnaryOp(self, op):
if not self.in_test:
return
label = self._label_for_unary_operand(op)
if label:
self.labels[op].append(label)
else:
self.visit(op.operand)
def _label_for_unary_operand(self, op):
if not isinstance(op.op, ast.Not):
return None
if isinstance(op.operand, ast.UnaryOp):
return self._label_for_unary_operand(op.operand)
elif isinstance(op.operand, ast.Name):
self.priority += 1
return AstLabel(op.operand.variable, "test", self.priority)
elif isinstance(op.operand, ast.Call):
return self._label_for_call(op.operand)
return None
def visit_Call(self, call):
if not self.in_test:
return
label = self._label_for_call(call)
if label:
self.labels[call].append(label)
return
def _label_for_call(self, call):
#TO DO -- Check for calls to len()
pass
def visit_For(self, loop):
self.in_loop = True
self.visit(loop.iter)
self.in_loop = False
self.visit(loop.body)
def visit_Name(self, name):
self.priority += 1
if self.in_test:
self.labels[name].append(AstLabel(name.variable, "test", self.priority))
elif self.in_loop:
self.labels[name].append(AstLabel(name.variable, "loop", self.priority))
def do_unrolling(ast_root, graph):
#Avoid unrolling if any class or function is defined in a loop.
hasdef = HasDefinitionInLoop()
hasdef.generic_visit(ast_root)
if hasdef:
return
ast_labels = defaultdict(list)
labeller = Labeller(ast_labels)
labeller.generic_visit(ast_root)
cfg_labels = label_cfg(graph, ast_labels)
unrolls = choose_loops_to_unroll(graph, cfg_labels)
for head, body in unrolls:
graph.unroll(head, body)
def choose_loops_to_unroll(graph, cfg_labels):
'''Select the set of nodes to unroll.'''
candidates = []
#Find pairs -- N1, N2 where N1 is a test on the variable and N2 is a loop over it.
labels = []
for node, label_list in cfg_labels.items():
for label in label_list:
labels.append(CfgLabel(node, label.variable, label.type, label.priority))
labels.sort(key=itemgetter(3))
for first_node, first_var, first_type, first_priority in labels:
if first_type == "loop":
continue
for second_node, second_var, second_type, second_priority in labels:
if second_var != first_var:
continue
# First node must dominate second node to be a viable unrolling candidate.
# Quick check to avoid doing pointless dominance checks.
if first_priority >= second_priority:
continue
#Avoid if second use is not a loop
if second_type != "loop":
continue
if not graph.strictly_dominates(first_node, second_node):
continue
candidates.append((second_node, second_priority))
iters = reversed([c for c, p in sorted(candidates, key=itemgetter(1))])
result = []
for iter in iters:
head = graph.succ[iter][0]
for body in graph.succ[head]:
if graph.edge_annotations[head, body] != EXHAUSTED_EDGE:
result.append((head, body))
break
return result