Merge pull request #2633 from github/koesie10/automodel-v2

Add LLM functionality using auto-model V2
This commit is contained in:
Koen Vlaswinkel
2023-07-27 10:02:55 +02:00
committed by GitHub
11 changed files with 722 additions and 53 deletions

View File

@@ -0,0 +1,12 @@
import { promisify } from "util";
import { gzip, gunzip } from "zlib";
/**
* Promisified version of zlib.gzip
*/
export const gzipEncode = promisify(gzip);
/**
* Promisified version of zlib.gunzip
*/
export const gzipDecode = promisify(gunzip);

View File

@@ -704,6 +704,7 @@ export function showQueriesPanel(): boolean {
const DATA_EXTENSIONS = new Setting("dataExtensions", ROOT_SETTING);
const LLM_GENERATION = new Setting("llmGeneration", DATA_EXTENSIONS);
const LLM_GENERATION_V2 = new Setting("llmGenerationV2", DATA_EXTENSIONS);
const FRAMEWORK_MODE = new Setting("frameworkMode", DATA_EXTENSIONS);
const DISABLE_AUTO_NAME_EXTENSION_PACK = new Setting(
"disableAutoNameExtensionPack",
@@ -718,6 +719,10 @@ export function showLlmGeneration(): boolean {
return !!LLM_GENERATION.getValue<boolean>();
}
export function useLlmGenerationV2(): boolean {
return !!LLM_GENERATION_V2.getValue<boolean>();
}
export function enableFrameworkMode(): boolean {
return !!FRAMEWORK_MODE.getValue<boolean>();
}

View File

@@ -0,0 +1,34 @@
import { Credentials } from "../common/authentication";
import { OctokitResponse } from "@octokit/types";
export enum AutomodelMode {
Unspecified = "AUTOMODEL_MODE_UNSPECIFIED",
Framework = "AUTOMODEL_MODE_FRAMEWORK",
Application = "AUTOMODEL_MODE_APPLICATION",
}
export interface ModelRequest {
mode: AutomodelMode;
// Base64-encoded GZIP-compressed SARIF log
candidates: string;
}
export interface ModelResponse {
models: string;
}
export async function autoModelV2(
credentials: Credentials,
request: ModelRequest,
): Promise<ModelResponse> {
const octokit = await credentials.getOctokit();
const response: OctokitResponse<ModelResponse> = await octokit.request(
"POST /repos/github/codeql/code-scanning/codeql/auto-model",
{
data: request,
},
);
return response.data;
}

View File

@@ -0,0 +1,230 @@
import { CodeQLCliServer, SourceInfo } from "../codeql-cli/cli";
import { QueryRunner } from "../query-server";
import { DatabaseItem } from "../databases/local-databases";
import { ProgressCallback } from "../common/vscode/progress";
import * as Sarif from "sarif";
import { qlpackOfDatabase, resolveQueries } from "../local-queries";
import { extLogger } from "../common/logging/vscode";
import { Mode } from "./shared/mode";
import { QlPacksForLanguage } from "../databases/qlpack";
import { createLockFileForStandardQuery } from "../local-queries/standard-queries";
import { CancellationToken, CancellationTokenSource } from "vscode";
import { getOnDiskWorkspaceFolders } from "../common/vscode/workspace-folders";
import { showAndLogExceptionWithTelemetry, TeeLogger } from "../common/logging";
import { QueryResultType } from "../query-server/new-messages";
import { telemetryListener } from "../common/vscode/telemetry";
import { redactableError } from "../common/errors";
import { interpretResultsSarif } from "../query-results";
import { join } from "path";
import { assertNever } from "../common/helpers-pure";
type AutoModelQueryOptions = {
queryTag: string;
mode: Mode;
cliServer: CodeQLCliServer;
queryRunner: QueryRunner;
databaseItem: DatabaseItem;
qlpack: QlPacksForLanguage;
sourceInfo: SourceInfo | undefined;
extensionPacks: string[];
queryStorageDir: string;
progress: ProgressCallback;
token: CancellationToken;
};
function modeTag(mode: Mode): string {
switch (mode) {
case Mode.Application:
return "application-mode";
case Mode.Framework:
return "framework-mode";
default:
assertNever(mode);
}
}
async function runAutoModelQuery({
queryTag,
mode,
cliServer,
queryRunner,
databaseItem,
qlpack,
sourceInfo,
extensionPacks,
queryStorageDir,
progress,
token,
}: AutoModelQueryOptions): Promise<Sarif.Log | undefined> {
// First, resolve the query that we want to run.
// All queries are tagged like this:
// internal extract automodel <mode> <queryTag>
// Example: internal extract automodel framework-mode candidates
const queries = await resolveQueries(
cliServer,
qlpack,
`Extract automodel ${queryTag}`,
{
kind: "problem",
"tags contain all": ["automodel", modeTag(mode), ...queryTag.split(" ")],
},
);
if (queries.length > 1) {
throw new Error(
`Found multiple auto model queries for ${mode} ${queryTag}. Can't continue`,
);
}
if (queries.length === 0) {
throw new Error(
`Did not found any auto model queries for ${mode} ${queryTag}. Can't continue`,
);
}
const queryPath = queries[0];
const { cleanup: cleanupLockFile } = await createLockFileForStandardQuery(
cliServer,
queryPath,
);
// Get metadata for the query. This is required to interpret the results. We already know the kind is problem
// (because of the constraint in resolveQueries), so we don't need any more checks on the metadata.
const metadata = await cliServer.resolveMetadata(queryPath);
const queryRun = queryRunner.createQueryRun(
databaseItem.databaseUri.fsPath,
{
queryPath,
quickEvalPosition: undefined,
quickEvalCountOnly: false,
},
false,
getOnDiskWorkspaceFolders(),
extensionPacks,
queryStorageDir,
undefined,
undefined,
);
const completedQuery = await queryRun.evaluate(
progress,
token,
new TeeLogger(queryRunner.logger, queryRun.outputDir.logPath),
);
await cleanupLockFile?.();
if (completedQuery.resultType !== QueryResultType.SUCCESS) {
void showAndLogExceptionWithTelemetry(
extLogger,
telemetryListener,
redactableError`Auto-model query ${queryTag} failed: ${
completedQuery.message ?? "No message"
}`,
);
return;
}
const interpretedResultsPath = join(
queryStorageDir,
`interpreted-results-${queryTag.replaceAll(" ", "-")}-${queryRun.id}.sarif`,
);
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- We only need the actual SARIF data, not the extra fields added by SarifInterpretationData
const { t, sortState, ...sarif } = await interpretResultsSarif(
cliServer,
metadata,
{
resultsPath: completedQuery.outputDir.bqrsPath,
interpretedResultsPath,
},
sourceInfo,
["--sarif-add-snippets"],
);
return sarif;
}
type AutoModelQueriesOptions = {
mode: Mode;
cliServer: CodeQLCliServer;
queryRunner: QueryRunner;
databaseItem: DatabaseItem;
queryStorageDir: string;
progress: ProgressCallback;
};
export type AutoModelQueriesResult = {
candidates: Sarif.Log;
};
export async function runAutoModelQueries({
mode,
cliServer,
queryRunner,
databaseItem,
queryStorageDir,
progress,
}: AutoModelQueriesOptions): Promise<AutoModelQueriesResult | undefined> {
// maxStep for this part is 1500
const maxStep = 1500;
const cancellationTokenSource = new CancellationTokenSource();
const qlpack = await qlpackOfDatabase(cliServer, databaseItem);
// CodeQL needs to have access to the database to be able to retrieve the
// snippets from it. The source location prefix is used to determine the
// base path of the database.
const sourceLocationPrefix = await databaseItem.getSourceLocationPrefix(
cliServer,
);
const sourceArchiveUri = databaseItem.sourceArchive;
const sourceInfo =
sourceArchiveUri === undefined
? undefined
: {
sourceArchive: sourceArchiveUri.fsPath,
sourceLocationPrefix,
};
const additionalPacks = getOnDiskWorkspaceFolders();
const extensionPacks = Object.keys(
await cliServer.resolveQlpacks(additionalPacks, true),
);
progress({
step: 0,
maxStep,
message: "Finding candidates and examples",
});
const candidates = await runAutoModelQuery({
mode,
queryTag: "candidates",
cliServer,
queryRunner,
databaseItem,
qlpack,
sourceInfo,
extensionPacks,
queryStorageDir,
progress: (update) => {
progress({
step: update.step,
maxStep,
message: "Finding candidates and examples",
});
},
token: cancellationTokenSource.token,
});
if (!candidates) {
return undefined;
}
return {
candidates,
};
}

View File

@@ -0,0 +1,40 @@
import { AutomodelMode, ModelRequest } from "./auto-model-api-v2";
import { Mode } from "./shared/mode";
import { AutoModelQueriesResult } from "./auto-model-codeml-queries";
import { assertNever } from "../common/helpers-pure";
import * as Sarif from "sarif";
import { gzipEncode } from "../common/zlib";
/**
* Encode a SARIF log to the format expected by the server: JSON, GZIP-compressed, base64-encoded
* @param log SARIF log to encode
* @returns base64-encoded GZIP-compressed SARIF log
*/
export async function encodeSarif(log: Sarif.Log): Promise<string> {
const json = JSON.stringify(log);
const buffer = Buffer.from(json, "utf-8");
const compressed = await gzipEncode(buffer);
return compressed.toString("base64");
}
export async function createAutoModelV2Request(
mode: Mode,
result: AutoModelQueriesResult,
): Promise<ModelRequest> {
let requestMode: AutomodelMode;
switch (mode) {
case Mode.Application:
requestMode = AutomodelMode.Application;
break;
case Mode.Framework:
requestMode = AutomodelMode.Framework;
break;
default:
assertNever(mode);
}
return {
mode: requestMode,
candidates: await encodeSarif(result.candidates),
};
}

View File

@@ -35,17 +35,30 @@ import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { ExtensionPack } from "./shared/extension-pack";
import { autoModel, ModelRequest, ModelResponse } from "./auto-model-api";
import {
autoModelV2,
ModelRequest as ModelRequestV2,
ModelResponse as ModelResponseV2,
} from "./auto-model-api-v2";
import {
createAutoModelRequest,
parsePredictedClassifications,
} from "./auto-model";
import { enableFrameworkMode, showLlmGeneration } from "../config";
import {
enableFrameworkMode,
showLlmGeneration,
useLlmGenerationV2,
} from "../config";
import { getAutoModelUsages } from "./auto-model-usages-query";
import { Mode } from "./shared/mode";
import { loadModeledMethods, saveModeledMethods } from "./modeled-method-fs";
import { join } from "path";
import { pickExtensionPack } from "./extension-pack-picker";
import { getLanguageDisplayName } from "../common/query-language";
import { runAutoModelQueries } from "./auto-model-codeml-queries";
import { createAutoModelV2Request } from "./auto-model-v2";
import { load as loadYaml } from "js-yaml";
import { loadDataExtensionYaml } from "./yaml";
export class DataExtensionsEditorView extends AbstractWebview<
ToDataExtensionsEditorMessage,
@@ -361,16 +374,66 @@ export class DataExtensionsEditorView extends AbstractWebview<
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): Promise<void> {
await withProgress(
async (progress) => {
const maxStep = 3000;
await withProgress(async (progress) => {
const maxStep = 3000;
progress({
step: 0,
maxStep,
message: "Retrieving usages",
});
let predictedModeledMethods: Record<string, ModeledMethod>;
if (useLlmGenerationV2()) {
const usages = await runAutoModelQueries({
mode: this.mode,
cliServer: this.cliServer,
queryRunner: this.queryRunner,
queryStorageDir: this.queryStorageDir,
databaseItem: this.databaseItem,
progress: (update) => progress({ ...update, maxStep }),
});
if (!usages) {
return;
}
progress({
step: 0,
step: 1800,
maxStep,
message: "Retrieving usages",
message: "Creating request",
});
const request = await createAutoModelV2Request(this.mode, usages);
progress({
step: 2000,
maxStep,
message: "Sending request",
});
const response = await this.callAutoModelApiV2(request);
if (!response) {
return;
}
progress({
step: 2500,
maxStep,
message: "Parsing response",
});
const models = loadYaml(response.models, {
filename: "auto-model.yml",
});
const modeledMethods = loadDataExtensionYaml(models);
if (!modeledMethods) {
return;
}
predictedModeledMethods = modeledMethods;
} else {
const usages = await getAutoModelUsages({
cliServer: this.cliServer,
queryRunner: this.queryRunner,
@@ -410,23 +473,22 @@ export class DataExtensionsEditorView extends AbstractWebview<
message: "Parsing response",
});
const predictedModeledMethods = parsePredictedClassifications(
predictedModeledMethods = parsePredictedClassifications(
response.predicted || [],
);
}
progress({
step: 2800,
maxStep,
message: "Applying results",
});
progress({
step: 2800,
maxStep,
message: "Applying results",
});
await this.postMessage({
t: "addModeledMethods",
modeledMethods: predictedModeledMethods,
});
},
{ cancellable: false },
);
await this.postMessage({
t: "addModeledMethods",
modeledMethods: predictedModeledMethods,
});
});
}
private async modelDependency(): Promise<void> {
@@ -505,4 +567,23 @@ export class DataExtensionsEditorView extends AbstractWebview<
}
}
}
private async callAutoModelApiV2(
request: ModelRequestV2,
): Promise<ModelResponseV2 | null> {
try {
return await autoModelV2(this.app.credentials, request);
} catch (e) {
if (e instanceof RequestError && e.status === 429) {
void showAndLogExceptionWithTelemetry(
this.app.logger,
this.app.telemetry,
redactableError(e)`Rate limit hit, please try again soon.`,
);
return null;
} else {
throw e;
}
}
}
}

