Move model generation functions to language definition

This commit is contained in:
Koen Vlaswinkel
2023-11-01 14:52:58 +01:00
parent 2a477140a6
commit 71f59b19b4
14 changed files with 352 additions and 295 deletions

View File

@@ -1,5 +1,6 @@
export * from "./local-queries";
export * from "./local-query-run";
export * from "./query-constraints";
export * from "./query-resolver";
export * from "./quick-eval-code-lens-provider";
export * from "./quick-query";

View File

@@ -0,0 +1,7 @@
export interface QueryConstraints {
kind?: string;
"tags contain"?: string[];
"tags contain all"?: string[];
"query filename"?: string;
"query path"?: string;
}

View File

@@ -14,6 +14,7 @@ import { showAndLogExceptionWithTelemetry } from "../common/logging";
import { extLogger } from "../common/logging/vscode";
import { telemetryListener } from "../common/vscode/telemetry";
import { SuiteInstruction } from "../packaging/suite-instruction";
import { QueryConstraints } from "./query-constraints";
export async function qlpackOfDatabase(
cli: Pick<CodeQLCliServer, "resolveQlpacks">,
@@ -27,14 +28,6 @@ export async function qlpackOfDatabase(
return await getQlPackForDbscheme(cli, dbscheme);
}
export interface QueryConstraints {
kind?: string;
"tags contain"?: string[];
"tags contain all"?: string[];
"query filename"?: string;
"query path"?: string;
}
/**
* Finds the queries with the specified kind and tags in a list of CodeQL packs.
*

View File

@@ -1,109 +0,0 @@
import { basename } from "path";
import { BaseLogger } from "../common/logging";
import {
getModelsAsDataLanguage,
ModelsAsDataLanguage,
ModelsAsDataLanguageModelType,
} from "./languages";
import { ModeledMethod } from "./modeled-method";
import { QueryLanguage } from "../common/query-language";
import { GenerateQueriesOptions, runGenerateQueries } from "./generate";
import { DecodedBqrs } from "../common/bqrs-cli-types";
const FLOW_MODEL_SUPPORTED_LANGUAGES = [
QueryLanguage.CSharp,
QueryLanguage.Java,
];
export function isFlowModelGenerationSupported(
language: QueryLanguage,
): boolean {
return FLOW_MODEL_SUPPORTED_LANGUAGES.includes(language);
}
type FlowModelOptions = GenerateQueriesOptions & {
logger: BaseLogger;
language: QueryLanguage;
};
const queriesToModel: Record<string, ModelsAsDataLanguageModelType> = {
"CaptureSummaryModels.ql": "summary",
"CaptureSinkModels.ql": "sink",
"CaptureSourceModels.ql": "source",
"CaptureNeutralModels.ql": "neutral",
};
function parseFlowModelResults(
queryPath: string,
bqrs: DecodedBqrs,
modelsAsDataLanguage: ModelsAsDataLanguage,
logger: BaseLogger,
): ModeledMethod[] {
if (Object.keys(bqrs).length !== 1) {
throw new Error(
`Expected exactly one result set from ${queryPath}, but got ${
Object.keys(bqrs).length
}`,
);
}
const modelType = queriesToModel[basename(queryPath)];
if (!modelType) {
void logger.log(`Unknown model type for ${queryPath}`);
return [];
}
const resultSet = bqrs[Object.keys(bqrs)[0]];
const results = resultSet.tuples;
const definition = modelsAsDataLanguage.predicates[modelType];
if (!definition) {
throw new Error(`No definition for ${modelType}`);
}
return (
results
// This is just a sanity check. The query should only return strings.
.filter((result) => typeof result[0] === "string")
.map((result) => {
const row = result[0] as string;
return definition.readModeledMethod(row.split(";"));
})
);
}
export async function runFlowModelQueries({
cliServer,
queryRunner,
logger,
queryStorageDir,
databaseItem,
language,
progress,
token,
onResults,
}: FlowModelOptions) {
const modelsAsDataLanguage = getModelsAsDataLanguage(language);
return runGenerateQueries(
{
queryConstraints: {
"tags contain": ["modelgenerator"],
},
filterQueries: (queryPath) => basename(queryPath) in queriesToModel,
parseResults: (queryPath, results) =>
parseFlowModelResults(queryPath, results, modelsAsDataLanguage, logger),
},
{
cliServer,
queryRunner,
queryStorageDir,
databaseItem,
progress,
token,
onResults,
},
);
}

View File

@@ -1,81 +0,0 @@
import { BaseLogger, NotificationLogger } from "../common/logging";
import { getModelsAsDataLanguage, ModelsAsDataLanguage } from "./languages";
import { ModeledMethod } from "./modeled-method";
import { QueryLanguage } from "../common/query-language";
import { DataTuple } from "./model-extension-file";
import { GenerateQueriesOptions, runGenerateQueries } from "./generate";
import { DecodedBqrs } from "../common/bqrs-cli-types";
const GENERATE_MODEL_SUPPORTED_LANGUAGES = [QueryLanguage.Ruby];
export function isGenerateModelSupported(language: QueryLanguage): boolean {
return GENERATE_MODEL_SUPPORTED_LANGUAGES.includes(language);
}
type GenerateModelOptions = GenerateQueriesOptions & {
logger: NotificationLogger;
language: QueryLanguage;
};
function parseGenerateModelResults(
bqrs: DecodedBqrs,
modelsAsDataLanguage: ModelsAsDataLanguage,
logger: BaseLogger,
): ModeledMethod[] {
const modeledMethods: ModeledMethod[] = [];
for (const resultSetName in bqrs) {
const definition = Object.values(modelsAsDataLanguage.predicates).find(
(definition) => definition.extensiblePredicate === resultSetName,
);
if (definition === undefined) {
void logger.log(`No predicate found for ${resultSetName}`);
continue;
}
const resultSet = bqrs[resultSetName];
if (
resultSet.tuples.some((tuple) =>
tuple.some((value) => typeof value === "object"),
)
) {
void logger.log(
`Skipping ${resultSetName} because it contains undefined values`,
);
continue;
}
modeledMethods.push(
...resultSet.tuples.map((tuple) => {
const row = tuple.filter(
(value): value is DataTuple => typeof value !== "object",
);
return definition.readModeledMethod(row);
}),
);
}
return modeledMethods;
}
export async function runGenerateModelQuery({
logger,
language,
...options
}: GenerateModelOptions) {
const modelsAsDataLanguage = getModelsAsDataLanguage(language);
return runGenerateQueries(
{
queryConstraints: {
"query path": "queries/modeling/GenerateModel.ql",
},
parseResults: (_queryPath, results) =>
parseGenerateModelResults(results, modelsAsDataLanguage, logger),
},
options,
);
}

View File

@@ -25,7 +25,7 @@ type GenerateQueriesQueryOptions = {
/**
* Options that are passed through by the caller of `runGenerateQueries`.
*/
export type GenerateQueriesOptions = {
type GenerateQueriesOptions = {
cliServer: CodeQLCliServer;
queryRunner: QueryRunner;
queryStorageDir: string;

View File

@@ -2,6 +2,9 @@ import { MethodDefinition } from "../method";
import { ModeledMethod, ModeledMethodType } from "../modeled-method";
import { DataTuple } from "../model-extension-file";
import { Mode } from "../shared/mode";
import type { QueryConstraints } from "../../local-queries/query-constraints";
import { DecodedBqrs } from "../../common/bqrs-cli-types";
import { BaseLogger } from "../../common/logging";
type GenerateMethodDefinition = (method: ModeledMethod) => DataTuple[];
type ReadModeledMethod = (row: DataTuple[]) => ModeledMethod;
@@ -15,6 +18,22 @@ export type ModelsAsDataLanguagePredicate = {
readModeledMethod: ReadModeledMethod;
};
type ModelsAsDataLanguageModelGeneration = {
queryConstraints: QueryConstraints;
filterQueries?: (queryPath: string) => boolean;
parseResults: (
// The path to the query that generated the results.
queryPath: string,
// The results of the query.
bqrs: DecodedBqrs,
// The language-specific predicate that was used to generate the results. This is passed to allow
// sharing of code between different languages.
modelsAsDataLanguage: ModelsAsDataLanguage,
// The logger to use for logging.
logger: BaseLogger,
) => ModeledMethod[];
};
export type ModelsAsDataLanguagePredicates = Record<
ModelsAsDataLanguageModelType,
ModelsAsDataLanguagePredicate
@@ -28,4 +47,5 @@ export type ModelsAsDataLanguage = {
availableModes?: Mode[];
createMethodSignature: (method: MethodDefinition) => string;
predicates: ModelsAsDataLanguagePredicates;
modelGeneration?: ModelsAsDataLanguageModelGeneration;
};

View File

@@ -0,0 +1,50 @@
import { BaseLogger } from "../../../common/logging";
import { DecodedBqrs } from "../../../common/bqrs-cli-types";
import { ModelsAsDataLanguage } from "../models-as-data";
import { ModeledMethod } from "../../modeled-method";
import { DataTuple } from "../../model-extension-file";
export function parseGenerateModelResults(
_queryPath: string,
bqrs: DecodedBqrs,
modelsAsDataLanguage: ModelsAsDataLanguage,
logger: BaseLogger,
): ModeledMethod[] {
const modeledMethods: ModeledMethod[] = [];
for (const resultSetName in bqrs) {
const definition = Object.values(modelsAsDataLanguage.predicates).find(
(definition) => definition.extensiblePredicate === resultSetName,
);
if (definition === undefined) {
void logger.log(`No predicate found for ${resultSetName}`);
continue;
}
const resultSet = bqrs[resultSetName];
if (
resultSet.tuples.some((tuple) =>
tuple.some((value) => typeof value === "object"),
)
) {
void logger.log(
`Skipping ${resultSetName} because it contains undefined values`,
);
continue;
}
modeledMethods.push(
...resultSet.tuples.map((tuple) => {
const row = tuple.filter(
(value): value is DataTuple => typeof value !== "object",
);
return definition.readModeledMethod(row);
}),
);
}
return modeledMethods;
}

View File

@@ -1,6 +1,7 @@
import { ModelsAsDataLanguage } from "../models-as-data";
import { sharedExtensiblePredicates, sharedKinds } from "../shared";
import { Mode } from "../../shared/mode";
import { parseGenerateModelResults } from "./generate";
function parseRubyMethodFromPath(path: string): string {
const match = path.match(/Method\[([^\]]+)].*/);
@@ -150,4 +151,10 @@ export const ruby: ModelsAsDataLanguage = {
},
},
},
modelGeneration: {
queryConstraints: {
"query path": "queries/modeling/GenerateModel.ql",
},
parseResults: parseGenerateModelResults,
},
};

View File

@@ -0,0 +1,60 @@
import { BaseLogger } from "../../../common/logging";
import {
ModelsAsDataLanguage,
ModelsAsDataLanguageModelType,
} from "../models-as-data";
import { DecodedBqrs } from "../../../common/bqrs-cli-types";
import { ModeledMethod } from "../../modeled-method";
import { basename } from "../../../common/path";
const queriesToModel: Record<string, ModelsAsDataLanguageModelType> = {
"CaptureSummaryModels.ql": "summary",
"CaptureSinkModels.ql": "sink",
"CaptureSourceModels.ql": "source",
"CaptureNeutralModels.ql": "neutral",
};
export function filterFlowModelQueries(queryPath: string): boolean {
return Object.keys(queriesToModel).includes(basename(queryPath));
}
export function parseFlowModelResults(
queryPath: string,
bqrs: DecodedBqrs,
modelsAsDataLanguage: ModelsAsDataLanguage,
logger: BaseLogger,
): ModeledMethod[] {
if (Object.keys(bqrs).length !== 1) {
throw new Error(
`Expected exactly one result set from ${queryPath}, but got ${
Object.keys(bqrs).length
}`,
);
}
const modelType = queriesToModel[basename(queryPath)];
if (!modelType) {
void logger.log(`Unknown model type for ${queryPath}`);
return [];
}
const resultSet = bqrs[Object.keys(bqrs)[0]];
const results = resultSet.tuples;
const definition = modelsAsDataLanguage.predicates[modelType];
if (!definition) {
throw new Error(`No definition for ${modelType}`);
}
return (
results
// This is just a sanity check. The query should only return strings.
.filter((result) => typeof result[0] === "string")
.map((result) => {
const row = result[0] as string;
return definition.readModeledMethod(row.split(";"));
})
);
}

