diff --git a/ql/src/semmle/go/PrintAst.ql b/ql/src/semmle/go/PrintAst.ql index f73b2a21ad9..153f30bf06f 100644 --- a/ql/src/semmle/go/PrintAst.ql +++ b/ql/src/semmle/go/PrintAst.ql @@ -13,4 +13,6 @@ import PrintAst */ class Cfg extends PrintAstConfiguration { override predicate shouldPrintFunction(FuncDef func) { any() } + + override predicate shouldPrintFile(File file) { any() } } diff --git a/ql/src/semmle/go/PrintAst.qll b/ql/src/semmle/go/PrintAst.qll index a4bba4b06e1..f4a88fe7108 100644 --- a/ql/src/semmle/go/PrintAst.qll +++ b/ql/src/semmle/go/PrintAst.qll @@ -5,7 +5,11 @@ import go /** - * Hook to customize the functions printed by this module. + * Hook to customize the files and functions printed by this module. + * + * For an AstNode to be printed, it always requires `shouldPrintFile(f)` to hold + * for its containing file `f`, and additionally requires `shouldPrintFunction(fun)` + * if it is, or falls within, function `fun`. */ class PrintAstConfiguration extends string { /** @@ -18,19 +22,34 @@ class PrintAstConfiguration extends string { * functions. */ predicate shouldPrintFunction(FuncDef func) { any() } + + /** + * Holds if the AST for `file` should be printed. By default, holds for all + * files. + */ + predicate shouldPrintFile(File file) { any() } } private predicate shouldPrintFunction(FuncDef func) { exists(PrintAstConfiguration config | config.shouldPrintFunction(func)) } +private predicate shouldPrintFile(File file) { + exists(PrintAstConfiguration config | config.shouldPrintFile(file)) +} + +private FuncDef getEnclosingFunction(AstNode n) { + result = n or + result = n.getEnclosingFunction() +} + /** * An AST node that should be printed. */ private newtype TPrintAstNode = TAstNode(AstNode ast) { - // Do print ast nodes without an enclosing function, e.g. file headers - forall(FuncDef f | f = ast.getEnclosingFunction() | shouldPrintFunction(f)) + shouldPrintFile(ast.getFile()) and + forall(FuncDef f | f = getEnclosingFunction(ast) | shouldPrintFunction(f)) } /** diff --git a/ql/test/library-tests/semmle/go/PrintAst/PrintAst.expected b/ql/test/library-tests/semmle/go/PrintAst/PrintAst.expected index c361de446a3..3ac5a533462 100644 --- a/ql/test/library-tests/semmle/go/PrintAst/PrintAst.expected +++ b/ql/test/library-tests/semmle/go/PrintAst/PrintAst.expected @@ -1,3 +1,34 @@ +other.go: +# 0| [File] library-tests/semmle/go/PrintAst/other.go +# 3| 0: [FuncDecl] function declaration +# 3| 0: [FunctionName, Ident] main +# 3| Type = func() +# 3| 1: [FuncTypeExpr] function type +# 3| 2: [BlockStmt] block statement +# 5| 1: [FuncDecl] function declaration +# 5| 0: [FunctionName, Ident] f +# 5| Type = func() +# 5| 1: [FuncTypeExpr] function type +# 5| 2: [BlockStmt] block statement +# 6| 2: [FuncDecl] function declaration +# 6| 0: [FunctionName, Ident] g +# 6| Type = func() +# 6| 1: [FuncTypeExpr] function type +# 6| 2: [BlockStmt] block statement +# 8| 3: [VarDecl] variable declaration +# 8| 0: [ValueSpec] value declaration specifier +# 8| 0: [Ident, VariableName] x +# 8| Type = int +# 8| 1: [Ident, TypeName] int +# 8| Type = int +# 8| 2: [IntLit] 0 +# 8| Type = int +# 8| Value = [IntLit] 0 +# 1| 4: [Ident] main +go.mod: +# 0| [GoModFile] library-tests/semmle/go/PrintAst/go.mod +# 1| 0: [GoModModuleLine] go.mod module line +# 3| 1: [GoModGoLine] go.mod go line input.go: # 0| [File] library-tests/semmle/go/PrintAst/input.go # 5| 0: [CommentGroup] comment group diff --git a/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFile.expected b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFile.expected new file mode 100644 index 00000000000..20ffb795d94 --- /dev/null +++ b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFile.expected @@ -0,0 +1,27 @@ +other.go: +# 0| [File] library-tests/semmle/go/PrintAst/other.go +# 3| 0: [FuncDecl] function declaration +# 3| 0: [FunctionName, Ident] main +# 3| Type = func() +# 3| 1: [FuncTypeExpr] function type +# 3| 2: [BlockStmt] block statement +# 5| 1: [FuncDecl] function declaration +# 5| 0: [FunctionName, Ident] f +# 5| Type = func() +# 5| 1: [FuncTypeExpr] function type +# 5| 2: [BlockStmt] block statement +# 6| 2: [FuncDecl] function declaration +# 6| 0: [FunctionName, Ident] g +# 6| Type = func() +# 6| 1: [FuncTypeExpr] function type +# 6| 2: [BlockStmt] block statement +# 8| 3: [VarDecl] variable declaration +# 8| 0: [ValueSpec] value declaration specifier +# 8| 0: [Ident, VariableName] x +# 8| Type = int +# 8| 1: [Ident, TypeName] int +# 8| Type = int +# 8| 2: [IntLit] 0 +# 8| Type = int +# 8| Value = [IntLit] 0 +# 1| 4: [Ident] main diff --git a/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFile.ql b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFile.ql new file mode 100644 index 00000000000..c0c7e57abc6 --- /dev/null +++ b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFile.ql @@ -0,0 +1,12 @@ +/** + * @kind graph + */ + +import go +import semmle.go.PrintAst + +class Cfg extends PrintAstConfiguration { + override predicate shouldPrintFunction(FuncDef func) { any() } + + override predicate shouldPrintFile(File file) { file.getBaseName() = "other.go" } +} diff --git a/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFunction.expected b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFunction.expected new file mode 100644 index 00000000000..b64578d6e31 --- /dev/null +++ b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFunction.expected @@ -0,0 +1,47 @@ +other.go: +# 0| [File] library-tests/semmle/go/PrintAst/other.go +# 6| 2: [FuncDecl] function declaration +# 6| 0: [FunctionName, Ident] g +# 6| Type = func() +# 6| 1: [FuncTypeExpr] function type +# 6| 2: [BlockStmt] block statement +# 8| 3: [VarDecl] variable declaration +# 8| 0: [ValueSpec] value declaration specifier +# 8| 0: [Ident, VariableName] x +# 8| Type = int +# 8| 1: [Ident, TypeName] int +# 8| Type = int +# 8| 2: [IntLit] 0 +# 8| Type = int +# 8| Value = [IntLit] 0 +# 1| 4: [Ident] main +go.mod: +# 0| [GoModFile] library-tests/semmle/go/PrintAst/go.mod +# 1| 0: [GoModModuleLine] go.mod module line +# 3| 1: [GoModGoLine] go.mod go line +input.go: +# 0| [File] library-tests/semmle/go/PrintAst/input.go +# 5| 0: [CommentGroup] comment group +# 5| 0: [SlashSlashComment] comment +# 7| 1: [CommentGroup] comment group +# 7| 0: [SlashSlashComment] comment +# 9| 2: [DocComment] comment group +# 9| 0: [SlashSlashComment] comment +# 17| 3: [CommentGroup] comment group +# 17| 0: [SlashSlashComment] comment +# 45| 4: [DocComment] comment group +# 45| 0: [SlashSlashComment] comment +# 64| 5: [DocComment] comment group +# 64| 0: [SlashSlashComment] comment +# 74| 6: [DocComment] comment group +# 74| 0: [SlashSlashComment] comment +# 111| 7: [DocComment] comment group +# 111| 0: [SlashSlashComment] comment +# 127| 8: [DocComment] comment group +# 127| 0: [SlashSlashComment] comment +# 132| 9: [DocComment] comment group +# 132| 0: [SlashSlashComment] comment +# 3| 10: [ImportDecl] import declaration +# 3| 0: [ImportSpec] import specifier +# 3| 0: [StringLit] "fmt" +# 1| 18: [Ident] main diff --git a/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFunction.ql b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFunction.ql new file mode 100644 index 00000000000..83571f02c0d --- /dev/null +++ b/ql/test/library-tests/semmle/go/PrintAst/PrintAstRestrictFunction.ql @@ -0,0 +1,12 @@ +/** + * @kind graph + */ + +import go +import semmle.go.PrintAst + +class Cfg extends PrintAstConfiguration { + override predicate shouldPrintFunction(FuncDef func) { func.getName() = "g" } + + override predicate shouldPrintFile(File file) { any() } +} diff --git a/ql/test/library-tests/semmle/go/PrintAst/go.mod b/ql/test/library-tests/semmle/go/PrintAst/go.mod new file mode 100644 index 00000000000..c3149650c4c --- /dev/null +++ b/ql/test/library-tests/semmle/go/PrintAst/go.mod @@ -0,0 +1,4 @@ +module codeql-go-tests/printast + +go 1.14 + diff --git a/ql/test/library-tests/semmle/go/PrintAst/other.go b/ql/test/library-tests/semmle/go/PrintAst/other.go new file mode 100644 index 00000000000..8a03dea095d --- /dev/null +++ b/ql/test/library-tests/semmle/go/PrintAst/other.go @@ -0,0 +1,8 @@ +package main + +func main() {} + +func f() {} +func g() {} + +var x int = 0