diff --git a/extensions/ql-vscode/src/databaseFetcher.ts b/extensions/ql-vscode/src/databaseFetcher.ts index 69ab63d24..a8994d955 100644 --- a/extensions/ql-vscode/src/databaseFetcher.ts +++ b/extensions/ql-vscode/src/databaseFetcher.ts @@ -1,7 +1,7 @@ import fetch, { Response } from "node-fetch"; import { zip } from "zip-a-folder"; import { Open } from "unzipper"; -import { Uri, CancellationToken, window } from "vscode"; +import { Uri, CancellationToken, window, InputBoxOptions } from "vscode"; import { CodeQLCliServer } from "./cli"; import { ensureDir, @@ -92,17 +92,7 @@ export async function promptImportGithubDatabase( token: CancellationToken, cli?: CodeQLCliServer, ): Promise { - progress({ - message: "Choose repository", - step: 1, - maxStep: 2, - }); - const githubRepo = await window.showInputBox({ - title: - 'Enter a GitHub repository URL or "name with owner" (e.g. https://github.com/github/codeql or github/codeql)', - placeHolder: "https://github.com// or /", - ignoreFocusOut: true, - }); + const githubRepo = await askForGitHubRepo(progress); if (!githubRepo) { return; } @@ -128,6 +118,30 @@ export async function promptImportGithubDatabase( return; } +export async function askForGitHubRepo( + progress?: ProgressCallback, + suggestedValue?: string, +): Promise { + progress?.({ + message: "Choose repository", + step: 1, + maxStep: 2, + }); + + const options: InputBoxOptions = { + title: + 'Enter a GitHub repository URL or "name with owner" (e.g. https://github.com/github/codeql or github/codeql)', + placeHolder: "https://github.com// or /", + ignoreFocusOut: true, + }; + + if (suggestedValue) { + options.value = suggestedValue; + } + + return await window.showInputBox(options); +} + /** * Downloads a database from GitHub * diff --git a/extensions/ql-vscode/src/local-databases.ts b/extensions/ql-vscode/src/local-databases.ts index c1c1c7f05..97e7dd931 100644 --- a/extensions/ql-vscode/src/local-databases.ts +++ b/extensions/ql-vscode/src/local-databases.ts @@ -911,6 +911,17 @@ export class DatabaseManager extends DisposableObject { return dbs[0]; } + public async digForDatabaseWithSameLanguage( + language: string, + ): Promise { + const dbItems = this.databaseItems || []; + const dbs = dbItems.filter((db) => db.language === language); + if (dbs.length === 0) { + return undefined; + } + return dbs[0]; + } + /** * Returns the index of the workspace folder that corresponds to the source archive of `item` * if there is one, and -1 otherwise. diff --git a/extensions/ql-vscode/src/skeleton-query-wizard.ts b/extensions/ql-vscode/src/skeleton-query-wizard.ts index 54a265fbe..d1e5f8865 100644 --- a/extensions/ql-vscode/src/skeleton-query-wizard.ts +++ b/extensions/ql-vscode/src/skeleton-query-wizard.ts @@ -9,7 +9,7 @@ import { getErrorMessage } from "./pure/helpers-pure"; import { QlPackGenerator } from "./qlpack-generator"; import { DatabaseManager } from "./local-databases"; import * as databaseFetcher from "./databaseFetcher"; -import { ProgressCallback } from "./progress"; +import { ProgressCallback, UserCancellationException } from "./progress"; type QueryLanguagesToDatabaseMap = Record; @@ -207,9 +207,17 @@ export class SkeletonQueryWizard { }); const githubRepoNwo = QUERY_LANGUAGE_TO_DATABASE_REPO[this.language]; + const chosenRepo = await databaseFetcher.askForGitHubRepo( + undefined, + githubRepoNwo, + ); + + if (!chosenRepo) { + throw new UserCancellationException("No GitHub repository provided"); + } await databaseFetcher.downloadGitHubDatabase( - githubRepoNwo, + chosenRepo, this.databaseManager, this.storagePath, this.credentials, @@ -231,17 +239,30 @@ export class SkeletonQueryWizard { const databaseNwo = QUERY_LANGUAGE_TO_DATABASE_REPO[this.language]; - const databaseItem = await this.databaseManager.digForDatabaseItem( + // Check that we haven't already downloaded a database for this language + const existingDatabaseItem = await this.databaseManager.digForDatabaseItem( this.language, databaseNwo, ); - if (databaseItem) { + if (existingDatabaseItem) { // select the found database - await this.databaseManager.setCurrentDatabaseItem(databaseItem); + await this.databaseManager.setCurrentDatabaseItem(existingDatabaseItem); } else { - // download new database and select it - await this.downloadDatabase(); + const sameLanguageDatabaseItem = + await this.databaseManager.digForDatabaseWithSameLanguage( + this.language, + ); + + if (sameLanguageDatabaseItem) { + // select the found database + await this.databaseManager.setCurrentDatabaseItem( + sameLanguageDatabaseItem, + ); + } else { + // download new database and select it + await this.downloadDatabase(); + } } } } diff --git a/extensions/ql-vscode/test/vscode-tests/cli-integration/skeleton-query-wizard.test.ts b/extensions/ql-vscode/test/vscode-tests/cli-integration/skeleton-query-wizard.test.ts index 2202a3487..8e9685cfd 100644 --- a/extensions/ql-vscode/test/vscode-tests/cli-integration/skeleton-query-wizard.test.ts +++ b/extensions/ql-vscode/test/vscode-tests/cli-integration/skeleton-query-wizard.test.ts @@ -32,6 +32,9 @@ describe("SkeletonQueryWizard", () => { let downloadGitHubDatabaseSpy: jest.SpiedFunction< typeof databaseFetcher.downloadGitHubDatabase >; + let askForGitHubRepoSpy: jest.SpiedFunction< + typeof databaseFetcher.askForGitHubRepo + >; let openTextDocumentSpy: jest.SpiedFunction< typeof workspace.openTextDocument >; @@ -42,6 +45,7 @@ describe("SkeletonQueryWizard", () => { const mockDatabaseManager = mockedObject({ setCurrentDatabaseItem: jest.fn(), digForDatabaseItem: jest.fn(), + digForDatabaseWithSameLanguage: jest.fn(), }); const mockCli = mockedObject({ resolveLanguages: jest @@ -58,6 +62,8 @@ describe("SkeletonQueryWizard", () => { getSupportedLanguages: jest.fn(), }); + jest.spyOn(extLogger, "log").mockResolvedValue(undefined); + beforeEach(async () => { dir = tmp.dirSync({ prefix: "skeleton_query_wizard_", @@ -101,6 +107,10 @@ describe("SkeletonQueryWizard", () => { mockDatabaseManager, token, ); + + askForGitHubRepoSpy = jest + .spyOn(databaseFetcher, "askForGitHubRepo") + .mockResolvedValue(QUERY_LANGUAGE_TO_DATABASE_REPO[chosenLanguage]); }); afterEach(async () => { @@ -250,10 +260,30 @@ describe("SkeletonQueryWizard", () => { .mockResolvedValue(undefined); }); - it("should download a new database for language", async () => { - await wizard.execute(); + describe("if the user choses to downloaded the suggested database from GitHub", () => { + it("should download a new database for language", async () => { + await wizard.execute(); - expect(downloadGitHubDatabaseSpy).toHaveBeenCalled(); + expect(askForGitHubRepoSpy).toHaveBeenCalled(); + expect(downloadGitHubDatabaseSpy).toHaveBeenCalled(); + }); + }); + + describe("if the user choses to download a different database from GitHub than the one suggested", () => { + beforeEach(() => { + const chosenGitHubRepo = "pickles-owner/pickles-repo"; + + askForGitHubRepoSpy = jest + .spyOn(databaseFetcher, "askForGitHubRepo") + .mockResolvedValue(chosenGitHubRepo); + }); + + it("should download the newly chosen database", async () => { + await wizard.execute(); + + expect(askForGitHubRepoSpy).toHaveBeenCalled(); + expect(downloadGitHubDatabaseSpy).toHaveBeenCalled(); + }); }); }); });