mirror of
https://github.com/github/codeql.git
synced 2025-12-17 01:03:14 +01:00
314 lines
12 KiB
Python
314 lines
12 KiB
Python
import sys, os
|
|
from collections import deque, defaultdict
|
|
import time
|
|
import multiprocessing as mp
|
|
import json
|
|
|
|
from queue import Empty as _Empty
|
|
from queue import Full as _Full
|
|
|
|
from semmle.extractors import SuperExtractor, ModulePrinter, SkippedBuiltin
|
|
from semmle.profiling import get_profiler
|
|
from semmle.path_rename import renamer_from_options_and_env
|
|
from semmle.logging import WARN, recursion_error_message, internal_error_message, Logger
|
|
|
|
class ExtractorFailure(Exception):
|
|
'Generic exception representing the failure of an extractor.'
|
|
pass
|
|
|
|
|
|
class ModuleImportGraph(object):
|
|
|
|
def __init__(self, max_depth):
|
|
self.modules = {}
|
|
self.succ = defaultdict(set)
|
|
self.todo = set()
|
|
self.done = set()
|
|
self.max_depth = max_depth
|
|
|
|
def add_root(self, mod):
|
|
self.modules[mod] = 0
|
|
if mod not in self.done:
|
|
self.todo.add(mod)
|
|
|
|
def add_import(self, mod, imported):
|
|
assert mod in self.modules
|
|
self.succ[mod].add(imported)
|
|
if imported in self.modules:
|
|
if self.modules[imported] > self.modules[mod] + 1:
|
|
self._reduce_depth(imported, self.modules[mod] + 1)
|
|
else:
|
|
if self.modules[mod] < self.max_depth and imported not in self.done:
|
|
self.todo.add(imported)
|
|
self.modules[imported] = self.modules[mod] + 1
|
|
|
|
def _reduce_depth(self, mod, depth):
|
|
if self.modules[mod] <= depth:
|
|
return
|
|
if depth > self.max_depth:
|
|
return
|
|
if mod not in self.done:
|
|
self.todo.add(mod)
|
|
self.modules[mod] = depth
|
|
for imp in self.succ[mod]:
|
|
self._reduce_depth(imp, depth+1)
|
|
|
|
def get(self):
|
|
mod = self.todo.pop()
|
|
assert not mod in self.done and self.modules[mod] <= self.max_depth
|
|
self.done.add(mod)
|
|
return mod
|
|
|
|
def push_back(self, mod):
|
|
self.done.remove(mod)
|
|
self.todo.add(mod)
|
|
|
|
def empty(self):
|
|
return not self.todo
|
|
|
|
class ExtractorPool(object):
|
|
'''Pool of worker processes running extractors'''
|
|
|
|
def __init__(self, outdir, archive, proc_count, options, logger: Logger):
|
|
if proc_count < 1:
|
|
raise ValueError("Number of processes must be at least one.")
|
|
self.verbose = options.verbose
|
|
self.outdir = outdir
|
|
self.max_import_depth = options.max_import_depth
|
|
# macOS does not support `fork` properly, so we must use `spawn` instead.
|
|
method = 'spawn' if sys.platform == "darwin" else None
|
|
try:
|
|
ctx = mp.get_context(method)
|
|
except AttributeError:
|
|
# `get_context` doesn't exist -- we must be running an old version of Python.
|
|
ctx = mp
|
|
#Keep queue short to minimise delay when stopping
|
|
self.module_queue = ctx.Queue(proc_count*2)
|
|
self.reply_queue = ctx.Queue(proc_count*20)
|
|
self.archive = archive
|
|
self.local_queue = deque()
|
|
self.enqueued = set()
|
|
self.done = set()
|
|
self.requirements = {}
|
|
self.import_graph = ModuleImportGraph(options.max_import_depth)
|
|
logger.debug("Source archive: %s", archive)
|
|
self.logger = logger
|
|
DiagnosticsWriter.create_output_dir()
|
|
args = (self.module_queue, outdir, archive, options, self.reply_queue, logger)
|
|
self.procs = [
|
|
ctx.Process(target=_extract_loop, args=(n+1,) + args + (n == 0,)) for n in range(proc_count)
|
|
]
|
|
for p in self.procs:
|
|
p.start()
|
|
self.start_time = time.time()
|
|
|
|
def extract(self, the_traverser):
|
|
'''Extract all the files from the given traverser,
|
|
and all the imported files up to the depth specified
|
|
by the options.
|
|
'''
|
|
self.logger.trace("Starting traversal")
|
|
for mod in the_traverser:
|
|
self.import_graph.add_root(mod)
|
|
self.try_to_send()
|
|
self.receive(False)
|
|
#Prime the queue
|
|
while self.try_to_send():
|
|
pass
|
|
while self.enqueued or not self.import_graph.empty():
|
|
self.try_to_send()
|
|
self.receive(True)
|
|
|
|
def try_to_send(self):
|
|
if self.import_graph.empty():
|
|
return False
|
|
module = self.import_graph.get()
|
|
try:
|
|
self.module_queue.put(module, False)
|
|
self.enqueued.add(module)
|
|
self.logger.debug("Enqueued %s", module)
|
|
return True
|
|
except _Full:
|
|
self.import_graph.push_back(module)
|
|
return False
|
|
|
|
def receive(self, block=False):
|
|
try:
|
|
what, mod, imp = self.reply_queue.get(block)
|
|
if what == "INTERRUPT":
|
|
self.logger.debug("Main process received interrupt")
|
|
raise KeyboardInterrupt
|
|
elif what == "UNRECOVERABLE_FAILURE":
|
|
raise ExtractorFailure(str(mod))
|
|
elif what == "FAILURE":
|
|
self.enqueued.remove(mod)
|
|
elif what == "SUCCESS":
|
|
self.enqueued.remove(mod)
|
|
else:
|
|
assert what == "IMPORT"
|
|
assert mod is not None
|
|
if imp is None:
|
|
self.logger.warning("Unexpected None as import.")
|
|
else:
|
|
self.import_graph.add_import(mod, imp)
|
|
except _Empty:
|
|
#Nothing in reply queue.
|
|
pass
|
|
|
|
def close(self):
|
|
self.logger.debug("Closing down workers")
|
|
assert not self.enqueued
|
|
for p in self.procs:
|
|
self.module_queue.put(None)
|
|
for p in self.procs:
|
|
p.join()
|
|
self.logger.info("Processed %d modules in %0.2fs", len(self.import_graph.done), time.time() - self.start_time)
|
|
|
|
def stop(self, timeout=2.0):
|
|
'''Stop the worker pool, reasonably promptly and as cleanly as possible.'''
|
|
try:
|
|
_drain_queue(self.module_queue)
|
|
for p in self.procs:
|
|
self.module_queue.put(None)
|
|
_drain_queue(self.reply_queue)
|
|
end = time.time() + timeout
|
|
running = set(self.procs)
|
|
while running and time.time() < end:
|
|
time.sleep(0.1)
|
|
_drain_queue(self.reply_queue)
|
|
running = {p for p in running if p.is_alive()}
|
|
if running:
|
|
for index, proc in enumerate(self.procs, 1):
|
|
if proc.is_alive():
|
|
self.logger.error("Process %d timed out", index)
|
|
except Exception as ex:
|
|
self.logger.error("Unexpected error when stopping %s", ex)
|
|
|
|
@staticmethod
|
|
def from_options(options, trap_dir, archive, logger: Logger):
|
|
'''Convenience method to create extractor pool from options.'''
|
|
cpus = mp.cpu_count()
|
|
procs = options.max_procs
|
|
if procs == 'all':
|
|
procs = cpus
|
|
elif procs is None or procs == 'half':
|
|
procs = (cpus+1)//2
|
|
else:
|
|
procs = int(procs)
|
|
return ExtractorPool(trap_dir, archive, procs, options, logger)
|
|
|
|
def _drain_queue(queue):
|
|
try:
|
|
while True:
|
|
queue.get(False)
|
|
except _Empty:
|
|
#Emptied queue as best we can.
|
|
pass
|
|
|
|
class DiagnosticsWriter(object):
|
|
def __init__(self, proc_id):
|
|
self.proc_id = proc_id
|
|
|
|
def write(self, message):
|
|
dir = os.environ.get("CODEQL_EXTRACTOR_PYTHON_DIAGNOSTIC_DIR")
|
|
if dir:
|
|
with open(os.path.join(dir, "worker-%d.jsonl" % self.proc_id), "a") as output_file:
|
|
output_file.write(json.dumps(message.to_dict()) + "\n")
|
|
|
|
@staticmethod
|
|
def create_output_dir():
|
|
dir = os.environ.get("CODEQL_EXTRACTOR_PYTHON_DIAGNOSTIC_DIR")
|
|
if dir:
|
|
os.makedirs(os.environ["CODEQL_EXTRACTOR_PYTHON_DIAGNOSTIC_DIR"], exist_ok=True)
|
|
|
|
|
|
|
|
# Function run by worker processes
|
|
def _extract_loop(proc_id, queue, trap_dir, archive, options, reply_queue, logger: Logger, write_global_data):
|
|
diagnostics_writer = DiagnosticsWriter(proc_id)
|
|
send_time = 0
|
|
recv_time = 0
|
|
extraction_time = 0
|
|
|
|
# use utf-8 as the character encoding for stdout/stderr to be able to properly
|
|
# log/print things on systems that use bad default encodings (windows).
|
|
sys.stdout.reconfigure(encoding='utf-8')
|
|
sys.stderr.reconfigure(encoding='utf-8')
|
|
|
|
try:
|
|
renamer = renamer_from_options_and_env(options, logger)
|
|
except Exception as ex:
|
|
logger.error("Exception: %s", ex)
|
|
reply_queue.put(("INTERRUPT", None, None))
|
|
sys.exit(2)
|
|
logger.set_process_id(proc_id)
|
|
try:
|
|
if options.trace_only:
|
|
extractor = ModulePrinter(options, trap_dir, archive, renamer, logger)
|
|
else:
|
|
extractor = SuperExtractor(options, trap_dir, archive, renamer, logger, diagnostics_writer)
|
|
profiler = get_profiler(options, id, logger)
|
|
with profiler:
|
|
while True:
|
|
start_recv = time.time()
|
|
unit = queue.get()
|
|
recv_time += time.time() - start_recv
|
|
if unit is None:
|
|
if write_global_data:
|
|
extractor.write_global_data()
|
|
extractor.close()
|
|
return
|
|
try:
|
|
start = time.time()
|
|
imports = extractor.process(unit)
|
|
end_time = time.time()
|
|
extraction_time += end_time - start
|
|
if imports is SkippedBuiltin:
|
|
logger.info("Skipped built-in %s", unit)
|
|
else:
|
|
for imp in imports:
|
|
reply_queue.put(("IMPORT", unit, imp))
|
|
send_time += time.time() - end_time
|
|
logger.info("Extracted %s in %0.0fms", unit, (end_time-start)*1000)
|
|
except SyntaxError as ex:
|
|
# Syntax errors have already been handled in extractor.py
|
|
reply_queue.put(("FAILURE", unit, None))
|
|
except RecursionError as ex:
|
|
logger.error("Failed to extract %s: %s", unit, ex)
|
|
logger.traceback(WARN)
|
|
try:
|
|
error = recursion_error_message(ex, unit)
|
|
diagnostics_writer.write(error)
|
|
except Exception as ex:
|
|
logger.warning("Failed to write diagnostics: %s", ex)
|
|
logger.traceback(WARN)
|
|
reply_queue.put(("FAILURE", unit, None))
|
|
except Exception as ex:
|
|
logger.error("Failed to extract %s: %s", unit, ex)
|
|
logger.traceback(WARN)
|
|
try:
|
|
error = internal_error_message(ex, unit)
|
|
diagnostics_writer.write(error)
|
|
except Exception as ex:
|
|
logger.warning("Failed to write diagnostics: %s", ex)
|
|
logger.traceback(WARN)
|
|
reply_queue.put(("FAILURE", unit, None))
|
|
else:
|
|
reply_queue.put(("SUCCESS", unit, None))
|
|
except KeyboardInterrupt:
|
|
logger.debug("Worker process received interrupt")
|
|
reply_queue.put(("INTERRUPT", None, None))
|
|
except Exception as ex:
|
|
logger.error("Exception: %s", ex)
|
|
reply_queue.put(("INTERRUPT", None, None))
|
|
# Avoid deadlock and speed up termination by clearing queue.
|
|
try:
|
|
while True:
|
|
msg = queue.get(False)
|
|
if msg is None:
|
|
break
|
|
except _Empty:
|
|
#Cleared queue enough to avoid deadlock.
|
|
pass
|
|
sys.exit(2)
|