MaD generator: apply black formatting to all sources

This commit is contained in:
Paolo Tranquilli
2025-06-13 08:47:07 +02:00
parent 1a36374718
commit 5df292c286
3 changed files with 63 additions and 24 deletions

View File

@@ -18,7 +18,7 @@ repos:
rev: 25.1.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
files: ^(misc/codegen/.*|misc/scripts/models-as-data/bulk_generate_mad)\.py$ files: ^(misc/codegen/.*|misc/scripts/models-as-data/.*)\.py$
- repo: local - repo: local
hooks: hooks:

View File

@@ -7,65 +7,86 @@ import subprocess
import sys import sys
import tempfile import tempfile
def quote_if_needed(v): def quote_if_needed(v):
# string columns # string columns
if type(v) is str: if type(v) is str:
return "\"" + v + "\"" return '"' + v + '"'
# bool column # bool column
return str(v) return str(v)
def parseData(data): def parseData(data):
rows = [{ }, { }] rows = [{}, {}]
for row in data: for row in data:
d = map(quote_if_needed, row) d = map(quote_if_needed, row)
provenance = row[-1] provenance = row[-1]
targetRows = rows[1] if provenance.endswith("generated") else rows[0] targetRows = rows[1] if provenance.endswith("generated") else rows[0]
helpers.insert_update(targetRows, row[0], " - [" + ', '.join(d) + ']\n') helpers.insert_update(targetRows, row[0], " - [" + ", ".join(d) + "]\n")
return rows return rows
class Converter: class Converter:
def __init__(self, language, dbDir): def __init__(self, language, dbDir):
self.language = language self.language = language
self.dbDir = dbDir self.dbDir = dbDir
self.codeQlRoot = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip() self.codeQlRoot = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("utf-8")
.strip()
)
self.extDir = os.path.join(self.codeQlRoot, f"{self.language}/ql/lib/ext/") self.extDir = os.path.join(self.codeQlRoot, f"{self.language}/ql/lib/ext/")
self.dirname = "modelconverter" self.dirname = "modelconverter"
self.modelFileExtension = ".model.yml" self.modelFileExtension = ".model.yml"
self.workDir = tempfile.mkdtemp() self.workDir = tempfile.mkdtemp()
def runQuery(self, query): def runQuery(self, query):
print('########## Querying: ', query) print("########## Querying: ", query)
queryFile = os.path.join(self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query) queryFile = os.path.join(
self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query
)
resultBqrs = os.path.join(self.workDir, "out.bqrs") resultBqrs = os.path.join(self.workDir, "out.bqrs")
helpers.run_cmd(['codeql', 'query', 'run', queryFile, '--database', self.dbDir, '--output', resultBqrs], "Failed to generate " + query) helpers.run_cmd(
[
"codeql",
"query",
"run",
queryFile,
"--database",
self.dbDir,
"--output",
resultBqrs,
],
"Failed to generate " + query,
)
return helpers.readData(self.workDir, resultBqrs) return helpers.readData(self.workDir, resultBqrs)
def asAddsTo(self, rows, predicate): def asAddsTo(self, rows, predicate):
extensions = [{ }, { }] extensions = [{}, {}]
for i in range(2): for i in range(2):
for key in rows[i]: for key in rows[i]:
extensions[i][key] = helpers.addsToTemplate.format(f"codeql/{self.language}-all", predicate, rows[i][key]) extensions[i][key] = helpers.addsToTemplate.format(
f"codeql/{self.language}-all", predicate, rows[i][key]
return extensions )
return extensions
def getAddsTo(self, query, predicate): def getAddsTo(self, query, predicate):
data = self.runQuery(query) data = self.runQuery(query)
rows = parseData(data) rows = parseData(data)
return self.asAddsTo(rows, predicate) return self.asAddsTo(rows, predicate)
def makeContent(self): def makeContent(self):
summaries = self.getAddsTo("ExtractSummaries.ql", helpers.summaryModelPredicate) summaries = self.getAddsTo("ExtractSummaries.ql", helpers.summaryModelPredicate)
sources = self.getAddsTo("ExtractSources.ql", helpers.sourceModelPredicate) sources = self.getAddsTo("ExtractSources.ql", helpers.sourceModelPredicate)
sinks = self.getAddsTo("ExtractSinks.ql", helpers.sinkModelPredicate) sinks = self.getAddsTo("ExtractSinks.ql", helpers.sinkModelPredicate)
neutrals = self.getAddsTo("ExtractNeutrals.ql", helpers.neutralModelPredicate) neutrals = self.getAddsTo("ExtractNeutrals.ql", helpers.neutralModelPredicate)
return [helpers.merge(sources[0], sinks[0], summaries[0], neutrals[0]), helpers.merge(sources[1], sinks[1], summaries[1], neutrals[1])] return [
helpers.merge(sources[0], sinks[0], summaries[0], neutrals[0]),
helpers.merge(sources[1], sinks[1], summaries[1], neutrals[1]),
]
def save(self, extensions): def save(self, extensions):
# Create directory if it doesn't exist # Create directory if it doesn't exist
@@ -77,9 +98,11 @@ class Converter:
for entry in extensions[0]: for entry in extensions[0]:
with open(self.extDir + "/" + entry + self.modelFileExtension, "w") as f: with open(self.extDir + "/" + entry + self.modelFileExtension, "w") as f:
f.write(extensionTemplate.format(extensions[0][entry])) f.write(extensionTemplate.format(extensions[0][entry]))
for entry in extensions[1]: for entry in extensions[1]:
with open(self.extDir + "/generated/" + entry + self.modelFileExtension, "w") as f: with open(
self.extDir + "/generated/" + entry + self.modelFileExtension, "w"
) as f:
f.write(extensionTemplate.format(extensions[1][entry])) f.write(extensionTemplate.format(extensions[1][entry]))
def run(self): def run(self):

