Merge pull request #3458 from github/robertbrignull/database-prompting

Refactor database-fetcher.ts into a class and simplify
This commit is contained in:
Robert
2024-03-28 18:05:54 +00:00
committed by GitHub
20 changed files with 753 additions and 912 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,8 @@ import { window } from "vscode";
import type { Octokit } from "@octokit/rest"; import type { Octokit } from "@octokit/rest";
import { showNeverAskAgainDialog } from "../../common/vscode/dialog"; import { showNeverAskAgainDialog } from "../../common/vscode/dialog";
import { getLanguageDisplayName } from "../../common/query-language"; import { getLanguageDisplayName } from "../../common/query-language";
import { downloadGitHubDatabaseFromUrl } from "../database-fetcher"; import type { DatabaseFetcher } from "../database-fetcher";
import { withProgress } from "../../common/vscode/progress"; import { withProgress } from "../../common/vscode/progress";
import type { DatabaseManager } from "../local-databases";
import type { CodeQLCliServer } from "../../codeql-cli/cli";
import type { AppCommandManager } from "../../common/commands"; import type { AppCommandManager } from "../../common/commands";
import type { GitHubDatabaseConfig } from "../../config"; import type { GitHubDatabaseConfig } from "../../config";
import type { CodeqlDatabase } from "./api"; import type { CodeqlDatabase } from "./api";
@@ -58,9 +56,7 @@ export async function downloadDatabaseFromGitHub(
owner: string, owner: string,
repo: string, repo: string,
databases: CodeqlDatabase[], databases: CodeqlDatabase[],
databaseManager: DatabaseManager, databaseFetcher: DatabaseFetcher,
storagePath: string,
cliServer: CodeQLCliServer,
commandManager: AppCommandManager, commandManager: AppCommandManager,
): Promise<void> { ): Promise<void> {
const selectedDatabases = await promptForDatabases(databases); const selectedDatabases = await promptForDatabases(databases);
@@ -72,7 +68,7 @@ export async function downloadDatabaseFromGitHub(
selectedDatabases.map((database) => selectedDatabases.map((database) =>
withProgress( withProgress(
async (progress) => { async (progress) => {
await downloadGitHubDatabaseFromUrl( await databaseFetcher.downloadGitHubDatabaseFromUrl(
database.url, database.url,
database.id, database.id,
database.created_at, database.created_at,
@@ -81,9 +77,6 @@ export async function downloadDatabaseFromGitHub(
repo, repo,
octokit, octokit,
progress, progress,
databaseManager,
storagePath,
cliServer,
true, true,
false, false,
); );

View File

@@ -14,7 +14,6 @@ import {
} from "./download"; } from "./download";
import type { GitHubDatabaseConfig } from "../../config"; import type { GitHubDatabaseConfig } from "../../config";
import type { DatabaseManager } from "../local-databases"; import type { DatabaseManager } from "../local-databases";
import type { CodeQLCliServer } from "../../codeql-cli/cli";
import type { CodeqlDatabase, ListDatabasesResult } from "./api"; import type { CodeqlDatabase, ListDatabasesResult } from "./api";
import { listDatabases } from "./api"; import { listDatabases } from "./api";
import type { DatabaseUpdate } from "./updates"; import type { DatabaseUpdate } from "./updates";
@@ -24,6 +23,7 @@ import {
isNewerDatabaseAvailable, isNewerDatabaseAvailable,
} from "./updates"; } from "./updates";
import type { Octokit } from "@octokit/rest"; import type { Octokit } from "@octokit/rest";
import type { DatabaseFetcher } from "../database-fetcher";
export class GitHubDatabasesModule extends DisposableObject { export class GitHubDatabasesModule extends DisposableObject {
/** /**
@@ -33,8 +33,7 @@ export class GitHubDatabasesModule extends DisposableObject {
constructor( constructor(
private readonly app: App, private readonly app: App,
private readonly databaseManager: DatabaseManager, private readonly databaseManager: DatabaseManager,
private readonly databaseStoragePath: string, private readonly databaseFetcher: DatabaseFetcher,
private readonly cliServer: CodeQLCliServer,
private readonly config: GitHubDatabaseConfig, private readonly config: GitHubDatabaseConfig,
) { ) {
super(); super();
@@ -43,15 +42,13 @@ export class GitHubDatabasesModule extends DisposableObject {
public static async initialize( public static async initialize(
app: App, app: App,
databaseManager: DatabaseManager, databaseManager: DatabaseManager,
databaseStoragePath: string, databaseFetcher: DatabaseFetcher,
cliServer: CodeQLCliServer,
config: GitHubDatabaseConfig, config: GitHubDatabaseConfig,
): Promise<GitHubDatabasesModule> { ): Promise<GitHubDatabasesModule> {
const githubDatabasesModule = new GitHubDatabasesModule( const githubDatabasesModule = new GitHubDatabasesModule(
app, app,
databaseManager, databaseManager,
databaseStoragePath, databaseFetcher,
cliServer,
config, config,
); );
app.subscriptions.push(githubDatabasesModule); app.subscriptions.push(githubDatabasesModule);
@@ -185,9 +182,7 @@ export class GitHubDatabasesModule extends DisposableObject {
owner, owner,
repo, repo,
databases, databases,
this.databaseManager, this.databaseFetcher,
this.databaseStoragePath,
this.cliServer,
this.app.commands, this.app.commands,
); );
} }
@@ -212,8 +207,7 @@ export class GitHubDatabasesModule extends DisposableObject {
repo, repo,
databaseUpdates, databaseUpdates,
this.databaseManager, this.databaseManager,
this.databaseStoragePath, this.databaseFetcher,
this.cliServer,
this.app.commands, this.app.commands,
); );
} }

View File

@@ -1,11 +1,10 @@
import type { CodeqlDatabase } from "./api"; import type { CodeqlDatabase } from "./api";
import type { DatabaseItem, DatabaseManager } from "../local-databases"; import type { DatabaseItem, DatabaseManager } from "../local-databases";
import type { Octokit } from "@octokit/rest"; import type { Octokit } from "@octokit/rest";
import type { CodeQLCliServer } from "../../codeql-cli/cli";
import type { AppCommandManager } from "../../common/commands"; import type { AppCommandManager } from "../../common/commands";
import { getLanguageDisplayName } from "../../common/query-language"; import { getLanguageDisplayName } from "../../common/query-language";
import { showNeverAskAgainDialog } from "../../common/vscode/dialog"; import { showNeverAskAgainDialog } from "../../common/vscode/dialog";
import { downloadGitHubDatabaseFromUrl } from "../database-fetcher"; import type { DatabaseFetcher } from "../database-fetcher";
import { withProgress } from "../../common/vscode/progress"; import { withProgress } from "../../common/vscode/progress";
import { window } from "vscode"; import { window } from "vscode";
import type { GitHubDatabaseConfig } from "../../config"; import type { GitHubDatabaseConfig } from "../../config";
@@ -156,8 +155,7 @@ export async function downloadDatabaseUpdateFromGitHub(
repo: string, repo: string,
updates: DatabaseUpdate[], updates: DatabaseUpdate[],
databaseManager: DatabaseManager, databaseManager: DatabaseManager,
storagePath: string, databaseFetcher: DatabaseFetcher,
cliServer: CodeQLCliServer,
commandManager: AppCommandManager, commandManager: AppCommandManager,
): Promise<void> { ): Promise<void> {
const selectedDatabases = await promptForDatabases( const selectedDatabases = await promptForDatabases(
@@ -179,21 +177,19 @@ export async function downloadDatabaseUpdateFromGitHub(
return withProgress( return withProgress(
async (progress) => { async (progress) => {
const newDatabase = await downloadGitHubDatabaseFromUrl( const newDatabase =
database.url, await databaseFetcher.downloadGitHubDatabaseFromUrl(
database.id, database.url,
database.created_at, database.id,
database.commit_oid ?? null, database.created_at,
owner, database.commit_oid ?? null,
repo, owner,
octokit, repo,
progress, octokit,
databaseManager, progress,
storagePath, databaseManager.currentDatabaseItem === update.databaseItem,
cliServer, update.databaseItem.hasSourceArchiveInExplorer(),
databaseManager.currentDatabaseItem === update.databaseItem, );
update.databaseItem.hasSourceArchiveInExplorer(),
);
if (newDatabase === undefined) { if (newDatabase === undefined) {
return; return;
} }

View File

@@ -37,11 +37,7 @@ import {
showAndLogExceptionWithTelemetry, showAndLogExceptionWithTelemetry,
showAndLogErrorMessage, showAndLogErrorMessage,
} from "../common/logging"; } from "../common/logging";
import { import type { DatabaseFetcher } from "./database-fetcher";
importLocalDatabase,
promptImportGithubDatabase,
promptImportInternetDatabase,
} from "./database-fetcher";
import { asError, asyncFilter, getErrorMessage } from "../common/helpers-pure"; import { asError, asyncFilter, getErrorMessage } from "../common/helpers-pure";
import type { QueryRunner } from "../query-server"; import type { QueryRunner } from "../query-server";
import type { App } from "../common/app"; import type { App } from "../common/app";
@@ -248,6 +244,7 @@ export class DatabaseUI extends DisposableObject {
public constructor( public constructor(
private app: App, private app: App,
private databaseManager: DatabaseManager, private databaseManager: DatabaseManager,
private readonly databaseFetcher: DatabaseFetcher,
languageContext: LanguageContextStore, languageContext: LanguageContextStore,
private readonly queryServer: QueryRunner, private readonly queryServer: QueryRunner,
private readonly storagePath: string, private readonly storagePath: string,
@@ -535,13 +532,7 @@ export class DatabaseUI extends DisposableObject {
private async handleChooseDatabaseInternet(): Promise<void> { private async handleChooseDatabaseInternet(): Promise<void> {
return withProgress( return withProgress(
async (progress) => { async (progress) => {
await promptImportInternetDatabase( await this.databaseFetcher.promptImportInternetDatabase(progress);
this.app.commands,
this.databaseManager,
this.storagePath,
progress,
this.queryServer.cliServer,
);
}, },
{ {
title: "Adding database from URL", title: "Adding database from URL",
@@ -552,13 +543,7 @@ export class DatabaseUI extends DisposableObject {
private async handleChooseDatabaseGithub(): Promise<void> { private async handleChooseDatabaseGithub(): Promise<void> {
return withProgress( return withProgress(
async (progress) => { async (progress) => {
await promptImportGithubDatabase( await this.databaseFetcher.promptImportGithubDatabase(progress);
this.app,
this.databaseManager,
this.storagePath,
progress,
this.queryServer.cliServer,
);
}, },
{ {
title: "Adding database from GitHub", title: "Adding database from GitHub",
@@ -707,13 +692,9 @@ export class DatabaseUI extends DisposableObject {
try { try {
// Assume user has selected an archive if the file has a .zip extension // Assume user has selected an archive if the file has a .zip extension
if (uri.path.endsWith(".zip")) { if (uri.path.endsWith(".zip")) {
await importLocalDatabase( await this.databaseFetcher.importLocalDatabase(
this.app.commands,
uri.toString(true), uri.toString(true),
this.databaseManager,
this.storagePath,
progress, progress,
this.queryServer.cliServer,
); );
} else { } else {
await this.databaseManager.openDatabase(uri, { await this.databaseManager.openDatabase(uri, {
@@ -758,13 +739,9 @@ export class DatabaseUI extends DisposableObject {
await this.databaseManager.removeDatabaseItem(existingItem); await this.databaseManager.removeDatabaseItem(existingItem);
} }
await importLocalDatabase( await this.databaseFetcher.importLocalDatabase(
this.app.commands,
uri.toString(true), uri.toString(true),
this.databaseManager,
this.storagePath,
progress, progress,
this.queryServer.cliServer,
); );
if (existingItem !== undefined) { if (existingItem !== undefined) {
@@ -1005,13 +982,9 @@ export class DatabaseUI extends DisposableObject {
// we are selecting a database archive or a testproj. // we are selecting a database archive or a testproj.
// Unzip archives (if an archive) and copy into a workspace-controlled area // Unzip archives (if an archive) and copy into a workspace-controlled area
// before importing. // before importing.
return await importLocalDatabase( return await this.databaseFetcher.importLocalDatabase(
this.app.commands,
uri.toString(true), uri.toString(true),
this.databaseManager,
this.storagePath,
progress, progress,
this.queryServer.cliServer,
); );
} }
} }

View File

@@ -1,4 +1,7 @@
import { pathExists, remove } from "fs-extra";
import { join } from "path";
import type { Uri } from "vscode"; import type { Uri } from "vscode";
import { zip } from "zip-a-folder";
/** /**
* The layout of the database. * The layout of the database.
@@ -28,3 +31,26 @@ export interface DatabaseContents {
export interface DatabaseContentsWithDbScheme extends DatabaseContents { export interface DatabaseContentsWithDbScheme extends DatabaseContents {
dbSchemeUri: Uri; // Always present dbSchemeUri: Uri; // Always present
} }
/**
* Databases created by the old odasa tool will not have a zipped
* source location. However, this extension works better if sources
* are zipped.
*
* This function ensures that the source location is zipped. If the
* `src` folder exists and the `src.zip` file does not, the `src`
* folder will be zipped and then deleted.
*
* @param databasePath The full path to the unzipped database
*/
export async function ensureZippedSourceLocation(
databasePath: string,
): Promise<void> {
const srcFolderPath = join(databasePath, "src");
const srcZipPath = `${srcFolderPath}.zip`;
if ((await pathExists(srcFolderPath)) && !(await pathExists(srcZipPath))) {
await zip(srcFolderPath, srcZipPath);
await remove(srcFolderPath);
}
}

View File

@@ -43,7 +43,7 @@ import { DatabaseResolver } from "./database-resolver";
import { telemetryListener } from "../../common/vscode/telemetry"; import { telemetryListener } from "../../common/vscode/telemetry";
import type { LanguageContextStore } from "../../language-context-store"; import type { LanguageContextStore } from "../../language-context-store";
import type { DatabaseOrigin } from "./database-origin"; import type { DatabaseOrigin } from "./database-origin";
import { ensureZippedSourceLocation } from "../database-fetcher"; import { ensureZippedSourceLocation } from "./database-contents";
/** /**
* The name of the key in the workspaceState dictionary in which we * The name of the key in the workspaceState dictionary in which we

View File

@@ -133,6 +133,7 @@ import { OpenReferencedFileCodeLensProvider } from "./local-queries/open-referen
import { LanguageContextStore } from "./language-context-store"; import { LanguageContextStore } from "./language-context-store";
import { LanguageSelectionPanel } from "./language-selection-panel/language-selection-panel"; import { LanguageSelectionPanel } from "./language-selection-panel/language-selection-panel";
import { GitHubDatabasesModule } from "./databases/github-databases"; import { GitHubDatabasesModule } from "./databases/github-databases";
import { DatabaseFetcher } from "./databases/database-fetcher";
/** /**
* extension.ts * extension.ts
@@ -799,12 +800,20 @@ async function activateWithInstalledDistribution(
// Let this run async. // Let this run async.
void dbm.loadPersistedState(); void dbm.loadPersistedState();
const databaseFetcher = new DatabaseFetcher(
app,
dbm,
getContextStoragePath(ctx),
cliServer,
);
ctx.subscriptions.push(dbm); ctx.subscriptions.push(dbm);
void extLogger.log("Initializing database panel."); void extLogger.log("Initializing database panel.");
const databaseUI = new DatabaseUI( const databaseUI = new DatabaseUI(
app, app,
dbm, dbm,
databaseFetcher,
languageContext, languageContext,
qs, qs,
getContextStoragePath(ctx), getContextStoragePath(ctx),
@@ -881,8 +890,7 @@ async function activateWithInstalledDistribution(
await GitHubDatabasesModule.initialize( await GitHubDatabasesModule.initialize(
app, app,
dbm, dbm,
getContextStoragePath(ctx), databaseFetcher,
cliServer,
githubDatabaseConfigListener, githubDatabaseConfigListener,
); );
@@ -953,6 +961,7 @@ async function activateWithInstalledDistribution(
qs, qs,
qhm, qhm,
dbm, dbm,
databaseFetcher,
cliServer, cliServer,
databaseUI, databaseUI,
localQueryResultsView, localQueryResultsView,
@@ -977,6 +986,7 @@ async function activateWithInstalledDistribution(
const modelEditorModule = await ModelEditorModule.initialize( const modelEditorModule = await ModelEditorModule.initialize(
app, app,
dbm, dbm,
databaseFetcher,
variantAnalysisManager, variantAnalysisManager,
cliServer, cliServer,
qs, qs,

View File

@@ -54,6 +54,7 @@ import type { QueryTreeViewItem } from "../queries-panel/query-tree-view-item";
import { tryGetQueryLanguage } from "../common/query-language"; import { tryGetQueryLanguage } from "../common/query-language";
import type { LanguageContextStore } from "../language-context-store"; import type { LanguageContextStore } from "../language-context-store";
import type { ExtensionApp } from "../common/vscode/vscode-app"; import type { ExtensionApp } from "../common/vscode/vscode-app";
import type { DatabaseFetcher } from "../databases/database-fetcher";
export enum QuickEvalType { export enum QuickEvalType {
None, None,
@@ -69,6 +70,7 @@ export class LocalQueries extends DisposableObject {
private readonly queryRunner: QueryRunner, private readonly queryRunner: QueryRunner,
private readonly queryHistoryManager: QueryHistoryManager, private readonly queryHistoryManager: QueryHistoryManager,
private readonly databaseManager: DatabaseManager, private readonly databaseManager: DatabaseManager,
private readonly databaseFetcher: DatabaseFetcher,
private readonly cliServer: CodeQLCliServer, private readonly cliServer: CodeQLCliServer,
private readonly databaseUI: DatabaseUI, private readonly databaseUI: DatabaseUI,
private readonly localQueryResultsView: ResultsView, private readonly localQueryResultsView: ResultsView,
@@ -319,15 +321,13 @@ export class LocalQueries extends DisposableObject {
private async createSkeletonQuery(): Promise<void> { private async createSkeletonQuery(): Promise<void> {
await withProgress( await withProgress(
async (progress: ProgressCallback) => { async (progress: ProgressCallback) => {
const contextStoragePath =
this.app.workspaceStoragePath || this.app.globalStoragePath;
const language = this.languageContextStore.selectedLanguage; const language = this.languageContextStore.selectedLanguage;
const skeletonQueryWizard = new SkeletonQueryWizard( const skeletonQueryWizard = new SkeletonQueryWizard(
this.cliServer, this.cliServer,
progress, progress,
this.app, this.app,
this.databaseManager, this.databaseManager,
contextStoragePath, this.databaseFetcher,
this.selectedQueryTreeViewItems, this.selectedQueryTreeViewItems,
language, language,
); );

View File

@@ -19,10 +19,7 @@ import {
UserCancellationException, UserCancellationException,
withProgress, withProgress,
} from "../common/vscode/progress"; } from "../common/vscode/progress";
import { import type { DatabaseFetcher } from "../databases/database-fetcher";
askForGitHubRepo,
downloadGitHubDatabase,
} from "../databases/database-fetcher";
import { import {
getQlPackLocation, getQlPackLocation,
isCodespacesTemplate, isCodespacesTemplate,
@@ -62,7 +59,7 @@ export class SkeletonQueryWizard {
private readonly progress: ProgressCallback, private readonly progress: ProgressCallback,
private readonly app: App, private readonly app: App,
private readonly databaseManager: DatabaseManager, private readonly databaseManager: DatabaseManager,
private readonly databaseStoragePath: string | undefined, private readonly databaseFetcher: DatabaseFetcher,
private readonly selectedItems: readonly QueryTreeViewItem[], private readonly selectedItems: readonly QueryTreeViewItem[],
private language: QueryLanguage | undefined = undefined, private language: QueryLanguage | undefined = undefined,
) {} ) {}
@@ -363,10 +360,6 @@ export class SkeletonQueryWizard {
} }
private async downloadDatabase(progress: ProgressCallback) { private async downloadDatabase(progress: ProgressCallback) {
if (this.databaseStoragePath === undefined) {
throw new Error("Database storage path is undefined");
}
if (this.language === undefined) { if (this.language === undefined) {
throw new Error("Language is undefined"); throw new Error("Language is undefined");
} }
@@ -378,20 +371,10 @@ export class SkeletonQueryWizard {
}); });
const githubRepoNwo = QUERY_LANGUAGE_TO_DATABASE_REPO[this.language]; const githubRepoNwo = QUERY_LANGUAGE_TO_DATABASE_REPO[this.language];
const chosenRepo = await askForGitHubRepo(undefined, githubRepoNwo); await this.databaseFetcher.promptImportGithubDatabase(
if (!chosenRepo) {
throw new UserCancellationException("No GitHub repository provided");
}
await downloadGitHubDatabase(
chosenRepo,
this.app,
this.databaseManager,
this.databaseStoragePath,
progress, progress,
this.cliServer,
this.language, this.language,
githubRepoNwo,
); );
} }

View File

@@ -32,6 +32,7 @@ import { INITIAL_MODE } from "./shared/mode";
import { isSupportedLanguage } from "./supported-languages"; import { isSupportedLanguage } from "./supported-languages";
import { DefaultNotifier, checkConsistency } from "./consistency-check"; import { DefaultNotifier, checkConsistency } from "./consistency-check";
import type { VariantAnalysisManager } from "../variant-analysis/variant-analysis-manager"; import type { VariantAnalysisManager } from "../variant-analysis/variant-analysis-manager";
import type { DatabaseFetcher } from "../databases/database-fetcher";
export class ModelEditorModule extends DisposableObject { export class ModelEditorModule extends DisposableObject {
private readonly queryStorageDir: string; private readonly queryStorageDir: string;
@@ -42,6 +43,7 @@ export class ModelEditorModule extends DisposableObject {
private constructor( private constructor(
private readonly app: App, private readonly app: App,
private readonly databaseManager: DatabaseManager, private readonly databaseManager: DatabaseManager,
private readonly databaseFetcher: DatabaseFetcher,
private readonly variantAnalysisManager: VariantAnalysisManager, private readonly variantAnalysisManager: VariantAnalysisManager,
private readonly cliServer: CodeQLCliServer, private readonly cliServer: CodeQLCliServer,
private readonly queryRunner: QueryRunner, private readonly queryRunner: QueryRunner,
@@ -65,6 +67,7 @@ export class ModelEditorModule extends DisposableObject {
public static async initialize( public static async initialize(
app: App, app: App,
databaseManager: DatabaseManager, databaseManager: DatabaseManager,
databaseFetcher: DatabaseFetcher,
variantAnalysisManager: VariantAnalysisManager, variantAnalysisManager: VariantAnalysisManager,
cliServer: CodeQLCliServer, cliServer: CodeQLCliServer,
queryRunner: QueryRunner, queryRunner: QueryRunner,
@@ -73,6 +76,7 @@ export class ModelEditorModule extends DisposableObject {
const modelEditorModule = new ModelEditorModule( const modelEditorModule = new ModelEditorModule(
app, app,
databaseManager, databaseManager,
databaseFetcher,
variantAnalysisManager, variantAnalysisManager,
cliServer, cliServer,
queryRunner, queryRunner,
@@ -236,6 +240,7 @@ export class ModelEditorModule extends DisposableObject {
this.modelingEvents, this.modelingEvents,
this.modelConfig, this.modelConfig,
this.databaseManager, this.databaseManager,
this.databaseFetcher,
this.variantAnalysisManager, this.variantAnalysisManager,
this.cliServer, this.cliServer,
this.queryRunner, this.queryRunner,

View File

@@ -29,7 +29,7 @@ import type {
} from "../databases/local-databases"; } from "../databases/local-databases";
import type { CodeQLCliServer } from "../codeql-cli/cli"; import type { CodeQLCliServer } from "../codeql-cli/cli";
import { asError, assertNever, getErrorMessage } from "../common/helpers-pure"; import { asError, assertNever, getErrorMessage } from "../common/helpers-pure";
import { promptImportGithubDatabase } from "../databases/database-fetcher"; import type { DatabaseFetcher } from "../databases/database-fetcher";
import type { App } from "../common/app"; import type { App } from "../common/app";
import { redactableError } from "../common/errors"; import { redactableError } from "../common/errors";
import { import {
@@ -86,6 +86,7 @@ export class ModelEditorView extends AbstractWebview<
private readonly modelingEvents: ModelingEvents, private readonly modelingEvents: ModelingEvents,
private readonly modelConfig: ModelConfigListener, private readonly modelConfig: ModelConfigListener,
private readonly databaseManager: DatabaseManager, private readonly databaseManager: DatabaseManager,
private readonly databaseFetcher: DatabaseFetcher,
private readonly variantAnalysisManager: VariantAnalysisManager, private readonly variantAnalysisManager: VariantAnalysisManager,
private readonly cliServer: CodeQLCliServer, private readonly cliServer: CodeQLCliServer,
private readonly queryRunner: QueryRunner, private readonly queryRunner: QueryRunner,
@@ -852,6 +853,7 @@ export class ModelEditorView extends AbstractWebview<
this.modelingEvents, this.modelingEvents,
this.modelConfig, this.modelConfig,
this.databaseManager, this.databaseManager,
this.databaseFetcher,
this.variantAnalysisManager, this.variantAnalysisManager,
this.cliServer, this.cliServer,
this.queryRunner, this.queryRunner,
@@ -920,13 +922,10 @@ export class ModelEditorView extends AbstractWebview<
// the user to import the library database. We need to have the database // the user to import the library database. We need to have the database
// imported to the query server, so we need to register it to our workspace. // imported to the query server, so we need to register it to our workspace.
const makeSelected = false; const makeSelected = false;
const addedDatabase = await promptImportGithubDatabase( const addedDatabase = await this.databaseFetcher.promptImportGithubDatabase(
this.app,
this.databaseManager,
this.app.workspaceStoragePath ?? this.app.globalStoragePath,
progress, progress,
this.cliServer,
this.databaseItem.language, this.databaseItem.language,
undefined,
makeSelected, makeSelected,
false, false,
); );

View File

@@ -3,10 +3,7 @@ import { Uri, window } from "vscode";
import type { CodeQLCliServer } from "../../../../src/codeql-cli/cli"; import type { CodeQLCliServer } from "../../../../src/codeql-cli/cli";
import type { DatabaseManager } from "../../../../src/databases/local-databases"; import type { DatabaseManager } from "../../../../src/databases/local-databases";
import { import { DatabaseFetcher } from "../../../../src/databases/database-fetcher";
importLocalDatabase,
promptImportInternetDatabase,
} from "../../../../src/databases/database-fetcher";
import { import {
cleanDatabases, cleanDatabases,
dbLoc, dbLoc,
@@ -15,9 +12,8 @@ import {
storagePath, storagePath,
testprojLoc, testprojLoc,
} from "../../global.helper"; } from "../../global.helper";
import { createMockCommandManager } from "../../../__mocks__/commandsMock"; import { existsSync, remove, utimesSync } from "fs-extra";
import { utimesSync } from "fs"; import { createMockApp } from "../../../__mocks__/appMock";
import { remove, existsSync } from "fs-extra";
/** /**
* Run various integration tests for databases * Run various integration tests for databases
@@ -51,14 +47,16 @@ describe("database-fetcher", () => {
describe("importLocalDatabase", () => { describe("importLocalDatabase", () => {
it("should add a database from an archive", async () => { it("should add a database from an archive", async () => {
const uri = Uri.file(dbLoc); const uri = Uri.file(dbLoc);
let dbItem = await importLocalDatabase( const databaseFetcher = new DatabaseFetcher(
createMockCommandManager(), createMockApp(),
uri.toString(true),
databaseManager, databaseManager,
storagePath, storagePath,
progressCallback,
cli, cli,
); );
let dbItem = await databaseFetcher.importLocalDatabase(
uri.toString(true),
progressCallback,
);
expect(dbItem).toBe(databaseManager.currentDatabaseItem); expect(dbItem).toBe(databaseManager.currentDatabaseItem);
expect(dbItem).toBe(databaseManager.databaseItems[0]); expect(dbItem).toBe(databaseManager.databaseItems[0]);
expect(dbItem).toBeDefined(); expect(dbItem).toBeDefined();
@@ -68,14 +66,16 @@ describe("database-fetcher", () => {
}); });
it("should import a testproj database", async () => { it("should import a testproj database", async () => {
let dbItem = await importLocalDatabase( const databaseFetcher = new DatabaseFetcher(
createMockCommandManager(), createMockApp(),
Uri.file(testprojLoc).toString(true),
databaseManager, databaseManager,
storagePath, storagePath,
progressCallback,
cli, cli,
); );
let dbItem = await databaseFetcher.importLocalDatabase(
Uri.file(testprojLoc).toString(true),
progressCallback,
);
expect(dbItem).toBe(databaseManager.currentDatabaseItem); expect(dbItem).toBe(databaseManager.currentDatabaseItem);
expect(dbItem).toBe(databaseManager.databaseItems[0]); expect(dbItem).toBe(databaseManager.databaseItems[0]);
expect(dbItem).toBeDefined(); expect(dbItem).toBeDefined();
@@ -109,13 +109,14 @@ describe("database-fetcher", () => {
// Provide a database URL when prompted // Provide a database URL when prompted
inputBoxStub.mockResolvedValue(DB_URL); inputBoxStub.mockResolvedValue(DB_URL);
let dbItem = await promptImportInternetDatabase( const databaseFetcher = new DatabaseFetcher(
createMockCommandManager(), createMockApp(),
databaseManager, databaseManager,
storagePath, storagePath,
progressCallback,
cli, cli,
); );
let dbItem =
await databaseFetcher.promptImportInternetDatabase(progressCallback);
expect(dbItem).toBeDefined(); expect(dbItem).toBeDefined();
dbItem = dbItem!; dbItem = dbItem!;
expect(dbItem.name).toBe("db"); expect(dbItem.name).toBe("db");

View File

@@ -23,7 +23,7 @@ import type {
DatabaseManager, DatabaseManager,
FullDatabaseOptions, FullDatabaseOptions,
} from "../../../../src/databases/local-databases"; } from "../../../../src/databases/local-databases";
import * as databaseFetcher from "../../../../src/databases/database-fetcher"; import type { DatabaseFetcher } from "../../../../src/databases/database-fetcher";
import { createMockDB } from "../../../factories/databases/databases"; import { createMockDB } from "../../../factories/databases/databases";
import { asError } from "../../../../src/common/helpers-pure"; import { asError } from "../../../../src/common/helpers-pure";
import { Setting } from "../../../../src/config"; import { Setting } from "../../../../src/config";
@@ -42,6 +42,7 @@ describe("SkeletonQueryWizard", () => {
let mockApp: App; let mockApp: App;
let wizard: SkeletonQueryWizard; let wizard: SkeletonQueryWizard;
let mockDatabaseManager: DatabaseManager; let mockDatabaseManager: DatabaseManager;
let databaseFetcher: DatabaseFetcher;
let dir: DirResult; let dir: DirResult;
let storagePath: string; let storagePath: string;
let quickPickSpy: jest.SpiedFunction<typeof window.showQuickPick>; let quickPickSpy: jest.SpiedFunction<typeof window.showQuickPick>;
@@ -55,11 +56,8 @@ describe("SkeletonQueryWizard", () => {
let createExampleQlFileSpy: jest.SpiedFunction< let createExampleQlFileSpy: jest.SpiedFunction<
typeof QlPackGenerator.prototype.createExampleQlFile typeof QlPackGenerator.prototype.createExampleQlFile
>; >;
let downloadGitHubDatabaseSpy: jest.SpiedFunction< let promptImportGithubDatabaseMock: jest.MockedFunction<
typeof databaseFetcher.downloadGitHubDatabase DatabaseFetcher["promptImportGithubDatabase"]
>;
let askForGitHubRepoSpy: jest.SpiedFunction<
typeof databaseFetcher.askForGitHubRepo
>; >;
let openTextDocumentSpy: jest.SpiedFunction< let openTextDocumentSpy: jest.SpiedFunction<
typeof workspace.openTextDocument typeof workspace.openTextDocument
@@ -115,6 +113,11 @@ describe("SkeletonQueryWizard", () => {
}, },
] as WorkspaceFolder[]); ] as WorkspaceFolder[]);
promptImportGithubDatabaseMock = jest.fn().mockReturnValue(undefined);
databaseFetcher = mockedObject<DatabaseFetcher>({
promptImportGithubDatabase: promptImportGithubDatabaseMock,
});
quickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValueOnce( quickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValueOnce(
mockedQuickPickItem({ mockedQuickPickItem({
label: chosenLanguage, label: chosenLanguage,
@@ -133,9 +136,6 @@ describe("SkeletonQueryWizard", () => {
createExampleQlFileSpy = jest createExampleQlFileSpy = jest
.spyOn(QlPackGenerator.prototype, "createExampleQlFile") .spyOn(QlPackGenerator.prototype, "createExampleQlFile")
.mockResolvedValue(undefined); .mockResolvedValue(undefined);
downloadGitHubDatabaseSpy = jest
.spyOn(databaseFetcher, "downloadGitHubDatabase")
.mockResolvedValue(undefined);
openTextDocumentSpy = jest openTextDocumentSpy = jest
.spyOn(workspace, "openTextDocument") .spyOn(workspace, "openTextDocument")
.mockResolvedValue({} as TextDocument); .mockResolvedValue({} as TextDocument);
@@ -145,13 +145,9 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
); );
askForGitHubRepoSpy = jest
.spyOn(databaseFetcher, "askForGitHubRepo")
.mockResolvedValue(QUERY_LANGUAGE_TO_DATABASE_REPO[chosenLanguage]);
}); });
afterEach(async () => { afterEach(async () => {
@@ -172,7 +168,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
QueryLanguage.Swift, QueryLanguage.Swift,
); );
@@ -202,7 +198,7 @@ describe("SkeletonQueryWizard", () => {
title: "Download database", title: "Download database",
}), }),
); );
expect(downloadGitHubDatabaseSpy).not.toHaveBeenCalled(); expect(promptImportGithubDatabaseMock).not.toHaveBeenCalled();
}); });
it("should download database for selected language when selecting download in prompt", async () => { it("should download database for selected language when selecting download in prompt", async () => {
@@ -219,7 +215,7 @@ describe("SkeletonQueryWizard", () => {
await wizard.execute(); await wizard.execute();
await wizard.waitForDownload(); await wizard.waitForDownload();
expect(downloadGitHubDatabaseSpy).toHaveBeenCalled(); expect(promptImportGithubDatabaseMock).toHaveBeenCalled();
}); });
it("should open the query file", async () => { it("should open the query file", async () => {
@@ -320,7 +316,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManagerWithItems, mockDatabaseManagerWithItems,
storagePath, databaseFetcher,
selectedItems, selectedItems,
); );
}); });
@@ -328,7 +324,7 @@ describe("SkeletonQueryWizard", () => {
it("should not download a new database for language", async () => { it("should not download a new database for language", async () => {
await wizard.execute(); await wizard.execute();
expect(downloadGitHubDatabaseSpy).not.toHaveBeenCalled(); expect(promptImportGithubDatabaseMock).not.toHaveBeenCalled();
}); });
it("should not select the database", async () => { it("should not select the database", async () => {
@@ -369,7 +365,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManagerWithItems, mockDatabaseManagerWithItems,
storagePath, databaseFetcher,
selectedItems, selectedItems,
); );
}); });
@@ -377,7 +373,7 @@ describe("SkeletonQueryWizard", () => {
it("should not download a new database for language", async () => { it("should not download a new database for language", async () => {
await wizard.execute(); await wizard.execute();
expect(downloadGitHubDatabaseSpy).not.toHaveBeenCalled(); expect(promptImportGithubDatabaseMock).not.toHaveBeenCalled();
}); });
it("should select an existing database", async () => { it("should select an existing database", async () => {
@@ -409,54 +405,23 @@ describe("SkeletonQueryWizard", () => {
}); });
describe("if database is missing", () => { describe("if database is missing", () => {
describe("if the user chooses to downloaded the suggested database from GitHub", () => { beforeEach(() => {
beforeEach(() => { showInformationMessageSpy.mockImplementation(
showInformationMessageSpy.mockImplementation( async (_message, options, item) => {
async (_message, options, item) => { if (item === undefined) {
if (item === undefined) { return options as MessageItem;
return options as MessageItem; }
}
return item; return item;
}, },
); );
});
it("should download a new database for language", async () => {
await wizard.execute();
await wizard.waitForDownload();
expect(askForGitHubRepoSpy).toHaveBeenCalled();
expect(downloadGitHubDatabaseSpy).toHaveBeenCalled();
});
}); });
describe("if the user choses to download a different database from GitHub than the one suggested", () => { it("should download a new database for language", async () => {
beforeEach(() => { await wizard.execute();
showInformationMessageSpy.mockImplementation( await wizard.waitForDownload();
async (_message, options, item) => {
if (item === undefined) {
return options as MessageItem;
}
return item; expect(promptImportGithubDatabaseMock).toHaveBeenCalled();
},
);
const chosenGitHubRepo = "pickles-owner/pickles-repo";
askForGitHubRepoSpy = jest
.spyOn(databaseFetcher, "askForGitHubRepo")
.mockResolvedValue(chosenGitHubRepo);
});
it("should download the newly chosen database", async () => {
await wizard.execute();
await wizard.waitForDownload();
expect(askForGitHubRepoSpy).toHaveBeenCalled();
expect(downloadGitHubDatabaseSpy).toHaveBeenCalled();
});
}); });
}); });
}); });
@@ -504,7 +469,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
QueryLanguage.Javascript, QueryLanguage.Javascript,
); );
@@ -725,7 +690,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
); );
}); });
@@ -754,7 +719,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
); );
}); });
@@ -787,7 +752,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
QueryLanguage.Swift, QueryLanguage.Swift,
); );
@@ -830,7 +795,7 @@ describe("SkeletonQueryWizard", () => {
jest.fn(), jest.fn(),
mockApp, mockApp,
mockDatabaseManager, mockDatabaseManager,
storagePath, databaseFetcher,
selectedItems, selectedItems,
); );
}); });

View File

@@ -7,8 +7,8 @@ import type {
} from "../../src/databases/local-databases"; } from "../../src/databases/local-databases";
import type { CodeQLCliServer } from "../../src/codeql-cli/cli"; import type { CodeQLCliServer } from "../../src/codeql-cli/cli";
import type { CodeQLExtensionInterface } from "../../src/extension"; import type { CodeQLExtensionInterface } from "../../src/extension";
import { importLocalDatabase } from "../../src/databases/database-fetcher"; import { DatabaseFetcher } from "../../src/databases/database-fetcher";
import { createMockCommandManager } from "../__mocks__/commandsMock"; import { createMockApp } from "../__mocks__/appMock";
// This file contains helpers shared between tests that work with an activated extension. // This file contains helpers shared between tests that work with an activated extension.
@@ -40,15 +40,17 @@ export async function ensureTestDatabase(
// Add a database, but make sure the database manager is empty first // Add a database, but make sure the database manager is empty first
await cleanDatabases(databaseManager); await cleanDatabases(databaseManager);
const uri = Uri.file(dbLoc); const uri = Uri.file(dbLoc);
const maybeDbItem = await importLocalDatabase( const databaseFetcher = new DatabaseFetcher(
createMockCommandManager(), createMockApp(),
uri.toString(true),
databaseManager, databaseManager,
storagePath, storagePath,
cli,
);
const maybeDbItem = await databaseFetcher.importLocalDatabase(
uri.toString(true),
(_p) => { (_p) => {
/**/ /**/
}, },
cli,
); );
if (!maybeDbItem) { if (!maybeDbItem) {

View File

@@ -10,13 +10,11 @@ import {
askForGitHubDatabaseDownload, askForGitHubDatabaseDownload,
downloadDatabaseFromGitHub, downloadDatabaseFromGitHub,
} from "../../../../../src/databases/github-databases/download"; } from "../../../../../src/databases/github-databases/download";
import type { DatabaseManager } from "../../../../../src/databases/local-databases";
import type { GitHubDatabaseConfig } from "../../../../../src/config"; import type { GitHubDatabaseConfig } from "../../../../../src/config";
import type { CodeQLCliServer } from "../../../../../src/codeql-cli/cli"; import type { DatabaseFetcher } from "../../../../../src/databases/database-fetcher";
import { createMockCommandManager } from "../../../../__mocks__/commandsMock";
import * as databaseFetcher from "../../../../../src/databases/database-fetcher";
import * as dialog from "../../../../../src/common/vscode/dialog"; import * as dialog from "../../../../../src/common/vscode/dialog";
import type { CodeqlDatabase } from "../../../../../src/databases/github-databases/api"; import type { CodeqlDatabase } from "../../../../../src/databases/github-databases/api";
import { createMockApp } from "../../../../__mocks__/appMock";
describe("askForGitHubDatabaseDownload", () => { describe("askForGitHubDatabaseDownload", () => {
const setDownload = jest.fn(); const setDownload = jest.fn();
@@ -96,11 +94,9 @@ describe("downloadDatabaseFromGitHub", () => {
let octokit: Octokit; let octokit: Octokit;
const owner = "github"; const owner = "github";
const repo = "codeql"; const repo = "codeql";
let databaseManager: DatabaseManager; let databaseFetcher: DatabaseFetcher;
const storagePath = "/a/b/c/d"; const app = createMockApp();
let cliServer: CodeQLCliServer;
const commandManager = createMockCommandManager();
let databases = [ let databases = [
mockedObject<CodeqlDatabase>({ mockedObject<CodeqlDatabase>({
@@ -116,14 +112,17 @@ describe("downloadDatabaseFromGitHub", () => {
]; ];
let showQuickPickSpy: jest.SpiedFunction<typeof window.showQuickPick>; let showQuickPickSpy: jest.SpiedFunction<typeof window.showQuickPick>;
let downloadGitHubDatabaseFromUrlSpy: jest.SpiedFunction< let downloadGitHubDatabaseFromUrlMock: jest.MockedFunction<
typeof databaseFetcher.downloadGitHubDatabaseFromUrl DatabaseFetcher["downloadGitHubDatabaseFromUrl"]
>; >;
beforeEach(() => { beforeEach(() => {
octokit = mockedObject<Octokit>({}); octokit = mockedObject<Octokit>({});
databaseManager = mockedObject<DatabaseManager>({});
cliServer = mockedObject<CodeQLCliServer>({}); downloadGitHubDatabaseFromUrlMock = jest.fn().mockReturnValue(undefined);
databaseFetcher = mockedObject<DatabaseFetcher>({
downloadGitHubDatabaseFromUrl: downloadGitHubDatabaseFromUrlMock,
});
showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue( showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue(
mockedQuickPickItem([ mockedQuickPickItem([
@@ -132,9 +131,6 @@ describe("downloadDatabaseFromGitHub", () => {
}), }),
]), ]),
); );
downloadGitHubDatabaseFromUrlSpy = jest
.spyOn(databaseFetcher, "downloadGitHubDatabaseFromUrl")
.mockResolvedValue(undefined);
}); });
it("downloads the database", async () => { it("downloads the database", async () => {
@@ -143,14 +139,12 @@ describe("downloadDatabaseFromGitHub", () => {
owner, owner,
repo, repo,
databases, databases,
databaseManager, databaseFetcher,
storagePath, app.commands,
cliServer,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(1); expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledTimes(1);
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
databases[0].url, databases[0].url,
databases[0].id, databases[0].id,
databases[0].created_at, databases[0].created_at,
@@ -159,9 +153,6 @@ describe("downloadDatabaseFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
true, true,
false, false,
); );
@@ -207,14 +198,12 @@ describe("downloadDatabaseFromGitHub", () => {
owner, owner,
repo, repo,
databases, databases,
databaseManager, databaseFetcher,
storagePath, app.commands,
cliServer,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(1); expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledTimes(1);
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
databases[1].url, databases[1].url,
databases[1].id, databases[1].id,
databases[1].created_at, databases[1].created_at,
@@ -223,9 +212,6 @@ describe("downloadDatabaseFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
true, true,
false, false,
); );
@@ -263,14 +249,12 @@ describe("downloadDatabaseFromGitHub", () => {
owner, owner,
repo, repo,
databases, databases,
databaseManager, databaseFetcher,
storagePath, app.commands,
cliServer,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(2); expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledTimes(2);
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
databases[0].url, databases[0].url,
databases[0].id, databases[0].id,
databases[0].created_at, databases[0].created_at,
@@ -279,13 +263,10 @@ describe("downloadDatabaseFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
true, true,
false, false,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
databases[1].url, databases[1].url,
databases[1].id, databases[1].id,
databases[1].created_at, databases[1].created_at,
@@ -294,9 +275,6 @@ describe("downloadDatabaseFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
true, true,
false, false,
); );
@@ -328,13 +306,11 @@ describe("downloadDatabaseFromGitHub", () => {
owner, owner,
repo, repo,
databases, databases,
databaseManager, databaseFetcher,
storagePath, app.commands,
cliServer,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); expect(downloadGitHubDatabaseFromUrlMock).not.toHaveBeenCalled();
}); });
}); });
}); });

View File

@@ -4,7 +4,6 @@ import { createMockApp } from "../../../../__mocks__/appMock";
import type { App } from "../../../../../src/common/app"; import type { App } from "../../../../../src/common/app";
import type { DatabaseManager } from "../../../../../src/databases/local-databases"; import type { DatabaseManager } from "../../../../../src/databases/local-databases";
import { mockEmptyDatabaseManager } from "../../query-testing/test-runner-helpers"; import { mockEmptyDatabaseManager } from "../../query-testing/test-runner-helpers";
import type { CodeQLCliServer } from "../../../../../src/codeql-cli/cli";
import { mockDatabaseItem, mockedObject } from "../../../utils/mocking.helpers"; import { mockDatabaseItem, mockedObject } from "../../../utils/mocking.helpers";
import type { GitHubDatabaseConfig } from "../../../../../src/config"; import type { GitHubDatabaseConfig } from "../../../../../src/config";
import { GitHubDatabasesModule } from "../../../../../src/databases/github-databases"; import { GitHubDatabasesModule } from "../../../../../src/databases/github-databases";
@@ -16,13 +15,13 @@ import * as githubDatabasesApi from "../../../../../src/databases/github-databas
import * as githubDatabasesDownload from "../../../../../src/databases/github-databases/download"; import * as githubDatabasesDownload from "../../../../../src/databases/github-databases/download";
import * as githubDatabasesUpdates from "../../../../../src/databases/github-databases/updates"; import * as githubDatabasesUpdates from "../../../../../src/databases/github-databases/updates";
import type { DatabaseUpdate } from "../../../../../src/databases/github-databases/updates"; import type { DatabaseUpdate } from "../../../../../src/databases/github-databases/updates";
import type { DatabaseFetcher } from "../../../../../src/databases/database-fetcher";
describe("GitHubDatabasesModule", () => { describe("GitHubDatabasesModule", () => {
describe("promptGitHubRepositoryDownload", () => { describe("promptGitHubRepositoryDownload", () => {
let app: App; let app: App;
let databaseManager: DatabaseManager; let databaseManager: DatabaseManager;
let databaseStoragePath: string; const databaseFetcher = mockedObject<DatabaseFetcher>({});
let cliServer: CodeQLCliServer;
let config: GitHubDatabaseConfig; let config: GitHubDatabaseConfig;
let gitHubDatabasesModule: GitHubDatabasesModule; let gitHubDatabasesModule: GitHubDatabasesModule;
@@ -64,8 +63,6 @@ describe("GitHubDatabasesModule", () => {
beforeEach(() => { beforeEach(() => {
app = createMockApp(); app = createMockApp();
databaseManager = mockEmptyDatabaseManager(); databaseManager = mockEmptyDatabaseManager();
databaseStoragePath = "/a/b/some-path";
cliServer = mockedObject<CodeQLCliServer>({});
config = mockedObject<GitHubDatabaseConfig>({ config = mockedObject<GitHubDatabaseConfig>({
download: "ask", download: "ask",
update: "ask", update: "ask",
@@ -74,8 +71,7 @@ describe("GitHubDatabasesModule", () => {
gitHubDatabasesModule = new GitHubDatabasesModule( gitHubDatabasesModule = new GitHubDatabasesModule(
app, app,
databaseManager, databaseManager,
databaseStoragePath, databaseFetcher,
cliServer,
config, config,
); );
@@ -124,8 +120,7 @@ describe("GitHubDatabasesModule", () => {
gitHubDatabasesModule = new GitHubDatabasesModule( gitHubDatabasesModule = new GitHubDatabasesModule(
app, app,
databaseManager, databaseManager,
databaseStoragePath, databaseFetcher,
cliServer,
config, config,
); );
@@ -206,9 +201,7 @@ describe("GitHubDatabasesModule", () => {
owner, owner,
repo, repo,
databases, databases,
databaseManager, databaseFetcher,
databaseStoragePath,
cliServer,
app.commands, app.commands,
); );
}); });
@@ -250,8 +243,7 @@ describe("GitHubDatabasesModule", () => {
repo, repo,
databaseUpdates, databaseUpdates,
databaseManager, databaseManager,
databaseStoragePath, databaseFetcher,
cliServer,
app.commands, app.commands,
); );
}); });

View File

@@ -10,9 +10,7 @@ import {
import type { CodeqlDatabase } from "../../../../../src/databases/github-databases/api"; import type { CodeqlDatabase } from "../../../../../src/databases/github-databases/api";
import type { DatabaseManager } from "../../../../../src/databases/local-databases"; import type { DatabaseManager } from "../../../../../src/databases/local-databases";
import type { GitHubDatabaseConfig } from "../../../../../src/config"; import type { GitHubDatabaseConfig } from "../../../../../src/config";
import type { CodeQLCliServer } from "../../../../../src/codeql-cli/cli"; import type { DatabaseFetcher } from "../../../../../src/databases/database-fetcher";
import { createMockCommandManager } from "../../../../__mocks__/commandsMock";
import * as databaseFetcher from "../../../../../src/databases/database-fetcher";
import * as dialog from "../../../../../src/common/vscode/dialog"; import * as dialog from "../../../../../src/common/vscode/dialog";
import type { DatabaseUpdate } from "../../../../../src/databases/github-databases/updates"; import type { DatabaseUpdate } from "../../../../../src/databases/github-databases/updates";
import { import {
@@ -20,6 +18,7 @@ import {
downloadDatabaseUpdateFromGitHub, downloadDatabaseUpdateFromGitHub,
isNewerDatabaseAvailable, isNewerDatabaseAvailable,
} from "../../../../../src/databases/github-databases/updates"; } from "../../../../../src/databases/github-databases/updates";
import { createMockApp } from "../../../../__mocks__/appMock";
describe("isNewerDatabaseAvailable", () => { describe("isNewerDatabaseAvailable", () => {
const owner = "github"; const owner = "github";
@@ -344,9 +343,8 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
const owner = "github"; const owner = "github";
const repo = "codeql"; const repo = "codeql";
let databaseManager: DatabaseManager; let databaseManager: DatabaseManager;
const storagePath = "/a/b/c/d"; let databaseFetcher: DatabaseFetcher;
let cliServer: CodeQLCliServer; const app = createMockApp();
const commandManager = createMockCommandManager();
let updates: DatabaseUpdate[] = [ let updates: DatabaseUpdate[] = [
{ {
@@ -367,8 +365,8 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
]; ];
let showQuickPickSpy: jest.SpiedFunction<typeof window.showQuickPick>; let showQuickPickSpy: jest.SpiedFunction<typeof window.showQuickPick>;
let downloadGitHubDatabaseFromUrlSpy: jest.SpiedFunction< let downloadGitHubDatabaseFromUrlMock: jest.MockedFunction<
typeof databaseFetcher.downloadGitHubDatabaseFromUrl DatabaseFetcher["downloadGitHubDatabaseFromUrl"]
>; >;
beforeEach(() => { beforeEach(() => {
@@ -376,7 +374,11 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
databaseManager = mockedObject<DatabaseManager>({ databaseManager = mockedObject<DatabaseManager>({
currentDatabaseItem: mockDatabaseItem(), currentDatabaseItem: mockDatabaseItem(),
}); });
cliServer = mockedObject<CodeQLCliServer>({});
downloadGitHubDatabaseFromUrlMock = jest.fn().mockReturnValue(undefined);
databaseFetcher = mockedObject<DatabaseFetcher>({
downloadGitHubDatabaseFromUrl: downloadGitHubDatabaseFromUrlMock,
});
showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue( showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue(
mockedQuickPickItem([ mockedQuickPickItem([
@@ -385,9 +387,6 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
}), }),
]), ]),
); );
downloadGitHubDatabaseFromUrlSpy = jest
.spyOn(databaseFetcher, "downloadGitHubDatabaseFromUrl")
.mockResolvedValue(undefined);
}); });
it("downloads the database", async () => { it("downloads the database", async () => {
@@ -397,13 +396,12 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
updates, updates,
databaseManager, databaseManager,
storagePath, databaseFetcher,
cliServer, app.commands,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(1); expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledTimes(1);
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
updates[0].database.url, updates[0].database.url,
updates[0].database.id, updates[0].database.id,
updates[0].database.created_at, updates[0].database.created_at,
@@ -412,9 +410,6 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
false, false,
false, false,
); );
@@ -476,13 +471,12 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
updates, updates,
databaseManager, databaseManager,
storagePath, databaseFetcher,
cliServer, app.commands,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(1); expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledTimes(1);
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
updates[1].database.url, updates[1].database.url,
updates[1].database.id, updates[1].database.id,
updates[1].database.created_at, updates[1].database.created_at,
@@ -491,9 +485,6 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
true, true,
true, true,
); );
@@ -532,13 +523,12 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
updates, updates,
databaseManager, databaseManager,
storagePath, databaseFetcher,
cliServer, app.commands,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(2); expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledTimes(2);
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
updates[0].database.url, updates[0].database.url,
updates[0].database.id, updates[0].database.id,
updates[0].database.created_at, updates[0].database.created_at,
@@ -547,13 +537,10 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
false, false,
false, false,
); );
expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( expect(downloadGitHubDatabaseFromUrlMock).toHaveBeenCalledWith(
updates[1].database.url, updates[1].database.url,
updates[1].database.id, updates[1].database.id,
updates[1].database.created_at, updates[1].database.created_at,
@@ -562,9 +549,6 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
octokit, octokit,
expect.anything(), expect.anything(),
databaseManager,
storagePath,
cliServer,
true, true,
true, true,
); );
@@ -597,12 +581,11 @@ describe("downloadDatabaseUpdateFromGitHub", () => {
repo, repo,
updates, updates,
databaseManager, databaseManager,
storagePath, databaseFetcher,
cliServer, app.commands,
commandManager,
); );
expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); expect(downloadGitHubDatabaseFromUrlMock).not.toHaveBeenCalled();
}); });
}); });
}); });

View File

@@ -20,6 +20,7 @@ import { testDisposeHandler } from "../../test-dispose-handler";
import { createMockApp } from "../../../__mocks__/appMock"; import { createMockApp } from "../../../__mocks__/appMock";
import { QueryLanguage } from "../../../../src/common/query-language"; import { QueryLanguage } from "../../../../src/common/query-language";
import { mockedQuickPickItem, mockedObject } from "../../utils/mocking.helpers"; import { mockedQuickPickItem, mockedObject } from "../../utils/mocking.helpers";
import type { DatabaseFetcher } from "../../../../src/databases/database-fetcher";
describe("local-databases-ui", () => { describe("local-databases-ui", () => {
const storageDir = dirSync({ unsafeCleanup: true }).name; const storageDir = dirSync({ unsafeCleanup: true }).name;
@@ -104,6 +105,7 @@ describe("local-databases-ui", () => {
}, },
setCurrentDatabaseItem: () => {}, setCurrentDatabaseItem: () => {},
} as any, } as any,
mockedObject<DatabaseFetcher>({}),
{ {
onLanguageContextChanged: () => { onLanguageContextChanged: () => {
/**/ /**/
@@ -141,6 +143,7 @@ describe("local-databases-ui", () => {
setCurrentDatabaseItem: () => {}, setCurrentDatabaseItem: () => {},
currentDatabaseItem: { databaseUri: Uri.file(db1) }, currentDatabaseItem: { databaseUri: Uri.file(db1) },
} as any, } as any,
mockedObject<DatabaseFetcher>({}),
{ {
onLanguageContextChanged: () => { onLanguageContextChanged: () => {
/**/ /**/
@@ -177,6 +180,7 @@ describe("local-databases-ui", () => {
const databaseUI = new DatabaseUI( const databaseUI = new DatabaseUI(
app, app,
databaseManager, databaseManager,
mockedObject<DatabaseFetcher>({}),
{ {
onLanguageContextChanged: () => { onLanguageContextChanged: () => {
/**/ /**/

View File

@@ -13,6 +13,7 @@ import type { ModelConfigListener } from "../../../../src/config";
import { createMockModelingEvents } from "../../../__mocks__/model-editor/modelingEventsMock"; import { createMockModelingEvents } from "../../../__mocks__/model-editor/modelingEventsMock";
import { QueryLanguage } from "../../../../src/common/query-language"; import { QueryLanguage } from "../../../../src/common/query-language";
import type { VariantAnalysisManager } from "../../../../src/variant-analysis/variant-analysis-manager"; import type { VariantAnalysisManager } from "../../../../src/variant-analysis/variant-analysis-manager";
import type { DatabaseFetcher } from "../../../../src/databases/database-fetcher";
describe("ModelEditorView", () => { describe("ModelEditorView", () => {
const app = createMockApp({}); const app = createMockApp({});
@@ -22,6 +23,7 @@ describe("ModelEditorView", () => {
onDidChangeConfiguration: jest.fn(), onDidChangeConfiguration: jest.fn(),
}); });
const databaseManager = mockEmptyDatabaseManager(); const databaseManager = mockEmptyDatabaseManager();
const databaseFetcher = mockedObject<DatabaseFetcher>({});
const variantAnalysisManager = mockedObject<VariantAnalysisManager>({}); const variantAnalysisManager = mockedObject<VariantAnalysisManager>({});
const cliServer = mockedObject<CodeQLCliServer>({}); const cliServer = mockedObject<CodeQLCliServer>({});
const queryRunner = mockedObject<QueryRunner>({}); const queryRunner = mockedObject<QueryRunner>({});
@@ -50,6 +52,7 @@ describe("ModelEditorView", () => {
modelingEvents, modelingEvents,
modelConfig, modelConfig,
databaseManager, databaseManager,
databaseFetcher,
variantAnalysisManager, variantAnalysisManager,
cliServer, cliServer,
queryRunner, queryRunner,