mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
Swift: add unit tests to code generation
Tests can be run with ``` bazel test //swift/codegen:tests ``` Coverage can be checked installing `pytest-cov` and running ``` pytest --cov=swift/codegen swift/codegen/test ```
This commit is contained in:
@@ -1,129 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, ClassVar
|
||||
|
||||
import inflection
|
||||
|
||||
from lib import schema, paths, generator
|
||||
from swift.codegen.lib import schema, paths, generator, ql
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlParam:
|
||||
param: str
|
||||
type: str = None
|
||||
first: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlProperty:
|
||||
singular: str
|
||||
type: str
|
||||
tablename: str
|
||||
tableparams: List[QlParam]
|
||||
plural: str = None
|
||||
params: List[QlParam] = field(default_factory=list)
|
||||
first: bool = False
|
||||
local_var: str = "x"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.params:
|
||||
self.params[0].first = True
|
||||
while self.local_var in (p.param for p in self.params):
|
||||
self.local_var += "_"
|
||||
assert self.tableparams
|
||||
if self.type_is_class:
|
||||
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
|
||||
self.tableparams = [QlParam(x) for x in self.tableparams]
|
||||
self.tableparams[0].first = True
|
||||
|
||||
@property
|
||||
def indefinite_article(self):
|
||||
if self.plural:
|
||||
return "An" if self.singular[0] in "AEIO" else "A"
|
||||
|
||||
@property
|
||||
def type_is_class(self):
|
||||
return self.type[0].isupper()
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlClass:
|
||||
template: ClassVar = 'ql_class'
|
||||
|
||||
name: str
|
||||
bases: List[str]
|
||||
final: bool
|
||||
properties: List[QlProperty]
|
||||
dir: pathlib.Path
|
||||
imports: List[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
self.bases = sorted(self.bases)
|
||||
if self.properties:
|
||||
self.properties[0].first = True
|
||||
|
||||
@property
|
||||
def db_id(self):
|
||||
return "@" + inflection.underscore(self.name)
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return not self.bases
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self.dir / self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlStub:
|
||||
template: ClassVar = 'ql_stub'
|
||||
|
||||
name: str
|
||||
base_import: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlImportList:
|
||||
template: ClassVar = 'ql_imports'
|
||||
|
||||
imports: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def get_ql_property(cls: schema.Class, prop: schema.Property):
|
||||
if prop.is_single:
|
||||
return QlProperty(
|
||||
return ql.QlProperty(
|
||||
singular=inflection.camelize(prop.name),
|
||||
type=prop.type,
|
||||
tablename=inflection.tableize(cls.name),
|
||||
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
|
||||
)
|
||||
elif prop.is_optional:
|
||||
return QlProperty(
|
||||
return ql.QlProperty(
|
||||
singular=inflection.camelize(prop.name),
|
||||
type=prop.type,
|
||||
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
|
||||
tableparams=["this", "result"],
|
||||
)
|
||||
elif prop.is_repeated:
|
||||
return QlProperty(
|
||||
return ql.QlProperty(
|
||||
singular=inflection.singularize(inflection.camelize(prop.name)),
|
||||
plural=inflection.pluralize(inflection.camelize(prop.name)),
|
||||
type=prop.type,
|
||||
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
|
||||
tableparams=["this", "index", "result"],
|
||||
params=[QlParam("index", type="int")],
|
||||
params=[ql.QlParam("index", type="int")],
|
||||
)
|
||||
|
||||
|
||||
def get_ql_class(cls: schema.Class):
|
||||
return QlClass(
|
||||
return ql.QlClass(
|
||||
name=cls.name,
|
||||
bases=cls.bases,
|
||||
final=not cls.derived,
|
||||
@@ -137,7 +51,7 @@ def get_import(file):
|
||||
return str(stem).replace("/", ".")
|
||||
|
||||
|
||||
def get_types_used_by(cls: QlClass):
|
||||
def get_types_used_by(cls: ql.QlClass):
|
||||
for b in cls.bases:
|
||||
yield b
|
||||
for p in cls.properties:
|
||||
@@ -146,7 +60,7 @@ def get_types_used_by(cls: QlClass):
|
||||
yield param.type
|
||||
|
||||
|
||||
def get_classes_used_by(cls: QlClass):
|
||||
def get_classes_used_by(cls: ql.QlClass):
|
||||
return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper()))
|
||||
|
||||
|
||||
@@ -164,34 +78,32 @@ def format(codeql, files):
|
||||
|
||||
|
||||
def generate(opts, renderer):
|
||||
input = opts.schema.resolve()
|
||||
out = opts.ql_output.resolve()
|
||||
stub_out = opts.ql_stub_output.resolve()
|
||||
input = opts.schema
|
||||
out = opts.ql_output
|
||||
stub_out = opts.ql_stub_output
|
||||
existing = {q for q in out.rglob("*.qll")}
|
||||
existing |= {q for q in stub_out.rglob("*.qll") if is_generated(q)}
|
||||
|
||||
with open(input) as src:
|
||||
data = schema.load(src)
|
||||
data = schema.load(input)
|
||||
|
||||
classes = [get_ql_class(cls) for cls in data.classes.values()]
|
||||
classes = [get_ql_class(cls) for cls in data.classes]
|
||||
imports = {}
|
||||
|
||||
for c in classes:
|
||||
imports[c.name] = get_import(stub_out / c.path)
|
||||
|
||||
for c in classes:
|
||||
assert not c.final or c.bases, c.name
|
||||
qll = (out / c.path).with_suffix(".qll")
|
||||
c.imports = [imports[t] for t in get_classes_used_by(c)]
|
||||
renderer.render(c, qll)
|
||||
stub_file = (stub_out / c.path).with_suffix(".qll")
|
||||
if not stub_file.is_file() or is_generated(stub_file):
|
||||
stub = QlStub(name=c.name, base_import=get_import(qll))
|
||||
stub = ql.QlStub(name=c.name, base_import=get_import(qll))
|
||||
renderer.render(stub, stub_file)
|
||||
|
||||
# for example path/to/syntax/generated -> path/to/syntax.qll
|
||||
include_file = stub_out.with_suffix(".qll")
|
||||
all_imports = QlImportList(v for _, v in sorted(imports.items()))
|
||||
all_imports = ql.QlImportList([v for _, v in sorted(imports.items())])
|
||||
renderer.render(all_imports, include_file)
|
||||
|
||||
renderer.cleanup(existing)
|
||||
|
||||
Reference in New Issue
Block a user