Merge pull request #3033 from github/koesie10/generate-model

Add generation of Ruby models
This commit is contained in:
Koen Vlaswinkel
2023-11-01 14:17:10 +01:00
committed by GitHub
14 changed files with 433 additions and 30 deletions

View File

@@ -10,7 +10,11 @@ import tk from "tree-kill";
import { promisify } from "util";
import { CancellationToken, Disposable, Uri } from "vscode";
import { BQRSInfo, DecodedBqrsChunk } from "../common/bqrs-cli-types";
import {
BQRSInfo,
DecodedBqrs,
DecodedBqrsChunk,
} from "../common/bqrs-cli-types";
import { allowCanaryQueryServer, CliConfig } from "../config";
import {
DistributionProvider,
@@ -1040,6 +1044,18 @@ export class CodeQLCliServer implements Disposable {
);
}
/**
* Gets all results from a bqrs.
* @param bqrsPath The path to the bqrs.
*/
async bqrsDecodeAll(bqrsPath: string): Promise<DecodedBqrs> {
return await this.runJsonCodeQlCliCommand<DecodedBqrs>(
["bqrs", "decode"],
[bqrsPath],
"Reading all bqrs data",
);
}
async runInterpretCommand(
format: string,
additonalArgs: string[],

View File

@@ -121,3 +121,5 @@ export interface DecodedBqrsChunk {
next?: number;
columns: BqrsColumn[];
}
export type DecodedBqrs = Record<string, DecodedBqrsChunk>;

View File

@@ -31,6 +31,8 @@ export interface QueryConstraints {
kind?: string;
"tags contain"?: string[];
"tags contain all"?: string[];
"query filename"?: string;
"query path"?: string;
}
/**
@@ -132,6 +134,14 @@ export async function resolveQueries(
`tagged all of "${constraints["tags contain all"].join(" ")}"`,
);
}
if (constraints["query filename"] !== undefined) {
humanConstraints.push(
`with query filename "${constraints["query filename"]}"`,
);
}
if (constraints["query path"] !== undefined) {
humanConstraints.push(`with query path "${constraints["query path"]}"`);
}
const joinedPacksToSearch = packsToSearch.join(", ");
const error = redactableError`No ${name} queries (${humanConstraints.join(

View File

@@ -0,0 +1,162 @@
import { CancellationToken } from "vscode";
import { DatabaseItem } from "../databases/local-databases";
import { QueryRunner } from "../query-server";
import { CodeQLCliServer } from "../codeql-cli/cli";
import {
NotificationLogger,
showAndLogExceptionWithTelemetry,
} from "../common/logging";
import { getModelsAsDataLanguage } from "./languages";
import { ProgressCallback } from "../common/vscode/progress";
import { getOnDiskWorkspaceFolders } from "../common/vscode/workspace-folders";
import { ModeledMethod } from "./modeled-method";
import { redactableError } from "../common/errors";
import { telemetryListener } from "../common/vscode/telemetry";
import { runQuery } from "../local-queries/run-query";
import { resolveQueries } from "../local-queries";
import { QueryLanguage } from "../common/query-language";
import { DataTuple } from "./model-extension-file";
const GENERATE_MODEL_SUPPORTED_LANGUAGES = [QueryLanguage.Ruby];
export function isGenerateModelSupported(language: QueryLanguage): boolean {
return GENERATE_MODEL_SUPPORTED_LANGUAGES.includes(language);
}
type GenerateModelOptions = {
cliServer: CodeQLCliServer;
queryRunner: QueryRunner;
logger: NotificationLogger;
queryStorageDir: string;
databaseItem: DatabaseItem;
language: QueryLanguage;
progress: ProgressCallback;
token: CancellationToken;
};
// resolve (100) + query (1000) + interpret (100)
const maxStep = 1200;
export async function runGenerateModelQuery({
cliServer,
queryRunner,
logger,
queryStorageDir,
databaseItem,
language,
progress,
token,
}: GenerateModelOptions): Promise<ModeledMethod[]> {
progress({
message: "Resolving generate model query",
step: 100,
maxStep,
});
const queryPath = await resolveGenerateModelQuery(
cliServer,
logger,
databaseItem,
);
if (queryPath === undefined) {
return [];
}
// Run the query
const completedQuery = await runQuery({
queryRunner,
databaseItem,
queryPath,
queryStorageDir,
additionalPacks: getOnDiskWorkspaceFolders(),
extensionPacks: undefined,
progress: ({ step, message }) =>
progress({
message: `Generating models: ${message}`,
step: 100 + step,
maxStep,
}),
token,
});
if (!completedQuery) {
return [];
}
progress({
message: "Decoding results",
step: 1100,
maxStep,
});
const decodedBqrs = await cliServer.bqrsDecodeAll(
completedQuery.outputDir.bqrsPath,
);
const modelsAsDataLanguage = getModelsAsDataLanguage(language);
const modeledMethods: ModeledMethod[] = [];
for (const resultSetName in decodedBqrs) {
const definition = Object.values(modelsAsDataLanguage.predicates).find(
(definition) => definition.extensiblePredicate === resultSetName,
);
if (definition === undefined) {
void logger.log(`No predicate found for ${resultSetName}`);
continue;
}
const resultSet = decodedBqrs[resultSetName];
if (
resultSet.tuples.some((tuple) =>
tuple.some((value) => typeof value === "object"),
)
) {
void logger.log(
`Skipping ${resultSetName} because it contains undefined values`,
);
continue;
}
modeledMethods.push(
...resultSet.tuples.map((tuple) => {
const row = tuple.filter(
(value): value is DataTuple => typeof value !== "object",
);
return definition.readModeledMethod(row);
}),
);
}
return modeledMethods;
}
async function resolveGenerateModelQuery(
cliServer: CodeQLCliServer,
logger: NotificationLogger,
databaseItem: DatabaseItem,
): Promise<string | undefined> {
const packsToSearch = [`codeql/${databaseItem.language}-queries`];
const queries = await resolveQueries(
cliServer,
packsToSearch,
"generate model",
{
"query path": "queries/modeling/GenerateModel.ql",
},
);
if (queries.length !== 1) {
void showAndLogExceptionWithTelemetry(
logger,
telemetryListener,
redactableError`Expected exactly one generate model query, got ${queries.length}`,
);
return undefined;
}
return queries[0];
}

View File

@@ -51,6 +51,10 @@ import { ModelingStore } from "./modeling-store";
import { ModelEditorViewTracker } from "./model-editor-view-tracker";
import { ModelingEvents } from "./modeling-events";
import { getModelsAsDataLanguage, ModelsAsDataLanguage } from "./languages";
import {
isGenerateModelSupported,
runGenerateModelQuery,
} from "./generate-model-queries";
export class ModelEditorView extends AbstractWebview<
ToModelEditorMessage,
@@ -266,7 +270,11 @@ export class ModelEditorView extends AbstractWebview<
break;
case "generateMethod":
await this.generateModeledMethods();
if (isFlowModelGenerationSupported(this.language)) {
await this.generateModeledMethodsFromFlow();
} else if (isGenerateModelSupported(this.language)) {
await this.generateModeledMethodsFromGenerateModel();
}
void telemetryListener?.sendUIInteraction(
"model-editor-generate-modeled-methods",
);
@@ -369,9 +377,10 @@ export class ModelEditorView extends AbstractWebview<
}
private async setViewState(): Promise<void> {
const showFlowGeneration =
const showGenerateButton =
this.modelConfig.flowGeneration &&
isFlowModelGenerationSupported(this.language);
(isFlowModelGenerationSupported(this.language) ||
isGenerateModelSupported(this.language));
const showLlmButton =
this.databaseItem.language === "java" && this.modelConfig.llmGeneration;
@@ -388,7 +397,7 @@ export class ModelEditorView extends AbstractWebview<
viewState: {
extensionPack: this.extensionPack,
language: this.language,
showFlowGeneration,
showGenerateButton,
showLlmButton,
showMultipleModels: this.modelConfig.showMultipleModels,
mode: this.modelingStore.getMode(this.databaseItem),
@@ -465,7 +474,7 @@ export class ModelEditorView extends AbstractWebview<
}
}
protected async generateModeledMethods(): Promise<void> {
protected async generateModeledMethodsFromFlow(): Promise<void> {
await withProgress(
async (progress) => {
const tokenSource = new CancellationTokenSource();
@@ -508,19 +517,7 @@ export class ModelEditorView extends AbstractWebview<
databaseItem: addedDatabase ?? this.databaseItem,
language: this.language,
onResults: async (modeledMethods) => {
const modeledMethodsByName: Record<string, ModeledMethod[]> = {};
for (const modeledMethod of modeledMethods) {
if (!(modeledMethod.signature in modeledMethodsByName)) {
modeledMethodsByName[modeledMethod.signature] = [];
}
modeledMethodsByName[modeledMethod.signature].push(
modeledMethod,
);
}
this.addModeledMethods(modeledMethodsByName);
this.addModeledMethodsFromArray(modeledMethods);
},
progress,
token: tokenSource.token,
@@ -539,6 +536,38 @@ export class ModelEditorView extends AbstractWebview<
);
}
protected async generateModeledMethodsFromGenerateModel(): Promise<void> {
await withProgress(
async (progress) => {
const tokenSource = new CancellationTokenSource();
try {
const modeledMethods = await runGenerateModelQuery({
cliServer: this.cliServer,
queryRunner: this.queryRunner,
logger: this.app.logger,
queryStorageDir: this.queryStorageDir,
databaseItem: this.databaseItem,
language: this.language,
progress,
token: tokenSource.token,
});
this.addModeledMethodsFromArray(modeledMethods);
} catch (e: unknown) {
void showAndLogExceptionWithTelemetry(
this.app.logger,
this.app.telemetry,
redactableError(
asError(e),
)`Failed to generate models: ${getErrorMessage(e)}`,
);
}
},
{ cancellable: false },
);
}
private async generateModeledMethodsFromLlm(
packageName: string,
methodSignatures: string[],
@@ -757,6 +786,20 @@ export class ModelEditorView extends AbstractWebview<
);
}
private addModeledMethodsFromArray(modeledMethods: ModeledMethod[]) {
const modeledMethodsByName: Record<string, ModeledMethod[]> = {};
for (const modeledMethod of modeledMethods) {
if (!(modeledMethod.signature in modeledMethodsByName)) {
modeledMethodsByName[modeledMethod.signature] = [];
}
modeledMethodsByName[modeledMethod.signature].push(modeledMethod);
}
this.addModeledMethods(modeledMethodsByName);
}
private setModeledMethods(signature: string, methods: ModeledMethod[]) {
this.modelingStore.updateModeledMethods(
this.databaseItem,

View File

@@ -5,7 +5,7 @@ import { QueryLanguage } from "../../common/query-language";
export interface ModelEditorViewState {
extensionPack: ExtensionPack;
language: QueryLanguage;
showFlowGeneration: boolean;
showGenerateButton: boolean;
showLlmButton: boolean;
showMultipleModels: boolean;
mode: Mode;

View File

@@ -219,7 +219,7 @@ LibraryRow.args = {
modifiedSignatures: new Set(["org.sql2o.Sql2o#Sql2o(String)"]),
inProgressMethods: new Set(),
viewState: createMockModelEditorViewState({
showFlowGeneration: true,
showGenerateButton: true,
showLlmButton: true,
showMultipleModels: true,
}),

View File

@@ -98,7 +98,7 @@ const modeledMethod: ModeledMethod = {
};
const viewState = createMockModelEditorViewState({
showFlowGeneration: true,
showGenerateButton: true,
showLlmButton: true,
showMultipleModels: true,
});

View File

@@ -28,7 +28,7 @@ ModelEditor.args = {
extensionTargets: {},
dataExtensions: [],
},
showFlowGeneration: true,
showGenerateButton: true,
showLlmButton: true,
showMultipleModels: true,
}),

View File

@@ -208,7 +208,7 @@ export const LibraryRow = ({
&nbsp;Stop
</VSCodeButton>
)}
{viewState.showFlowGeneration &&
{viewState.showGenerateButton &&
viewState.mode === Mode.Application && (
<VSCodeButton appearance="icon" onClick={handleModelFromSource}>
<Codicon name="code" label="Model from source" />

View File

@@ -323,7 +323,7 @@ export function ModelEditor({
<VSCodeButton appearance="secondary" onClick={onRefreshClick}>
Refresh
</VSCodeButton>
{viewState.showFlowGeneration &&
{viewState.showGenerateButton &&
viewState.mode === Mode.Framework && (
<VSCodeButton onClick={onGenerateFromSourceClick}>
Generate

View File

@@ -57,11 +57,11 @@ describe(LibraryRow.name, () => {
expect(screen.queryByText("Model dependency")).toBeInTheDocument();
});
it("renders the row when flow generation is enabled", () => {
it("renders the row when generate button is enabled", () => {
render({
viewState: {
...viewState,
showFlowGeneration: true,
showGenerateButton: true,
},
});
@@ -83,11 +83,11 @@ describe(LibraryRow.name, () => {
expect(screen.queryByText("Model dependency")).toBeInTheDocument();
});
it("renders the row when flow generation and LLM are enabled", () => {
it("renders the row when generate button and LLM are enabled", () => {
render({
viewState: {
...viewState,
showFlowGeneration: true,
showGenerateButton: true,
showLlmButton: true,
},
});

View File

@@ -9,7 +9,7 @@ export function createMockModelEditorViewState(
return {
language: QueryLanguage.Java,
mode: Mode.Application,
showFlowGeneration: false,
showGenerateButton: false,
showLlmButton: false,
showMultipleModels: false,
showModeSwitchButton: true,

View File

@@ -0,0 +1,170 @@
import { createMockLogger } from "../../../__mocks__/loggerMock";
import {
DatabaseItem,
DatabaseKind,
} from "../../../../src/databases/local-databases";
import { file } from "tmp-promise";
import { QueryResultType } from "../../../../src/query-server/new-messages";
import { Mode } from "../../../../src/model-editor/shared/mode";
import { mockedObject, mockedUri } from "../../utils/mocking.helpers";
import { CodeQLCliServer } from "../../../../src/codeql-cli/cli";
import { QueryRunner } from "../../../../src/query-server";
import { join } from "path";
import { CancellationTokenSource } from "vscode-jsonrpc";
import { QueryOutputDir } from "../../../../src/run-queries-shared";
import { runGenerateModelQuery } from "../../../../src/model-editor/generate-model-queries";
import { QueryLanguage } from "../../../../src/common/query-language";
describe("runGenerateModelQuery", () => {
it("should run the query and return the results", async () => {
const queryStorageDir = (await file()).path;
const outputDir = new QueryOutputDir(join(queryStorageDir, "1"));
const options = {
mode: Mode.Application,
cliServer: mockedObject<CodeQLCliServer>({
resolveQueriesInSuite: jest
.fn()
.mockResolvedValue(["/a/b/c/GenerateModel.ql"]),
bqrsDecodeAll: jest.fn().mockResolvedValue({
sourceModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [],
},
sinkModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [],
},
typeVariableModel: {
columns: [
{ name: "name", kind: "String" },
{ name: "path", kind: "String" },
],
tuples: [],
},
typeModel: {
columns: [
{ name: "type1", kind: "String" },
{ name: "type2", kind: "String" },
{ name: "path", kind: "String" },
],
tuples: [
["Array", "SQLite3::ResultSet", "Method[types].ReturnValue"],
["Array", "SQLite3::ResultSet", "Method[columns].ReturnValue"],
["Array", "SQLite3::Statement", "Method[types].ReturnValue"],
["Array", "SQLite3::Statement", "Method[columns].ReturnValue"],
],
},
summaryModel: {
columns: [
{ name: "type", kind: "String" },
{ name: "path", kind: "String" },
{ name: "input", kind: "String" },
{ name: "output", kind: "String" },
{ name: "kind", kind: "String" },
],
tuples: [
[
"SQLite3::Database",
"Method[create_function]",
"Argument[self]",
"ReturnValue",
"value",
],
[
"SQLite3::Value!",
"Method[new]",
"Argument[1]",
"ReturnValue",
"value",
],
],
},
}),
}),
queryRunner: mockedObject<QueryRunner>({
createQueryRun: jest.fn().mockReturnValue({
evaluate: jest.fn().mockResolvedValue({
resultType: QueryResultType.SUCCESS,
outputDir,
}),
outputDir,
}),
logger: createMockLogger(),
}),
logger: createMockLogger(),
databaseItem: mockedObject<DatabaseItem>({
databaseUri: mockedUri("/a/b/c/src.zip"),
contents: {
kind: DatabaseKind.Database,
name: "foo",
datasetUri: mockedUri(),
},
language: "ruby",
getSourceLocationPrefix: jest
.fn()
.mockResolvedValue("/home/runner/work/my-repo/my-repo"),
sourceArchive: mockedUri("/a/b/c/src.zip"),
}),
language: QueryLanguage.Ruby,
queryStorageDir: "/tmp/queries",
progress: jest.fn(),
token: new CancellationTokenSource().token,
};
const result = await runGenerateModelQuery(options);
expect(result.sort()).toEqual(
[
{
input: "Argument[self]",
kind: "value",
methodName: "create_function",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Database#create_function",
type: "summary",
typeName: "SQLite3::Database",
},
{
input: "Argument[1]",
kind: "value",
methodName: "new",
methodParameters: "",
output: "ReturnValue",
packageName: "",
provenance: "manual",
signature: "SQLite3::Value!#new",
type: "summary",
typeName: "SQLite3::Value!",
},
].sort(),
);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledTimes(1);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledWith(
"/a/b/c/src.zip",
{
queryPath: "/a/b/c/GenerateModel.ql",
quickEvalPosition: undefined,
quickEvalCountOnly: false,
},
false,
[],
undefined,
{},
"/tmp/queries",
undefined,
undefined,
);
});
});