Merge pull request #11131 from github/redsun82/swift-incomplete-ast

Swift: deal with incomplete ASTs
This commit is contained in:
Paolo Tranquilli
2022-11-08 14:01:58 +01:00
committed by GitHub
36 changed files with 1025 additions and 224 deletions

View File

@@ -22,6 +22,10 @@ def codeql_workspace(repository_name = "codeql"):
_swift_prebuilt_version,
repo_arch,
),
patches = [
"@%s//swift/third_party/swift-llvm-support:patches/remove_getFallthrougDest_assert.patch" % repository_name,
],
patch_args = ["-p1"],
build_file = "@%s//swift/third_party/swift-llvm-support:BUILD.swift-prebuilt.bazel" % repository_name,
sha256 = sha256,
)

View File

@@ -12,15 +12,14 @@ Each class in the schema gets a corresponding `struct` in `TrapClasses.h`, where
"""
import functools
import pathlib
from typing import Dict
import typing
import inflection
from swift.codegen.lib import cpp, schema
def _get_type(t: str) -> str:
def _get_type(t: str, add_or_none_except: typing.Optional[str] = None) -> str:
if t is None:
# this is a predicate
return "bool"
@@ -29,11 +28,15 @@ def _get_type(t: str) -> str:
if t == "boolean":
return "bool"
if t[0].isupper():
return f"TrapLabel<{t}Tag>"
if add_or_none_except is not None and t != add_or_none_except:
suffix = "OrNone"
else:
suffix = ""
return f"TrapLabel<{t}{suffix}Tag>"
return t
def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
trap_name = None
if not p.is_single:
trap_name = inflection.camelize(f"{cls.name}_{p.name}")
@@ -41,7 +44,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
trap_name = inflection.pluralize(trap_name)
args = dict(
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
type=_get_type(p.type),
base_type=_get_type(p.type, add_or_none_except),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
@@ -52,8 +55,13 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
class Processor:
def __init__(self, data: Dict[str, schema.Class]):
self._classmap = data
def __init__(self, data: schema.Schema):
self._classmap = data.classes
if data.null:
root_type = next(iter(data.classes))
self._add_or_none_except = root_type
else:
self._add_or_none_except = None
@functools.lru_cache(maxsize=None)
def _get_class(self, name: str) -> cpp.Class:
@@ -64,7 +72,10 @@ class Processor:
return cpp.Class(
name=name,
bases=[self._get_class(b) for b in cls.bases],
fields=[_get_field(cls, p) for p in cls.properties if "cpp_skip" not in p.pragmas],
fields=[
_get_field(cls, p, self._add_or_none_except)
for p in cls.properties if "cpp_skip" not in p.pragmas
],
final=not cls.derived,
trap_name=trap_name,
)
@@ -78,8 +89,8 @@ class Processor:
def generate(opts, renderer):
assert opts.cpp_output
processor = Processor(schema.load_file(opts.schema).classes)
processor = Processor(schema.load_file(opts.schema))
out = opts.cpp_output
for dir, classes in processor.get_classes().items():
renderer.render(cpp.ClassList(classes, opts.schema,
include_parent=bool(dir)), out / dir / "TrapClasses")
include_parent=bool(dir)), out / dir / "TrapClasses")

View File

@@ -13,6 +13,7 @@ Moreover:
as columns
The type hierarchy will be translated to corresponding `union` declarations.
"""
import typing
import inflection
@@ -23,14 +24,21 @@ from typing import Set, List
log = logging.getLogger(__name__)
def dbtype(typename):
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes """
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
For class types, appends an underscore followed by `null` if provided
"""
if typename[0].isupper():
return "@" + inflection.underscore(typename)
underscored = inflection.underscore(typename)
if add_or_none_except is not None and typename != add_or_none_except:
suffix = "_or_none"
else:
suffix = ""
return f"@{underscored}{suffix}"
return typename
def cls_to_dbscheme(cls: schema.Class):
def cls_to_dbscheme(cls: schema.Class, add_or_none_except: typing.Optional[str] = None):
""" Yield all dbscheme entities needed to model class `cls` """
if cls.derived:
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived))
@@ -48,7 +56,7 @@ def cls_to_dbscheme(cls: schema.Class):
columns=[
Column("id", type=dbtype(cls.name), binding=binding),
] + [
Column(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
Column(f.name, dbtype(f.type, add_or_none_except)) for f in cls.properties if f.is_single
],
dir=dir,
)
@@ -61,7 +69,7 @@ def cls_to_dbscheme(cls: schema.Class):
columns=[
Column("id", type=dbtype(cls.name)),
Column("index", type="int"),
Column(inflection.singularize(f.name), dbtype(f.type)),
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
],
dir=dir,
)
@@ -71,7 +79,7 @@ def cls_to_dbscheme(cls: schema.Class):
name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(f.name, dbtype(f.type)),
Column(f.name, dbtype(f.type, add_or_none_except)),
],
dir=dir,
)
@@ -87,7 +95,17 @@ def cls_to_dbscheme(cls: schema.Class):
def get_declarations(data: schema.Schema):
return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)]
add_or_none_except = data.root_class.name if data.null else None
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, add_or_none_except)]
if data.null:
property_classes = {
prop.type for cls in data.classes.values() for prop in cls.properties
if cls.name != data.null and prop.type and prop.type[0].isupper()
}
declarations += [
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
]
return declarations
def get_includes(data: schema.Schema, include_dir: pathlib.Path, swift_dir: pathlib.Path):

View File

@@ -147,7 +147,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, prev_child: str =
return ql.Property(**args)
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
def get_ql_class(cls: schema.Class):
pragmas = {k: True for k in cls.pragmas if k.startswith("ql")}
prev_child = ""
properties = []
@@ -314,7 +314,7 @@ def generate(opts, renderer):
data = schema.load_file(input)
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
classes = {name: get_ql_class(cls) for name, cls in data.classes.items()}
if not classes:
raise NoClasses
root = next(iter(classes.values()))

View File

@@ -41,10 +41,10 @@ def get_cpp_type(schema_type: str):
def get_field(c: dbscheme.Column):
args = {
"field_name": c.schema_name,
"type": c.type,
"base_type": c.type,
}
args.update(cpp.get_field_override(c.schema_name))
args["type"] = get_cpp_type(args["type"])
args["base_type"] = get_cpp_type(args["base_type"])
return cpp.Field(**args)

View File

@@ -16,7 +16,7 @@ cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "
"xor", "xor_eq"}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"type": "unsigned"}),
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
@@ -32,7 +32,7 @@ def get_field_override(field: str):
@dataclass
class Field:
field_name: str
type: str
base_type: str
is_optional: bool = False
is_repeated: bool = False
is_predicate: bool = False
@@ -40,13 +40,18 @@ class Field:
first: bool = False
def __post_init__(self):
if self.is_optional:
self.type = f"std::optional<{self.type}>"
if self.is_repeated:
self.type = f"std::vector<{self.type}>"
if self.field_name in cpp_keywords:
self.field_name += "_"
@property
def type(self) -> str:
type = self.base_type
if self.is_optional:
type = f"std::optional<{type}>"
if self.is_repeated:
type = f"std::vector<{type}>"
return type
# using @property breaks pystache internals here
def get_streamer(self):
if self.type == "std::string":
@@ -60,6 +65,10 @@ class Field:
def is_single(self):
return not (self.is_optional or self.is_repeated or self.is_predicate)
@property
def is_label(self):
return self.base_type.startswith("TrapLabel<")
@dataclass
class Trap:

View File

