MaD generator: add single file mode

This commit is contained in:
Paolo Tranquilli
2025-06-19 12:54:30 +02:00
parent 2818e6ee17
commit 261c129555
3 changed files with 47 additions and 37 deletions

View File

@@ -5,7 +5,7 @@ Experimental script for bulk generation of MaD models based on a list of project
Note: This file must be formatted using the Black Python formatter.
"""
import os.path
import pathlib
import subprocess
import sys
from typing import Required, TypedDict, List, Callable, Optional
@@ -41,7 +41,7 @@ gitroot = (
.decode("utf-8")
.strip()
)
build_dir = os.path.join(gitroot, "mad-generation-build")
build_dir = pathlib.Path(gitroot, "mad-generation-build")
# A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
git_tag = project.get("git-tag")
# Determine target directory
target_dir = os.path.join(build_dir, name)
target_dir = build_dir / name
# Clone only if directory doesn't already exist
if not os.path.exists(target_dir):
if not target_dir.exists():
if git_tag:
print(f"Cloning {name} from {repo_url} at tag {git_tag}")
else:
@@ -191,10 +191,10 @@ def build_database(
name = project["name"]
# Create database directory path
database_dir = os.path.join(build_dir, f"{name}-db")
database_dir = build_dir / f"{name}-db"
# Only build the database if it doesn't already exist
if not os.path.exists(database_dir):
if not database_dir.exists():
print(f"Building CodeQL database for {name}...")
extractor_options = [option for x in extractor_options for option in ("-O", x)]
try:
@@ -241,7 +241,11 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
generator.with_summaries = should_generate_summaries(project)
generator.threads = args.codeql_threads
generator.ram = args.codeql_ram
generator.setenvironment(database=database_dir, folder=name)
if config.get("single-file", False):
generator.single_file = name
else:
generator.folder = name
generator.setenvironment(database=database_dir)
generator.run()
@@ -312,7 +316,7 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
if response.status_code != 200:
print(f"Failed to download file. Status code: {response.status_code}")
sys.exit(1)
target_zip = os.path.join(build_dir, zipName)
target_zip = build_dir / zipName
with open(target_zip, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
@@ -320,12 +324,6 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
return target_zip
def remove_extension(filename: str) -> str:
while "." in filename:
filename, _ = os.path.splitext(filename)
return filename
def pretty_name_from_artifact_name(artifact_name: str) -> str:
return artifact_name.split("___")[1]
@@ -399,19 +397,17 @@ def download_dca_databases(
# The database is in a zip file, which contains a tar.gz file with the DB
# First we open the zip file
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
artifact_unzipped_location = build_dir / artifact_name
# clean up any remnants of previous runs
shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
# And then we extract it to build_dir/artifact_name
zip_ref.extractall(artifact_unzipped_location)
# And then we extract the language tar.gz file inside it
artifact_tar_location = os.path.join(
artifact_unzipped_location, f"{language}.tar.gz"
)
artifact_tar_location = artifact_unzipped_location / f"{language}.tar.gz"
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
# And we just untar it to the same directory as the zip file
tar_ref.extractall(artifact_unzipped_location)
ret = os.path.join(artifact_unzipped_location, language)
ret = artifact_unzipped_location / language
print(f"Decompression complete: {ret}")
return ret
@@ -431,8 +427,16 @@ def download_dca_databases(
return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
def get_mad_destination_for_project(config, name: str) -> str:
return os.path.join(config["destination"], name)
def clean_up_mad_destination_for_project(config, name: str):
target = pathlib.Path(config["destination"], name)
if config.get("single-file", False):
target = target.with_suffix(".model.yml")
if target.exists():
print(f"Deleting existing MaD file at {target}")
target.unlink()
elif target.exists():
print(f"Deleting existing MaD directory at {target}")
shutil.rmtree(target, ignore_errors=True)
def get_strategy(config) -> str:
@@ -454,8 +458,7 @@ def main(config, args) -> None:
language = config["language"]
# Create build directory if it doesn't exist
if not os.path.exists(build_dir):
os.makedirs(build_dir)
build_dir.mkdir(parents=True, exist_ok=True)
database_results = []
match get_strategy(config):
@@ -475,7 +478,7 @@ def main(config, args) -> None:
if args.pat is None:
print("ERROR: --pat argument is required for DCA strategy")
sys.exit(1)
if not os.path.exists(args.pat):
if not args.pat.exists():
print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
sys.exit(1)
with open(args.pat, "r") as f:
@@ -499,12 +502,9 @@ def main(config, args) -> None:
)
sys.exit(1)
# Delete the MaD directory for each project
for project, database_dir in database_results:
mad_dir = get_mad_destination_for_project(config, project["name"])
if os.path.exists(mad_dir):
print(f"Deleting existing MaD directory at {mad_dir}")
subprocess.check_call(["rm", "-rf", mad_dir])
# clean up existing MaD data for the projects
for project, _ in database_results:
clean_up_mad_destination_for_project(config, project["name"])
for project, database_dir in database_results:
if database_dir is not None:
@@ -514,7 +514,10 @@ def main(config, args) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, help="Path to the configuration file.", required=True
"--config",
type=pathlib.Path,
help="Path to the configuration file.",
required=True,
)
parser.add_argument(
"--dca",
@@ -525,7 +528,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--pat",
type=str,
type=pathlib.Path,
help="Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)",
)
parser.add_argument(
@@ -544,7 +547,7 @@ if __name__ == "__main__":
# Load config file
config = {}
if not os.path.exists(args.config):
if not args.config.exists():
print(f"ERROR: Config file '{args.config}' does not exist.")
sys.exit(1)
try:

View File

@@ -53,12 +53,13 @@ class Generator:
ram = None
threads = 0
folder = ""
single_file = None
def __init__(self, language=None):
self.language = language
def setenvironment(self, database=None, folder=None):
self.codeQlRoot = (
self.codeql_root = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("utf-8")
.strip()
@@ -66,7 +67,7 @@ class Generator:
self.database = database or self.database
self.folder = folder or self.folder
self.generated_frameworks = os.path.join(
self.codeQlRoot, f"{self.language}/ql/lib/ext/generated/{self.folder}"
self.codeql_root, f"{self.language}/ql/lib/ext/generated/{self.folder}"
)
self.workDir = tempfile.mkdtemp()
if self.ram is None:
@@ -134,6 +135,10 @@ class Generator:
type=int,
help="Amount of RAM to use for CodeQL queries in MB. Default is to use 2048 MB per thread.",
)
p.add_argument(
"--single-file",
help="Generate a single file with all models instead of separate files for each namespace, using provided argument as the base filename.",
)
generator = p.parse_args(namespace=Generator())
if (
@@ -154,7 +159,7 @@ class Generator:
def runQuery(self, query):
print("########## Querying " + query + "...")
queryFile = os.path.join(
self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query
self.codeql_root, f"{self.language}/ql/src/utils/{self.dirname}", query
)
resultBqrs = os.path.join(self.workDir, "out.bqrs")
@@ -187,6 +192,8 @@ class Generator:
def getAddsTo(self, query, predicate):
data = self.runQuery(query)
rows = parseData(data)
if self.single_file and rows:
rows = {self.single_file: "".join(rows.values())}
return self.asAddsTo(rows, predicate)
def makeContent(self):

View File

@@ -22,7 +22,7 @@ def remove_dir(dirName):
def run_cmd(cmd, msg="Failed to run command"):
print("Running " + " ".join(cmd))
print("Running " + " ".join(map(str, cmd)))
if subprocess.check_call(cmd):
print(msg)
exit(1)