diff --git a/misc/codegen/generators/dbschemegen.py b/misc/codegen/generators/dbschemegen.py index 8266eb5be0f..f861972cdd6 100755 --- a/misc/codegen/generators/dbschemegen.py +++ b/misc/codegen/generators/dbschemegen.py @@ -69,11 +69,12 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a ) # use property-specific tables for 1-to-many and 1-to-at-most-1 properties for f in cls.properties: + overridden_table_name = f.pragmas.get("ql_db_table_name") if f.synth: continue if f.is_unordered: yield Table( - name=inflection.tableize(f"{cls.name}_{f.name}"), + name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)), @@ -83,7 +84,7 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a elif f.is_repeated: yield Table( keyset=KeySet(["id", "index"]), - name=inflection.tableize(f"{cls.name}_{f.name}"), + name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), Column("index", type="int"), @@ -94,7 +95,7 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a elif f.is_optional: yield Table( keyset=KeySet(["id"]), - name=inflection.tableize(f"{cls.name}_{f.name}"), + name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), Column(f.name, dbtype(f.type, add_or_none_except)), @@ -104,7 +105,7 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a elif f.is_predicate: yield Table( keyset=KeySet(["id"]), - name=inflection.underscore(f"{cls.name}_{f.name}"), + name=overridden_table_name or inflection.underscore(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), ], @@ -118,7 +119,8 @@ def check_name_conflicts(decls: list[Table | Union]): match decl: case Table(name=name): if name in names: - raise Error(f"Duplicate table name: {name}") + raise Error(f"Duplicate table name: { + name}, you can use `@ql.db_table_name` on a property to resolve this") names.add(name) diff --git a/misc/codegen/lib/schemadefs.py b/misc/codegen/lib/schemadefs.py index dc8bd2aab4d..b0cf2b038a8 100644 --- a/misc/codegen/lib/schemadefs.py +++ b/misc/codegen/lib/schemadefs.py @@ -248,6 +248,7 @@ ql.add(_Parametrized(_ClassPragma("default_doc_name"), factory=lambda doc: doc)) ql.add(_ClassPragma("hideable", inherited=True)) ql.add(_Pragma("internal")) ql.add(_Parametrized(_Pragma("name"), factory=lambda name: name)) +ql.add(_Parametrized(_PropertyPragma("db_table_name"), factory=lambda name: name)) cpp.add(_Pragma("skip")) diff --git a/misc/codegen/test/test_dbschemegen.py b/misc/codegen/test/test_dbschemegen.py index 5b1bd7e73dc..653ad7fc8a3 100644 --- a/misc/codegen/test/test_dbschemegen.py +++ b/misc/codegen/test/test_dbschemegen.py @@ -603,5 +603,68 @@ def test_table_conflict(generate): ]) +def test_table_name_overrides(generate): + assert generate([ + schema.Class("Obj", properties=[ + schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}), + schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}), + schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}), + schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}), + schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}), + ]), + ]) == dbscheme.Scheme( + src=schema_file.name, + includes=[], + declarations=[ + dbscheme.Table( + name="objs", + columns=[ + dbscheme.Column("id", "@obj", binding=True), + ], + ), + dbscheme.Table( + name="foo", + keyset=dbscheme.KeySet(["id"]), + columns=[ + dbscheme.Column("id", "@obj"), + dbscheme.Column("x", "a"), + ], + ), + dbscheme.Table( + name="bar", + keyset=dbscheme.KeySet(["id", "index"]), + columns=[ + dbscheme.Column("id", "@obj"), + dbscheme.Column("index", "int"), + dbscheme.Column("y", "b"), + ], + ), + dbscheme.Table( + name="baz", + keyset=dbscheme.KeySet(["id", "index"]), + columns=[ + dbscheme.Column("id", "@obj"), + dbscheme.Column("index", "int"), + dbscheme.Column("z", "c"), + ], + ), + dbscheme.Table( + name="hello", + keyset=dbscheme.KeySet(["id"]), + columns=[ + dbscheme.Column("id", "@obj"), + ], + ), + dbscheme.Table( + name="world", + columns=[ + dbscheme.Column("id", "@obj"), + dbscheme.Column("q", "d"), + ], + ), + ], + ) + + if __name__ == '__main__': sys.exit(pytest.main([__file__] + sys.argv[1:]))