View File

@@ -14,37 +14,53 @@ addsToTemplate = """ - addsTo:
data: data:
{2}""" {2}"""
def remove_dir(dirName): def remove_dir(dirName):
if os.path.isdir(dirName): if os.path.isdir(dirName):
shutil.rmtree(dirName) shutil.rmtree(dirName)
print("Removed directory:", dirName) print("Removed directory:", dirName)
def run_cmd(cmd, msg="Failed to run command"): def run_cmd(cmd, msg="Failed to run command"):
print('Running ' + ' '.join(cmd)) print("Running " + " ".join(cmd))
if subprocess.check_call(cmd): if subprocess.check_call(cmd):
print(msg) print(msg)
exit(1) exit(1)
def readData(workDir, bqrsFile): def readData(workDir, bqrsFile):
generatedJson = os.path.join(workDir, "out.json") generatedJson = os.path.join(workDir, "out.json")
print('Decoding BQRS to JSON.') print("Decoding BQRS to JSON.")
run_cmd(['codeql', 'bqrs', 'decode', bqrsFile, '--output', generatedJson, '--format=json'], "Failed to decode BQRS.") run_cmd(
[
"codeql",
"bqrs",
"decode",
bqrsFile,
"--output",
generatedJson,
"--format=json",
],
"Failed to decode BQRS.",
)
with open(generatedJson) as f: with open(generatedJson) as f:
results = json.load(f) results = json.load(f)
try: try:
return results['#select']['tuples'] return results["#select"]["tuples"]
except KeyError: except KeyError:
print('Unexpected JSON output - no tuples found') print("Unexpected JSON output - no tuples found")
exit(1) exit(1)
def insert_update(rows, key, value): def insert_update(rows, key, value):
if key in rows: if key in rows:
rows[key] += value rows[key] += value
else: else:
rows[key] = value rows[key] = value
def merge(*dicts): def merge(*dicts):
merged = {} merged = {}
for d in dicts: for d in dicts: