Use ModelConfig for all model settings

This switches all places where we're retrieving some model configuration
to use the `ModelConfig` or `ModelConfigListener` types. This makes it
much easier to mock these settings in tests.

This also adds a listener to the `ModelEditorView` to send the new view
state when any of the settings is changed. This should make it easier
to test settings changes in the model editor without having to re-open
the model editor.
This commit is contained in:
Koen Vlaswinkel
2023-10-06 14:29:07 +02:00
parent 08944a292c
commit 951bd13881
8 changed files with 80 additions and 122 deletions

View File

@@ -714,7 +714,7 @@ const SHOW_MULTIPLE_MODELS = new Setting("showMultipleModels", MODEL_SETTING);
export interface ModelConfig {
flowGeneration: boolean;
llmGeneration: boolean;
extensionsDirectory: string | undefined;
getExtensionsDirectory(languageId: string): string | undefined;
showMultipleModels: boolean;
}
@@ -731,29 +731,13 @@ export class ModelConfigListener extends ConfigListener implements ModelConfig {
return !!LLM_GENERATION.getValue<boolean>();
}
public get extensionsDirectory(): string | undefined {
return EXTENSIONS_DIRECTORY.getValue<string>();
public getExtensionsDirectory(languageId: string): string | undefined {
return EXTENSIONS_DIRECTORY.getValue<string>({
languageId,
});
}
public get showMultipleModels(): boolean {
return !!SHOW_MULTIPLE_MODELS.getValue<boolean>();
}
}
export function showFlowGeneration(): boolean {
return !!FLOW_GENERATION.getValue<boolean>();
}
export function showLlmGeneration(): boolean {
return !!LLM_GENERATION.getValue<boolean>();
}
export function getExtensionsDirectory(languageId: string): string | undefined {
return EXTENSIONS_DIRECTORY.getValue<string>({
languageId,
});
}
export function showMultipleModels(): boolean {
return !!SHOW_MULTIPLE_MODELS.getValue<boolean>();
}

View File