@@ -115,6 +115,8 @@ child = _ChildModifier()
doc = _DocModifier
desc = _DescModifier
use_for_null = _annotate(null=True)
qltest = _Namespace(
skip=_Pragma("qltest_skip"),
collapse_hierarchy=_Pragma("qltest_collapse_hierarchy"),

View File

@@ -55,6 +55,14 @@ class Property:
def is_predicate(self) -> bool:
return self.kind == self.Kind.PREDICATE
@property
def has_class_type(self) -> bool:
return bool(self.type) and self.type[0].isupper()
@property
def has_builtin_type(self) -> bool:
return bool(self.type) and self.type[0].islower()
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
@@ -104,6 +112,16 @@ class Class:
class Schema:
classes: Dict[str, Class] = field(default_factory=dict)
includes: Set[str] = field(default_factory=set)
null: Optional[str] = None
@property
def root_class(self):
# always the first in the dictionary
return next(iter(self.classes.values()))
@property
def null_class(self):
return self.classes[self.null] if self.null else None
predicate_marker = object()
@@ -195,6 +213,8 @@ def _get_class(cls: type) -> Class:
raise Error(f"Class name must be capitalized, found {cls.__name__}")
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
raise Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise Error(f"Null class cannot be derived")
return Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
@@ -233,6 +253,7 @@ def load(m: types.ModuleType) -> Schema:
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import swift.codegen.lib.schema.defs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
@@ -247,8 +268,13 @@ def load(m: types.ModuleType) -> Schema:
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if getattr(data, "_null", False):
if null is not None:
raise Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
cls.is_null_class = True
return Schema(includes=includes, classes=_toposort_classes_by_group(classes))
return Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> Schema:

View File

@@ -17,6 +17,8 @@ namespace codeql {
{{#classes}}
struct {{name}}{{#has_bases}} : {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/bases}}{{/has_bases}} {
static constexpr const char* NAME = "{{name}}";
{{#final}}
explicit {{name}}(TrapLabel<{{name}}Tag> id) : id{id} {}
@@ -33,6 +35,41 @@ struct {{name}}{{#has_bases}} : {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/b
}
{{/final}}
{{^final}}
protected:
{{/final}}
template <typename F>
void forEachLabel(F f) {
{{#final}}
f("id", -1, id);
{{/final}}
{{#bases}}
{{ref.name}}::forEachLabel(f);
{{/bases}}
{{#fields}}
{{#is_label}}
{{#is_repeated}}
for (auto i = 0u; i < {{field_name}}.size(); ++i) {
{{#is_optional}}
if ({{field_name}}[i]) f("{{field_name}}", i, *{{field_name}}[i]);
{{/is_optional}}
{{^is_optional}}
f("{{field_name}}", i, {{field_name}}[i]);
{{/is_optional}}
}
{{/is_repeated}}
{{^is_repeated}}
{{#is_optional}}
if ({{field_name}}) f("{{field_name}}", -1, *{{field_name}});
{{/is_optional}}
{{^is_optional}}
f("{{field_name}}", -1, {{field_name}});
{{/is_optional}}
{{/is_repeated}}
{{/is_label}}
{{/fields}}
}
protected:
void emit({{^final}}TrapLabel<{{name}}Tag> id, {{/final}}std::ostream& out) const;
};

View File

@@ -14,12 +14,24 @@ namespace codeql {
// {{table_name}}
struct {{name}}Trap {
{{#fields}}
static constexpr const char* NAME = "{{name}}Trap";
{{#fields}}
{{type}} {{field_name}}{};
{{/fields}}
{{/fields}}
template <typename F>
void forEachLabel(F f) {
{{#fields}}
{{#is_label}}
f("{{field_name}}", -1, {{field_name}});
{{/is_label}}
{{/fields}}
}
};
std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e);
{{#id}}
namespace detail {

View File

@@ -80,6 +80,25 @@ def test_class_with_field(generate, type, expected, property_cls, optional, repe
]
def test_class_field_with_null(generate, input):
input.null = "Null"
a = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"], properties=[
schema.SingleProperty("x", "A"),
schema.SingleProperty("y", "B"),
])
]) == [
a,
cpp.Class(name="B", bases=[a], final=True, trap_name="Bs",
fields=[
cpp.Field("x", "TrapLabel<ATag>"),
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
]),
]
def test_class_with_predicate(generate):
assert generate([
schema.Class(name="MyClass", properties=[

View File

@@ -18,8 +18,9 @@ def dir_param(request):
@pytest.fixture
def generate(opts, input, renderer):
def func(classes):
def func(classes, null=None):
input.classes = {cls.name: cls for cls in classes}
input.null = null
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
assert out is opts.dbscheme
return data
@@ -359,5 +360,114 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
)
def test_null_class(generate):
assert generate([
schema.Class(
name="Base",
derived={"W", "X", "Y", "Z", "Null"},
),
schema.Class(
name="W",
bases=["Base"],
properties=[
schema.SingleProperty("w", "W"),
schema.SingleProperty("x", "X"),
schema.OptionalProperty("y", "Y"),
schema.RepeatedProperty("z", "Z"),
]
),
schema.Class(
name="X",
bases=["Base"],
),
schema.Class(
name="Y",
bases=["Base"],
),
schema.Class(
name="Z",
bases=["Base"],
),
schema.Class(
name="Null",
bases=["Base"],
),
], null="Null") == dbscheme.Scheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@null", "@w", "@x", "@y", "@z"],
),
dbscheme.Table(
name="ws",
columns=[
dbscheme.Column('id', '@w', binding=True),
dbscheme.Column('w', '@w_or_none'),
dbscheme.Column('x', '@x_or_none'),
],
),
dbscheme.Table(
name="w_ies",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('y', '@y_or_none'),
],
),
dbscheme.Table(
name="w_zs",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('index', 'int'),
dbscheme.Column('z', '@z_or_none'),
],
),
dbscheme.Table(
name="xes",
columns=[
dbscheme.Column('id', '@x', binding=True),
],
),
dbscheme.Table(
name="ys",
columns=[
dbscheme.Column('id', '@y', binding=True),
],
),
dbscheme.Table(
name="zs",
columns=[
dbscheme.Column('id', '@z', binding=True),
],
),
dbscheme.Table(
name="nulls",
columns=[
dbscheme.Column('id', '@null', binding=True),
],
),
dbscheme.Union(
lhs="@w_or_none",
rhs=["@w", "@null"],
),
dbscheme.Union(
lhs="@x_or_none",
rhs=["@x", "@null"],
),
dbscheme.Union(
lhs="@y_or_none",
rhs=["@y", "@null"],
),
dbscheme.Union(
lhs="@z_or_none",
rhs=["@z", "@null"],
),
],
)
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -13,6 +13,8 @@ def test_empty_schema():
assert data.classes == {}
assert data.includes == set()
assert data.null is None
assert data.null_class is None
def test_one_empty_class():
@@ -24,6 +26,7 @@ def test_one_empty_class():
assert data.classes == {
'MyClass': schema.Class('MyClass'),
}
assert data.root_class is data.classes['MyClass']
def test_two_empty_classes():
@@ -39,6 +42,7 @@ def test_two_empty_classes():
'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}),
'MyClass2': schema.Class('MyClass2', bases=['MyClass1']),
}
assert data.root_class is data.classes['MyClass1']
def test_no_external_bases():
@@ -452,7 +456,8 @@ def test_property_docstring_newline():
property.""")
assert data.classes == {
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
'A': schema.Class('A',
properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
}
@@ -566,5 +571,54 @@ def test_class_default_doc_name():
}
def test_null_class():
@schema.load
class data:
class Root:
pass
@defs.use_for_null
class Null(Root):
pass
assert data.classes == {
'Root': schema.Class('Root', derived={'Null'}),
'Null': schema.Class('Null', bases=['Root']),
}
assert data.null == 'Null'
assert data.null_class is data.classes[data.null]
def test_null_class_cannot_be_derived():
with pytest.raises(schema.Error):
@schema.load
class data:
class Root:
pass
@defs.use_for_null
class Null(Root):
pass
class Impossible(Null):
pass
def test_null_class_cannot_be_defined_multiple_times():
with pytest.raises(schema.Error):
@schema.load
class data:
class Root:
pass
@defs.use_for_null
class Null1(Root):
pass
@defs.use_for_null
class Null2(Root):
pass
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -60,20 +60,39 @@ class SwiftDispatcher {
}
template <typename Entry>
void emit(const Entry& entry) {
trap.emit(entry);
void emit(Entry&& entry) {
bool valid = true;
entry.forEachLabel([&valid, &entry, this](const char* field, int index, auto& label) {
using Label = std::remove_reference_t<decltype(label)>;
if (!label.valid()) {
std::cerr << entry.NAME << " has undefined " << field;
if (index >= 0) {
std::cerr << '[' << index << ']';
}
if constexpr (std::is_base_of_v<typename Label::Tag, UnspecifiedElementTag>) {
std::cerr << ", replacing with unspecified element\n";
label = emitUnspecified(idOf(entry), field, index);
} else {
std::cerr << ", skipping emission\n";
valid = false;
}
}
});
if (valid) {
trap.emit(entry);
}
}
template <typename Entry>
void emit(const std::optional<Entry>& entry) {
void emit(std::optional<Entry>&& entry) {
if (entry) {
emit(*entry);
emit(std::move(*entry));
}
}
template <typename... Cases>
void emit(const std::variant<Cases...>& entry) {
std::visit([this](const auto& e) { this->emit(e); }, entry);
void emit(std::variant<Cases...>&& entry) {
std::visit([this](auto&& e) { this->emit(std::move(e)); }, std::move(entry));
}
// This is a helper method to emit TRAP entries for AST nodes that we don't fully support yet.
@@ -88,13 +107,39 @@ class SwiftDispatcher {
emit(ElementIsUnknownTrap{label});
}
TrapLabel<UnspecifiedElementTag> emitUnspecified(std::optional<TrapLabel<ElementTag>>&& parent,
const char* property,
int index) {
UnspecifiedElement entry{trap.createLabel<UnspecifiedElementTag>()};
entry.error = "element was unspecified by the extractor";
entry.parent = std::move(parent);
entry.property = property;
if (index >= 0) {
entry.index = index;
}
trap.emit(entry);
return entry.id;
}
template <typename E>
std::optional<TrapLabel<ElementTag>> idOf(const E& entry) {
if constexpr (HasId<E>::value) {
return entry.id;
} else {
return std::nullopt;
}
}
// This method gives a TRAP label for already emitted AST node.
// If the AST node was not emitted yet, then the emission is dispatched to a corresponding
// visitor (see `visit(T *)` methods below).
template <typename E, typename... Args, std::enable_if_t<IsStorable<E>>* = nullptr>
TrapLabelOf<E> fetchLabel(const E& e, Args&&... args) {
if constexpr (std::is_constructible_v<bool, const E&>) {
assert(e && "fetching a label on a null entity, maybe fetchOptionalLabel is to be used?");
if (!e) {
// this will be treated on emission
return undefined_label;
}
}
// this is required so we avoid any recursive loop: a `fetchLabel` during the visit of `e` might
// end up calling `fetchLabel` on `e` itself, so we want the visit of `e` to call `fetchLabel`
@@ -205,17 +250,18 @@ class SwiftDispatcher {
return std::nullopt;
}
// map `fetchLabel` on the iterable `arg`, returning a vector of all labels
// map `fetchLabel` on the iterable `arg`
// universal reference `Arg&&` is used to catch both temporary and non-const references, not
// for perfect forwarding
template <typename Iterable>
auto fetchRepeatedLabels(Iterable&& arg) {
std::vector<decltype(fetchLabel(*arg.begin()))> ret;
using Label = decltype(fetchLabel(*arg.begin()));
TrapLabelVectorWrapper<typename Label::Tag> ret;
if constexpr (HasSize<Iterable>::value) {
ret.reserve(arg.size());
ret.data.reserve(arg.size());
}
for (auto&& e : arg) {
ret.push_back(fetchLabel(e));
ret.data.push_back(fetchLabel(e));
}
return ret;
}
@@ -270,6 +316,12 @@ class SwiftDispatcher {
template <typename T>
struct HasSize<T, decltype(std::declval<T>().size(), void())> : std::true_type {};
template <typename T, typename = void>
struct HasId : std::false_type {};
template <typename T>
struct HasId<T, decltype(std::declval<T>().id, void())> : std::true_type {};
void attachLocation(swift::SourceLoc start,
swift::SourceLoc end,
TrapLabel<LocatableTag> locatableLabel) {
@@ -293,19 +345,20 @@ class SwiftDispatcher {
TrapLabel<Tag> fetchLabelFromUnion(const llvm::PointerUnion<Ts...> u) {
TrapLabel<Tag> ret{};
// with logical op short-circuiting, this will stop trying on the first successful fetch
// don't feel tempted to replace the variable with the expression inside the `assert`, or
// building with `NDEBUG` will not trigger the fetching
bool unionCaseFound = (... || fetchLabelFromUnionCase<Tag, Ts>(u, ret));
assert(unionCaseFound && "llvm::PointerUnion not set to a known case");
if (!unionCaseFound) {
// TODO emit error/warning here
return undefined_label;
}
return ret;
}
template <typename Tag, typename T, typename... Ts>
bool fetchLabelFromUnionCase(const llvm::PointerUnion<Ts...> u, TrapLabel<Tag>& output) {
// we rely on the fact that when we extract `ASTNode` instances (which only happens
// on `BraceStmt` elements), we cannot encounter a standalone `TypeRepr` there, so we skip
// this case; extracting `TypeRepr`s here would be problematic as we would not be able to
// provide the corresponding type
// on `BraceStmt`/`IfConfigDecl` elements), we cannot encounter a standalone `TypeRepr` there,
// so we skip this case; extracting `TypeRepr`s here would be problematic as we would not be
// able to provide the corresponding type
if constexpr (!std::is_same_v<T, swift::TypeRepr*>) {
if (auto e = u.template dyn_cast<T>()) {
output = fetchLabel(e);

View File

@@ -4,9 +4,14 @@
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
namespace codeql {
struct UndefinedTrapLabel {};
constexpr UndefinedTrapLabel undefined_label{};
class UntypedTrapLabel {
uint64_t id_;
@@ -18,14 +23,17 @@ class UntypedTrapLabel {
protected:
UntypedTrapLabel() : id_{undefined} {}
UntypedTrapLabel(uint64_t id) : id_{id} {}
UntypedTrapLabel(uint64_t id) : id_{id} { assert(id != undefined); }
public:
bool valid() const { return id_ != undefined; }
explicit operator bool() const { return valid(); }
friend std::ostream& operator<<(std::ostream& out, UntypedTrapLabel l) {
// TODO: this is a temporary fix to catch us from outputting undefined labels to trap
// this should be moved to a validity check, probably aided by code generation and carried out
// by `SwiftDispatcher`
assert(l.id_ != undefined && "outputting an undefined label!");
assert(l && "outputting an undefined label!");
out << '#' << std::hex << l.id_ << std::dec;
return out;
}
@@ -44,14 +52,37 @@ class TrapLabel : public UntypedTrapLabel {
using Tag = TagParam;
TrapLabel() = default;
TrapLabel(UndefinedTrapLabel) : TrapLabel() {}
TrapLabel& operator=(UndefinedTrapLabel) {
*this = TrapLabel{};
return *this;
}
// The caller is responsible for ensuring ID uniqueness.
static TrapLabel unsafeCreateFromExplicitId(uint64_t id) { return {id}; }
static TrapLabel unsafeCreateFromUntyped(UntypedTrapLabel label) { return {label.id_}; }
template <typename OtherTag>
TrapLabel(const TrapLabel<OtherTag>& other) : UntypedTrapLabel(other) {
static_assert(std::is_base_of_v<Tag, OtherTag>, "wrong label assignment!");
template <typename SourceTag>
TrapLabel(const TrapLabel<SourceTag>& other) : UntypedTrapLabel(other) {
static_assert(std::is_base_of_v<Tag, SourceTag>, "wrong label assignment!");
}
};
// wrapper class to allow directly assigning a vector of TrapLabel<A> to a vector of
// TrapLabel<B> if B is a base of A, using move semantics rather than copying
template <typename TagParam>
struct TrapLabelVectorWrapper {
using Tag = TagParam;
std::vector<TrapLabel<TagParam>> data;
template <typename DestinationTag>
operator std::vector<TrapLabel<DestinationTag>>() && {
static_assert(std::is_base_of_v<DestinationTag, Tag>, "wrong label assignment!");
// reinterpret_cast is safe because TrapLabel instances differ only on the type, not the
// underlying data
return std::move(reinterpret_cast<std::vector<TrapLabel<DestinationTag>>&>(data));
}
};

View File

@@ -11,6 +11,7 @@ import codeql.swift.elements.Location
import codeql.swift.elements.UnknownFile
import codeql.swift.elements.UnknownLocation
import codeql.swift.elements.UnresolvedElement
import codeql.swift.elements.UnspecifiedElement
import codeql.swift.elements.decl.AbstractFunctionDecl
import codeql.swift.elements.decl.AbstractStorageDecl
import codeql.swift.elements.decl.AbstractTypeParamDecl

View File

@@ -0,0 +1,22 @@
private import codeql.swift.generated.UnspecifiedElement
import codeql.swift.elements.Location
class UnspecifiedElement extends Generated::UnspecifiedElement {
override string toString() {
exists(string source, string index |
(
source = " from " + this.getParent().getPrimaryQlClasses()
or
not this.hasParent() and source = ""
) and
(
index = "[" + this.getIndex() + "]"
or
not this.hasIndex() and index = ""
) and
result = "missing " + this.getProperty() + index + source
)
}
override Location getImmediateLocation() { result = this.getParent().(Locatable).getLocation() }
}

View File

@@ -0,0 +1,4 @@
// generated by codegen/codegen.py, remove this comment if you wish to edit this file
private import codeql.swift.generated.Raw
predicate constructUnspecifiedElement(Raw::UnspecifiedElement id) { any() }

View File

@@ -165,6 +165,21 @@ private module Impl {
)
}
private Element getImmediateChildOfUnspecifiedElement(
UnspecifiedElement e, int index, string partialPredicateCall
) {
exists(int b, int bLocatable, int n |
b = 0 and
bLocatable = b + 1 + max(int i | i = -1 or exists(getImmediateChildOfLocatable(e, i, _)) | i) and
n = bLocatable and
(
none()
or
result = getImmediateChildOfLocatable(e, index - b, partialPredicateCall)
)
)
}
private Element getImmediateChildOfDecl(Decl e, int index, string partialPredicateCall) {
exists(int b, int bAstNode, int n |
b = 0 and
@@ -4871,6 +4886,8 @@ private module Impl {
or
result = getImmediateChildOfUnknownLocation(e, index, partialAccessor)
or
result = getImmediateChildOfUnspecifiedElement(e, index, partialAccessor)
or
result = getImmediateChildOfEnumCaseDecl(e, index, partialAccessor)
or
result = getImmediateChildOfExtensionDecl(e, index, partialAccessor)

View File

@@ -51,6 +51,18 @@ module Raw {
override string toString() { result = "DbLocation" }
}
class UnspecifiedElement extends @unspecified_element, Locatable {
override string toString() { result = "UnspecifiedElement" }
Element getParent() { unspecified_element_parents(this, result) }
string getProperty() { unspecified_elements(this, result, _) }
int getIndex() { unspecified_element_indices(this, result) }
string getError() { unspecified_elements(this, _, result) }
}
class Decl extends @decl, AstNode {
ModuleDecl getModule() { decls(this, result) }
}

View File

@@ -10,6 +10,7 @@ module Synth {
TDbLocation(Raw::DbLocation id) { constructDbLocation(id) } or
TUnknownFile() or
TUnknownLocation() or
TUnspecifiedElement(Raw::UnspecifiedElement id) { constructUnspecifiedElement(id) } or
TAccessorDecl(Raw::AccessorDecl id) { constructAccessorDecl(id) } or
TAssociatedTypeDecl(Raw::AssociatedTypeDecl id) { constructAssociatedTypeDecl(id) } or
TClassDecl(Raw::ClassDecl id) { constructClassDecl(id) } or
@@ -321,7 +322,7 @@ module Synth {
class TFile = TDbFile or TUnknownFile;
class TLocatable = TArgument or TAstNode or TComment;
class TLocatable = TArgument or TAstNode or TComment or TUnspecifiedElement;
class TLocation = TDbLocation or TUnknownLocation;
@@ -499,6 +500,11 @@ module Synth {
cached
TUnknownLocation convertUnknownLocationFromRaw(Raw::Element e) { none() }
cached
TUnspecifiedElement convertUnspecifiedElementFromRaw(Raw::Element e) {
result = TUnspecifiedElement(e)
}
cached
TAccessorDecl convertAccessorDeclFromRaw(Raw::Element e) { result = TAccessorDecl(e) }
@@ -1464,6 +1470,8 @@ module Synth {
result = convertAstNodeFromRaw(e)
or
result = convertCommentFromRaw(e)
or
result = convertUnspecifiedElementFromRaw(e)
}
cached
@@ -2199,6 +2207,11 @@ module Synth {
cached
Raw::Element convertUnknownLocationToRaw(TUnknownLocation e) { none() }
cached
Raw::Element convertUnspecifiedElementToRaw(TUnspecifiedElement e) {
e = TUnspecifiedElement(result)
}
cached
Raw::Element convertAccessorDeclToRaw(TAccessorDecl e) { e = TAccessorDecl(result) }
@@ -3162,6 +3175,8 @@ module Synth {
result = convertAstNodeToRaw(e)
or
result = convertCommentToRaw(e)
or
result = convertUnspecifiedElementToRaw(e)
}
cached

View File

@@ -2,6 +2,7 @@
import codeql.swift.elements.CommentConstructor
import codeql.swift.elements.DbFileConstructor
import codeql.swift.elements.DbLocationConstructor
import codeql.swift.elements.UnspecifiedElementConstructor
import codeql.swift.elements.decl.AccessorDeclConstructor
import codeql.swift.elements.decl.AssociatedTypeDeclConstructor
import codeql.swift.elements.decl.ClassDeclConstructor

View File

@@ -0,0 +1,60 @@
// generated by codegen/codegen.py
private import codeql.swift.generated.Synth
private import codeql.swift.generated.Raw
import codeql.swift.elements.Element
import codeql.swift.elements.Locatable
module Generated {
class UnspecifiedElement extends Synth::TUnspecifiedElement, Locatable {
override string getAPrimaryQlClass() { result = "UnspecifiedElement" }
/**
* Gets the parent of this unspecified element, if it exists.
*
* This includes nodes from the "hidden" AST. It can be overridden in subclasses to change the
* behavior of both the `Immediate` and non-`Immediate` versions.
*/
Element getImmediateParent() {
result =
Synth::convertElementFromRaw(Synth::convertUnspecifiedElementToRaw(this)
.(Raw::UnspecifiedElement)
.getParent())
}
/**
* Gets the parent of this unspecified element, if it exists.
*/
final Element getParent() { result = getImmediateParent().resolve() }
/**
* Holds if `getParent()` exists.
*/
final predicate hasParent() { exists(getParent()) }
/**
* Gets the property of this unspecified element.
*/
string getProperty() {
result = Synth::convertUnspecifiedElementToRaw(this).(Raw::UnspecifiedElement).getProperty()
}
/**
* Gets the index of this unspecified element, if it exists.
*/
int getIndex() {
result = Synth::convertUnspecifiedElementToRaw(this).(Raw::UnspecifiedElement).getIndex()
}
/**
* Holds if `getIndex()` exists.
*/
final predicate hasIndex() { exists(getIndex()) }
/**
* Gets the error of this unspecified element.
*/
string getError() {
result = Synth::convertUnspecifiedElementToRaw(this).(Raw::UnspecifiedElement).getError()
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -29,4 +29,6 @@ predicate toBeTested(Element e) {
)
)
)
or
toBeTested(e.(UnspecifiedElement).getParent())
}

View File

@@ -0,0 +1,9 @@
| wrong.swift:3:1:3:23 | missing extended_type_decl from ExtensionDecl | getProperty: | extended_type_decl | getError: | element was unspecified by the extractor |
| wrong.swift:9:9:9:9 | missing fallthrough_dest from FallthroughStmt | getProperty: | fallthrough_dest | getError: | element was unspecified by the extractor |
| wrong.swift:9:9:9:9 | missing fallthrough_source from FallthroughStmt | getProperty: | fallthrough_source | getError: | element was unspecified by the extractor |
| wrong.swift:12:18:12:21 | missing element from EnumElementPattern | getProperty: | element | getError: | element was unspecified by the extractor |
| wrong.swift:14:18:14:26 | missing element from EnumElementPattern | getProperty: | element | getError: | element was unspecified by the extractor |
| wrong.swift:19:18:19:19 | missing element from EnumElementPattern | getProperty: | element | getError: | element was unspecified by the extractor |
| wrong.swift:22:13:22:13 | missing fallthrough_dest from FallthroughStmt | getProperty: | fallthrough_dest | getError: | element was unspecified by the extractor |
| wrong.swift:22:13:22:13 | missing fallthrough_source from FallthroughStmt | getProperty: | fallthrough_source | getError: | element was unspecified by the extractor |
| wrong.swift:26:18:26:19 | missing element from EnumElementPattern | getProperty: | element | getError: | element was unspecified by the extractor |

View File

@@ -0,0 +1,11 @@
// generated by codegen/codegen.py
import codeql.swift.elements
import TestUtils
from UnspecifiedElement x, string getProperty, string getError
where
toBeTested(x) and
not x.isUnknown() and
getProperty = x.getProperty() and
getError = x.getError()
select x, "getProperty:", getProperty, "getError:", getError

View File

@@ -0,0 +1,7 @@
// generated by codegen/codegen.py
import codeql.swift.elements
import TestUtils
from UnspecifiedElement x
where toBeTested(x) and not x.isUnknown()
select x, x.getIndex()

View File

@@ -0,0 +1,9 @@
| wrong.swift:3:1:3:23 | missing extended_type_decl from ExtensionDecl | extension |
| wrong.swift:9:9:9:9 | missing fallthrough_dest from FallthroughStmt | fallthrough |
| wrong.swift:9:9:9:9 | missing fallthrough_source from FallthroughStmt | fallthrough |
| wrong.swift:12:18:12:21 | missing element from EnumElementPattern | (no string representation) |
| wrong.swift:14:18:14:26 | missing element from EnumElementPattern | (no string representation) |
| wrong.swift:19:18:19:19 | missing element from EnumElementPattern | (no string representation) |
| wrong.swift:22:13:22:13 | missing fallthrough_dest from FallthroughStmt | fallthrough |
| wrong.swift:22:13:22:13 | missing fallthrough_source from FallthroughStmt | fallthrough |
| wrong.swift:26:18:26:19 | missing element from EnumElementPattern | (no string representation) |

View File

@@ -0,0 +1,7 @@
// generated by codegen/codegen.py
import codeql.swift.elements
import TestUtils
from UnspecifiedElement x
where toBeTested(x) and not x.isUnknown()
select x, x.getParent()

View File

@@ -0,0 +1,30 @@
//codeql-extractor-expected-status: 1
extension Undefined { }
enum Enum {
case A, B
func test(e: Enum) {
fallthrough
switch e {
case .A():
break
case .B(let x):
let _ = x
break
case Int:
break
case .C:
break
default:
fallthrough
}
switch undefined {
case .Whatever:
break
}
}
}

View File

@@ -0,0 +1 @@
extractor-tests/generated/UnspecifiedElement/UnspecifiedElement.ql

View File

@@ -37,6 +37,13 @@ class Location(Element):
class Locatable(Element):
location: optional[Location] | cpp.skip | doc("location associated with this element in the code")
@use_for_null
class UnspecifiedElement(Locatable):
parent: optional[Element]
property: string
index: optional[int]
error: string
class Comment(Locatable):
text: string

View File

@@ -0,0 +1,11 @@
diff -ru a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h
--- a/include/swift/AST/Stmt.h 2022-09-21 12:56:54.000000000 +0200
+++ b/include/swift/AST/Stmt.h 2022-11-04 14:39:18.407971007 +0100
@@ -920,7 +920,6 @@
/// Get the CaseStmt block to which the fallthrough transfers control.
/// Set during Sema.
CaseStmt *getFallthroughDest() const {
- assert(FallthroughDest && "fallthrough dest is not set until Sema");
return FallthroughDest;
}
void setFallthroughDest(CaseStmt *C) {