View File

@@ -0,0 +1,82 @@
import {
createAutoModelV2Request,
encodeSarif,
} from "../../../src/data-extensions-editor/auto-model-v2";
import { Mode } from "../../../src/data-extensions-editor/shared/mode";
import { AutomodelMode } from "../../../src/data-extensions-editor/auto-model-api-v2";
import { AutoModelQueriesResult } from "../../../src/data-extensions-editor/auto-model-codeml-queries";
import * as sarif from "sarif";
import { gzipDecode } from "../../../src/common/zlib";
describe("createAutoModelV2Request", () => {
const createSarifLog = (queryId: string): sarif.Log => {
return {
version: "2.1.0",
$schema: "http://json.schemastore.org/sarif-2.1.0-rtm.4",
runs: [
{
tool: {
driver: {
name: "CodeQL",
rules: [
{
id: queryId,
},
],
},
},
results: [
{
message: {
text: "msg",
},
locations: [
{
physicalLocation: {
contextRegion: {
startLine: 10,
endLine: 12,
snippet: {
text: "Foo",
},
},
region: {
startLine: 10,
startColumn: 1,
endColumn: 3,
},
artifactLocation: {
uri: "foo.js",
},
},
},
],
},
],
},
],
};
};
const result: AutoModelQueriesResult = {
candidates: createSarifLog(
"java/ml/extract-automodel-application-candidates",
),
};
it("creates a matching request", async () => {
expect(await createAutoModelV2Request(Mode.Application, result)).toEqual({
mode: AutomodelMode.Application,
candidates: await encodeSarif(result.candidates),
});
});
it("can decode the SARIF", async () => {
const request = await createAutoModelV2Request(Mode.Application, result);
const decoded = Buffer.from(request.candidates, "base64");
const decompressed = await gzipDecode(decoded);
const json = decompressed.toString("utf-8");
const parsed = JSON.parse(json);
expect(parsed).toEqual(result.candidates);
});
});

