mirror of
https://github.com/github/codeql.git
synced 2026-01-16 16:04:45 +01:00
385 lines
13 KiB
Python
Executable File
385 lines
13 KiB
Python
Executable File
'''
|
|
Split the flow-graph to allow tests to dominate all parts of the code that depends on them.
|
|
We split on `if`s and `try`s. Either because of several tests on the same condition or
|
|
subsequent tests on a constant determined by the first condition.
|
|
E.g.
|
|
|
|
if a:
|
|
A
|
|
B
|
|
if a:
|
|
C
|
|
becomes
|
|
if a:
|
|
A
|
|
B
|
|
C
|
|
else:
|
|
B
|
|
ensuring that A dominates C.
|
|
|
|
or...
|
|
|
|
try:
|
|
import foo
|
|
except:
|
|
foo = None
|
|
X
|
|
if foo:
|
|
Y
|
|
becomes
|
|
try:
|
|
import foo
|
|
X
|
|
Y
|
|
except:
|
|
foo = None
|
|
X
|
|
|
|
To split on CFG node N we require that there exists nodes H1..Hn and N2 such that:
|
|
N and N2 are tests or conditional assignments to the same variable.
|
|
N dominates H1 .. Hn and N2
|
|
There is no assignment to the variable between N and N2
|
|
H1..Hn are the "split heads" of N, that is:
|
|
if N is a test, H1 and H2 are its true and false successors (there is no H3).
|
|
if N is a `try` then H1 .. Hn-1 are exists from the try body and Hn is the CFG node for the first (and only) `except` statement.
|
|
Within the region strictly dominated by N, N2 must reachable from all of H1..Hn
|
|
|
|
For simplicity we limit n (as in Hn) to 2, but that is not required for correctness.
|
|
'''
|
|
|
|
from collections import defaultdict
|
|
|
|
from semmle.python import ast
|
|
from semmle.python.passes.ast_pass import iter_fields
|
|
from operator import itemgetter
|
|
from semmle.graph import FlowGraph
|
|
|
|
MAX_SPLITS = 2
|
|
|
|
def do_split(ast_root, graph: FlowGraph):
|
|
'''Split the flow graph, using the AST to determine split points.'''
|
|
ast_labels = label_ast(ast_root)
|
|
cfg_labels = label_cfg(graph, ast_labels)
|
|
split_points = choose_split_points(graph, cfg_labels)
|
|
graph.split(split_points)
|
|
|
|
class ScopedAstLabellingVisitor(object):
|
|
'''Visitor for labelling AST nodes in scope.
|
|
Does not visit nodes belonging to inner scopes (methods, etc)
|
|
'''
|
|
|
|
def __init__(self, labels):
|
|
self.labels = labels
|
|
self.priority = 0
|
|
|
|
def visit(self, node):
|
|
"""Visit a node."""
|
|
method = 'visit_' + node.__class__.__name__
|
|
getattr(self, method, self.generic_visit)(node)
|
|
|
|
def generic_visit(self, node):
|
|
if isinstance(node, ast.AstBase):
|
|
for _, _, value in iter_fields(node):
|
|
self.visit(value)
|
|
|
|
def visit_Class(self, node):
|
|
#Do not visit sub-scopes
|
|
return
|
|
|
|
visit_Function = visit_Class
|
|
|
|
def visit_list(self, the_list):
|
|
for item in the_list:
|
|
method = 'visit_' + item.__class__.__name__
|
|
getattr(self, method, self.generic_visit)(item)
|
|
|
|
#Helper methods
|
|
|
|
@staticmethod
|
|
def get_variable(expr):
|
|
'''Returns the variable of this expr. Returns None if no variable.'''
|
|
if hasattr(expr, "variable"):
|
|
return expr.variable
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def is_const(expr):
|
|
if isinstance(expr, ast.Name):
|
|
return expr.variable.id in ("None", "True", "False")
|
|
elif isinstance(expr, ast.UnaryOp):
|
|
return ScopedAstLabellingVisitor.is_const(expr.operand)
|
|
return isinstance(expr, (ast.Num, ast.Str))
|
|
|
|
|
|
|
|
class AstLabeller(ScopedAstLabellingVisitor):
|
|
'''Visitor to label tests and assignments
|
|
for later scanning to determine split points.
|
|
'''
|
|
|
|
def __init__(self, *args):
|
|
ScopedAstLabellingVisitor.__init__(self, *args)
|
|
self.in_test = 0
|
|
|
|
def _label_for_compare(self, cmp):
|
|
if len(cmp.ops) != 1:
|
|
return None
|
|
var = self.get_variable(cmp.left)
|
|
if var is None:
|
|
var = self.get_variable(cmp.comparators[0])
|
|
k = cmp.left
|
|
else:
|
|
k = cmp.comparators[0]
|
|
if var is not None and self.is_const(k):
|
|
self.priority += 1
|
|
return (var, k, self.priority)
|
|
return None
|
|
|
|
def visit_Compare(self, cmp):
|
|
label = self._label_for_compare(cmp)
|
|
if label:
|
|
self.labels[cmp].append(label)
|
|
|
|
def visit_Name(self, name):
|
|
self.priority += 1
|
|
if isinstance(name.ctx, ast.Store):
|
|
self.labels[name].append((name.variable, "assign", self.priority))
|
|
elif self.in_test:
|
|
self.labels[name].append((name.variable, None, self.priority))
|
|
|
|
def _label_for_unary_operand(self, op):
|
|
if not isinstance(op.op, ast.Not):
|
|
return None
|
|
if isinstance(op.operand, ast.UnaryOp):
|
|
return self._label_for_unary_operand(op.operand)
|
|
elif isinstance(op.operand, ast.Name):
|
|
self.priority += 1
|
|
return (op.operand.variable, None, self.priority)
|
|
elif isinstance(op.operand, ast.Compare):
|
|
return self._label_for_compare(op.operand)
|
|
return None
|
|
|
|
def visit_UnaryOp(self, op):
|
|
if not self.in_test:
|
|
return
|
|
label = self._label_for_unary_operand(op)
|
|
if label:
|
|
self.labels[op].append(label)
|
|
else:
|
|
self.visit(op.operand)
|
|
|
|
def visit_If(self, ifstmt):
|
|
# Looking for the pattern:
|
|
# if x: k = K0 else: k = K1
|
|
# the test is the split point, but the variable is `k`
|
|
self.in_test += 1
|
|
self.visit(ifstmt.test)
|
|
self.in_test -= 1
|
|
self.visit(ifstmt.body)
|
|
self.visit(ifstmt.orelse)
|
|
k1 = {}
|
|
ConstantAssignmentVisitor(k1).visit(ifstmt.body)
|
|
k2 = {}
|
|
ConstantAssignmentVisitor(k2).visit(ifstmt.orelse)
|
|
k = set(k1.keys()).union(k2.keys())
|
|
self.priority += 1
|
|
for var in k:
|
|
val = k1[var] if var in k1 else k2[var]
|
|
self.labels[ifstmt.test].append((var, val, self.priority))
|
|
|
|
def visit_Try(self, stmt):
|
|
# Looking for the pattern:
|
|
# if try: k = K0 except: k = K1
|
|
# the try is the split point, and the variable is `k`
|
|
self.generic_visit(stmt)
|
|
if not stmt.handlers or len(stmt.handlers) > 1:
|
|
return
|
|
k1 = {}
|
|
ConstantAssignmentVisitor(k1).visit(stmt.body)
|
|
k2 = {}
|
|
ConstantAssignmentVisitor(k2).visit(stmt.handlers[0])
|
|
k = set(k1.keys()).union(k2.keys())
|
|
self.priority += 1
|
|
for var in k:
|
|
val = k1[var] if var in k1 else k2[var]
|
|
self.labels[stmt].append((var, val, self.priority))
|
|
|
|
def visit_ClassExpr(self, node):
|
|
# Don't split over class definitions,
|
|
# as the presence of multiple ClassObjects for a
|
|
# single class can be confusing.
|
|
# The same applies to function definitions.
|
|
self.priority += 1
|
|
self.labels[node].append((None, "define", self.priority))
|
|
|
|
visit_FunctionExpr = visit_ClassExpr
|
|
|
|
|
|
class TryBodyAndHandlerVisitor(ScopedAstLabellingVisitor):
|
|
'''Visitor to gather all AST nodes under visited node
|
|
including, but not under `ExceptStmt`s.'''
|
|
|
|
def generic_visit(self, node):
|
|
if isinstance(node, ast.AstBase):
|
|
self.labels.add(node)
|
|
for _, _, value in iter_fields(node):
|
|
self.visit(value)
|
|
|
|
def visit_ExceptStmt(self, node):
|
|
#Do not visit node below this.
|
|
self.labels.add(node)
|
|
return
|
|
|
|
|
|
class ConstantAssignmentVisitor(ScopedAstLabellingVisitor):
|
|
'''Visitor to label assignments where RHS is a constant'''
|
|
|
|
def visit_Assign(self, asgn):
|
|
if not self.is_const(asgn.value):
|
|
return
|
|
for target in asgn.targets:
|
|
if hasattr(target, "variable"):
|
|
self.labels[target.variable] = asgn.value
|
|
|
|
def label_ast(ast_root):
|
|
'''Visits the AST, returning the labels'''
|
|
labels = defaultdict(list)
|
|
labeller = AstLabeller(labels)
|
|
labeller.generic_visit(ast_root)
|
|
return labels
|
|
|
|
def _is_branch(node, graph: FlowGraph):
|
|
'''Holds if `node` (in `graph`) is a branch point.'''
|
|
if len(graph.succ[node]) == 2 or isinstance(node.node, ast.Try):
|
|
return True
|
|
if len(graph.succ[node]) != 1:
|
|
return False
|
|
succ = graph.succ[node][0]
|
|
if not isinstance(succ.node, ast.UnaryOp):
|
|
return False
|
|
return _is_branch(succ, graph)
|
|
|
|
|
|
def label_cfg(graph: FlowGraph, ast_labels):
|
|
'''Copies labels from AST to CFG for branches and assignments.'''
|
|
cfg_labels = {}
|
|
for node, _ in graph.nodes():
|
|
if node.node not in ast_labels:
|
|
continue
|
|
labels = ast_labels[node.node]
|
|
if not labels:
|
|
continue
|
|
if _is_branch(node, graph) or labels[0][1] in ("assign", "define", "loop"):
|
|
cfg_labels[node] = labels
|
|
return cfg_labels
|
|
|
|
def usefully_comparable_types(o1, o2):
|
|
'''Holds if a test against object o1 can provide any
|
|
meaningful information w.r.t. to a test against o2.
|
|
'''
|
|
if o1 is None or o2 is None:
|
|
return True
|
|
return type(o1) is type(o2)
|
|
|
|
def exits_from_subtree(head, subtree, graph: FlowGraph):
|
|
'''Returns all nodes in `subtree`, that exit
|
|
the subtree and are reachable from `head`
|
|
'''
|
|
exits = set()
|
|
seen = set()
|
|
todo = set([head])
|
|
while todo:
|
|
node = todo.pop()
|
|
if node in seen:
|
|
continue
|
|
seen.add(node)
|
|
if not graph.succ[node]:
|
|
continue
|
|
is_exit = True
|
|
for succ in graph.succ[node]:
|
|
if succ.node in subtree:
|
|
todo.add(succ)
|
|
is_exit = False
|
|
if is_exit:
|
|
exits.add(node)
|
|
return exits
|
|
|
|
def get_split_heads(head, graph: FlowGraph):
|
|
'''Compute the split tails for the node `head`
|
|
That is, the set of nodes from which splitting should commence.
|
|
'''
|
|
if isinstance(head.node, ast.Try):
|
|
try_body = set()
|
|
TryBodyAndHandlerVisitor(try_body).visit(head.node)
|
|
if head.node.handlers:
|
|
try_body.add(head.node.handlers[0])
|
|
try_split_tails = exits_from_subtree(head, try_body, graph)
|
|
return try_split_tails
|
|
else:
|
|
return graph.succ[head]
|
|
|
|
|
|
def choose_split_points(graph: FlowGraph, cfg_labels):
|
|
'''Select the set of nodes to be the split heads for the graph,
|
|
from the given labels. A maximum of two points are chosen to avoid
|
|
excessive blow up.
|
|
'''
|
|
candidates = []
|
|
#Find pairs -- N1, N2 where N1 and N2 are tests on the same variable and the tests are similar.
|
|
labels = []
|
|
for node, label_list in cfg_labels.items():
|
|
for label in label_list:
|
|
labels.append((node, label[0], label[1], label[2]))
|
|
labels.sort(key=itemgetter(3))
|
|
for first_node, first_var, first_type, first_priority in labels:
|
|
if first_type in ("assign", "define"):
|
|
continue
|
|
#Avoid splitting if any class or function is defined later in scope.
|
|
if 'define' in [type for (_, _, type, priority) in labels if priority > first_priority]:
|
|
break
|
|
for second_node, second_var, second_type, second_priority in labels:
|
|
if second_var != first_var:
|
|
continue
|
|
# First node must dominate second node to be a viable splitting candidate.
|
|
# Quick check to avoid doing pointless dominance checks.
|
|
if first_priority >= second_priority:
|
|
continue
|
|
#Avoid splitting if variable is reassigned
|
|
if second_type == "assign":
|
|
break
|
|
if not graph.strictly_dominates(first_node, second_node):
|
|
continue
|
|
if not usefully_comparable_types(first_type, second_type):
|
|
continue
|
|
split_heads = get_split_heads(first_node, graph)
|
|
if len(split_heads) != 2:
|
|
continue
|
|
# Unless both of the split heads reach the second node,
|
|
# then there is no benefit to splitting.
|
|
for head in split_heads:
|
|
if not graph.strictly_dominates(first_node, head):
|
|
break
|
|
if not graph.reaches_while_dominated(head, second_node, first_node):
|
|
break
|
|
else:
|
|
candidates.append((first_node, split_heads, first_var, first_priority))
|
|
#Candidates is a list of (node, split-heads, variable, priority) tuples.
|
|
#Remove any duplicate nodes
|
|
candidates = deduplicate(candidates, 0, 3)
|
|
#Remove repeated splits on the same variable if more than MAX_SPLITS split and more than one variable.
|
|
if len(candidates) > MAX_SPLITS and len({c[2] for c in candidates}) > 1:
|
|
candidates = deduplicate(candidates, 2, 3)
|
|
# Return best two results, but must return in reverse priority order,
|
|
# so that splitting on one node does not remove a later one.
|
|
return [c[:2] for c in candidates[MAX_SPLITS-1::-1]]
|
|
|
|
def deduplicate(lst, col, sort_col):
|
|
'''De-duplicate list `lst` of tuples removing all but the first tuple containing
|
|
duplicates of `col`. Sort the result on `sort_col'''
|
|
dedupped = {}
|
|
for t in reversed(lst):
|
|
dedupped[t[col]] = t
|
|
return sorted(dedupped.values(), key=itemgetter(sort_col))
|