mirror of
https://github.com/github/codeql.git
synced 2026-05-04 21:25:44 +02:00
Rust: make File usable in codegen
This commit is contained in:
@@ -110,7 +110,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
|
||||
|
||||
def get_declarations(data: schema.Schema):
|
||||
add_or_none_except = data.root_class.name if data.null else None
|
||||
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)]
|
||||
declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls,
|
||||
data.classes, add_or_none_except)]
|
||||
if data.null:
|
||||
property_classes = {
|
||||
prop.type for cls in data.classes.values() for prop in cls.properties
|
||||
|
||||
@@ -104,8 +104,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
|
||||
return f"{prop_name} of this {class_name}"
|
||||
|
||||
|
||||
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.Class],
|
||||
def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> bool:
|
||||
if t in lookup:
|
||||
match lookup[t]:
|
||||
case schema.Class() as cls:
|
||||
return "ql_hideable" in cls.pragmas
|
||||
return False
|
||||
|
||||
|
||||
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase],
|
||||
prev_child: str = "") -> ql.Property:
|
||||
|
||||
args = dict(
|
||||
type=prop.type if not prop.is_predicate else "predicate",
|
||||
qltest_skip="qltest_skip" in prop.pragmas,
|
||||
@@ -115,7 +124,8 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
|
||||
is_unordered=prop.is_unordered,
|
||||
description=prop.description,
|
||||
synth=bool(cls.synth) or prop.synth,
|
||||
type_is_hideable="ql_hideable" in lookup[prop.type].pragmas if prop.type in lookup else False,
|
||||
type_is_hideable=_type_is_hideable(prop.type, lookup),
|
||||
type_is_codegen_class=prop.type in lookup and not lookup[prop.type].imported,
|
||||
internal="ql_internal" in prop.pragmas,
|
||||
)
|
||||
ql_name = prop.pragmas.get("ql_name", prop.name)
|
||||
@@ -154,7 +164,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
|
||||
return ql.Property(**args)
|
||||
|
||||
|
||||
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> ql.Class:
|
||||
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class:
|
||||
if "ql_name" in cls.pragmas:
|
||||
raise Error("ql_name is not supported yet for classes, only for properties")
|
||||
prev_child = ""
|
||||
@@ -391,14 +401,15 @@ def generate(opts, renderer):
|
||||
|
||||
data = schemaloader.load_file(input)
|
||||
|
||||
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
|
||||
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported}
|
||||
if not classes:
|
||||
raise NoClasses
|
||||
root = next(iter(classes.values()))
|
||||
if root.has_children:
|
||||
raise RootElementHasChildren(root)
|
||||
|
||||
imports = {}
|
||||
pre_imports = {n: cls.module for n, cls in data.classes.items() if cls.imported}
|
||||
imports = dict(pre_imports)
|
||||
imports_impl = {}
|
||||
classes_used_by = {}
|
||||
cfg_classes = []
|
||||
@@ -410,7 +421,7 @@ def generate(opts, renderer):
|
||||
force=opts.force) as renderer:
|
||||
|
||||
db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
|
||||
renderer.render(ql.DbClasses(db_classes), out / "Raw.qll")
|
||||
renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll")
|
||||
|
||||
classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
|
||||
for c in classes_by_dir_and_name:
|
||||
@@ -439,6 +450,8 @@ def generate(opts, renderer):
|
||||
renderer.render(cfg_classes_val, cfg_qll)
|
||||
|
||||
for c in data.classes.values():
|
||||
if c.imported:
|
||||
continue
|
||||
path = _get_path(c)
|
||||
path_impl = _get_path_impl(c)
|
||||
stub_file = stub_out / path_impl
|
||||
@@ -457,7 +470,7 @@ def generate(opts, renderer):
|
||||
renderer.render(class_public, class_public_file)
|
||||
|
||||
# for example path/to/elements -> path/to/elements.qll
|
||||
renderer.render(ql.ImportList([i for name, i in imports.items() if not classes[name].internal]),
|
||||
renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]),
|
||||
include_file)
|
||||
|
||||
elements_module = get_import(include_file, opts.root_dir)
|
||||
@@ -465,12 +478,15 @@ def generate(opts, renderer):
|
||||
renderer.render(
|
||||
ql.GetParentImplementation(
|
||||
classes=list(classes.values()),
|
||||
imports=[elements_module] + [i for name, i in imports.items() if classes[name].internal],
|
||||
imports=[elements_module] + [i for name,
|
||||
i in imports.items() if name in classes and classes[name].internal],
|
||||
),
|
||||
out / 'ParentChild.qll')
|
||||
|
||||
if test_out:
|
||||
for c in data.classes.values():
|
||||
if c.imported:
|
||||
continue
|
||||
if should_skip_qltest(c, data.classes):
|
||||
continue
|
||||
test_with_name = c.pragmas.get("qltest_test_with")
|
||||
@@ -500,7 +516,8 @@ def generate(opts, renderer):
|
||||
constructor_imports = []
|
||||
synth_constructor_imports = []
|
||||
stubs = {}
|
||||
for cls in sorted(data.classes.values(), key=lambda cls: (cls.group, cls.name)):
|
||||
for cls in sorted((cls for cls in data.classes.values() if not cls.imported),
|
||||
key=lambda cls: (cls.group, cls.name)):
|
||||
synth_type = get_ql_synth_class(cls)
|
||||
if synth_type.is_final:
|
||||
final_synth_types.append(synth_type)
|
||||
|
||||
@@ -49,7 +49,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
|
||||
|
||||
|
||||
def _get_properties(
|
||||
cls: schema.Class, lookup: dict[str, schema.Class],
|
||||
cls: schema.Class, lookup: dict[str, schema.ClassBase],
|
||||
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
|
||||
for b in cls.bases:
|
||||
yield from _get_properties(lookup[b], lookup)
|
||||
@@ -58,12 +58,14 @@ def _get_properties(
|
||||
|
||||
|
||||
def _get_ancestors(
|
||||
cls: schema.Class, lookup: dict[str, schema.Class]
|
||||
cls: schema.Class, lookup: dict[str, schema.ClassBase]
|
||||
) -> typing.Iterable[schema.Class]:
|
||||
for b in cls.bases:
|
||||
base = lookup[b]
|
||||
yield base
|
||||
yield from _get_ancestors(base, lookup)
|
||||
if not base.imported:
|
||||
base = typing.cast(schema.Class, base)
|
||||
yield base
|
||||
yield from _get_ancestors(base, lookup)
|
||||
|
||||
|
||||
class Processor:
|
||||
@@ -71,7 +73,7 @@ class Processor:
|
||||
self._classmap = data.classes
|
||||
|
||||
def _get_class(self, name: str) -> rust.Class:
|
||||
cls = self._classmap[name]
|
||||
cls = typing.cast(schema.Class, self._classmap[name])
|
||||
properties = [
|
||||
(c, p)
|
||||
for c, p in _get_properties(cls, self._classmap)
|
||||
@@ -101,8 +103,10 @@ class Processor:
|
||||
def get_classes(self):
|
||||
ret = {"": []}
|
||||
for k, cls in self._classmap.items():
|
||||
if not cls.synth:
|
||||
if not cls.imported and not cls.synth:
|
||||
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
|
||||
elif cls.imported:
|
||||
ret[""].append(rust.Class(name=cls.name))
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ def generate(opts, renderer):
|
||||
registry=opts.ql_test_output / ".generated_tests.list",
|
||||
force=opts.force) as renderer:
|
||||
for cls in schema.classes.values():
|
||||
if cls.imported:
|
||||
continue
|
||||
if (qlgen.should_skip_qltest(cls, schema.classes) or
|
||||
"rust_skip_doc_test" in cls.pragmas):
|
||||
continue
|
||||
|
||||
@@ -44,6 +44,7 @@ class Property:
|
||||
doc_plural: Optional[str] = None
|
||||
synth: bool = False
|
||||
type_is_hideable: bool = False
|
||||
type_is_codegen_class: bool = False
|
||||
internal: bool = False
|
||||
cfg: bool = False
|
||||
|
||||
@@ -66,10 +67,6 @@ class Property:
|
||||
article = "An" if self.singular[0] in "AEIO" else "A"
|
||||
return f"get{article}{self.singular}"
|
||||
|
||||
@property
|
||||
def type_is_class(self):
|
||||
return bool(self.type) and self.type[0].isupper()
|
||||
|
||||
@property
|
||||
def is_repeated(self):
|
||||
return bool(self.plural)
|
||||
@@ -191,6 +188,7 @@ class DbClasses:
|
||||
template: ClassVar = 'ql_db'
|
||||
|
||||
classes: List[Class] = field(default_factory=list)
|
||||
imports: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,7 +3,7 @@ import abc
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Set, Union, Dict, Optional
|
||||
from typing import List, Set, Union, Dict, Optional, FrozenSet
|
||||
from enum import Enum, auto
|
||||
import functools
|
||||
|
||||
@@ -87,8 +87,22 @@ class SynthInfo:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Class:
|
||||
class ClassBase:
|
||||
imported: typing.ClassVar[bool]
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportedClass(ClassBase):
|
||||
imported: typing.ClassVar[bool] = True
|
||||
|
||||
module: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Class(ClassBase):
|
||||
imported: typing.ClassVar[bool] = False
|
||||
|
||||
bases: List[str] = field(default_factory=list)
|
||||
derived: Set[str] = field(default_factory=set)
|
||||
properties: List[Property] = field(default_factory=list)
|
||||
@@ -133,7 +147,7 @@ class Class:
|
||||
|
||||
@dataclass
|
||||
class Schema:
|
||||
classes: Dict[str, Class] = field(default_factory=dict)
|
||||
classes: Dict[str, ClassBase] = field(default_factory=dict)
|
||||
includes: List[str] = field(default_factory=list)
|
||||
null: Optional[str] = None
|
||||
|
||||
@@ -155,7 +169,7 @@ class Schema:
|
||||
|
||||
predicate_marker = object()
|
||||
|
||||
TypeRef = Union[type, str]
|
||||
TypeRef = type | str | ImportedClass
|
||||
|
||||
|
||||
def get_type_name(arg: TypeRef) -> str:
|
||||
@@ -164,6 +178,8 @@ def get_type_name(arg: TypeRef) -> str:
|
||||
return arg.__name__
|
||||
case str():
|
||||
return arg
|
||||
case ImportedClass():
|
||||
return arg.name
|
||||
case _:
|
||||
raise Error(f"Not a schema type or string ({arg})")
|
||||
|
||||
@@ -172,9 +188,9 @@ def _make_property(arg: object) -> Property:
|
||||
match arg:
|
||||
case _ if arg is predicate_marker:
|
||||
return PredicateProperty()
|
||||
case str() | type():
|
||||
case (str() | type() | ImportedClass()) as arg:
|
||||
return SingleProperty(type=get_type_name(arg))
|
||||
case Property():
|
||||
case Property() as arg:
|
||||
return arg
|
||||
case _:
|
||||
raise Error(f"Illegal property specifier {arg}")
|
||||
|
||||
@@ -8,8 +8,6 @@ from misc.codegen.lib import schema as _schema
|
||||
import inspect as _inspect
|
||||
from dataclasses import dataclass as _dataclass
|
||||
|
||||
from misc.codegen.lib.schema import Property
|
||||
|
||||
_set = set
|
||||
|
||||
|
||||
@@ -69,6 +67,9 @@ def include(source: str):
|
||||
_inspect.currentframe().f_back.f_locals.setdefault("includes", []).append(source)
|
||||
|
||||
|
||||
imported = _schema.ImportedClass
|
||||
|
||||
|
||||
@_dataclass
|
||||
class _Namespace:
|
||||
""" simple namespacing mechanism """
|
||||
@@ -264,7 +265,7 @@ class _PropertyModifierList(_schema.PropertyModifier):
|
||||
def __or__(self, other: _schema.PropertyModifier):
|
||||
return _PropertyModifierList(self._mods + (other,))
|
||||
|
||||
def modify(self, prop: Property):
|
||||
def modify(self, prop: _schema.Property):
|
||||
for m in self._mods:
|
||||
m.modify(prop)
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ def _check_test_with(classes: typing.Dict[str, schema.Class]):
|
||||
def load(m: types.ModuleType) -> schema.Schema:
|
||||
includes = set()
|
||||
classes = {}
|
||||
imported_classes = {}
|
||||
known = {"int", "string", "boolean"}
|
||||
known.update(n for n in m.__dict__ if not n.startswith("__"))
|
||||
import misc.codegen.lib.schemadefs as defs
|
||||
@@ -146,6 +147,9 @@ def load(m: types.ModuleType) -> schema.Schema:
|
||||
continue
|
||||
if isinstance(data, types.ModuleType):
|
||||
continue
|
||||
if isinstance(data, schema.ImportedClass):
|
||||
imported_classes[name] = data
|
||||
continue
|
||||
cls = _get_class(data)
|
||||
if classes and not cls.bases:
|
||||
raise schema.Error(
|
||||
@@ -162,7 +166,7 @@ def load(m: types.ModuleType) -> schema.Schema:
|
||||
_fill_hideable_information(classes)
|
||||
_check_test_with(classes)
|
||||
|
||||
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
|
||||
return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null)
|
||||
|
||||
|
||||
def load_file(path: pathlib.Path) -> schema.Schema:
|
||||
|
||||
@@ -113,7 +113,7 @@ module Generated {
|
||||
*/
|
||||
{{type}} {{getter}}({{#is_indexed}}int index{{/is_indexed}}) {
|
||||
{{^synth}}
|
||||
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_class}}Synth::convert{{type}}FromRaw({{/type_is_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_class}}){{/type_is_class}}
|
||||
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_codegen_class}}Synth::convert{{type}}FromRaw({{/type_is_codegen_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_codegen_class}}){{/type_is_codegen_class}}
|
||||
{{/synth}}
|
||||
{{#synth}}
|
||||
none()
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
* This module holds thin fully generated class definitions around DB entities.
|
||||
*/
|
||||
module Raw {
|
||||
{{#imports}}
|
||||
private import {{.}}
|
||||
{{/imports}}
|
||||
|
||||
{{#classes}}
|
||||
/**
|
||||
* INTERNAL: Do not use.
|
||||
|
||||
@@ -12,21 +12,6 @@ def test_property_has_first_table_param_marked():
|
||||
assert [p.param for p in prop.tableparams] == tableparams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,expected", [
|
||||
("Foo", True),
|
||||
("Bar", True),
|
||||
("foo", False),
|
||||
("bar", False),
|
||||
(None, False),
|
||||
])
|
||||
def test_property_is_a_class(type, expected):
|
||||
tableparams = ["a", "result", "b"]
|
||||
expected_tableparams = ["a", "result" if expected else "result", "b"]
|
||||
prop = ql.Property("Prop", type, tableparams=tableparams)
|
||||
assert prop.type_is_class is expected
|
||||
assert [p.param for p in prop.tableparams] == expected_tableparams
|
||||
|
||||
|
||||
indefinite_getters = [
|
||||
("Argument", "getAnArgument"),
|
||||
("Element", "getAnElement"),
|
||||
|
||||
@@ -448,7 +448,8 @@ def test_single_class_property(generate_classes, is_child, prev_child):
|
||||
ql.Property(singular="Foo", type="Bar", tablename="my_objects",
|
||||
tableparams=[
|
||||
"this", "result"],
|
||||
prev_child=prev_child, doc="foo of this my object"),
|
||||
prev_child=prev_child, doc="foo of this my object",
|
||||
type_is_codegen_class=True),
|
||||
],
|
||||
)),
|
||||
"Bar.qll": (a_ql_class_public(name="Bar"), a_ql_stub(name="Bar"), a_ql_class(name="Bar", final=True, imports=[stub_import_prefix + "Bar"])),
|
||||
@@ -1006,6 +1007,7 @@ def test_hideable_property(generate_classes):
|
||||
final=True, properties=[
|
||||
ql.Property(singular="X", type="MyObject", tablename="others",
|
||||
type_is_hideable=True,
|
||||
type_is_codegen_class=True,
|
||||
tableparams=["this", "result"], doc="x of this other"),
|
||||
])),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user