View File

@@ -0,0 +1,175 @@
import { createMockLogger } from "../../../__mocks__/loggerMock";
import {
DatabaseItem,
DatabaseKind,
} from "../../../../src/databases/local-databases";
import { file } from "tmp-promise";
import { QueryResultType } from "../../../../src/query-server/new-messages";
import { runAutoModelQueries } from "../../../../src/data-extensions-editor/auto-model-codeml-queries";
import { Mode } from "../../../../src/data-extensions-editor/shared/mode";
import { mockedObject, mockedUri } from "../../utils/mocking.helpers";
import { CodeQLCliServer } from "../../../../src/codeql-cli/cli";
import { QueryRunner } from "../../../../src/query-server";
import * as queryResolver from "../../../../src/local-queries/query-resolver";
import * as standardQueries from "../../../../src/local-queries/standard-queries";
describe("runAutoModelQueries", () => {
const qlpack = {
dbschemePack: "dbschemePack",
dbschemePackIsLibraryPack: false,
};
let resolveQueriesSpy: jest.SpiedFunction<
typeof queryResolver.resolveQueries
>;
let createLockFileForStandardQuerySpy: jest.SpiedFunction<
typeof standardQueries.createLockFileForStandardQuery
>;
beforeEach(() => {
jest.spyOn(queryResolver, "qlpackOfDatabase").mockResolvedValue(qlpack);
resolveQueriesSpy = jest
.spyOn(queryResolver, "resolveQueries")
.mockImplementation(async (_cliServer, _qlPack, _name, constraints) => {
if (constraints["tags contain all"]?.includes("candidates")) {
return ["/a/b/c/ql/candidates.ql"];
}
if (constraints["tags contain all"]?.includes("positive")) {
return ["/a/b/c/ql/positive-examples.ql"];
}
if (constraints["tags contain all"]?.includes("negative")) {
return ["/a/b/c/ql/negative-examples.ql"];
}
return [];
});
createLockFileForStandardQuerySpy = jest
.spyOn(standardQueries, "createLockFileForStandardQuery")
.mockResolvedValue({});
});
it("should run the query and return the results", async () => {
const logPath = (await file()).path;
const bqrsPath = (await file()).path;
const outputDir = {
logPath,
bqrsPath,
};
const options = {
mode: Mode.Application,
cliServer: mockedObject<CodeQLCliServer>({
resolveQlpacks: jest.fn().mockResolvedValue({
"/a/b/c/my-extension-pack": {},
}),
resolveMetadata: jest.fn().mockResolvedValue({
kind: "problem",
}),
interpretBqrsSarif: jest.fn().mockResolvedValue({
version: "2.1.0",
$schema: "http://json.schemastore.org/sarif-2.1.0-rtm.4",
runs: [
{
tool: {
driver: {
name: "CodeQL",
},
},
results: [
{
message: {
text: "msg",
},
locations: [
{
physicalLocation: {
contextRegion: {
startLine: 10,
endLine: 12,
snippet: {
text: "Foo",
},
},
region: {
startLine: 10,
startColumn: 1,
endColumn: 3,
},
artifactLocation: {
uri: "foo.js",
},
},
},
],
},
],
},
],
}),
}),
queryRunner: mockedObject<QueryRunner>({
createQueryRun: jest.fn().mockReturnValue({
evaluate: jest.fn().mockResolvedValue({
resultType: QueryResultType.SUCCESS,
outputDir,
}),
outputDir,
}),
logger: createMockLogger(),
}),
databaseItem: mockedObject<DatabaseItem>({
databaseUri: mockedUri("/a/b/c/src.zip"),
contents: {
kind: DatabaseKind.Database,
name: "foo",
datasetUri: mockedUri(),
},
language: "java",
getSourceLocationPrefix: jest
.fn()
.mockResolvedValue("/home/runner/work/my-repo/my-repo"),
sourceArchive: mockedUri("/a/b/c/src.zip"),
}),
queryStorageDir: "/tmp/queries",
progress: jest.fn(),
};
const result = await runAutoModelQueries(options);
expect(result).not.toBeUndefined();
expect(options.cliServer.resolveQlpacks).toHaveBeenCalledTimes(1);
expect(options.cliServer.resolveQlpacks).toHaveBeenCalledWith([], true);
expect(resolveQueriesSpy).toHaveBeenCalledTimes(1);
expect(resolveQueriesSpy).toHaveBeenCalledWith(
options.cliServer,
qlpack,
"Extract automodel candidates",
{
kind: "problem",
"tags contain all": ["automodel", "application-mode", "candidates"],
},
);
expect(createLockFileForStandardQuerySpy).toHaveBeenCalledTimes(1);
expect(createLockFileForStandardQuerySpy).toHaveBeenCalledWith(
options.cliServer,
"/a/b/c/ql/candidates.ql",
);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledTimes(1);
expect(options.queryRunner.createQueryRun).toHaveBeenCalledWith(
"/a/b/c/src.zip",
{
queryPath: "/a/b/c/ql/candidates.ql",
quickEvalPosition: undefined,
quickEvalCountOnly: false,
},
false,
[],
["/a/b/c/my-extension-pack"],
"/tmp/queries",
undefined,
undefined,
);
});
});