View File

@@ -2,6 +2,7 @@ import { ModelsAsDataLanguage } from "../models-as-data";
import { ModeledMethodType, Provenance } from "../../modeled-method";
import { DataTuple } from "../../model-extension-file";
import { sharedExtensiblePredicates, sharedKinds } from "../shared";
import { filterFlowModelQueries, parseFlowModelResults } from "./generate";
function readRowToMethod(row: DataTuple[]): string {
return `${row[0]}.${row[1]}#${row[3]}${row[4]}`;
@@ -137,4 +138,11 @@ export const staticLanguage: ModelsAsDataLanguage = {
}),
},
},
modelGeneration: {
queryConstraints: {
"tags contain": ["modelgenerator"],
},
filterQueries: filterFlowModelQueries,
parseResults: parseFlowModelResults,
},
};

View File

@@ -23,10 +23,6 @@ import {
import { DatabaseItem, DatabaseManager } from "../databases/local-databases";
import { CodeQLCliServer } from "../codeql-cli/cli";
import { asError, assertNever, getErrorMessage } from "../common/helpers-pure";
import {
isFlowModelGenerationSupported,
runFlowModelQueries,
} from "./flow-model-queries";
import { promptImportGithubDatabase } from "../databases/database-fetcher";
import { App } from "../common/app";
import { redactableError } from "../common/errors";
@@ -51,10 +47,7 @@ import { ModelingStore } from "./modeling-store";
import { ModelEditorViewTracker } from "./model-editor-view-tracker";
import { ModelingEvents } from "./modeling-events";
import { getModelsAsDataLanguage, ModelsAsDataLanguage } from "./languages";
import {
isGenerateModelSupported,
runGenerateModelQuery,
} from "./generate-model-queries";
import { runGenerateQueries } from "./generate";
export class ModelEditorView extends AbstractWebview<
ToModelEditorMessage,
@@ -270,11 +263,8 @@ export class ModelEditorView extends AbstractWebview<
break;
case "generateMethod":
if (isFlowModelGenerationSupported(this.language)) {
await this.generateModeledMethodsFromFlow();
} else if (isGenerateModelSupported(this.language)) {
await this.generateModeledMethodsFromGenerateModel();
}
await this.generateModeledMethods();
void telemetryListener?.sendUIInteraction(
"model-editor-generate-modeled-methods",
);
@@ -377,10 +367,10 @@ export class ModelEditorView extends AbstractWebview<
}
private async setViewState(): Promise<void> {
const modelsAsDataLanguage = getModelsAsDataLanguage(this.language);
const showGenerateButton =
this.modelConfig.flowGeneration &&
(isFlowModelGenerationSupported(this.language) ||
isGenerateModelSupported(this.language));
this.modelConfig.flowGeneration && !!modelsAsDataLanguage.modelGeneration;
const showLlmButton =
this.databaseItem.language === "java" && this.modelConfig.llmGeneration;
@@ -474,13 +464,23 @@ export class ModelEditorView extends AbstractWebview<
}
}
protected async generateModeledMethodsFromFlow(): Promise<void> {
protected async generateModeledMethods(): Promise<void> {
await withProgress(
async (progress) => {
const tokenSource = new CancellationTokenSource();
const mode = this.modelingStore.getMode(this.databaseItem);
const modelsAsDataLanguage = getModelsAsDataLanguage(this.language);
const modelGeneration = modelsAsDataLanguage.modelGeneration;
if (!modelGeneration) {
void showAndLogErrorMessage(
this.app.logger,
`Model generation is not supported for ${this.language}.`,
);
return;
}
let addedDatabase: DatabaseItem | undefined;
// In application mode, we need the database of a specific library to generate
@@ -509,52 +509,30 @@ export class ModelEditorView extends AbstractWebview<
});
try {
await runFlowModelQueries({
cliServer: this.cliServer,
queryRunner: this.queryRunner,
logger: this.app.logger,
queryStorageDir: this.queryStorageDir,
databaseItem: addedDatabase ?? this.databaseItem,
language: this.language,
onResults: async (modeledMethods) => {
this.addModeledMethodsFromArray(modeledMethods);
await runGenerateQueries(
{
queryConstraints: modelGeneration.queryConstraints,
filterQueries: modelGeneration.filterQueries,
parseResults: (queryPath, results) =>
modelGeneration.parseResults(
queryPath,
results,
modelsAsDataLanguage,
this.app.logger,
),
},
{
cliServer: this.cliServer,
queryRunner: this.queryRunner,
queryStorageDir: this.queryStorageDir,
databaseItem: addedDatabase ?? this.databaseItem,
onResults: async (modeledMethods) => {
this.addModeledMethodsFromArray(modeledMethods);
},
progress,
token: tokenSource.token,
},
progress,
token: tokenSource.token,
});
} catch (e: unknown) {
void showAndLogExceptionWithTelemetry(
this.app.logger,
this.app.telemetry,
redactableError(
asError(e),
)`Failed to generate flow model: ${getErrorMessage(e)}`,
);
}
},
{ cancellable: false },
);
}
protected async generateModeledMethodsFromGenerateModel(): Promise<void> {
await withProgress(
async (progress) => {
const tokenSource = new CancellationTokenSource();
try {
await runGenerateModelQuery({
cliServer: this.cliServer,
queryRunner: this.queryRunner,
logger: this.app.logger,
queryStorageDir: this.queryStorageDir,
databaseItem: this.databaseItem,
language: this.language,
progress,
token: tokenSource.token,
onResults: async (modeledMethods) => {
this.addModeledMethodsFromArray(modeledMethods);
},
});
} catch (e: unknown) {
void showAndLogExceptionWithTelemetry(
this.app.logger,

View File

@@ -0,0 +1,107 @@
import { DecodedBqrs } from "../../../../../src/common/bqrs-cli-types";
import { parseGenerateModelResults } from "../../../../../src/model-editor/languages/ruby/generate";
import { ruby } from "../../../../../src/model-editor/languages/ruby";
import { createMockLogger } from "../../../../__mocks__/loggerMock";
describe("parseGenerateModelResults", () => {
it("should return the results", async () => {
const bqrs: DecodedBqrs = {
sourceModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [],
},
sinkModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [],
},
typeVariableModel: {
columns: [
{ name: "name", kind: "String" },
{ name: "path", kind: "String" },
],
tuples: [],
},
typeModel: {
columns: [
{ name: "type1", kind: "String" },
{ name: "type2", kind: "String" },
{ name: "path", kind: "String" },
],
tuples: [
["Array", "SQLite3::ResultSet", "Method[types].ReturnValue"],
["Array", "SQLite3::ResultSet", "Method[columns].ReturnValue"],
["Array", "SQLite3::Statement", "Method[types].ReturnValue"],
["Array", "SQLite3::Statement", "Method[columns].ReturnValue"],
],
},
summaryModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "input", kind: "String" },
{ name: "output", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [
[
"SQLite3::Database",
"Method[create_function]",
"Argument[self]",
"ReturnValue",
"value",
],
[
"SQLite3::Value!",
"Method[new]",
"Argument[1]",
"ReturnValue",
"value",
],
],
},
};
const result = parseGenerateModelResults(
"/a/b/c/query.ql",
bqrs,
ruby,
createMockLogger(),
);
expect(result.sort()).toEqual(
[
{
input: "Argument[self]",
kind: "value",
methodName: "create_function",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Database#create_function",
type: "summary",
typeName: "SQLite3::Database",
},
{
input: "Argument[1]",
kind: "value",
methodName: "new",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Value!#new",
type: "summary",
typeName: "SQLite3::Value!",
},
].sort(),
);
});
});

View File

@@ -5,23 +5,29 @@ import {
} from "../../../../src/databases/local-databases";
import { file } from "tmp-promise";
import { QueryResultType } from "../../../../src/query-server/new-messages";
import { Mode } from "../../../../src/model-editor/shared/mode";
import { mockedObject, mockedUri } from "../../utils/mocking.helpers";
import { CodeQLCliServer } from "../../../../src/codeql-cli/cli";
import { QueryRunner } from "../../../../src/query-server";
import { join } from "path";
import { CancellationTokenSource } from "vscode-jsonrpc";
import { QueryOutputDir } from "../../../../src/run-queries-shared";
import { runGenerateModelQuery } from "../../../../src/model-editor/generate-model-queries";
import { QueryLanguage } from "../../../../src/common/query-language";
import { runGenerateQueries } from "../../../../src/model-editor/generate";
import { ruby } from "../../../../src/model-editor/languages/ruby";
describe("runGenerateQueries", () => {
const modelsAsDataLanguage = ruby;
const modelGeneration = modelsAsDataLanguage.modelGeneration;
if (!modelGeneration) {
throw new Error("Test requires a model generation step");
}
describe("runGenerateModelQuery", () => {
it("should run the query and return the results", async () => {
const queryStorageDir = (await file()).path;
const outputDir = new QueryOutputDir(join(queryStorageDir, "1"));
const onResults = jest.fn();
const options = {
mode: Mode.Application,
cliServer: mockedObject<CodeQLCliServer>({
resolveQueriesInSuite: jest
.fn()
@@ -100,7 +106,6 @@ describe("runGenerateModelQuery", () => {
}),
logger: createMockLogger(),
}),
logger: createMockLogger(),
databaseItem: mockedObject<DatabaseItem>({
databaseUri: mockedUri("/a/b/c/src.zip"),
contents: {
@@ -114,41 +119,52 @@ describe("runGenerateModelQuery", () => {
.mockResolvedValue("/home/runner/work/my-repo/my-repo"),
sourceArchive: mockedUri("/a/b/c/src.zip"),
}),
language: QueryLanguage.Ruby,
queryStorageDir: "/tmp/queries",
progress: jest.fn(),
token: new CancellationTokenSource().token,
onResults,
};
const result = await runGenerateModelQuery(options);
expect(result.sort()).toEqual(
[
{
input: "Argument[self]",
kind: "value",
methodName: "create_function",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Database#create_function",
type: "summary",
typeName: "SQLite3::Database",
},
{
input: "Argument[1]",
kind: "value",
methodName: "new",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Value!#new",
type: "summary",
typeName: "SQLite3::Value!",
},
].sort(),
await runGenerateQueries(
{
queryConstraints: modelGeneration.queryConstraints,
filterQueries: modelGeneration.filterQueries,
parseResults: (queryPath, results) =>
modelGeneration.parseResults(
queryPath,
results,
modelsAsDataLanguage,
createMockLogger(),
),
},
options,
);
expect(onResults).toHaveBeenCalledWith([
{
input: "Argument[self]",
kind: "value",
methodName: "create_function",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Database#create_function",
type: "summary",
typeName: "SQLite3::Database",
},
{
input: "Argument[1]",
kind: "value",
methodName: "new",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Value!#new",
type: "summary",
typeName: "SQLite3::Value!",
},
]);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledTimes(1);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledWith(