Files
codeql/python/extractor/semmle/worker.py
2025-10-06 11:47:12 +02:00

348 lines
13 KiB
Python

import sys, os
from collections import deque, defaultdict
import time
import multiprocessing as mp
import json
from queue import Empty as _Empty
from queue import Full as _Full
from semmle.extractors import SuperExtractor, ModulePrinter, SkippedBuiltin
from semmle.profiling import get_profiler
from semmle.path_rename import renamer_from_options_and_env
from semmle.logging import WARN, recursion_error_message, internal_error_message, Logger
from semmle.util import FileExtractable, FolderExtractable
class ExtractorFailure(Exception):
'Generic exception representing the failure of an extractor.'
pass
class ModuleImportGraph(object):
def __init__(self, max_depth, logger: Logger):
self.modules = {}
self.succ = defaultdict(set)
self.todo = set()
self.done = set()
self.max_depth = max_depth
self.logger = logger
# During overlay extraction, only traverse the files that were changed.
self.overlay_changes = None
if 'CODEQL_EXTRACTOR_PYTHON_OVERLAY_CHANGES' in os.environ:
overlay_changes_file = os.environ['CODEQL_EXTRACTOR_PYTHON_OVERLAY_CHANGES']
logger.info("Overlay extraction mode: only extracting files changed according to '%s'", overlay_changes_file)
try:
with open(overlay_changes_file, 'r', encoding='utf-8') as f:
data = json.load(f)
changed_paths = data.get('changes', [])
self.overlay_changes = { os.path.abspath(p) for p in changed_paths }
except (IOError, ValueError) as e:
logger.warn("Failed to read overlay changes from '%s' (falling back to full extraction): %s", overlay_changes_file, e)
self.overlay_changes = None
def add_root(self, mod):
self.modules[mod] = 0
if mod not in self.done:
self.add_todo(mod)
def add_import(self, mod, imported):
assert mod in self.modules
self.succ[mod].add(imported)
if imported in self.modules:
if self.modules[imported] > self.modules[mod] + 1:
self._reduce_depth(imported, self.modules[mod] + 1)
else:
if self.modules[mod] < self.max_depth and imported not in self.done:
self.add_todo(imported)
self.modules[imported] = self.modules[mod] + 1
def _reduce_depth(self, mod, depth):
if self.modules[mod] <= depth:
return
if depth > self.max_depth:
return
if mod not in self.done:
self.add_todo(mod)
self.modules[mod] = depth
for imp in self.succ[mod]:
self._reduce_depth(imp, depth+1)
def get(self):
mod = self.todo.pop()
assert not mod in self.done and self.modules[mod] <= self.max_depth
self.done.add(mod)
return mod
def push_back(self, mod):
self.done.remove(mod)
self.add_todo(mod)
def empty(self):
return not self.todo
def add_todo(self, mod):
if not self._module_in_overlay_changes(mod):
self.logger.debug("Skipping module '%s' as it was not changed in overlay extraction.", mod)
return
self.todo.add(mod)
def _module_in_overlay_changes(self, mod):
if self.overlay_changes is not None:
if isinstance(mod, FileExtractable):
return mod.path in self.overlay_changes
if isinstance(mod, FolderExtractable):
return mod.path + '/__init__.py' in self.overlay_changes
return True
class ExtractorPool(object):
'''Pool of worker processes running extractors'''
def __init__(self, outdir, archive, proc_count, options, logger: Logger):
if proc_count < 1:
raise ValueError("Number of processes must be at least one.")
self.verbose = options.verbose
self.outdir = outdir
self.max_import_depth = options.max_import_depth
# macOS does not support `fork` properly, so we must use `spawn` instead.
method = 'spawn' if sys.platform == "darwin" else None
try:
ctx = mp.get_context(method)
except AttributeError:
# `get_context` doesn't exist -- we must be running an old version of Python.
ctx = mp
#Keep queue short to minimise delay when stopping
self.module_queue = ctx.Queue(proc_count*2)
self.reply_queue = ctx.Queue(proc_count*20)
self.archive = archive
self.local_queue = deque()
self.enqueued = set()
self.done = set()
self.requirements = {}
self.import_graph = ModuleImportGraph(options.max_import_depth, logger)
logger.debug("Source archive: %s", archive)
self.logger = logger
DiagnosticsWriter.create_output_dir()
args = (self.module_queue, outdir, archive, options, self.reply_queue, logger)
self.procs = [
ctx.Process(target=_extract_loop, args=(n+1,) + args + (n == 0,)) for n in range(proc_count)
]
for p in self.procs:
p.start()
self.start_time = time.time()
def extract(self, the_traverser):
'''Extract all the files from the given traverser,
and all the imported files up to the depth specified
by the options.
'''
self.logger.trace("Starting traversal")
for mod in the_traverser:
self.import_graph.add_root(mod)
self.try_to_send()
self.receive(False)
#Prime the queue
while self.try_to_send():
pass
while self.enqueued or not self.import_graph.empty():
self.try_to_send()
self.receive(True)
def try_to_send(self):
if self.import_graph.empty():
return False
module = self.import_graph.get()
try:
self.module_queue.put(module, False)
self.enqueued.add(module)
self.logger.debug("Enqueued %s", module)
return True
except _Full:
self.import_graph.push_back(module)
return False
def receive(self, block=False):
try:
what, mod, imp = self.reply_queue.get(block)
if what == "INTERRUPT":
self.logger.debug("Main process received interrupt")
raise KeyboardInterrupt
elif what == "UNRECOVERABLE_FAILURE":
raise ExtractorFailure(str(mod))
elif what == "FAILURE":
self.enqueued.remove(mod)
elif what == "SUCCESS":
self.enqueued.remove(mod)
else:
assert what == "IMPORT"
assert mod is not None
if imp is None:
self.logger.warning("Unexpected None as import.")
else:
self.import_graph.add_import(mod, imp)
except _Empty:
#Nothing in reply queue.
pass
def close(self):
self.logger.debug("Closing down workers")
assert not self.enqueued
for p in self.procs:
self.module_queue.put(None)
for p in self.procs:
p.join()
if 'CODEQL_EXTRACTOR_PYTHON_OVERLAY_BASE_METADATA_OUT' in os.environ:
with open(os.environ['CODEQL_EXTRACTOR_PYTHON_OVERLAY_BASE_METADATA_OUT'], 'w', encoding='utf-8') as f:
metadata = {}
json.dump(metadata, f)
self.logger.info("Processed %d modules in %0.2fs", len(self.import_graph.done), time.time() - self.start_time)
def stop(self, timeout=2.0):
'''Stop the worker pool, reasonably promptly and as cleanly as possible.'''
try:
_drain_queue(self.module_queue)
for p in self.procs:
self.module_queue.put(None)
_drain_queue(self.reply_queue)
end = time.time() + timeout
running = set(self.procs)
while running and time.time() < end:
time.sleep(0.1)
_drain_queue(self.reply_queue)
running = {p for p in running if p.is_alive()}
if running:
for index, proc in enumerate(self.procs, 1):
if proc.is_alive():
self.logger.error("Process %d timed out", index)
except Exception as ex:
self.logger.error("Unexpected error when stopping %s", ex)
@staticmethod
def from_options(options, trap_dir, archive, logger: Logger):
'''Convenience method to create extractor pool from options.'''
cpus = mp.cpu_count()
procs = options.max_procs
if procs == 'all':
procs = cpus
elif procs is None or procs == 'half':
procs = (cpus+1)//2
else:
procs = int(procs)
return ExtractorPool(trap_dir, archive, procs, options, logger)
def _drain_queue(queue):
try:
while True:
queue.get(False)
except _Empty:
#Emptied queue as best we can.
pass
class DiagnosticsWriter(object):
def __init__(self, proc_id):
self.proc_id = proc_id
def write(self, message):
dir = os.environ.get("CODEQL_EXTRACTOR_PYTHON_DIAGNOSTIC_DIR")
if dir:
with open(os.path.join(dir, "worker-%d.jsonl" % self.proc_id), "a") as output_file:
output_file.write(json.dumps(message.to_dict()) + "\n")
@staticmethod
def create_output_dir():
dir = os.environ.get("CODEQL_EXTRACTOR_PYTHON_DIAGNOSTIC_DIR")
if dir:
os.makedirs(os.environ["CODEQL_EXTRACTOR_PYTHON_DIAGNOSTIC_DIR"], exist_ok=True)
# Function run by worker processes
def _extract_loop(proc_id, queue, trap_dir, archive, options, reply_queue, logger: Logger, write_global_data):
diagnostics_writer = DiagnosticsWriter(proc_id)
send_time = 0
recv_time = 0
extraction_time = 0
# use utf-8 as the character encoding for stdout/stderr to be able to properly
# log/print things on systems that use bad default encodings (windows).
sys.stdout.reconfigure(encoding='utf-8')
sys.stderr.reconfigure(encoding='utf-8')
try:
renamer = renamer_from_options_and_env(options, logger)
except Exception as ex:
logger.error("Exception: %s", ex)
reply_queue.put(("INTERRUPT", None, None))
sys.exit(2)
logger.set_process_id(proc_id)
try:
if options.trace_only:
extractor = ModulePrinter(options, trap_dir, archive, renamer, logger)
else:
extractor = SuperExtractor(options, trap_dir, archive, renamer, logger, diagnostics_writer)
profiler = get_profiler(options, id, logger)
with profiler:
while True:
start_recv = time.time()
unit = queue.get()
recv_time += time.time() - start_recv
if unit is None:
if write_global_data:
extractor.write_global_data()
extractor.close()
return
try:
start = time.time()
imports = extractor.process(unit)
end_time = time.time()
extraction_time += end_time - start
if imports is SkippedBuiltin:
logger.info("Skipped built-in %s", unit)
else:
for imp in imports:
reply_queue.put(("IMPORT", unit, imp))
send_time += time.time() - end_time
logger.info("Extracted %s in %0.0fms", unit, (end_time-start)*1000)
except SyntaxError as ex:
# Syntax errors have already been handled in extractor.py
reply_queue.put(("FAILURE", unit, None))
except RecursionError as ex:
logger.error("Failed to extract %s: %s", unit, ex)
logger.traceback(WARN)
try:
error = recursion_error_message(ex, unit)
diagnostics_writer.write(error)
except Exception as ex:
logger.warning("Failed to write diagnostics: %s", ex)
logger.traceback(WARN)
reply_queue.put(("FAILURE", unit, None))
except Exception as ex:
logger.error("Failed to extract %s: %s", unit, ex)
logger.traceback(WARN)
try:
error = internal_error_message(ex, unit)
diagnostics_writer.write(error)
except Exception as ex:
logger.warning("Failed to write diagnostics: %s", ex)
logger.traceback(WARN)
reply_queue.put(("FAILURE", unit, None))
else:
reply_queue.put(("SUCCESS", unit, None))
except KeyboardInterrupt:
logger.debug("Worker process received interrupt")
reply_queue.put(("INTERRUPT", None, None))
except Exception as ex:
logger.error("Exception: %s", ex)
reply_queue.put(("INTERRUPT", None, None))
# Avoid deadlock and speed up termination by clearing queue.
try:
while True:
msg = queue.get(False)
if msg is None:
break
except _Empty:
#Cleared queue enough to avoid deadlock.
pass
sys.exit(2)