Convert remaining extension host code to handle multiple models

This converts all remaining extension host code to handle multiple
models per method. The only place where we're using the legacy format
is in the webview and in the boundary between the webview and the
extension host.
This commit is contained in:
Koen Vlaswinkel
2023-10-04 14:22:57 +02:00
parent 603c799717
commit b76369330d
12 changed files with 179 additions and 84 deletions

View File

@@ -14,13 +14,13 @@ import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
* the order in the UI.
* @param mode Whether it is application or framework mode.
* @param methods all methods.
* @param modeledMethods the currently modeled methods.
* @param modeledMethodsBySignature the currently modeled methods.
* @returns list of modeled methods that are candidates for modeling.
*/
export function getCandidates(
mode: Mode,
methods: Method[],
modeledMethods: Record<string, ModeledMethod>,
modeledMethodsBySignature: Record<string, ModeledMethod[]>,
): MethodSignature[] {
// Sort the same way as the UI so we send the first ones listed in the UI first
const grouped = groupMethods(methods, mode);
@@ -32,12 +32,11 @@ export function getCandidates(
const candidates: MethodSignature[] = [];
for (const method of sortedMethods) {
const modeledMethod: ModeledMethod = modeledMethods[method.signature] ?? {
type: "none",
};
const modeledMethods: ModeledMethod[] =
modeledMethodsBySignature[method.signature] ?? [];
// Anything that is modeled is not a candidate
if (modeledMethod.type !== "none") {
if (modeledMethods.some((m) => m.type !== "none")) {
continue;
}

View File

@@ -16,7 +16,6 @@ import { QueryRunner } from "../query-server";
import { DatabaseItem } from "../databases/local-databases";
import { Mode } from "./shared/mode";
import { CancellationTokenSource } from "vscode";
import { convertToLegacyModeledMethods } from "./modeled-methods-legacy";
// Limit the number of candidates we send to the model in each request
// to avoid long requests.
@@ -43,7 +42,7 @@ export class AutoModeler {
inProgressMethods: string[],
) => Promise<void>,
private readonly addModeledMethods: (
modeledMethods: Record<string, ModeledMethod>,
modeledMethods: Record<string, ModeledMethod[]>,
) => Promise<void>,
) {
this.jobs = new Map<string, CancellationTokenSource>();
@@ -60,7 +59,7 @@ export class AutoModeler {
public async startModeling(
packageName: string,
methods: Method[],
modeledMethods: Record<string, ModeledMethod>,
modeledMethods: Record<string, ModeledMethod[]>,
mode: Mode,
): Promise<void> {
if (this.jobs.has(packageName)) {
@@ -107,7 +106,7 @@ export class AutoModeler {
private async modelPackage(
packageName: string,
methods: Method[],
modeledMethods: Record<string, ModeledMethod>,
modeledMethods: Record<string, ModeledMethod[]>,
mode: Mode,
cancellationTokenSource: CancellationTokenSource,
): Promise<void> {
@@ -193,31 +192,31 @@ export class AutoModeler {
filename: "auto-model.yml",
});
const rawLoadedMethods = loadDataExtensionYaml(models);
if (!rawLoadedMethods) {
const loadedMethods = loadDataExtensionYaml(models);
if (!loadedMethods) {
return;
}
const loadedMethods = convertToLegacyModeledMethods(rawLoadedMethods);
// Any candidate that was part of the response is a negative result
// meaning that the canidate is not a sink for the kinds that the LLM is checking for.
// For now we model this as a sink neutral method, however this is subject
// to discussion.
for (const candidate of candidateMethods) {
if (!(candidate.signature in loadedMethods)) {
loadedMethods[candidate.signature] = {
type: "neutral",
kind: "sink",
input: "",
output: "",
provenance: "ai-generated",
signature: candidate.signature,
packageName: candidate.packageName,
typeName: candidate.typeName,
methodName: candidate.methodName,
methodParameters: candidate.methodParameters,
};
loadedMethods[candidate.signature] = [
{
type: "neutral",
kind: "sink",
input: "",
output: "",
provenance: "ai-generated",
signature: candidate.signature,
packageName: candidate.packageName,
typeName: candidate.typeName,
methodName: candidate.methodName,
methodParameters: candidate.methodParameters,
},
];
}
}

View File

@@ -13,6 +13,10 @@ import { AbstractWebviewViewProvider } from "../../common/vscode/abstract-webvie
import { assertNever } from "../../common/helpers-pure";
import { ModelEditorViewTracker } from "../model-editor-view-tracker";
import { showMultipleModels } from "../../config";
import {
convertFromLegacyModeledMethod,
convertToLegacyModeledMethod,
} from "../modeled-methods-legacy";
export class MethodModelingViewProvider extends AbstractWebviewViewProvider<
ToMethodModelingMessage,
@@ -62,7 +66,9 @@ export class MethodModelingViewProvider extends AbstractWebviewViewProvider<
await this.postMessage({
t: "setSelectedMethod",
method: selectedMethod.method,
modeledMethod: selectedMethod.modeledMethod,
modeledMethod: convertToLegacyModeledMethod(
selectedMethod.modeledMethods,
),
isModified: selectedMethod.isModified,
});
}
@@ -94,9 +100,10 @@ export class MethodModelingViewProvider extends AbstractWebviewViewProvider<
case "setModeledMethod": {
const activeState = this.ensureActiveState();
this.modelingStore.updateModeledMethod(
this.modelingStore.updateModeledMethods(
activeState.databaseItem,
msg.method,
msg.method.signature,
convertFromLegacyModeledMethod(msg.method),
);
break;
}
@@ -141,11 +148,11 @@ export class MethodModelingViewProvider extends AbstractWebviewViewProvider<
this.push(
this.modelingStore.onModeledMethodsChanged(async (e) => {
if (this.webviewView && e.isActiveDb) {
const modeledMethod = e.modeledMethods[this.method?.signature ?? ""];
if (modeledMethod) {
const modeledMethods = e.modeledMethods[this.method?.signature ?? ""];
if (modeledMethods) {
await this.postMessage({
t: "setModeledMethod",
method: modeledMethod,
method: convertToLegacyModeledMethod(modeledMethods),
});
}
}
@@ -171,7 +178,7 @@ export class MethodModelingViewProvider extends AbstractWebviewViewProvider<
await this.postMessage({
t: "setSelectedMethod",
method: e.method,
modeledMethod: e.modeledMethod,
modeledMethod: convertToLegacyModeledMethod(e.modeledMethods),
isModified: e.isModified,
});
}

View File

@@ -14,7 +14,7 @@ import { DatabaseItem } from "../../databases/local-databases";
import { relative } from "path";
import { CodeQLCliServer } from "../../codeql-cli/cli";
import { INITIAL_HIDE_MODELED_METHODS_VALUE } from "../shared/hide-modeled-methods";
import { getModelingStatus } from "../shared/modeling-status";
import { getModelingStatusForModeledMethods } from "../shared/modeling-status";
import { assertNever } from "../../common/helpers-pure";
import { ModeledMethod } from "../modeled-method";
@@ -26,7 +26,7 @@ export class MethodsUsageDataProvider
private databaseItem: DatabaseItem | undefined = undefined;
private sourceLocationPrefix: string | undefined = undefined;
private hideModeledMethods: boolean = INITIAL_HIDE_MODELED_METHODS_VALUE;
private modeledMethods: Record<string, ModeledMethod> = {};
private modeledMethods: Record<string, ModeledMethod[]> = {};
private modifiedMethodSignatures: Set<string> = new Set();
private readonly onDidChangeTreeDataEmitter = this.push(
@@ -52,7 +52,7 @@ export class MethodsUsageDataProvider
methods: Method[],
databaseItem: DatabaseItem,
hideModeledMethods: boolean,
modeledMethods: Record<string, ModeledMethod>,
modeledMethods: Record<string, ModeledMethod[]>,
modifiedMethodSignatures: Set<string>,
): Promise<void> {
if (
@@ -99,10 +99,13 @@ export class MethodsUsageDataProvider
}
private getModelingStatusIcon(method: Method): ThemeIcon {
const modeledMethod = this.modeledMethods[method.signature];
const modeledMethods = this.modeledMethods[method.signature];
const modifiedMethod = this.modifiedMethodSignatures.has(method.signature);
const status = getModelingStatus(modeledMethod, modifiedMethod);
const status = getModelingStatusForModeledMethods(
modeledMethods,
modifiedMethod,
);
switch (status) {
case "unmodeled":
return new ThemeIcon("error", new ThemeColor("errorForeground"));

View File

@@ -34,7 +34,7 @@ export class MethodsUsagePanel extends DisposableObject {
methods: Method[],
databaseItem: DatabaseItem,
hideModeledMethods: boolean,
modeledMethods: Record<string, ModeledMethod>,
modeledMethods: Record<string, ModeledMethod[]>,
modifiedMethodSignatures: Set<string>,
): Promise<void> {
await this.dataProvider.setState(

View File

@@ -48,6 +48,7 @@ import { telemetryListener } from "../common/vscode/telemetry";
import { ModelingStore } from "./modeling-store";
import { ModelEditorViewTracker } from "./model-editor-view-tracker";
import {
convertFromLegacyModeledMethod,
convertFromLegacyModeledMethods,
convertToLegacyModeledMethods,
} from "./modeled-methods-legacy";
@@ -259,7 +260,7 @@ export class ModelEditorView extends AbstractWebview<
await this.generateModeledMethodsFromLlm(
msg.packageName,
msg.methods,
msg.modeledMethods,
convertFromLegacyModeledMethods(msg.modeledMethods),
);
void telemetryListener?.sendUIInteraction(
"model-editor-generate-methods-from-llm",
@@ -303,7 +304,10 @@ export class ModelEditorView extends AbstractWebview<
);
break;
case "setModeledMethod": {
this.setModeledMethod(msg.method);
this.setModeledMethods(
msg.method.signature,
convertFromLegacyModeledMethod(msg.method),
);
break;
}
default:
@@ -363,10 +367,7 @@ export class ModelEditorView extends AbstractWebview<
this.cliServer,
this.app.logger,
);
this.modelingStore.setModeledMethods(
this.databaseItem,
convertToLegacyModeledMethods(modeledMethods),
);
this.modelingStore.setModeledMethods(this.databaseItem, modeledMethods);
} catch (e: unknown) {
void showAndLogErrorMessage(
this.app.logger,
@@ -438,10 +439,16 @@ export class ModelEditorView extends AbstractWebview<
queryStorageDir: this.queryStorageDir,
databaseItem: addedDatabase ?? this.databaseItem,
onResults: async (modeledMethods) => {
const modeledMethodsByName: Record<string, ModeledMethod> = {};
const modeledMethodsByName: Record<string, ModeledMethod[]> = {};
for (const modeledMethod of modeledMethods) {
modeledMethodsByName[modeledMethod.signature] = modeledMethod;
if (!(modeledMethod.signature in modeledMethodsByName)) {
modeledMethodsByName[modeledMethod.signature] = [];
}
modeledMethodsByName[modeledMethod.signature].push(
modeledMethod,
);
}
this.addModeledMethods(modeledMethodsByName);
@@ -466,7 +473,7 @@ export class ModelEditorView extends AbstractWebview<
private async generateModeledMethodsFromLlm(
packageName: string,
methods: Method[],
modeledMethods: Record<string, ModeledMethod>,
modeledMethods: Record<string, ModeledMethod[]>,
): Promise<void> {
await this.autoModeler.startModeling(
packageName,
@@ -603,7 +610,7 @@ export class ModelEditorView extends AbstractWebview<
if (event.dbUri === this.databaseItem.databaseUri.toString()) {
await this.postMessage({
t: "setModeledMethods",
methods: event.modeledMethods,
methods: convertToLegacyModeledMethods(event.modeledMethods),
});
}
}),
@@ -621,7 +628,7 @@ export class ModelEditorView extends AbstractWebview<
);
}
private addModeledMethods(modeledMethods: Record<string, ModeledMethod>) {
private addModeledMethods(modeledMethods: Record<string, ModeledMethod[]>) {
this.modelingStore.addModeledMethods(this.databaseItem, modeledMethods);
this.modelingStore.addModifiedMethods(
@@ -630,13 +637,17 @@ export class ModelEditorView extends AbstractWebview<
);
}
private setModeledMethod(method: ModeledMethod) {
private setModeledMethods(signature: string, methods: ModeledMethod[]) {
const state = this.modelingStore.getStateForActiveDb();
if (!state) {
throw new Error("Attempting to set modeled method without active db");
}
this.modelingStore.updateModeledMethod(state.databaseItem, method);
this.modelingStore.addModifiedMethod(state.databaseItem, method.signature);
this.modelingStore.updateModeledMethods(
state.databaseItem,
signature,
methods,
);
this.modelingStore.addModifiedMethod(state.databaseItem, signature);
}
}

View File

@@ -1,23 +1,71 @@
import { ModeledMethod } from "./modeled-method";
/**
* Converts a record of ModeledMethod[] indexed by signature to a record of a single ModeledMethod indexed by signature
* for legacy usage. This function should always be used instead of the trivial conversion to track usages of this
* conversion.
*
* This method should only be called inside a `onMessage` function (or its equivalent). If it's used anywhere else,
* consider whether the boundary is correct: the boundary should as close as possible to the webview -> extension host
* boundary.
*
* @param modeledMethods The record of ModeledMethod[] indexed by signature
*/
export function convertFromLegacyModeledMethods(
modeledMethods: Record<string, ModeledMethod>,
): Record<string, ModeledMethod[]> {
// Convert a single ModeledMethod to an array of ModeledMethods
return Object.fromEntries(
Object.entries(modeledMethods).map(([signature, modeledMethod]) => {
return [signature, [modeledMethod]];
return [signature, convertFromLegacyModeledMethod(modeledMethod)];
}),
);
}
/**
* Converts a record of a single ModeledMethod indexed by signature to a record of ModeledMethod[] indexed by signature
* for legacy usage. This function should always be used instead of the trivial conversion to track usages of this
* conversion.
*
* This method should only be called inside a `postMessage` call. If it's used anywhere else, consider whether the
* boundary is correct: the boundary should as close as possible to the extension host -> webview boundary.
*
* @param modeledMethods The record of a single ModeledMethod indexed by signature
*/
export function convertToLegacyModeledMethods(
modeledMethods: Record<string, ModeledMethod[]>,
): Record<string, ModeledMethod> {
// Always take the first modeled method in the array
return Object.fromEntries(
Object.entries(modeledMethods).map(([signature, modeledMethods]) => {
return [signature, modeledMethods[0]];
return [signature, convertToLegacyModeledMethod(modeledMethods)];
}),
);
}
/**
* Converts a single ModeledMethod to a ModeledMethod[] for legacy usage. This function should always be used instead
* of the trivial conversion to track usages of this conversion.
*
* This method should only be called inside a `onMessage` function (or its equivalent). If it's used anywhere else,
* consider whether the boundary is correct: the boundary should as close as possible to the webview -> extension host
* boundary.
*
* @param modeledMethod The single ModeledMethod
*/
export function convertFromLegacyModeledMethod(modeledMethod: ModeledMethod) {
return [modeledMethod];
}
/**
* Converts a ModeledMethod[] to a single ModeledMethod for legacy usage. This function should always be used instead
* of the trivial conversion to track usages of this conversion.
*
* This method should only be called inside a `postMessage` call. If it's used anywhere else, consider whether the
* boundary is correct: the boundary should as close as possible to the extension host -> webview boundary.
*
* @param modeledMethods The ModeledMethod[]
*/
export function convertToLegacyModeledMethod(modeledMethods: ModeledMethod[]) {
return modeledMethods[0];
}

View File

@@ -10,7 +10,7 @@ export interface DbModelingState {
databaseItem: DatabaseItem;
methods: Method[];
hideModeledMethods: boolean;
modeledMethods: Record<string, ModeledMethod>;
modeledMethods: Record<string, ModeledMethod[]>;
modifiedMethodSignatures: Set<string>;
selectedMethod: Method | undefined;
selectedUsage: Usage | undefined;
@@ -28,7 +28,7 @@ interface HideModeledMethodsChangedEvent {
}
interface ModeledMethodsChangedEvent {
modeledMethods: Record<string, ModeledMethod>;
modeledMethods: Record<string, ModeledMethod[]>;
dbUri: string;
isActiveDb: boolean;
}
@@ -43,7 +43,7 @@ interface SelectedMethodChangedEvent {
databaseItem: DatabaseItem;
method: Method;
usage: Usage;
modeledMethod: ModeledMethod | undefined;
modeledMethods: ModeledMethod[];
isModified: boolean;
}
@@ -199,14 +199,15 @@ export class ModelingStore extends DisposableObject {
public addModeledMethods(
dbItem: DatabaseItem,
methods: Record<string, ModeledMethod>,
methods: Record<string, ModeledMethod[]>,
) {
this.changeModeledMethods(dbItem, (state) => {
const newModeledMethods = {
...methods,
// Keep all methods that are already modeled in some form in the state
...Object.fromEntries(
Object.entries(state.modeledMethods).filter(
([_, value]) => value.type !== "none",
Object.entries(state.modeledMethods).filter(([_, value]) =>
value.some((m) => m.type !== "none"),
),
),
};
@@ -216,17 +217,21 @@ export class ModelingStore extends DisposableObject {
public setModeledMethods(
dbItem: DatabaseItem,
methods: Record<string, ModeledMethod>,
methods: Record<string, ModeledMethod[]>,
) {
this.changeModeledMethods(dbItem, (state) => {
state.modeledMethods = { ...methods };
});
}
public updateModeledMethod(dbItem: DatabaseItem, method: ModeledMethod) {
public updateModeledMethods(
dbItem: DatabaseItem,
signature: string,
modeledMethods: ModeledMethod[],
) {
this.changeModeledMethods(dbItem, (state) => {
const newModeledMethods = { ...state.modeledMethods };
newModeledMethods[method.signature] = method;
newModeledMethods[signature] = modeledMethods;
state.modeledMethods = newModeledMethods;
});
}
@@ -280,7 +285,7 @@ export class ModelingStore extends DisposableObject {
databaseItem: dbItem,
method,
usage,
modeledMethod: dbState.modeledMethods[method.signature],
modeledMethods: dbState.modeledMethods[method.signature],
isModified: dbState.modifiedMethodSignatures.has(method.signature),
});
}
@@ -299,7 +304,7 @@ export class ModelingStore extends DisposableObject {
return {
method: selectedMethod,
usage: dbState.selectedUsage,
modeledMethod: dbState.modeledMethods[selectedMethod.signature],
modeledMethods: dbState.modeledMethods[selectedMethod.signature],
isModified: dbState.modifiedMethodSignatures.has(
selectedMethod.signature,
),

View File

@@ -15,3 +15,24 @@ export function getModelingStatus(
}
return "unmodeled";
}
export function getModelingStatusForModeledMethods(
modeledMethods: ModeledMethod[],
methodIsUnsaved: boolean,
): ModelingStatus {
if (modeledMethods.length === 0) {
return "unmodeled";
}
if (methodIsUnsaved) {
return "unsaved";
}
for (const modeledMethod of modeledMethods) {
if (modeledMethod.type !== "none") {
return "saved";
}
}
return "unmodeled";
}

View File

@@ -99,19 +99,21 @@ describe("getCandidates", () => {
usages: [],
},
];
const modeledMethods: Record<string, ModeledMethod> = {
"org.my.A#x()": {
type: "neutral",
kind: "",
input: "",
output: "",
provenance: "manual",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
},
const modeledMethods: Record<string, ModeledMethod[]> = {
"org.my.A#x()": [
{
type: "neutral",
kind: "",
input: "",
output: "",
provenance: "manual",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
},
],
};
const candidates = getCandidates(Mode.Application, methods, modeledMethods);
expect(candidates.length).toEqual(0);

View File

@@ -20,7 +20,7 @@ describe("MethodsUsageDataProvider", () => {
describe("setState", () => {
const hideModeledMethods = false;
const methods: Method[] = [];
const modeledMethods: Record<string, ModeledMethod> = {};
const modeledMethods: Record<string, ModeledMethod[]> = {};
const modifiedMethodSignatures: Set<string> = new Set();
const dbItem = mockedObject<DatabaseItem>({
getSourceLocationPrefix: () => "test",
@@ -125,7 +125,7 @@ describe("MethodsUsageDataProvider", () => {
});
it("should emit onDidChangeTreeData event when modeled methods has changed", async () => {
const modeledMethods2: Record<string, ModeledMethod> = {};
const modeledMethods2: Record<string, ModeledMethod[]> = {};
await dataProvider.setState(
methods,
@@ -213,7 +213,7 @@ describe("MethodsUsageDataProvider", () => {
});
const methods: Method[] = [supportedMethod, unsupportedMethod];
const modeledMethods: Record<string, ModeledMethod> = {};
const modeledMethods: Record<string, ModeledMethod[]> = {};
const modifiedMethodSignatures: Set<string> = new Set();
const dbItem = mockedObject<DatabaseItem>({

View File

@@ -21,7 +21,7 @@ describe("MethodsUsagePanel", () => {
describe("setState", () => {
const hideModeledMethods = false;
const methods: Method[] = [createMethod()];
const modeledMethods: Record<string, ModeledMethod> = {};
const modeledMethods: Record<string, ModeledMethod[]> = {};
const modifiedMethodSignatures: Set<string> = new Set();
it("should update the tree view with the correct batch number", async () => {
@@ -50,7 +50,7 @@ describe("MethodsUsagePanel", () => {
let modelingStore: ModelingStore;
const hideModeledMethods: boolean = false;
const modeledMethods: Record<string, ModeledMethod> = {};
const modeledMethods: Record<string, ModeledMethod[]> = {};
const modifiedMethodSignatures: Set<string> = new Set();
const usage = createUsage();