View File

@@ -3,7 +3,6 @@ import {
runQuery,
} from "../../../../src/data-extensions-editor/external-api-usage-query";
import { createMockLogger } from "../../../__mocks__/loggerMock";
import type { Uri } from "vscode";
import { DatabaseKind } from "../../../../src/databases/local-databases";
import { file } from "tmp-promise";
import { QueryResultType } from "../../../../src/query-server/new-messages";
@@ -16,19 +15,7 @@ import { RedactableError } from "../../../../src/common/errors";
import { showAndLogExceptionWithTelemetry } from "../../../../src/common/logging";
import { QueryLanguage } from "../../../../src/common/query-language";
import { Query } from "../../../../src/data-extensions-editor/queries/query";
function createMockUri(path = "/a/b/c/foo"): Uri {
return {
scheme: "file",
authority: "",
path,
query: "",
fragment: "",
fsPath: path,
with: jest.fn(),
toJSON: jest.fn(),
};
}
import { mockedUri } from "../../utils/mocking.helpers";
describe("runQuery", () => {
const cases = Object.keys(fetchExternalApiQueries).flatMap((lang) => {
@@ -74,11 +61,11 @@ describe("runQuery", () => {
logger: createMockLogger(),
},
databaseItem: {
databaseUri: createMockUri("/a/b/c/src.zip"),
databaseUri: mockedUri("/a/b/c/src.zip"),
contents: {
kind: DatabaseKind.Database,
name: "foo",
datasetUri: createMockUri(),
datasetUri: mockedUri(),
},
language,
},

View File

@@ -3,23 +3,29 @@ import { mockDatabaseItem } from "../../../utils/mocking.helpers";
import { tryResolveLocation } from "../../../../../src/databases/local-databases/locations";
describe("tryResolveLocation", () => {
const resolveSourceFile = jest.fn();
const databaseItem = mockDatabaseItem({
resolveSourceFile,
});
beforeEach(() => {
resolveSourceFile.mockReturnValue(Uri.file("abc"));
});
it("should resolve a whole file location", () => {
const databaseItem = mockDatabaseItem();
expect(tryResolveLocation("file://hucairz:0:0:0:0", databaseItem)).toEqual(
new Location(Uri.file("abc"), new Range(0, 0, 0, 0)),
);
});
it("should resolve a five-part location edge case", () => {
const databaseItem = mockDatabaseItem();
expect(tryResolveLocation("file://hucairz:1:1:1:1", databaseItem)).toEqual(
new Location(Uri.file("abc"), new Range(0, 0, 0, 1)),
);
});
it("should resolve a five-part location", () => {
const databaseItem = mockDatabaseItem();
expect(
tryResolveLocation(
{
@@ -33,7 +39,7 @@ describe("tryResolveLocation", () => {
),
).toEqual(
new Location(
Uri.parse("abc"),
Uri.file("abc"),
new Range(new Position(4, 3), new Position(3, 0)),
),
);
@@ -42,8 +48,6 @@ describe("tryResolveLocation", () => {
});
it("should resolve a five-part location with an empty path", () => {
const databaseItem = mockDatabaseItem();
expect(
tryResolveLocation(
{
@@ -59,21 +63,17 @@ describe("tryResolveLocation", () => {
});
it("should resolve a string location for whole file", () => {
const databaseItem = mockDatabaseItem();
expect(tryResolveLocation("file://hucairz:0:0:0:0", databaseItem)).toEqual(
new Location(Uri.parse("abc"), new Range(0, 0, 0, 0)),
new Location(Uri.file("abc"), new Range(0, 0, 0, 0)),
);
expect(databaseItem.resolveSourceFile).toHaveBeenCalledTimes(1);
expect(databaseItem.resolveSourceFile).toHaveBeenCalledWith("hucairz");
});
it("should resolve a string location for five-part location", () => {
const databaseItem = mockDatabaseItem();
expect(tryResolveLocation("file://hucairz:5:4:3:2", databaseItem)).toEqual(
new Location(
Uri.parse("abc"),
Uri.file("abc"),
new Range(new Position(4, 3), new Position(2, 2)),
),
);
@@ -82,8 +82,6 @@ describe("tryResolveLocation", () => {
});
it("should resolve a string location for invalid string", () => {
const databaseItem = mockDatabaseItem();
expect(
tryResolveLocation("file://hucairz:x:y:z:a", databaseItem),
).toBeUndefined();

View File

@@ -1,4 +1,4 @@
import { QuickPickItem, window, Uri } from "vscode";
import type { QuickPickItem, window, Uri } from "vscode";
import { DatabaseItem } from "../../../src/databases/local-databases";
export type DeepPartial<T> = T extends object
@@ -40,6 +40,18 @@ export function mockedObject<T extends object>(
return undefined;
}
// The `asymmetricMatch` is accessed by jest to check if the object is a matcher.
// We don't want to throw an error when this happens.
if (prop === "asymmetricMatch") {
return undefined;
}
// The `Symbol.iterator` is accessed by jest to check if the object is iterable.
// We don't want to throw an error when this happens.
if (prop === Symbol.iterator) {
return undefined;
}
throw new Error(`Method ${String(prop)} not mocked`);
},
});
@@ -49,11 +61,11 @@ export function mockDatabaseItem(
props: DeepPartial<DatabaseItem> = {},
): DatabaseItem {
return mockedObject<DatabaseItem>({
databaseUri: Uri.file("abc"),
databaseUri: mockedUri("abc"),
name: "github/codeql",
language: "javascript",
sourceArchive: undefined,
resolveSourceFile: jest.fn().mockReturnValue(Uri.file("abc")),
resolveSourceFile: jest.fn().mockReturnValue(mockedUri("abc")),
...props,
});
}
@@ -63,3 +75,16 @@ export function mockedQuickPickItem<T extends QuickPickItem | string>(
): Awaited<ReturnType<typeof window.showQuickPick>> {
return value as Awaited<ReturnType<typeof window.showQuickPick>>;
}
export function mockedUri(path = "/a/b/c/foo"): Uri {
return {
scheme: "file",
authority: "",
path,
query: "",
fragment: "",
fsPath: path,
with: jest.fn(),
toJSON: jest.fn(),
};
}