Kotlin: Improve the dbscheme generator

We now work out the supertype relationships based on the sets of leaf
types that are included, rather than simply following the hierarchy of
declarations. This means that we know about more supertype relationships
that exist, so there is less need to cast types.
This commit is contained in:
Ian Lynagh
2021-09-23 19:47:33 +01:00
parent 5aac46f20f
commit bcbcd612a3

View File

@@ -3,35 +3,119 @@
import re
import sys
enums = {}
unions = {}
tables = {}
def parse_dbscheme(filename):
with open(filename, 'r') as f:
dbscheme = f.read()
# Remove comments
dbscheme = re.sub(r'/\*.*?\*/', '', dbscheme, flags=re.DOTALL)
dbscheme = re.sub(r'//[^\r\n]*/', '', dbscheme)
# kind enums
for name, kind, body in re.findall(r'case\s+@([^.\s]*)\.([^.\s]*)\s+of\b(.*?);',
dbscheme,
flags=re.DOTALL):
mapping = []
for num, typ in re.findall(r'(\d+)\s*=\s*@(\S+)', body):
mapping.append((int(num), typ))
enums[name] = (kind, mapping)
# unions
for name, rhs in re.findall(r'@(\w+)\s*=\s*(@\w+(?:\s*\|\s*@\w+)*)',
dbscheme,
flags=re.DOTALL):
typs = re.findall(r'@(\w+)', rhs)
unions[name] = typs
# tables
for relname, body in re.findall('\n([\w_]+)(\([^)]*\))',
dbscheme,
flags=re.DOTALL):
columns = list(re.findall('(\S+)\s*:\s*([^\s,]+)(?:\s+(ref)|)', body))
tables[relname] = columns
parse_dbscheme('../ql/lib/config/semmlecode.dbscheme')
type_aliases = {}
for alias, typs in unions.items():
if len(typs) == 1:
real = typs[0]
if real in type_aliases:
real = type_aliases[real]
type_aliases[alias] = real
def unalias(t):
return type_aliases.get(t, t)
type_leaf = set()
type_union = {}
for name, (kind, mapping) in enums.items():
s = set()
for num, typ in mapping:
s.add(typ)
type_leaf.add(typ)
type_union[name] = s
for name, typs in unions.items():
if name not in type_aliases:
type_union[name] = set(map(unalias, typs))
for relname, columns in tables.items():
for _, db_type, ref in columns:
if db_type[0] == '@' and ref == '':
db_type_name = db_type[1:]
if db_type_name not in enums:
type_leaf.add(db_type_name)
type_union_of_leaves = {}
def to_leaves(t):
if t not in type_union_of_leaves:
xs = type_union[t]
leaves = set()
for x in xs:
if x in type_leaf:
leaves.add(x)
else:
to_leaves(x)
leaves.update(type_union_of_leaves[x])
type_union_of_leaves[t] = leaves
for t in type_union:
to_leaves(t)
supertypes = {}
for t in type_leaf:
supers = set()
for sup, s in type_union_of_leaves.items():
if t in s:
supers.add(sup)
supertypes[t] = supers
for t, leaves in type_union_of_leaves.items():
supers = set()
for sup, s in type_union_of_leaves.items():
if t != sup and leaves.issubset(s):
supers.add(sup)
supertypes[t] = supers
def upperFirst(string):
return string[0].upper() + string[1:]
with open('../ql/lib/config/semmlecode.dbscheme', 'r') as f:
dbscheme = f.read()
# Remove comments
dbscheme = re.sub(r'/\*.*?\*/', '', dbscheme, flags=re.DOTALL)
dbscheme = re.sub(r'//[^\r\n]*/', '', dbscheme)
enums = {}
type_aliases = {}
type_hierarchy = {}
def unalias(t):
while t in type_aliases:
t = type_aliases[t]
return t
def genTable(kt, relname, body, enum = None, kind = None, num = None, typ = None):
def genTable(kt, relname, columns, enum = None, kind = None, num = None, typ = None):
kt.write('fun TrapWriter.write' + upperFirst(relname))
if kind is not None:
kt.write('_' + typ)
kt.write('(')
for colname, db_type in re.findall('(\S+)\s*:\s*([^\s,]+)', body):
for colname, db_type, _ in columns:
if colname != kind:
kt.write(colname + ': ')
if db_type == 'int':
# TODO: Do something better if the column is a 'case'
kt.write('Int')
elif db_type == 'float':
kt.write('Double')
@@ -52,7 +136,7 @@ def genTable(kt, relname, body, enum = None, kind = None, num = None, typ = None
kt.write(') {\n')
kt.write(' this.writeTrap("' + relname + '(')
comma = ''
for colname, db_type in re.findall('(\S+)\s*:\s*([^\s,]+)', body):
for colname, db_type, _ in columns:
kt.write(comma)
if colname == kind:
kt.write(str(num))
@@ -70,59 +154,27 @@ with open('src/main/kotlin/KotlinExtractorDbScheme.kt', 'w') as kt:
kt.write('/* Generated by ' + sys.argv[0] + ': Do not edit manually. */\n')
kt.write('package com.github.codeql\n')
# kind enums
for name, kind, body in re.findall(r'case\s+@([^.\s]*)\.([^.\s]*)\s+of\b(.*?);',
dbscheme,
flags=re.DOTALL):
mapping = []
for num, typ in re.findall(r'(\d+)\s*=\s*@(\S+)', body):
s = type_hierarchy.get(typ, set())
s.add(name)
type_hierarchy[typ] = s
mapping.append((int(num), typ))
enums[name] = (kind, mapping)
# unions
for name, unions in re.findall(r'@(\w+)\s*=\s*(@\w+(?:\s*\|\s*@\w+)*)',
dbscheme,
flags=re.DOTALL):
type_hierarchy[name] = type_hierarchy.get(name, set())
typs = re.findall(r'@(\w+)', unions)
if len(typs) == 1:
type_aliases[name] = typs[0]
else:
for typ in typs:
s = type_hierarchy.get(typ, set())
s.add(name)
type_hierarchy[typ] = s
# tables
for relname, body in re.findall('\n([\w_]+)(\([^)]*\))',
dbscheme,
flags=re.DOTALL):
for relname, columns in tables.items():
enum = None
for db_type in re.findall(':\s*@([^\s,]+)\s*(?:,|$)', body):
type_hierarchy[db_type] = type_hierarchy.get(db_type, set())
if db_type in enums:
enum = db_type
for _, db_type, ref in columns:
if db_type[0] == '@' and ref == '':
db_type_name = db_type[1:]
if db_type_name in enums:
enum = db_type_name
if enum is None:
genTable(kt, relname, body)
genTable(kt, relname, columns)
else:
(kind, mapping) = enums[enum]
for num, typ in mapping:
genTable(kt, relname, body, enum, kind, num, typ)
for typ in sorted(type_hierarchy):
if typ in type_aliases:
kt.write('typealias Db' + upperFirst(typ) + ' = Db' + upperFirst(type_aliases[typ]) + '\n')
else:
kt.write('sealed interface Db' + upperFirst(typ))
# This map of unalias avoids duplicates when both T and an
# alias of T appear in the set. Sorting makes the output
# deterministic.
names = sorted(set(map(unalias, type_hierarchy[typ])))
if names:
kt.write(': ')
kt.write(', '.join(map(lambda name: 'Db' + upperFirst(name), names)))
kt.write('\n')
genTable(kt, relname, columns, enum, kind, num, typ)
for typ in sorted(supertypes):
kt.write('sealed interface Db' + upperFirst(typ))
# Sorting makes the output deterministic.
names = sorted(supertypes[typ])
if names:
kt.write(': ')
kt.write(', '.join(map(lambda name: 'Db' + upperFirst(name), names)))
kt.write('\n')
for alias in sorted(type_aliases):
kt.write('typealias Db' + upperFirst(alias) + ' = Db' + upperFirst(type_aliases[alias]) + '\n')