Files
codeql/go/extractor/dbscheme/dbscheme.go
2024-10-08 19:23:17 +01:00

427 lines
10 KiB
Go

package dbscheme
import (
"fmt"
"io"
"log"
"reflect"
"strings"
"github.com/github/codeql-go/extractor/trap"
)
// A Type represents a database type
type Type interface {
def() string
ref() string
repr() string
valid(val interface{}) bool
}
// A PrimitiveType represents a primitive dataase type
type PrimitiveType int
const (
// INT represents the primitive database type `int`
INT PrimitiveType = iota
// FLOAT represents the primitive database type `float`
FLOAT
// BOOLEAN represents the primitive database type `boolean`
BOOLEAN
// DATE represents the primitive database type `date`
DATE
// STRING represents the primitive database type `string`
STRING
)
// A PrimaryKeyType represents a database type defined by a primary key column
type PrimaryKeyType struct {
name string
}
// A UnionType represents a database type defined as the union of other database types
type UnionType struct {
name string
components []Type
}
// An AliasType represents a database type which is an alias of another database type
type AliasType struct {
name string
underlying Type
}
// A CaseType represents a database type defined by a primary key column with a supplementary kind column
type CaseType struct {
base Type
column string
branches []*BranchType
}
// A BranchType represents one branch of a case type
type BranchType struct {
idx int
name string
}
func (pt PrimitiveType) def() string {
return ""
}
func (pt PrimitiveType) ref() string {
switch pt {
case INT:
return "int"
case FLOAT:
return "float"
case BOOLEAN:
return "boolean"
case DATE:
return "date"
case STRING:
return "string"
default:
panic(fmt.Sprintf("Unexpected primitive type %d", pt))
}
}
func (pt PrimitiveType) repr() string {
switch pt {
case INT:
return "int"
case FLOAT:
return "float"
case BOOLEAN:
return "boolean"
case DATE:
return "date"
case STRING:
return "string"
default:
panic(fmt.Sprintf("Unexpected primitive type %d", pt))
}
}
func (pt PrimitiveType) valid(value interface{}) bool {
switch value.(type) {
case int:
return pt == INT
case float64:
return pt == FLOAT
case bool:
return pt == BOOLEAN
case string:
return pt == STRING
}
return false
}
func (pkt PrimaryKeyType) def() string {
return ""
}
func (pkt PrimaryKeyType) ref() string {
return pkt.name
}
func (pkt PrimaryKeyType) repr() string {
return "int"
}
func (pkt PrimaryKeyType) valid(value interface{}) bool {
_, ok := value.(trap.Label)
return ok
}
func (ut UnionType) def() string {
var b strings.Builder
nl := 0
fmt.Fprintf(&b, "%s = ", ut.name)
for i, comp := range ut.components {
if i > 0 {
if i < len(ut.components)-1 && b.Len()-nl > 100 {
fmt.Fprintf(&b, "\n%s", strings.Repeat(" ", len(ut.name)))
nl = b.Len()
}
fmt.Fprint(&b, " | ")
}
fmt.Fprint(&b, comp.ref())
}
fmt.Fprint(&b, ";")
return b.String()
}
func (ut UnionType) ref() string {
return ut.name
}
func (ut UnionType) repr() string {
return "int"
}
func (ut UnionType) valid(value interface{}) bool {
_, ok := value.(trap.Label)
return ok
}
func (at AliasType) def() string {
return at.name + " = " + at.underlying.ref() + ";"
}
func (at AliasType) ref() string {
return at.name
}
func (at AliasType) repr() string {
return at.underlying.repr()
}
func (at AliasType) valid(value interface{}) bool {
return at.underlying.valid(value)
}
func (ct CaseType) def() string {
var b strings.Builder
fmt.Fprintf(&b, "case %s.%s of", ct.base.ref(), ct.column)
sep := " "
for _, branch := range ct.branches {
fmt.Fprintf(&b, "\n%s%s", sep, branch.def())
sep = "| "
}
fmt.Fprint(&b, ";")
return b.String()
}
func (ct CaseType) ref() string {
panic("case types do not have a name")
}
func (ct CaseType) repr() string {
return "int"
}
func (ct CaseType) valid(value interface{}) bool {
_, ok := value.(trap.Label)
return ok
}
func (bt BranchType) def() string {
return fmt.Sprintf("%d = %s", bt.idx, bt.name)
}
func (bt BranchType) ref() string {
return bt.name
}
func (bt BranchType) repr() string {
return "int"
}
func (bt BranchType) valid(value interface{}) bool {
_, ok := value.(trap.Label)
return ok
}
// Index returns the numeric index of this branch type
func (bt BranchType) Index() int {
return bt.idx
}
// A Column represents a column in a database table
type Column struct {
columnName string
columnType Type
unique bool
ref bool
}
func (col Column) String() string {
var b strings.Builder
if col.unique {
fmt.Fprint(&b, "unique ")
}
fmt.Fprintf(&b, "%s %s: %s", col.columnType.repr(), col.columnName, col.columnType.ref())
if col.ref {
fmt.Fprint(&b, " ref")
}
return b.String()
}
// Key returns a new column that is the same as this column, but has the `key` flag set to `true`
func (col Column) Key() Column {
return Column{col.columnName, col.columnType, true, false}
}
// Unique returns a new column that is the same as this column, but has the `unique` flag set to `true`
func (col Column) Unique() Column {
return Column{col.columnName, col.columnType, true, col.ref}
}
// EntityColumn constructs a column with name `columnName` holding entities of type `columnType`
func EntityColumn(columnType Type, columnName string) Column {
return Column{columnName, columnType, false, true}
}
// StringColumn constructs a column with name `columnName` holding string values
func StringColumn(columnName string) Column {
return Column{columnName, STRING, false, true}
}
// IntColumn constructs a column with name `columnName` holding integer values
func IntColumn(columnName string) Column {
return Column{columnName, INT, false, true}
}
// FloatColumn constructs a column with name `columnName` holding floating point number values
func FloatColumn(columnName string) Column {
return Column{columnName, FLOAT, false, true}
}
// A Table represents a database table
type Table struct {
name string
schema []Column
keysets [][]string
}
// KeySet adds `keys` as a keyset to this table
func (tbl *Table) KeySet(keys ...string) *Table {
tbl.keysets = append(tbl.keysets, keys)
return tbl
}
func (tbl Table) String() string {
var b strings.Builder
for _, keyset := range tbl.keysets {
fmt.Fprint(&b, "#keyset[")
sep := ""
for _, key := range keyset {
fmt.Fprintf(&b, "%s%s", sep, key)
sep = ", "
}
fmt.Fprint(&b, "]\n")
}
fmt.Fprint(&b, tbl.name)
fmt.Fprint(&b, "(")
nl := 0
for i, column := range tbl.schema {
if i > 0 {
// wrap >100 char lines
if i < len(tbl.schema)-1 && b.Len()-nl > 100 {
fmt.Fprintf(&b, ",\n%s", strings.Repeat(" ", len(tbl.name)+1))
nl = b.Len()
} else {
fmt.Fprint(&b, ", ")
}
}
fmt.Fprint(&b, column.String())
}
fmt.Fprint(&b, ");")
return b.String()
}
// Emit outputs a tuple of `values` for this table using trap writer `tw`
// and panicks if the tuple does not have the right schema
func (tbl Table) Emit(tw *trap.Writer, values ...interface{}) {
if ncol, nval := len(tbl.schema), len(values); ncol != nval {
log.Fatalf("wrong number of values for table %s; expected %d, but got %d", tbl.name, ncol, nval)
}
for i, col := range tbl.schema {
if !col.columnType.valid(values[i]) {
panic(fmt.Sprintf("Invalid value for column %d of table %s; expected a %s, but got %s which is a %s", i, tbl.name, col.columnType.ref(), values[i], reflect.TypeOf(values[i])))
}
}
tw.Emit(tbl.name, values)
}
var tables = []*Table{}
var types = []Type{}
var defaultSnippets = []string{}
// NewTable constructs a new table with the given `name` and `columns`
func NewTable(name string, columns ...Column) *Table {
tbl := &Table{name, columns, [][]string{}}
tables = append(tables, tbl)
return tbl
}
// NewPrimaryKeyType constructs a new primary key type with the given `name`,
// and adds it to the union types `parents` (if any)
func NewPrimaryKeyType(name string, parents ...*UnionType) *PrimaryKeyType {
tp := &PrimaryKeyType{name}
types = append(types, tp)
for _, parent := range parents {
parent.components = append(parent.components, tp)
}
return tp
}
// NewUnionType constructs a new union type with the given `name`,
// and adds it to the union types `parents` (if any)
func NewUnionType(name string, parents ...*UnionType) *UnionType {
tp := &UnionType{name, []Type{}}
types = append(types, tp)
for _, parent := range parents {
parent.components = append(parent.components, tp)
}
return tp
}
// AddChild adds the type with given `name` to the union type.
// This is useful if a type defined in a snippet should be a child of a type defined in Go.
func (parent *UnionType) AddChild(name string) bool {
tp := &PrimaryKeyType{name}
// don't add tp to types; it's expected that it's already in the db somehow.
parent.components = append(parent.components, tp)
return true
}
// NewAliasType constructs a new alias type with the given `name` that aliases `underlying`
func NewAliasType(name string, underlying Type) *AliasType {
tp := &AliasType{name, underlying}
types = append(types, tp)
return tp
}
// NewCaseType constructs a new case type on the given `base` type whose discriminator values
// come from `column`
func NewCaseType(base Type, column string) *CaseType {
tp := &CaseType{base, column, []*BranchType{}}
types = append(types, tp)
return tp
}
// NewBranch adds a new branch with the given `name` to this case type
// and adds it to the union types `parents` (if any)
func (ct *CaseType) NewBranch(name string, parents ...*UnionType) *BranchType {
tp := &BranchType{len(ct.branches), name}
ct.branches = append(ct.branches, tp)
for _, parent := range parents {
parent.components = append(parent.components, tp)
}
return tp
}
// AddDefaultSnippet adds the given text `snippet` to the schema of this database
func AddDefaultSnippet(snippet string) bool {
defaultSnippets = append(defaultSnippets, snippet)
return true
}
// PrintDbScheme prints the schema of this database to the writer `w`
func PrintDbScheme(w io.Writer) {
fmt.Fprintf(w, "/** Auto-generated dbscheme; do not edit. Run `make gen` in directory `go/` to regenerate. */\n\n")
for _, snippet := range defaultSnippets {
fmt.Fprintf(w, "%s\n", snippet)
}
for _, table := range tables {
fmt.Fprintf(w, "%s\n\n", table.String())
}
for _, tp := range types {
def := tp.def()
if def != "" {
fmt.Fprintf(w, "%s\n\n", def)
}
}
}