mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
MaD: make bulk generator DCA strategy download DBs in parallel
This commit is contained in:
@@ -8,7 +8,7 @@ Note: This file must be formatted using the Black Python formatter.
|
||||
import os.path
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import NotRequired, TypedDict, List
|
||||
from typing import NotRequired, TypedDict, List, Callable, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import time
|
||||
import argparse
|
||||
@@ -111,6 +111,37 @@ def clone_project(project: Project) -> str:
|
||||
return target_dir
|
||||
|
||||
|
||||
def run_in_parallel[T, U](
|
||||
func: Callable[[T], U],
|
||||
items: List[T],
|
||||
*,
|
||||
on_error=lambda item, exc: None,
|
||||
error_summary=lambda failures: None,
|
||||
max_workers=8,
|
||||
) -> List[Optional[U]]:
|
||||
if not items:
|
||||
return []
|
||||
max_workers = min(max_workers, len(items))
|
||||
results = [None for _ in range(len(items))]
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Start cloning tasks and keep track of them
|
||||
futures = {
|
||||
executor.submit(func, item): index for index, item in enumerate(items)
|
||||
}
|
||||
# Process results as they complete
|
||||
for future in as_completed(futures):
|
||||
index = futures[future]
|
||||
try:
|
||||
results[index] = future.result()
|
||||
except Exception as e:
|
||||
on_error(items[index], e)
|
||||
failed = [item for item, result in zip(items, results) if result is None]
|
||||
if failed:
|
||||
error_summary(failed)
|
||||
sys.exit(1)
|
||||
return results
|
||||
|
||||
|
||||
def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
|
||||
"""
|
||||
Clone all projects in parallel.
|
||||
@@ -122,40 +153,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
|
||||
List of (project, project_dir) pairs in the same order as the input projects
|
||||
"""
|
||||
start_time = time.time()
|
||||
max_workers = min(8, len(projects)) # Use at most 8 threads
|
||||
project_dirs_map = {} # Map to store results by project name
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Start cloning tasks and keep track of them
|
||||
future_to_project = {
|
||||
executor.submit(clone_project, project): project for project in projects
|
||||
}
|
||||
|
||||
# Process results as they complete
|
||||
for future in as_completed(future_to_project):
|
||||
project = future_to_project[future]
|
||||
try:
|
||||
project_dir = future.result()
|
||||
project_dirs_map[project["name"]] = (project, project_dir)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to clone {project['name']}: {e}")
|
||||
|
||||
if len(project_dirs_map) != len(projects):
|
||||
failed_projects = [
|
||||
project["name"]
|
||||
for project in projects
|
||||
if project["name"] not in project_dirs_map
|
||||
]
|
||||
print(
|
||||
f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
project_dirs = [project_dirs_map[project["name"]] for project in projects]
|
||||
|
||||
dirs = run_in_parallel(
|
||||
clone_project,
|
||||
projects,
|
||||
on_error=lambda project, exc: print(
|
||||
f"ERROR: Failed to clone project {project['name']}: {exc}"
|
||||
),
|
||||
error_summary=lambda failures: print(
|
||||
f"ERROR: Failed to clone {len(failures)} projects: {', '.join(p['name'] for p in failures)}"
|
||||
),
|
||||
)
|
||||
clone_time = time.time() - start_time
|
||||
print(f"Cloning completed in {clone_time:.2f} seconds")
|
||||
return project_dirs
|
||||
return list(zip(projects, dirs))
|
||||
|
||||
|
||||
def build_database(
|
||||
@@ -352,7 +362,8 @@ def download_dca_databases(
|
||||
|
||||
artifact_map[pretty_name] = analyzed_database
|
||||
|
||||
for pretty_name, analyzed_database in artifact_map.items():
|
||||
def download(item: tuple[str, dict]) -> str:
|
||||
pretty_name, analyzed_database = item
|
||||
artifact_name = analyzed_database["artifact_name"]
|
||||
repository = analyzed_database["repository"]
|
||||
run_id = analyzed_database["run_id"]
|
||||
@@ -383,13 +394,22 @@ def download_dca_databases(
|
||||
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
|
||||
# And we just untar it to the same directory as the zip file
|
||||
tar_ref.extractall(artifact_unzipped_location)
|
||||
database_results[pretty_name] = os.path.join(
|
||||
artifact_unzipped_location, remove_extension(entry)
|
||||
)
|
||||
return os.path.join(artifact_unzipped_location, remove_extension(entry))
|
||||
|
||||
results = run_in_parallel(
|
||||
download,
|
||||
list(artifact_map.items()),
|
||||
on_error=lambda item, exc: print(
|
||||
f"ERROR: Failed to download database for {item[0]}: {exc}"
|
||||
),
|
||||
error_summary=lambda failures: print(
|
||||
f"ERROR: Failed to download {len(failures)} databases: {', '.join(item[0] for item in failures)}"
|
||||
),
|
||||
)
|
||||
|
||||
print(f"\n=== Extracted {len(database_results)} databases ===")
|
||||
|
||||
return [(project, database_results[project["name"]]) for project in projects]
|
||||
return [(project_map[n], r) for n, r in zip(artifact_map, results)]
|
||||
|
||||
|
||||
def get_mad_destination_for_project(config, name: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user