MaD: make bulk generator DCA strategy download DBs in parallel

This commit is contained in:
Paolo Tranquilli
2025-06-05 09:29:46 +02:00
parent fbd50583fe
commit 4f47ee2e72

View File

@@ -8,7 +8,7 @@ Note: This file must be formatted using the Black Python formatter.
import os.path
import subprocess
import sys
from typing import NotRequired, TypedDict, List
from typing import NotRequired, TypedDict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import argparse
@@ -111,6 +111,37 @@ def clone_project(project: Project) -> str:
return target_dir
def run_in_parallel[T, U](
func: Callable[[T], U],
items: List[T],
*,
on_error=lambda item, exc: None,
error_summary=lambda failures: None,
max_workers=8,
) -> List[Optional[U]]:
if not items:
return []
max_workers = min(max_workers, len(items))
results = [None for _ in range(len(items))]
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Start cloning tasks and keep track of them
futures = {
executor.submit(func, item): index for index, item in enumerate(items)
}
# Process results as they complete
for future in as_completed(futures):
index = futures[future]
try:
results[index] = future.result()
except Exception as e:
on_error(items[index], e)
failed = [item for item, result in zip(items, results) if result is None]
if failed:
error_summary(failed)
sys.exit(1)
return results
def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
"""
Clone all projects in parallel.
@@ -122,40 +153,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
List of (project, project_dir) pairs in the same order as the input projects
"""
start_time = time.time()
max_workers = min(8, len(projects)) # Use at most 8 threads
project_dirs_map = {} # Map to store results by project name
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Start cloning tasks and keep track of them
future_to_project = {
executor.submit(clone_project, project): project for project in projects
}
# Process results as they complete
for future in as_completed(future_to_project):
project = future_to_project[future]
try:
project_dir = future.result()
project_dirs_map[project["name"]] = (project, project_dir)
except Exception as e:
print(f"ERROR: Failed to clone {project['name']}: {e}")
if len(project_dirs_map) != len(projects):
failed_projects = [
project["name"]
for project in projects
if project["name"] not in project_dirs_map
]
print(
f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
)
sys.exit(1)
project_dirs = [project_dirs_map[project["name"]] for project in projects]
dirs = run_in_parallel(
clone_project,
projects,
on_error=lambda project, exc: print(
f"ERROR: Failed to clone project {project['name']}: {exc}"
),
error_summary=lambda failures: print(
f"ERROR: Failed to clone {len(failures)} projects: {', '.join(p['name'] for p in failures)}"
),
)
clone_time = time.time() - start_time
print(f"Cloning completed in {clone_time:.2f} seconds")
return project_dirs
return list(zip(projects, dirs))
def build_database(
@@ -352,7 +362,8 @@ def download_dca_databases(
artifact_map[pretty_name] = analyzed_database
for pretty_name, analyzed_database in artifact_map.items():
def download(item: tuple[str, dict]) -> str:
pretty_name, analyzed_database = item
artifact_name = analyzed_database["artifact_name"]
repository = analyzed_database["repository"]
run_id = analyzed_database["run_id"]
@@ -383,13 +394,22 @@ def download_dca_databases(
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
# And we just untar it to the same directory as the zip file
tar_ref.extractall(artifact_unzipped_location)
database_results[pretty_name] = os.path.join(
artifact_unzipped_location, remove_extension(entry)
)
return os.path.join(artifact_unzipped_location, remove_extension(entry))
results = run_in_parallel(
download,
list(artifact_map.items()),
on_error=lambda item, exc: print(
f"ERROR: Failed to download database for {item[0]}: {exc}"
),
error_summary=lambda failures: print(
f"ERROR: Failed to download {len(failures)} databases: {', '.join(item[0] for item in failures)}"
),
)
print(f"\n=== Extracted {len(database_results)} databases ===")
return [(project, database_results[project["name"]]) for project in projects]
return [(project_map[n], r) for n, r in zip(artifact_map, results)]
def get_mad_destination_for_project(config, name: str) -> str: