MaD generator: apply black formatting to all sources

This commit is contained in:
Paolo Tranquilli
2025-06-13 08:47:07 +02:00
parent 1a36374718
commit 5df292c286
3 changed files with 63 additions and 24 deletions

View File

@@ -18,7 +18,7 @@ repos:
rev: 25.1.0
hooks:
- id: black
files: ^(misc/codegen/.*|misc/scripts/models-as-data/bulk_generate_mad)\.py$
files: ^(misc/codegen/.*|misc/scripts/models-as-data/.*)\.py$
- repo: local
hooks:

View File

@@ -7,65 +7,86 @@ import subprocess
import sys
import tempfile
def quote_if_needed(v):
# string columns
if type(v) is str:
return "\"" + v + "\""
return '"' + v + '"'
# bool column
return str(v)
def parseData(data):
rows = [{ }, { }]
rows = [{}, {}]
for row in data:
d = map(quote_if_needed, row)
provenance = row[-1]
targetRows = rows[1] if provenance.endswith("generated") else rows[0]
helpers.insert_update(targetRows, row[0], " - [" + ', '.join(d) + ']\n')
helpers.insert_update(targetRows, row[0], " - [" + ", ".join(d) + "]\n")
return rows
class Converter:
def __init__(self, language, dbDir):
self.language = language
self.dbDir = dbDir
self.codeQlRoot = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip()
self.codeQlRoot = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("utf-8")
.strip()
)
self.extDir = os.path.join(self.codeQlRoot, f"{self.language}/ql/lib/ext/")
self.dirname = "modelconverter"
self.modelFileExtension = ".model.yml"
self.workDir = tempfile.mkdtemp()
def runQuery(self, query):
print('########## Querying: ', query)
queryFile = os.path.join(self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query)
print("########## Querying: ", query)
queryFile = os.path.join(
self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query
)
resultBqrs = os.path.join(self.workDir, "out.bqrs")
helpers.run_cmd(['codeql', 'query', 'run', queryFile, '--database', self.dbDir, '--output', resultBqrs], "Failed to generate " + query)
helpers.run_cmd(
[
"codeql",
"query",
"run",
queryFile,
"--database",
self.dbDir,
"--output",
resultBqrs,
],
"Failed to generate " + query,
)
return helpers.readData(self.workDir, resultBqrs)
def asAddsTo(self, rows, predicate):
extensions = [{ }, { }]
extensions = [{}, {}]
for i in range(2):
for key in rows[i]:
extensions[i][key] = helpers.addsToTemplate.format(f"codeql/{self.language}-all", predicate, rows[i][key])
return extensions
extensions[i][key] = helpers.addsToTemplate.format(
f"codeql/{self.language}-all", predicate, rows[i][key]
)
return extensions
def getAddsTo(self, query, predicate):
data = self.runQuery(query)
rows = parseData(data)
return self.asAddsTo(rows, predicate)
def makeContent(self):
summaries = self.getAddsTo("ExtractSummaries.ql", helpers.summaryModelPredicate)
sources = self.getAddsTo("ExtractSources.ql", helpers.sourceModelPredicate)
sinks = self.getAddsTo("ExtractSinks.ql", helpers.sinkModelPredicate)
neutrals = self.getAddsTo("ExtractNeutrals.ql", helpers.neutralModelPredicate)
return [helpers.merge(sources[0], sinks[0], summaries[0], neutrals[0]), helpers.merge(sources[1], sinks[1], summaries[1], neutrals[1])]
return [
helpers.merge(sources[0], sinks[0], summaries[0], neutrals[0]),
helpers.merge(sources[1], sinks[1], summaries[1], neutrals[1]),
]
def save(self, extensions):
# Create directory if it doesn't exist
@@ -77,9 +98,11 @@ class Converter:
for entry in extensions[0]:
with open(self.extDir + "/" + entry + self.modelFileExtension, "w") as f:
f.write(extensionTemplate.format(extensions[0][entry]))
for entry in extensions[1]:
with open(self.extDir + "/generated/" + entry + self.modelFileExtension, "w") as f:
with open(
self.extDir + "/generated/" + entry + self.modelFileExtension, "w"
) as f:
f.write(extensionTemplate.format(extensions[1][entry]))
def run(self):

View File

@@ -14,37 +14,53 @@ addsToTemplate = """ - addsTo:
data:
{2}"""
def remove_dir(dirName):
if os.path.isdir(dirName):
shutil.rmtree(dirName)
print("Removed directory:", dirName)
def run_cmd(cmd, msg="Failed to run command"):
print('Running ' + ' '.join(cmd))
print("Running " + " ".join(cmd))
if subprocess.check_call(cmd):
print(msg)
exit(1)
def readData(workDir, bqrsFile):
generatedJson = os.path.join(workDir, "out.json")
print('Decoding BQRS to JSON.')
run_cmd(['codeql', 'bqrs', 'decode', bqrsFile, '--output', generatedJson, '--format=json'], "Failed to decode BQRS.")
print("Decoding BQRS to JSON.")
run_cmd(
[
"codeql",
"bqrs",
"decode",
bqrsFile,
"--output",
generatedJson,
"--format=json",
],
"Failed to decode BQRS.",
)
with open(generatedJson) as f:
results = json.load(f)
try:
return results['#select']['tuples']
return results["#select"]["tuples"]
except KeyError:
print('Unexpected JSON output - no tuples found')
print("Unexpected JSON output - no tuples found")
exit(1)
def insert_update(rows, key, value):
if key in rows:
rows[key] += value
else:
rows[key] = value
def merge(*dicts):
merged = {}
for d in dicts: