mirror of
https://github.com/github/codeql.git
synced 2025-12-16 08:43:11 +01:00
427 lines
10 KiB
Go
427 lines
10 KiB
Go
package dbscheme
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/github/codeql-go/extractor/trap"
|
|
)
|
|
|
|
// A Type represents a database type
|
|
type Type interface {
|
|
def() string
|
|
ref() string
|
|
repr() string
|
|
valid(val interface{}) bool
|
|
}
|
|
|
|
// A PrimitiveType represents a primitive dataase type
|
|
type PrimitiveType int
|
|
|
|
const (
|
|
// INT represents the primitive database type `int`
|
|
INT PrimitiveType = iota
|
|
// FLOAT represents the primitive database type `float`
|
|
FLOAT
|
|
// BOOLEAN represents the primitive database type `boolean`
|
|
BOOLEAN
|
|
// DATE represents the primitive database type `date`
|
|
DATE
|
|
// STRING represents the primitive database type `string`
|
|
STRING
|
|
)
|
|
|
|
// A PrimaryKeyType represents a database type defined by a primary key column
|
|
type PrimaryKeyType struct {
|
|
name string
|
|
}
|
|
|
|
// A UnionType represents a database type defined as the union of other database types
|
|
type UnionType struct {
|
|
name string
|
|
components []Type
|
|
}
|
|
|
|
// An AliasType represents a database type which is an alias of another database type
|
|
type AliasType struct {
|
|
name string
|
|
underlying Type
|
|
}
|
|
|
|
// A CaseType represents a database type defined by a primary key column with a supplementary kind column
|
|
type CaseType struct {
|
|
base Type
|
|
column string
|
|
branches []*BranchType
|
|
}
|
|
|
|
// A BranchType represents one branch of a case type
|
|
type BranchType struct {
|
|
idx int
|
|
name string
|
|
}
|
|
|
|
func (pt PrimitiveType) def() string {
|
|
return ""
|
|
}
|
|
|
|
func (pt PrimitiveType) ref() string {
|
|
switch pt {
|
|
case INT:
|
|
return "int"
|
|
case FLOAT:
|
|
return "float"
|
|
case BOOLEAN:
|
|
return "boolean"
|
|
case DATE:
|
|
return "date"
|
|
case STRING:
|
|
return "string"
|
|
default:
|
|
panic(fmt.Sprintf("Unexpected primitive type %d", pt))
|
|
}
|
|
}
|
|
|
|
func (pt PrimitiveType) repr() string {
|
|
switch pt {
|
|
case INT:
|
|
return "int"
|
|
case FLOAT:
|
|
return "float"
|
|
case BOOLEAN:
|
|
return "boolean"
|
|
case DATE:
|
|
return "date"
|
|
case STRING:
|
|
return "string"
|
|
default:
|
|
panic(fmt.Sprintf("Unexpected primitive type %d", pt))
|
|
}
|
|
}
|
|
|
|
func (pt PrimitiveType) valid(value interface{}) bool {
|
|
switch value.(type) {
|
|
case int:
|
|
return pt == INT
|
|
case float64:
|
|
return pt == FLOAT
|
|
case bool:
|
|
return pt == BOOLEAN
|
|
case string:
|
|
return pt == STRING
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (pkt PrimaryKeyType) def() string {
|
|
return ""
|
|
}
|
|
|
|
func (pkt PrimaryKeyType) ref() string {
|
|
return pkt.name
|
|
}
|
|
|
|
func (pkt PrimaryKeyType) repr() string {
|
|
return "int"
|
|
}
|
|
|
|
func (pkt PrimaryKeyType) valid(value interface{}) bool {
|
|
_, ok := value.(trap.Label)
|
|
return ok
|
|
}
|
|
|
|
func (ut UnionType) def() string {
|
|
var b strings.Builder
|
|
nl := 0
|
|
fmt.Fprintf(&b, "%s = ", ut.name)
|
|
for i, comp := range ut.components {
|
|
if i > 0 {
|
|
if i < len(ut.components)-1 && b.Len()-nl > 100 {
|
|
fmt.Fprintf(&b, "\n%s", strings.Repeat(" ", len(ut.name)))
|
|
nl = b.Len()
|
|
}
|
|
fmt.Fprint(&b, " | ")
|
|
}
|
|
fmt.Fprint(&b, comp.ref())
|
|
}
|
|
fmt.Fprint(&b, ";")
|
|
return b.String()
|
|
}
|
|
|
|
func (ut UnionType) ref() string {
|
|
return ut.name
|
|
}
|
|
|
|
func (ut UnionType) repr() string {
|
|
return "int"
|
|
}
|
|
|
|
func (ut UnionType) valid(value interface{}) bool {
|
|
_, ok := value.(trap.Label)
|
|
return ok
|
|
}
|
|
|
|
func (at AliasType) def() string {
|
|
return at.name + " = " + at.underlying.ref() + ";"
|
|
}
|
|
|
|
func (at AliasType) ref() string {
|
|
return at.name
|
|
}
|
|
|
|
func (at AliasType) repr() string {
|
|
return at.underlying.repr()
|
|
}
|
|
|
|
func (at AliasType) valid(value interface{}) bool {
|
|
return at.underlying.valid(value)
|
|
}
|
|
|
|
func (ct CaseType) def() string {
|
|
var b strings.Builder
|
|
fmt.Fprintf(&b, "case %s.%s of", ct.base.ref(), ct.column)
|
|
sep := " "
|
|
for _, branch := range ct.branches {
|
|
fmt.Fprintf(&b, "\n%s%s", sep, branch.def())
|
|
sep = "| "
|
|
}
|
|
fmt.Fprint(&b, ";")
|
|
return b.String()
|
|
}
|
|
|
|
func (ct CaseType) ref() string {
|
|
panic("case types do not have a name")
|
|
}
|
|
|
|
func (ct CaseType) repr() string {
|
|
return "int"
|
|
}
|
|
|
|
func (ct CaseType) valid(value interface{}) bool {
|
|
_, ok := value.(trap.Label)
|
|
return ok
|
|
}
|
|
|
|
func (bt BranchType) def() string {
|
|
return fmt.Sprintf("%d = %s", bt.idx, bt.name)
|
|
}
|
|
|
|
func (bt BranchType) ref() string {
|
|
return bt.name
|
|
}
|
|
|
|
func (bt BranchType) repr() string {
|
|
return "int"
|
|
}
|
|
|
|
func (bt BranchType) valid(value interface{}) bool {
|
|
_, ok := value.(trap.Label)
|
|
return ok
|
|
}
|
|
|
|
// Index returns the numeric index of this branch type
|
|
func (bt BranchType) Index() int {
|
|
return bt.idx
|
|
}
|
|
|
|
// A Column represents a column in a database table
|
|
type Column struct {
|
|
columnName string
|
|
columnType Type
|
|
unique bool
|
|
ref bool
|
|
}
|
|
|
|
func (col Column) String() string {
|
|
var b strings.Builder
|
|
if col.unique {
|
|
fmt.Fprint(&b, "unique ")
|
|
}
|
|
fmt.Fprintf(&b, "%s %s: %s", col.columnType.repr(), col.columnName, col.columnType.ref())
|
|
if col.ref {
|
|
fmt.Fprint(&b, " ref")
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
// Key returns a new column that is the same as this column, but has the `key` flag set to `true`
|
|
func (col Column) Key() Column {
|
|
return Column{col.columnName, col.columnType, true, false}
|
|
}
|
|
|
|
// Unique returns a new column that is the same as this column, but has the `unique` flag set to `true`
|
|
func (col Column) Unique() Column {
|
|
return Column{col.columnName, col.columnType, true, col.ref}
|
|
}
|
|
|
|
// EntityColumn constructs a column with name `columnName` holding entities of type `columnType`
|
|
func EntityColumn(columnType Type, columnName string) Column {
|
|
return Column{columnName, columnType, false, true}
|
|
}
|
|
|
|
// StringColumn constructs a column with name `columnName` holding string values
|
|
func StringColumn(columnName string) Column {
|
|
return Column{columnName, STRING, false, true}
|
|
}
|
|
|
|
// IntColumn constructs a column with name `columnName` holding integer values
|
|
func IntColumn(columnName string) Column {
|
|
return Column{columnName, INT, false, true}
|
|
}
|
|
|
|
// FloatColumn constructs a column with name `columnName` holding floating point number values
|
|
func FloatColumn(columnName string) Column {
|
|
return Column{columnName, FLOAT, false, true}
|
|
}
|
|
|
|
// A Table represents a database table
|
|
type Table struct {
|
|
name string
|
|
schema []Column
|
|
keysets [][]string
|
|
}
|
|
|
|
// KeySet adds `keys` as a keyset to this table
|
|
func (tbl *Table) KeySet(keys ...string) *Table {
|
|
tbl.keysets = append(tbl.keysets, keys)
|
|
return tbl
|
|
}
|
|
|
|
func (tbl Table) String() string {
|
|
var b strings.Builder
|
|
for _, keyset := range tbl.keysets {
|
|
fmt.Fprint(&b, "#keyset[")
|
|
sep := ""
|
|
for _, key := range keyset {
|
|
fmt.Fprintf(&b, "%s%s", sep, key)
|
|
sep = ", "
|
|
}
|
|
fmt.Fprint(&b, "]\n")
|
|
}
|
|
fmt.Fprint(&b, tbl.name)
|
|
fmt.Fprint(&b, "(")
|
|
nl := 0
|
|
for i, column := range tbl.schema {
|
|
if i > 0 {
|
|
// wrap >100 char lines
|
|
if i < len(tbl.schema)-1 && b.Len()-nl > 100 {
|
|
fmt.Fprintf(&b, ",\n%s", strings.Repeat(" ", len(tbl.name)+1))
|
|
nl = b.Len()
|
|
} else {
|
|
fmt.Fprint(&b, ", ")
|
|
}
|
|
}
|
|
fmt.Fprint(&b, column.String())
|
|
}
|
|
fmt.Fprint(&b, ");")
|
|
return b.String()
|
|
}
|
|
|
|
// Emit outputs a tuple of `values` for this table using trap writer `tw`
|
|
// and panicks if the tuple does not have the right schema
|
|
func (tbl Table) Emit(tw *trap.Writer, values ...interface{}) {
|
|
if ncol, nval := len(tbl.schema), len(values); ncol != nval {
|
|
log.Fatalf("wrong number of values for table %s; expected %d, but got %d", tbl.name, ncol, nval)
|
|
}
|
|
for i, col := range tbl.schema {
|
|
if !col.columnType.valid(values[i]) {
|
|
panic(fmt.Sprintf("Invalid value for column %d of table %s; expected a %s, but got %s which is a %s", i, tbl.name, col.columnType.ref(), values[i], reflect.TypeOf(values[i])))
|
|
}
|
|
}
|
|
tw.Emit(tbl.name, values)
|
|
}
|
|
|
|
var tables = []*Table{}
|
|
var types = []Type{}
|
|
var defaultSnippets = []string{}
|
|
|
|
// NewTable constructs a new table with the given `name` and `columns`
|
|
func NewTable(name string, columns ...Column) *Table {
|
|
tbl := &Table{name, columns, [][]string{}}
|
|
tables = append(tables, tbl)
|
|
return tbl
|
|
}
|
|
|
|
// NewPrimaryKeyType constructs a new primary key type with the given `name`,
|
|
// and adds it to the union types `parents` (if any)
|
|
func NewPrimaryKeyType(name string, parents ...*UnionType) *PrimaryKeyType {
|
|
tp := &PrimaryKeyType{name}
|
|
types = append(types, tp)
|
|
for _, parent := range parents {
|
|
parent.components = append(parent.components, tp)
|
|
}
|
|
return tp
|
|
}
|
|
|
|
// NewUnionType constructs a new union type with the given `name`,
|
|
// and adds it to the union types `parents` (if any)
|
|
func NewUnionType(name string, parents ...*UnionType) *UnionType {
|
|
tp := &UnionType{name, []Type{}}
|
|
types = append(types, tp)
|
|
for _, parent := range parents {
|
|
parent.components = append(parent.components, tp)
|
|
}
|
|
return tp
|
|
}
|
|
|
|
// AddChild adds the type with given `name` to the union type.
|
|
// This is useful if a type defined in a snippet should be a child of a type defined in Go.
|
|
func (parent *UnionType) AddChild(name string) bool {
|
|
tp := &PrimaryKeyType{name}
|
|
// don't add tp to types; it's expected that it's already in the db somehow.
|
|
parent.components = append(parent.components, tp)
|
|
return true
|
|
}
|
|
|
|
// NewAliasType constructs a new alias type with the given `name` that aliases `underlying`
|
|
func NewAliasType(name string, underlying Type) *AliasType {
|
|
tp := &AliasType{name, underlying}
|
|
types = append(types, tp)
|
|
return tp
|
|
}
|
|
|
|
// NewCaseType constructs a new case type on the given `base` type whose discriminator values
|
|
// come from `column`
|
|
func NewCaseType(base Type, column string) *CaseType {
|
|
tp := &CaseType{base, column, []*BranchType{}}
|
|
types = append(types, tp)
|
|
return tp
|
|
}
|
|
|
|
// NewBranch adds a new branch with the given `name` to this case type
|
|
// and adds it to the union types `parents` (if any)
|
|
func (ct *CaseType) NewBranch(name string, parents ...*UnionType) *BranchType {
|
|
tp := &BranchType{len(ct.branches), name}
|
|
ct.branches = append(ct.branches, tp)
|
|
for _, parent := range parents {
|
|
parent.components = append(parent.components, tp)
|
|
}
|
|
return tp
|
|
}
|
|
|
|
// AddDefaultSnippet adds the given text `snippet` to the schema of this database
|
|
func AddDefaultSnippet(snippet string) bool {
|
|
defaultSnippets = append(defaultSnippets, snippet)
|
|
return true
|
|
}
|
|
|
|
// PrintDbScheme prints the schema of this database to the writer `w`
|
|
func PrintDbScheme(w io.Writer) {
|
|
fmt.Fprintf(w, "/** Auto-generated dbscheme; do not edit. */\n\n")
|
|
for _, snippet := range defaultSnippets {
|
|
fmt.Fprintf(w, "%s\n", snippet)
|
|
}
|
|
for _, table := range tables {
|
|
fmt.Fprintf(w, "%s\n\n", table.String())
|
|
}
|
|
for _, tp := range types {
|
|
def := tp.def()
|
|
if def != "" {
|
|
fmt.Fprintf(w, "%s\n\n", def)
|
|
}
|
|
}
|
|
}
|