MaD generator: add single file mode

This commit is contained in:
Paolo Tranquilli
2025-06-19 12:54:30 +02:00
parent 2818e6ee17
commit 261c129555
3 changed files with 47 additions and 37 deletions

View File

@@ -5,7 +5,7 @@ Experimental script for bulk generation of MaD models based on a list of project
Note: This file must be formatted using the Black Python formatter. Note: This file must be formatted using the Black Python formatter.
""" """
import os.path import pathlib
import subprocess import subprocess
import sys import sys
from typing import Required, TypedDict, List, Callable, Optional from typing import Required, TypedDict, List, Callable, Optional
@@ -41,7 +41,7 @@ gitroot = (
.decode("utf-8") .decode("utf-8")
.strip() .strip()
) )
build_dir = os.path.join(gitroot, "mad-generation-build") build_dir = pathlib.Path(gitroot, "mad-generation-build")
# A project to generate models for # A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
git_tag = project.get("git-tag") git_tag = project.get("git-tag")
# Determine target directory # Determine target directory
target_dir = os.path.join(build_dir, name) target_dir = build_dir / name
# Clone only if directory doesn't already exist # Clone only if directory doesn't already exist
if not os.path.exists(target_dir): if not target_dir.exists():
if git_tag: if git_tag:
print(f"Cloning {name} from {repo_url} at tag {git_tag}") print(f"Cloning {name} from {repo_url} at tag {git_tag}")
else: else:
@@ -191,10 +191,10 @@ def build_database(
name = project["name"] name = project["name"]
# Create database directory path # Create database directory path
database_dir = os.path.join(build_dir, f"{name}-db") database_dir = build_dir / f"{name}-db"
# Only build the database if it doesn't already exist # Only build the database if it doesn't already exist
if not os.path.exists(database_dir): if not database_dir.exists():
print(f"Building CodeQL database for {name}...") print(f"Building CodeQL database for {name}...")
extractor_options = [option for x in extractor_options for option in ("-O", x)] extractor_options = [option for x in extractor_options for option in ("-O", x)]
try: try:
@@ -241,7 +241,11 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
generator.with_summaries = should_generate_summaries(project) generator.with_summaries = should_generate_summaries(project)
generator.threads = args.codeql_threads generator.threads = args.codeql_threads
generator.ram = args.codeql_ram generator.ram = args.codeql_ram
generator.setenvironment(database=database_dir, folder=name) if config.get("single-file", False):
generator.single_file = name
else:
generator.folder = name
generator.setenvironment(database=database_dir)
generator.run() generator.run()
@@ -312,7 +316,7 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to download file. Status code: {response.status_code}") print(f"Failed to download file. Status code: {response.status_code}")
sys.exit(1) sys.exit(1)
target_zip = os.path.join(build_dir, zipName) target_zip = build_dir / zipName
with open(target_zip, "wb") as file: with open(target_zip, "wb") as file:
for chunk in response.iter_content(chunk_size=8192): for chunk in response.iter_content(chunk_size=8192):
file.write(chunk) file.write(chunk)
@@ -320,12 +324,6 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
return target_zip return target_zip
def remove_extension(filename: str) -> str:
while "." in filename:
filename, _ = os.path.splitext(filename)
return filename
def pretty_name_from_artifact_name(artifact_name: str) -> str: def pretty_name_from_artifact_name(artifact_name: str) -> str:
return artifact_name.split("___")[1] return artifact_name.split("___")[1]
@@ -399,19 +397,17 @@ def download_dca_databases(
# The database is in a zip file, which contains a tar.gz file with the DB # The database is in a zip file, which contains a tar.gz file with the DB
# First we open the zip file # First we open the zip file
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref: with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
artifact_unzipped_location = os.path.join(build_dir, artifact_name) artifact_unzipped_location = build_dir / artifact_name
# clean up any remnants of previous runs # clean up any remnants of previous runs
shutil.rmtree(artifact_unzipped_location, ignore_errors=True) shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
# And then we extract it to build_dir/artifact_name # And then we extract it to build_dir/artifact_name
zip_ref.extractall(artifact_unzipped_location) zip_ref.extractall(artifact_unzipped_location)
# And then we extract the language tar.gz file inside it # And then we extract the language tar.gz file inside it
artifact_tar_location = os.path.join( artifact_tar_location = artifact_unzipped_location / f"{language}.tar.gz"
artifact_unzipped_location, f"{language}.tar.gz"
)
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref: with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
# And we just untar it to the same directory as the zip file # And we just untar it to the same directory as the zip file
tar_ref.extractall(artifact_unzipped_location) tar_ref.extractall(artifact_unzipped_location)
ret = os.path.join(artifact_unzipped_location, language) ret = artifact_unzipped_location / language
print(f"Decompression complete: {ret}") print(f"Decompression complete: {ret}")
return ret return ret
@@ -431,8 +427,16 @@ def download_dca_databases(
return [(project_map[n], r) for n, r in zip(analyzed_databases, results)] return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
def get_mad_destination_for_project(config, name: str) -> str: def clean_up_mad_destination_for_project(config, name: str):
return os.path.join(config["destination"], name) target = pathlib.Path(config["destination"], name)
if config.get("single-file", False):
target = target.with_suffix(".model.yml")
if target.exists():
print(f"Deleting existing MaD file at {target}")
target.unlink()
elif target.exists():
print(f"Deleting existing MaD directory at {target}")
shutil.rmtree(target, ignore_errors=True)
def get_strategy(config) -> str: def get_strategy(config) -> str:
@@ -454,8 +458,7 @@ def main(config, args) -> None:
language = config["language"] language = config["language"]
# Create build directory if it doesn't exist # Create build directory if it doesn't exist
if not os.path.exists(build_dir): build_dir.mkdir(parents=True, exist_ok=True)
os.makedirs(build_dir)
database_results = [] database_results = []
match get_strategy(config): match get_strategy(config):
@@ -475,7 +478,7 @@ def main(config, args) -> None:
if args.pat is None: if args.pat is None:
print("ERROR: --pat argument is required for DCA strategy") print("ERROR: --pat argument is required for DCA strategy")
sys.exit(1) sys.exit(1)
if not os.path.exists(args.pat): if not args.pat.exists():
print(f"ERROR: Personal Access Token file '{pat}' does not exist.") print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
sys.exit(1) sys.exit(1)
with open(args.pat, "r") as f: with open(args.pat, "r") as f:
@@ -499,12 +502,9 @@ def main(config, args) -> None:
) )
sys.exit(1) sys.exit(1)
# Delete the MaD directory for each project # clean up existing MaD data for the projects
for project, database_dir in database_results: for project, _ in database_results:
mad_dir = get_mad_destination_for_project(config, project["name"]) clean_up_mad_destination_for_project(config, project["name"])
if os.path.exists(mad_dir):
print(f"Deleting existing MaD directory at {mad_dir}")
subprocess.check_call(["rm", "-rf", mad_dir])
for project, database_dir in database_results: for project, database_dir in database_results:
if database_dir is not None: if database_dir is not None:
@@ -514,7 +514,10 @@ def main(config, args) -> None:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--config", type=str, help="Path to the configuration file.", required=True "--config",
type=pathlib.Path,
help="Path to the configuration file.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--dca", "--dca",
@@ -525,7 +528,7 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--pat", "--pat",
type=str, type=pathlib.Path,
help="Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)", help="Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)",
) )
parser.add_argument( parser.add_argument(
@@ -544,7 +547,7 @@ if __name__ == "__main__":
# Load config file # Load config file
config = {} config = {}
if not os.path.exists(args.config): if not args.config.exists():
print(f"ERROR: Config file '{args.config}' does not exist.") print(f"ERROR: Config file '{args.config}' does not exist.")
sys.exit(1) sys.exit(1)
try: try:

View File

@@ -53,12 +53,13 @@ class Generator:
ram = None ram = None
threads = 0 threads = 0
folder = "" folder = ""
single_file = None
def __init__(self, language=None): def __init__(self, language=None):
self.language = language self.language = language
def setenvironment(self, database=None, folder=None): def setenvironment(self, database=None, folder=None):
self.codeQlRoot = ( self.codeql_root = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"]) subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("utf-8") .decode("utf-8")
.strip() .strip()
@@ -66,7 +67,7 @@ class Generator:
self.database = database or self.database self.database = database or self.database
self.folder = folder or self.folder self.folder = folder or self.folder
self.generated_frameworks = os.path.join( self.generated_frameworks = os.path.join(
self.codeQlRoot, f"{self.language}/ql/lib/ext/generated/{self.folder}" self.codeql_root, f"{self.language}/ql/lib/ext/generated/{self.folder}"
) )
self.workDir = tempfile.mkdtemp() self.workDir = tempfile.mkdtemp()
if self.ram is None: if self.ram is None:
@@ -134,6 +135,10 @@ class Generator:
type=int, type=int,
help="Amount of RAM to use for CodeQL queries in MB. Default is to use 2048 MB per thread.", help="Amount of RAM to use for CodeQL queries in MB. Default is to use 2048 MB per thread.",
) )
p.add_argument(
"--single-file",
help="Generate a single file with all models instead of separate files for each namespace, using provided argument as the base filename.",
)
generator = p.parse_args(namespace=Generator()) generator = p.parse_args(namespace=Generator())
if ( if (
@@ -154,7 +159,7 @@ class Generator:
def runQuery(self, query): def runQuery(self, query):
print("########## Querying " + query + "...") print("########## Querying " + query + "...")
queryFile = os.path.join( queryFile = os.path.join(
self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query self.codeql_root, f"{self.language}/ql/src/utils/{self.dirname}", query
) )
resultBqrs = os.path.join(self.workDir, "out.bqrs") resultBqrs = os.path.join(self.workDir, "out.bqrs")
@@ -187,6 +192,8 @@ class Generator:
def getAddsTo(self, query, predicate): def getAddsTo(self, query, predicate):
data = self.runQuery(query) data = self.runQuery(query)
rows = parseData(data) rows = parseData(data)
if self.single_file and rows:
rows = {self.single_file: "".join(rows.values())}
return self.asAddsTo(rows, predicate) return self.asAddsTo(rows, predicate)
def makeContent(self): def makeContent(self):

View File

@@ -22,7 +22,7 @@ def remove_dir(dirName):
def run_cmd(cmd, msg="Failed to run command"): def run_cmd(cmd, msg="Failed to run command"):
print("Running " + " ".join(cmd)) print("Running " + " ".join(map(str, cmd)))
if subprocess.check_call(cmd): if subprocess.check_call(cmd):
print(msg) print(msg)
exit(1) exit(1)