Add generation of Ruby models

This adds the ability to generate Ruby models from a database. It uses
the `GenerateModel.ql` query to do this. The query will essentially
return data in the data extensions format, so this will just parse it
and return it as `ModeledMethod` objects.
This commit is contained in:
Koen Vlaswinkel
2023-10-31 10:19:18 +01:00
parent 135bce889e
commit 474ec197a0
6 changed files with 411 additions and 4 deletions

View File

@@ -10,7 +10,11 @@ import tk from "tree-kill";
import { promisify } from "util";
import { CancellationToken, Disposable, Uri } from "vscode";
import { BQRSInfo, DecodedBqrsChunk } from "../common/bqrs-cli-types";
import {
BQRSInfo,
DecodedBqrs,
DecodedBqrsChunk,
} from "../common/bqrs-cli-types";
import { allowCanaryQueryServer, CliConfig } from "../config";
import {
DistributionProvider,
@@ -1040,6 +1044,18 @@ export class CodeQLCliServer implements Disposable {
);
}
/**
* Gets all results from a bqrs.
* @param bqrsPath The path to the bqrs.
*/
async bqrsDecodeAll(bqrsPath: string): Promise<DecodedBqrs> {
return await this.runJsonCodeQlCliCommand<DecodedBqrs>(
["bqrs", "decode"],
[bqrsPath],
"Reading all bqrs data",
);
}
async runInterpretCommand(
format: string,
additonalArgs: string[],

View File

@@ -121,3 +121,5 @@ export interface DecodedBqrsChunk {
next?: number;
columns: BqrsColumn[];
}
export type DecodedBqrs = Record<string, DecodedBqrsChunk>;

View File

@@ -31,6 +31,8 @@ export interface QueryConstraints {
kind?: string;
"tags contain"?: string[];
"tags contain all"?: string[];
"query filename"?: string;
"query path"?: string;
}
/**
@@ -132,6 +134,14 @@ export async function resolveQueries(
`tagged all of "${constraints["tags contain all"].join(" ")}"`,
);
}
if (constraints["query filename"] !== undefined) {
humanConstraints.push(
`with query filename "${constraints["query filename"]}"`,
);
}
if (constraints["query path"] !== undefined) {
humanConstraints.push(`with query path "${constraints["query path"]}"`);
}
const joinedPacksToSearch = packsToSearch.join(", ");
const error = redactableError`No ${name} queries (${humanConstraints.join(

View File

@@ -0,0 +1,158 @@
import { CancellationToken } from "vscode";
import { DatabaseItem } from "../databases/local-databases";
import { QueryRunner } from "../query-server";
import { CodeQLCliServer } from "../codeql-cli/cli";
import {
BaseLogger,
showAndLogExceptionWithTelemetry,
} from "../common/logging";
import { extLogger } from "../common/logging/vscode";
import { getModelsAsDataLanguage } from "./languages";
import { ProgressCallback } from "../common/vscode/progress";
import { getOnDiskWorkspaceFolders } from "../common/vscode/workspace-folders";
import { ModeledMethod } from "./modeled-method";
import { redactableError } from "../common/errors";
import { telemetryListener } from "../common/vscode/telemetry";
import { runQuery } from "../local-queries/run-query";
import { resolveQueries } from "../local-queries";
import { QueryLanguage } from "../common/query-language";
import { DataTuple } from "./model-extension-file";
const GENERATE_MODEL_SUPPORTED_LANGUAGES = [QueryLanguage.Ruby];
export function isGenerateModelSupported(language: QueryLanguage): boolean {
return GENERATE_MODEL_SUPPORTED_LANGUAGES.includes(language);
}
type GenerateModelOptions = {
cliServer: CodeQLCliServer;
queryRunner: QueryRunner;
logger: BaseLogger;
queryStorageDir: string;
databaseItem: DatabaseItem;
language: QueryLanguage;
progress: ProgressCallback;
token: CancellationToken;
};
// resolve (100) + query (1000) + interpret (100)
const maxStep = 1200;
export async function runGenerateModelQuery({
cliServer,
queryRunner,
logger,
queryStorageDir,
databaseItem,
language,
progress,
token,
}: GenerateModelOptions): Promise<ModeledMethod[]> {
progress({
message: "Resolving generate model query",
step: 100,
maxStep,
});
const queryPath = await resolveGenerateModelQuery(cliServer, databaseItem);
if (queryPath === undefined) {
return [];
}
// Run the query
const completedQuery = await runQuery({
queryRunner,
databaseItem,
queryPath,
queryStorageDir,
additionalPacks: getOnDiskWorkspaceFolders(),
extensionPacks: undefined,
progress: ({ step, message }) =>
progress({
message: `Generating models: ${message}`,
step: 100 + step,
maxStep,
}),
token,
});
if (!completedQuery) {
return [];
}
progress({
message: "Decoding results",
step: 1100,
maxStep,
});
const decodedBqrs = await cliServer.bqrsDecodeAll(
completedQuery.outputDir.bqrsPath,
);
const modelsAsDataLanguage = getModelsAsDataLanguage(language);
const modeledMethods: ModeledMethod[] = [];
for (const resultSetName in decodedBqrs) {
const definition = Object.values(modelsAsDataLanguage.predicates).find(
(definition) => definition.extensiblePredicate === resultSetName,
);
if (definition === undefined) {
void logger.log(`No predicate found for ${resultSetName}`);
continue;
}
const resultSet = decodedBqrs[resultSetName];
if (
resultSet.tuples.some((tuple) =>
tuple.some((value) => typeof value === "object"),
)
) {
void logger.log(
`Skipping ${resultSetName} because it contains undefined values`,
);
continue;
}
modeledMethods.push(
...resultSet.tuples.map((tuple) => {
const row = tuple.filter(
(value): value is DataTuple => typeof value !== "object",
);
return definition.readModeledMethod(row);
}),
);
}
return modeledMethods;
}
async function resolveGenerateModelQuery(
cliServer: CodeQLCliServer,
databaseItem: DatabaseItem,
): Promise<string | undefined> {
const packsToSearch = [`codeql/${databaseItem.language}-queries`];
const queries = await resolveQueries(
cliServer,
packsToSearch,
"generate model",
{
"query path": "queries/modeling/GenerateModel.ql",
},
);
if (queries.length !== 1) {
void showAndLogExceptionWithTelemetry(
extLogger,
telemetryListener,
redactableError`Expected exactly one generate model query, got ${queries.length}`,
);
return undefined;
}
return queries[0];
}

View File

@@ -51,6 +51,10 @@ import { ModelingStore } from "./modeling-store";
import { ModelEditorViewTracker } from "./model-editor-view-tracker";
import { ModelingEvents } from "./modeling-events";
import { getModelsAsDataLanguage, ModelsAsDataLanguage } from "./languages";
import {
isGenerateModelSupported,
runGenerateModelQuery,
} from "./generate-model-queries";
export class ModelEditorView extends AbstractWebview<
ToModelEditorMessage,
@@ -266,7 +270,11 @@ export class ModelEditorView extends AbstractWebview<
break;
case "generateMethod":
await this.generateModeledMethods();
if (isFlowModelGenerationSupported(this.language)) {
await this.generateModeledMethodsFromFlow();
} else if (isGenerateModelSupported(this.language)) {
await this.generateModeledMethodsFromGenerateModel();
}
void telemetryListener?.sendUIInteraction(
"model-editor-generate-modeled-methods",
);
@@ -371,7 +379,8 @@ export class ModelEditorView extends AbstractWebview<
private async setViewState(): Promise<void> {
const showFlowGeneration =
this.modelConfig.flowGeneration &&
isFlowModelGenerationSupported(this.language);
(isFlowModelGenerationSupported(this.language) ||
isGenerateModelSupported(this.language));
const showLlmButton =
this.databaseItem.language === "java" && this.modelConfig.llmGeneration;
@@ -464,7 +473,7 @@ export class ModelEditorView extends AbstractWebview<
}
}
protected async generateModeledMethods(): Promise<void> {
protected async generateModeledMethodsFromFlow(): Promise<void> {
await withProgress(
async (progress) => {
const tokenSource = new CancellationTokenSource();
@@ -537,6 +546,48 @@ export class ModelEditorView extends AbstractWebview<
);
}
protected async generateModeledMethodsFromGenerateModel(): Promise<void> {
await withProgress(
async (progress) => {
const tokenSource = new CancellationTokenSource();
try {
const modeledMethods = await runGenerateModelQuery({
cliServer: this.cliServer,
queryRunner: this.queryRunner,
logger: this.app.logger,
queryStorageDir: this.queryStorageDir,
databaseItem: this.databaseItem,
language: this.language,
progress,
token: tokenSource.token,
});
const modeledMethodsByName: Record<string, ModeledMethod[]> = {};
for (const modeledMethod of modeledMethods) {
if (!(modeledMethod.signature in modeledMethodsByName)) {
modeledMethodsByName[modeledMethod.signature] = [];
}
modeledMethodsByName[modeledMethod.signature].push(modeledMethod);
}
this.addModeledMethods(modeledMethodsByName);
} catch (e: unknown) {
void showAndLogExceptionWithTelemetry(
this.app.logger,
this.app.telemetry,
redactableError(
asError(e),
)`Failed to generate models: ${getErrorMessage(e)}`,
);
}
},
{ cancellable: false },
);
}
private async generateModeledMethodsFromLlm(
packageName: string,
methodSignatures: string[],

View File

@@ -0,0 +1,170 @@
import { createMockLogger } from "../../../__mocks__/loggerMock";
import {
DatabaseItem,
DatabaseKind,
} from "../../../../src/databases/local-databases";
import { file } from "tmp-promise";
import { QueryResultType } from "../../../../src/query-server/new-messages";
import { Mode } from "../../../../src/model-editor/shared/mode";
import { mockedObject, mockedUri } from "../../utils/mocking.helpers";
import { CodeQLCliServer } from "../../../../src/codeql-cli/cli";
import { QueryRunner } from "../../../../src/query-server";
import { join } from "path";
import { CancellationTokenSource } from "vscode-jsonrpc";
import { QueryOutputDir } from "../../../../src/run-queries-shared";
import { runGenerateModelQuery } from "../../../../src/model-editor/generate-model-queries";
import { QueryLanguage } from "../../../../src/common/query-language";
describe("runGenerateModelQuery", () => {
it("should run the query and return the results", async () => {
const queryStorageDir = (await file()).path;
const outputDir = new QueryOutputDir(join(queryStorageDir, "1"));
const options = {
mode: Mode.Application,
cliServer: mockedObject<CodeQLCliServer>({
resolveQueriesInSuite: jest
.fn()
.mockResolvedValue(["/a/b/c/GenerateModel.ql"]),
bqrsDecodeAll: jest.fn().mockResolvedValue({
sourceModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [],
},
sinkModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [],
},
typeVariableModel: {
columns: [
{ name: "name", kind: "String" },
{ name: "path", kind: "String" },
],
tuples: [],
},
typeModel: {
columns: [
{ name: "type1", kind: "String" },
{ name: "type2", kind: "String" },
{ name: "path", kind: "String" },
],
tuples: [
["Array", "SQLite3::ResultSet", "Method[types].ReturnValue"],
["Array", "SQLite3::ResultSet", "Method[columns].ReturnValue"],
["Array", "SQLite3::Statement", "Method[types].ReturnValue"],
["Array", "SQLite3::Statement", "Method[columns].ReturnValue"],
],
},
summaryModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "input", kind: "String" },
{ name: "output", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [
[
"SQLite3::Database",
"Method[create_function]",
"Argument[self]",
"ReturnValue",
"value",
],
[
"SQLite3::Value!",
"Method[new]",
"Argument[1]",
"ReturnValue",
"value",
],
],
},
}),
}),
queryRunner: mockedObject<QueryRunner>({
createQueryRun: jest.fn().mockReturnValue({
evaluate: jest.fn().mockResolvedValue({
resultType: QueryResultType.SUCCESS,
outputDir,
}),
outputDir,
}),
logger: createMockLogger(),
}),
logger: createMockLogger(),
databaseItem: mockedObject<DatabaseItem>({
databaseUri: mockedUri("/a/b/c/src.zip"),
contents: {
kind: DatabaseKind.Database,
name: "foo",
datasetUri: mockedUri(),
},
language: "ruby",
getSourceLocationPrefix: jest
.fn()
.mockResolvedValue("/home/runner/work/my-repo/my-repo"),
sourceArchive: mockedUri("/a/b/c/src.zip"),
}),
language: QueryLanguage.Ruby,
queryStorageDir: "/tmp/queries",
progress: jest.fn(),
token: new CancellationTokenSource().token,
};
const result = await runGenerateModelQuery(options);
expect(result.sort()).toEqual(
[
{
input: "Argument[self]",
kind: "value",
methodName: "create_function",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Database#create_function",
type: "summary",
typeName: "SQLite3::Database",
},
{
input: "Argument[1]",
kind: "value",
methodName: "new",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Value!#new",
type: "summary",
typeName: "SQLite3::Value!",
},
].sort(),
);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledTimes(1);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledWith(
"/a/b/c/src.zip",
{
queryPath: "/a/b/c/GenerateModel.ql",
quickEvalPosition: undefined,
quickEvalCountOnly: false,
},
false,
[],
undefined,
{},
"/tmp/queries",
undefined,
undefined,
);
});
});