@@ -11,7 +11,7 @@ import { getQlPackPath, QLPACK_FILENAMES } from "../common/ql";
import { getErrorMessage } from "../common/helpers-pure";
import { ExtensionPack } from "./shared/extension-pack";
import { NotificationLogger, showAndLogErrorMessage } from "../common/logging";
import { getExtensionsDirectory } from "../config";
import { ModelConfig } from "../config";
import {
autoNameExtensionPack,
ExtensionPackName,
@@ -28,6 +28,7 @@ const extensionPackValidate = ajv.compile(extensionPackMetadataSchemaJson);
export async function pickExtensionPack(
cliServer: Pick<CodeQLCliServer, "resolveQlpacks">,
databaseItem: Pick<DatabaseItem, "name" | "language">,
modelConfig: ModelConfig,
logger: NotificationLogger,
progress: ProgressCallback,
maxStep: number,
@@ -56,7 +57,9 @@ export async function pickExtensionPack(
});
// Get the `codeQL.model.extensionsDirectory` setting for the language
const userExtensionsDirectory = getExtensionsDirectory(databaseItem.language);
const userExtensionsDirectory = modelConfig.getExtensionsDirectory(
databaseItem.language,
);
// If the setting is not set, automatically pick a suitable directory
const extensionsDirectory = userExtensionsDirectory

View File

@@ -21,6 +21,7 @@ import { MethodModelingPanel } from "./method-modeling/method-modeling-panel";
import { ModelingStore } from "./modeling-store";
import { showResolvableLocation } from "../databases/local-databases/locations";
import { ModelEditorViewTracker } from "./model-editor-view-tracker";
import { ModelConfigListener } from "../config";
const SUPPORTED_LANGUAGES: string[] = ["java", "csharp"];
@@ -150,9 +151,12 @@ export class ModelEditorModule extends DisposableObject {
return;
}
const modelConfig = this.push(new ModelConfigListener());
const modelFile = await pickExtensionPack(
this.cliServer,
db,
modelConfig,
this.app.logger,
progress,
maxStep,
@@ -172,7 +176,12 @@ export class ModelEditorModule extends DisposableObject {
unsafeCleanup: true,
});
const success = await setUpPack(this.cliServer, queryDir, language);
const success = await setUpPack(
this.cliServer,
queryDir,
language,
modelConfig,
);
if (!success) {
await cleanupQueryDir();
return;
@@ -188,6 +197,7 @@ export class ModelEditorModule extends DisposableObject {
this.app,
this.modelingStore,
this.editorViewTracker,
modelConfig,
this.databaseManager,
this.cliServer,
this.queryRunner,

View File

@@ -4,7 +4,7 @@ import { writeFile } from "fs-extra";
import { dump } from "js-yaml";
import { prepareExternalApiQuery } from "./external-api-usage-queries";
import { CodeQLCliServer } from "../codeql-cli/cli";
import { showLlmGeneration } from "../config";
import { ModelConfig } from "../config";
import { Mode } from "./shared/mode";
import { resolveQueriesFromPacks } from "../local-queries";
import { modeTag } from "./mode-tag";
@@ -28,12 +28,14 @@ export const syntheticQueryPackName = "codeql/external-api-usage";
* @param cliServer The CodeQL CLI server to use.
* @param queryDir The directory to set up.
* @param language The language to use for the queries.
* @param modelConfig The model config to use.
* @returns true if the setup was successful, false otherwise.
*/
export async function setUpPack(
cliServer: CodeQLCliServer,
queryDir: string,
language: QueryLanguage,
modelConfig: ModelConfig,
): Promise<boolean> {
// Download the required query packs
await cliServer.packDownload([`codeql/${language}-queries`]);
@@ -84,7 +86,7 @@ export async function setUpPack(
}
// Download any other required packs
if (language === "java" && showLlmGeneration()) {
if (language === "java" && modelConfig.llmGeneration) {
await cliServer.packDownload([`codeql/${language}-automodel-queries`]);
}

View File

@@ -17,8 +17,8 @@ import {
import { ProgressCallback, withProgress } from "../common/vscode/progress";
import { QueryRunner } from "../query-server";
import {
showAndLogExceptionWithTelemetry,
showAndLogErrorMessage,
showAndLogExceptionWithTelemetry,
} from "../common/logging";
import { DatabaseItem, DatabaseManager } from "../databases/local-databases";
import { CodeQLCliServer } from "../codeql-cli/cli";
@@ -34,11 +34,7 @@ import {
import { Method, Usage } from "./method";
import { ModeledMethod } from "./modeled-method";
import { ExtensionPack } from "./shared/extension-pack";
import {
showFlowGeneration,
showLlmGeneration,
showMultipleModels,
} from "../config";
import { ModelConfigListener } from "../config";
import { Mode } from "./shared/mode";
import { loadModeledMethods, saveModeledMethods } from "./modeled-method-fs";
import { pickExtensionPack } from "./extension-pack-picker";
@@ -58,6 +54,7 @@ export class ModelEditorView extends AbstractWebview<
protected readonly app: App,
private readonly modelingStore: ModelingStore,
private readonly viewTracker: ModelEditorViewTracker<ModelEditorView>,
private readonly modelConfig: ModelConfigListener,
private readonly databaseManager: DatabaseManager,
private readonly cliServer: CodeQLCliServer,
private readonly queryRunner: QueryRunner,
@@ -71,6 +68,7 @@ export class ModelEditorView extends AbstractWebview<
this.modelingStore.initializeStateForDb(databaseItem);
this.registerToModelingStoreEvents();
this.registerToModelConfigEvents();
this.viewTracker.registerView(this);
@@ -334,15 +332,15 @@ export class ModelEditorView extends AbstractWebview<
private async setViewState(): Promise<void> {
const showLlmButton =
this.databaseItem.language === "java" && showLlmGeneration();
this.databaseItem.language === "java" && this.modelConfig.llmGeneration;
await this.postMessage({
t: "setModelEditorViewState",
viewState: {
extensionPack: this.extensionPack,
showFlowGeneration: showFlowGeneration(),
showFlowGeneration: this.modelConfig.flowGeneration,
showLlmButton,
showMultipleModels: showMultipleModels(),
showMultipleModels: this.modelConfig.showMultipleModels,
mode: this.mode,
},
});
@@ -481,6 +479,7 @@ export class ModelEditorView extends AbstractWebview<
const modelFile = await pickExtensionPack(
this.cliServer,
addedDatabase,
this.modelConfig,
this.app.logger,
progress,
3,
@@ -493,6 +492,7 @@ export class ModelEditorView extends AbstractWebview<
this.app,
this.modelingStore,
this.viewTracker,
this.modelConfig,
this.databaseManager,
this.cliServer,
this.queryRunner,
@@ -614,6 +614,14 @@ export class ModelEditorView extends AbstractWebview<
);
}
private registerToModelConfigEvents() {
this.push(
this.modelConfig.onDidChangeConfiguration(() => {
void this.setViewState();
}),
);
}
private addModeledMethods(modeledMethods: Record<string, ModeledMethod>) {
this.modelingStore.addModeledMethods(this.databaseItem, modeledMethods);

View File

@@ -1,10 +1,4 @@
import {
ConfigurationScope,
Uri,
workspace,
WorkspaceConfiguration as VSCodeWorkspaceConfiguration,
WorkspaceFolder,
} from "vscode";
import { Uri, workspace, WorkspaceFolder } from "vscode";
import { dump as dumpYaml, load as loadYaml } from "js-yaml";
import { outputFile, readFile } from "fs-extra";
import { join } from "path";
@@ -14,7 +8,8 @@ import { QlpacksInfo } from "../../../../src/codeql-cli/cli";
import { pickExtensionPack } from "../../../../src/model-editor/extension-pack-picker";
import { ExtensionPack } from "../../../../src/model-editor/shared/extension-pack";
import { createMockLogger } from "../../../__mocks__/loggerMock";
import { vscodeGetConfigurationMock } from "../../test-config";
import { ModelConfig } from "../../../../src/config";
import { mockedObject } from "../../utils/mocking.helpers";
describe("pickExtensionPack", () => {
let tmpDir: string;
@@ -32,6 +27,7 @@ describe("pickExtensionPack", () => {
let workspaceFoldersSpy: jest.SpyInstance;
let additionalPacks: string[];
let workspaceFolder: WorkspaceFolder;
let modelConfig: ModelConfig;
const logger = createMockLogger();
const maxStep = 4;
@@ -67,41 +63,20 @@ describe("pickExtensionPack", () => {
workspaceFoldersSpy = jest
.spyOn(workspace, "workspaceFolders", "get")
.mockReturnValue([workspaceFolder]);
modelConfig = mockedObject<ModelConfig>({
getExtensionsDirectory: jest.fn().mockReturnValue(undefined),
});
});
it("selects an existing extension pack", async () => {
vscodeGetConfigurationMock.mockImplementation(
(
section?: string,
scope?: ConfigurationScope | null,
): VSCodeWorkspaceConfiguration => {
expect(section).toEqual("codeQL.model");
expect((scope as any)?.languageId).toEqual("java");
return {
get: (key: string) => {
expect(key).toEqual("extensionsDirectory");
return undefined;
},
has: (key: string) => {
return key === "extensionsDirectory";
},
inspect: () => {
throw new Error("inspect not implemented");
},
update: () => {
throw new Error("update not implemented");
},
};
},
);
const cliServer = mockCliServer(qlPacks);
expect(
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -112,35 +87,10 @@ describe("pickExtensionPack", () => {
additionalPacks,
true,
);
expect(modelConfig.getExtensionsDirectory).toHaveBeenCalledWith("java");
});
it("creates a new extension pack using default extensions directory", async () => {
vscodeGetConfigurationMock.mockImplementation(
(
section?: string,
scope?: ConfigurationScope | null,
): VSCodeWorkspaceConfiguration => {
expect(section).toEqual("codeQL.model");
expect((scope as any)?.languageId).toEqual("java");
return {
get: (key: string) => {
expect(key).toEqual("extensionsDirectory");
return undefined;
},
has: (key: string) => {
return key === "extensionsDirectory";
},
inspect: () => {
throw new Error("inspect not implemented");
},
update: () => {
throw new Error("update not implemented");
},
};
},
);
const tmpDir = await dir({
unsafeCleanup: true,
});
@@ -183,6 +133,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -199,6 +150,7 @@ describe("pickExtensionPack", () => {
dataExtensions: ["models/**/*.yml"],
});
expect(cliServer.resolveQlpacks).toHaveBeenCalled();
expect(modelConfig.getExtensionsDirectory).toHaveBeenCalledWith("java");
expect(
loadYaml(await readFile(join(newPackDir, "codeql-pack.yml"), "utf8")),
@@ -223,31 +175,9 @@ describe("pickExtensionPack", () => {
"my-custom-extensions-directory",
);
vscodeGetConfigurationMock.mockImplementation(
(
section?: string,
scope?: ConfigurationScope | null,
): VSCodeWorkspaceConfiguration => {
expect(section).toEqual("codeQL.model");
expect((scope as any)?.languageId).toEqual("java");
return {
get: (key: string) => {
expect(key).toEqual("extensionsDirectory");
return configExtensionsDir;
},
has: (key: string) => {
return key === "extensionsDirectory";
},
inspect: () => {
throw new Error("inspect not implemented");
},
update: () => {
throw new Error("update not implemented");
},
};
},
);
const modelConfig = mockedObject<ModelConfig>({
getExtensionsDirectory: jest.fn().mockReturnValue(configExtensionsDir),
});
const newPackDir = join(configExtensionsDir, "vscode-codeql-java");
@@ -257,6 +187,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -273,6 +204,7 @@ describe("pickExtensionPack", () => {
dataExtensions: ["models/**/*.yml"],
});
expect(cliServer.resolveQlpacks).toHaveBeenCalled();
expect(modelConfig.getExtensionsDirectory).toHaveBeenCalledWith("java");
expect(
loadYaml(await readFile(join(newPackDir, "codeql-pack.yml"), "utf8")),
@@ -299,6 +231,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -324,6 +257,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -351,6 +285,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -388,6 +323,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -425,6 +361,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -465,6 +402,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,
@@ -522,6 +460,7 @@ describe("pickExtensionPack", () => {
await pickExtensionPack(
cliServer,
databaseItem,
modelConfig,
logger,
progress,
maxStep,

View File

@@ -8,6 +8,7 @@ import { QueryLanguage } from "../../../../src/common/query-language";
import { Mode } from "../../../../src/model-editor/shared/mode";
import { mockedObject } from "../../utils/mocking.helpers";
import { CodeQLCliServer } from "../../../../src/codeql-cli/cli";
import { ModelConfig } from "../../../../src/config";
describe("setUpPack", () => {
let queryDir: string;
@@ -32,8 +33,11 @@ describe("setUpPack", () => {
packInstall: jest.fn(),
resolveQueriesInSuite: jest.fn().mockResolvedValue([]),
});
const modelConfig = mockedObject<ModelConfig>({
llmGeneration: false,
});
await setUpPack(cliServer, queryDir, language);
await setUpPack(cliServer, queryDir, language, modelConfig);
const queryFiles = await readdir(queryDir);
expect(queryFiles.sort()).toEqual(
@@ -89,8 +93,11 @@ describe("setUpPack", () => {
.fn()
.mockResolvedValue(["/a/b/c/ApplicationModeEndpoints.ql"]),
});
const modelConfig = mockedObject<ModelConfig>({
llmGeneration: false,
});
await setUpPack(cliServer, queryDir, language);
await setUpPack(cliServer, queryDir, language, modelConfig);
const queryFiles = await readdir(queryDir);
expect(queryFiles.sort()).toEqual(["codeql-pack.yml"].sort());

View File

@@ -10,11 +10,15 @@ import { QueryRunner } from "../../../../src/query-server";
import { ExtensionPack } from "../../../../src/model-editor/shared/extension-pack";
import { createMockModelingStore } from "../../../__mocks__/model-editor/modelingStoreMock";
import { createMockModelEditorViewTracker } from "../../../__mocks__/model-editor/modelEditorViewTrackerMock";
import { ModelConfigListener } from "../../../../src/config";
describe("ModelEditorView", () => {
const app = createMockApp({});
const modelingStore = createMockModelingStore();
const viewTracker = createMockModelEditorViewTracker();
const modelConfig = mockedObject<ModelConfigListener>({
onDidChangeConfiguration: jest.fn(),
});
const databaseManager = mockEmptyDatabaseManager();
const cliServer = mockedObject<CodeQLCliServer>({});
const queryRunner = mockedObject<QueryRunner>({});
@@ -41,6 +45,7 @@ describe("ModelEditorView", () => {
app,
modelingStore,
viewTracker,
modelConfig,
databaseManager,
cliServer,
queryRunner,