diff --git a/extensions/ql-vscode/src/databases/github-database-prompt.ts b/extensions/ql-vscode/src/databases/github-database-prompt.ts index f5cd79450..235af1f63 100644 --- a/extensions/ql-vscode/src/databases/github-database-prompt.ts +++ b/extensions/ql-vscode/src/databases/github-database-prompt.ts @@ -3,10 +3,7 @@ import { RestEndpointMethodTypes } from "@octokit/plugin-rest-endpoint-methods"; import { Octokit } from "@octokit/rest"; import { showNeverAskAgainDialog } from "../common/vscode/dialog"; import { getLanguageDisplayName } from "../common/query-language"; -import { - downloadGitHubDatabaseFromUrl, - promptForLanguage, -} from "./database-fetcher"; +import { downloadGitHubDatabaseFromUrl } from "./database-fetcher"; import { withProgress } from "../common/vscode/progress"; import { DatabaseManager } from "./local-databases"; import { CodeQLCliServer } from "../codeql-cli/cli"; @@ -77,38 +74,46 @@ export async function promptGitHubDatabaseDownload( return; } - const language = await promptForLanguage(languages, undefined); - if (!language) { + const selectedDatabases = await promptForDatabases(databases); + if (selectedDatabases.length === 0) { return; } - const database = databases.find((database) => database.language === language); - if (!database) { - return; - } + await Promise.all( + selectedDatabases.map((database) => + withProgress( + async (progress) => { + await downloadGitHubDatabaseFromUrl( + database.url, + database.id, + database.created_at, + database.commit_oid ?? null, + owner, + repo, + octokit, + progress, + databaseManager, + storagePath, + cliServer, + true, + false, + ); - await withProgress(async (progress) => { - await downloadGitHubDatabaseFromUrl( - database.url, - database.id, - database.created_at, - database.commit_oid ?? null, - owner, - repo, - octokit, - progress, - databaseManager, - storagePath, - cliServer, - true, - false, - ); - - await commandManager.execute("codeQLDatabases.focus"); - void window.showInformationMessage( - `Downloaded ${getLanguageDisplayName(language)} database from GitHub.`, - ); - }); + await commandManager.execute("codeQLDatabases.focus"); + void window.showInformationMessage( + `Downloaded ${getLanguageDisplayName( + database.language, + )} database from GitHub.`, + ); + }, + { + title: `Adding ${getLanguageDisplayName( + database.language, + )} database from GitHub`, + }, + ), + ), + ); } /** @@ -135,3 +140,34 @@ function joinLanguages(languages: string[]): string { return result; } + +async function promptForDatabases( + databases: CodeqlDatabase[], +): Promise { + if (databases.length === 1) { + return databases; + } + + const items = databases + .map((database) => { + const bytesToDisplayMB = `${(database.size / (1024 * 1024)).toFixed( + 1, + )} MB`; + + return { + label: getLanguageDisplayName(database.language), + description: bytesToDisplayMB, + database, + }; + }) + .sort((a, b) => a.label.localeCompare(b.label)); + + const selectedItems = await window.showQuickPick(items, { + title: "Select databases to download", + placeHolder: "Databases found in this repository", + ignoreFocusOut: true, + canPickMany: true, + }); + + return selectedItems?.map((selectedItem) => selectedItem.database) ?? []; +} diff --git a/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts b/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts index 5790ef803..42ca917a1 100644 --- a/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts +++ b/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts @@ -1,6 +1,7 @@ import { faker } from "@faker-js/faker"; import { Octokit } from "@octokit/rest"; -import { mockedObject } from "../../utils/mocking.helpers"; +import { QuickPickItem, window } from "vscode"; +import { mockedObject, mockedQuickPickItem } from "../../utils/mocking.helpers"; import { CodeqlDatabase, promptGitHubDatabaseDownload, @@ -29,6 +30,7 @@ describe("promptGitHubDatabaseDownload", () => { created_at: faker.date.past().toISOString(), commit_oid: faker.git.commitSha(), language: "swift", + size: 27389673, url: faker.internet.url({ protocol: "https", }), @@ -38,9 +40,7 @@ describe("promptGitHubDatabaseDownload", () => { let showNeverAskAgainDialogSpy: jest.SpiedFunction< typeof dialog.showNeverAskAgainDialog >; - let promptForLanguageSpy: jest.SpiedFunction< - typeof databaseFetcher.promptForLanguage - >; + let showQuickPickSpy: jest.SpiedFunction; let downloadGitHubDatabaseFromUrlSpy: jest.SpiedFunction< typeof databaseFetcher.downloadGitHubDatabaseFromUrl >; @@ -56,9 +56,13 @@ describe("promptGitHubDatabaseDownload", () => { showNeverAskAgainDialogSpy = jest .spyOn(dialog, "showNeverAskAgainDialog") .mockResolvedValue("Connect"); - promptForLanguageSpy = jest - .spyOn(databaseFetcher, "promptForLanguage") - .mockResolvedValue(databases[0].language); + showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[0], + }), + ]), + ); downloadGitHubDatabaseFromUrlSpy = jest .spyOn(databaseFetcher, "downloadGitHubDatabaseFromUrl") .mockResolvedValue(undefined); @@ -93,7 +97,7 @@ describe("promptGitHubDatabaseDownload", () => { true, false, ); - expect(promptForLanguageSpy).toHaveBeenCalledWith(["swift"], undefined); + expect(showQuickPickSpy).not.toHaveBeenCalled(); expect(config.setDownload).not.toHaveBeenCalled(); }); @@ -180,28 +184,6 @@ describe("promptGitHubDatabaseDownload", () => { }); }); - describe("when not selecting language", () => { - beforeEach(() => { - promptForLanguageSpy.mockResolvedValue(undefined); - }); - - it("does not download the database", async () => { - await promptGitHubDatabaseDownload( - octokit, - owner, - repo, - databases, - config, - databaseManager, - storagePath, - cliServer, - commandManager, - ); - - expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); - }); - }); - describe("when there are multiple languages", () => { beforeEach(() => { databases = [ @@ -210,6 +192,7 @@ describe("promptGitHubDatabaseDownload", () => { created_at: faker.date.past().toISOString(), commit_oid: faker.git.commitSha(), language: "swift", + size: 27389673, url: faker.internet.url({ protocol: "https", }), @@ -219,16 +202,23 @@ describe("promptGitHubDatabaseDownload", () => { created_at: faker.date.past().toISOString(), commit_oid: null, language: "go", + size: 2930572385, url: faker.internet.url({ protocol: "https", }), }), ]; - - promptForLanguageSpy.mockResolvedValue(databases[1].language); }); - it("downloads the correct database", async () => { + it("downloads a single selected language", async () => { + showQuickPickSpy.mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[1], + }), + ]), + ); + await promptGitHubDatabaseDownload( octokit, owner, @@ -257,11 +247,117 @@ describe("promptGitHubDatabaseDownload", () => { true, false, ); - expect(promptForLanguageSpy).toHaveBeenCalledWith( - ["swift", "go"], - undefined, + expect(showQuickPickSpy).toHaveBeenCalledWith( + [ + expect.objectContaining({ + label: "Go", + description: "2794.8 MB", + database: databases[1], + }), + expect.objectContaining({ + label: "Swift", + description: "26.1 MB", + database: databases[0], + }), + ], + expect.anything(), ); expect(config.setDownload).not.toHaveBeenCalled(); }); + + it("downloads multiple selected languages", async () => { + showQuickPickSpy.mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[0], + }), + mockedObject({ + database: databases[1], + }), + ]), + ); + + await promptGitHubDatabaseDownload( + octokit, + owner, + repo, + databases, + config, + databaseManager, + storagePath, + cliServer, + commandManager, + ); + + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(2); + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( + databases[0].url, + databases[0].id, + databases[0].created_at, + databases[0].commit_oid, + owner, + repo, + octokit, + expect.anything(), + databaseManager, + storagePath, + cliServer, + true, + false, + ); + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( + databases[1].url, + databases[1].id, + databases[1].created_at, + databases[1].commit_oid, + owner, + repo, + octokit, + expect.anything(), + databaseManager, + storagePath, + cliServer, + true, + false, + ); + expect(showQuickPickSpy).toHaveBeenCalledWith( + [ + expect.objectContaining({ + label: "Go", + description: "2794.8 MB", + database: databases[1], + }), + expect.objectContaining({ + label: "Swift", + description: "26.1 MB", + database: databases[0], + }), + ], + expect.anything(), + ); + expect(config.setDownload).not.toHaveBeenCalled(); + }); + + describe("when not selecting language", () => { + beforeEach(() => { + showQuickPickSpy.mockResolvedValue(undefined); + }); + + it("does not download the database", async () => { + await promptGitHubDatabaseDownload( + octokit, + owner, + repo, + databases, + config, + databaseManager, + storagePath, + cliServer, + commandManager, + ); + + expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); + }); + }); }); });