Merge pull request #2682 from github/charisk/batch-automodeling

Batch automodeling
This commit is contained in:
Charis Kyriakou
2023-08-10 08:58:23 +01:00
committed by GitHub
10 changed files with 143 additions and 51 deletions

View File

@@ -554,6 +554,11 @@ interface GenerateExternalApiFromLlmMessage {
modeledMethods: Record<string, ModeledMethod>;
}
interface StopGeneratingExternalApiFromLlmMessage {
t: "stopGeneratingExternalApiFromLlm";
dependencyName: string;
}
interface ModelDependencyMessage {
t: "modelDependency";
}
@@ -575,4 +580,5 @@ export type FromDataExtensionsEditorMessage =
| SaveModeledMethods
| GenerateExternalApiMessage
| GenerateExternalApiFromLlmMessage
| StopGeneratingExternalApiFromLlmMessage
| ModelDependencyMessage;

View File

@@ -160,6 +160,7 @@ type AutoModelQueriesOptions = {
queryStorageDir: string;
progress: ProgressCallback;
cancellationTokenSource: CancellationTokenSource;
};
export type AutoModelQueriesResult = {
@@ -174,12 +175,11 @@ export async function runAutoModelQueries({
databaseItem,
queryStorageDir,
progress,
cancellationTokenSource,
}: AutoModelQueriesOptions): Promise<AutoModelQueriesResult | undefined> {
// maxStep for this part is 1500
const maxStep = 1500;
const cancellationTokenSource = new CancellationTokenSource();
const qlpack = await qlpackOfDatabase(cliServer, databaseItem);
// CodeQL needs to have access to the database to be able to retrieve the

View File

@@ -8,9 +8,6 @@ import { ExternalApiUsage, MethodSignature } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
// Soft limit on the number of candidates to send to the model.
// Note that the model may return fewer than this number of candidates.
const candidateLimit = 20;
/**
* Return the candidates that the model should be run on. This includes limiting the number of
* candidates to the candidate limit and filtering out anything that is already modeled and respecting
@@ -41,11 +38,6 @@ export function getCandidates(
type: "none",
};
// If we have reached the max number of candidates then stop
if (candidates.length >= candidateLimit) {
break;
}
// Anything that is modeled is not a candidate
if (modeledMethod.type !== "none") {
continue;

View File

@@ -15,18 +15,31 @@ import { CodeQLCliServer } from "../codeql-cli/cli";
import { QueryRunner } from "../query-server";
import { DatabaseItem } from "../databases/local-databases";
import { Mode } from "./shared/mode";
import { CancellationTokenSource } from "vscode";
// Limit the number of candidates we send to the model in each request
// to avoid long requests.
// Note that the model may return fewer than this number of candidates.
const candidateBatchSize = 20;
export class AutoModeler {
private readonly jobs: Map<string, CancellationTokenSource>;
constructor(
private readonly app: App,
private readonly cliServer: CodeQLCliServer,
private readonly queryRunner: QueryRunner,
private readonly queryStorageDir: string,
private readonly databaseItem: DatabaseItem,
private readonly setInProgressMethods: (
inProgressMethods: string[],
) => Promise<void>,
private readonly addModeledMethods: (
modeledMethods: Record<string, ModeledMethod>,
) => Promise<void>,
) {}
) {
this.jobs = new Map<string, CancellationTokenSource>();
}
public async startModeling(
dependency: string,
@@ -34,12 +47,38 @@ export class AutoModeler {
modeledMethods: Record<string, ModeledMethod>,
mode: Mode,
): Promise<void> {
await this.modelDependency(
dependency,
externalApiUsages,
modeledMethods,
mode,
);
if (this.jobs.has(dependency)) {
return;
}
const cancellationTokenSource = new CancellationTokenSource();
this.jobs.set(dependency, cancellationTokenSource);
try {
await this.modelDependency(
dependency,
externalApiUsages,
modeledMethods,
mode,
cancellationTokenSource,
);
} finally {
this.jobs.delete(dependency);
}
}
public async stopModeling(dependency: string): Promise<void> {
void extLogger.log(`Stopping modeling for dependency ${dependency}`);
const cancellationTokenSource = this.jobs.get(dependency);
if (cancellationTokenSource) {
cancellationTokenSource.cancel();
}
}
public async stopAllModeling(): Promise<void> {
for (const cancellationTokenSource of this.jobs.values()) {
cancellationTokenSource.cancel();
}
}
private async modelDependency(
@@ -47,31 +86,63 @@ export class AutoModeler {
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
mode: Mode,
cancellationTokenSource: CancellationTokenSource,
): Promise<void> {
void extLogger.log(`Modeling dependency ${dependency}`);
await withProgress(async (progress) => {
const maxStep = 3000;
progress({
step: 0,
maxStep,
message: "Retrieving usages",
});
// Fetch the candidates to send to the model
const candidateMethods = getCandidates(
const allCandidateMethods = getCandidates(
mode,
externalApiUsages,
modeledMethods,
);
// If there are no candidates, there is nothing to model and we just return
if (candidateMethods.length === 0) {
if (allCandidateMethods.length === 0) {
void extLogger.log("No candidates to model. Stopping.");
return;
}
await this.modelCandidates(candidateMethods, mode, progress, maxStep);
// Find number of slices to make
const batchNumber = Math.ceil(
allCandidateMethods.length / candidateBatchSize,
);
try {
for (let i = 0; i < batchNumber; i++) {
if (cancellationTokenSource.token.isCancellationRequested) {
break;
}
const start = i * candidateBatchSize;
const end = start + candidateBatchSize;
const candidatesToProcess = allCandidateMethods.slice(start, end);
await this.setInProgressMethods(
candidatesToProcess.map((c) => c.signature),
);
progress({
step: 1800 + i * 100,
maxStep,
message: `Automodeling candidates, batch ${
i + 1
} of ${batchNumber}`,
});
await this.modelCandidates(
candidatesToProcess,
mode,
progress,
maxStep,
cancellationTokenSource,
);
}
} finally {
// Clear out in progress methods
await this.setInProgressMethods([]);
}
});
}
@@ -80,6 +151,7 @@ export class AutoModeler {
mode: Mode,
progress: ProgressCallback,
maxStep: number,
cancellationTokenSource: CancellationTokenSource,
): Promise<void> {
const usages = await runAutoModelQueries({
mode,
@@ -89,6 +161,7 @@ export class AutoModeler {
queryStorageDir: this.queryStorageDir,
databaseItem: this.databaseItem,
progress: (update) => progress({ ...update, maxStep }),
cancellationTokenSource,
});
if (!usages) {
return;

View File

@@ -83,6 +83,12 @@ export class DataExtensionsEditorView extends AbstractWebview<
queryRunner,
queryStorageDir,
databaseItem,
async (inProgressMethods) => {
await this.postMessage({
t: "setInProgressMethods",
inProgressMethods,
});
},
async (modeledMethods) => {
await this.postMessage({ t: "addModeledMethods", modeledMethods });
},
@@ -182,6 +188,9 @@ export class DataExtensionsEditorView extends AbstractWebview<
);
}
break;
case "stopGeneratingExternalApiFromLlm":
await this.autoModeler.stopModeling(msg.dependencyName);
break;
case "modelDependency":
await this.modelDependency();
break;

View File

@@ -241,6 +241,13 @@ export function DataExtensionsEditor({
[],
);
const onStopGenerateFromLlmClick = useCallback((dependencyName: string) => {
vscode.postMessage({
t: "stopGeneratingExternalApiFromLlm",
dependencyName,
});
}, []);
const onOpenDatabaseClick = useCallback(() => {
vscode.postMessage({
t: "openDatabase",
@@ -345,6 +352,7 @@ export function DataExtensionsEditor({
onChange={onChange}
onSaveModelClick={onSaveModelClick}
onGenerateFromLlmClick={onGenerateFromLlmClick}
onStopGenerateFromLlmClick={onStopGenerateFromLlmClick}
onGenerateFromSourceClick={onGenerateFromSourceClick}
onModelDependencyClick={onModelDependencyClick}
/>

View File

@@ -89,6 +89,7 @@ type Props = {
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
) => void;
onStopGenerateFromLlmClick: (dependencyName: string) => void;
onGenerateFromSourceClick: () => void;
onModelDependencyClick: () => void;
};
@@ -105,6 +106,7 @@ export const LibraryRow = ({
onChange,
onSaveModelClick,
onGenerateFromLlmClick,
onStopGenerateFromLlmClick,
onGenerateFromSourceClick,
onModelDependencyClick,
}: Props) => {
@@ -127,6 +129,15 @@ export const LibraryRow = ({
[title, externalApiUsages, modeledMethods, onGenerateFromLlmClick],
);
const handleStopModelWithAI = useCallback(
async (e: React.MouseEvent) => {
onStopGenerateFromLlmClick(title);
e.stopPropagation();
e.preventDefault();
},
[title, onStopGenerateFromLlmClick],
);
const handleModelFromSource = useCallback(
async (e: React.MouseEvent) => {
onGenerateFromSourceClick();
@@ -167,6 +178,12 @@ export const LibraryRow = ({
);
}, [externalApiUsages, modifiedSignatures]);
const canStopAutoModeling = useMemo(() => {
return externalApiUsages.some((externalApiUsage) =>
inProgressSignatures.has(externalApiUsage.signature),
);
}, [externalApiUsages, inProgressSignatures]);
return (
<LibraryContainer>
<TitleContainer onClick={toggleExpanded} aria-expanded={isExpanded}>
@@ -185,12 +202,18 @@ export const LibraryRow = ({
</ModeledPercentage>
{hasUnsavedChanges ? <VSCodeTag>UNSAVED</VSCodeTag> : null}
</NameContainer>
{viewState.showLlmButton && (
{viewState.showLlmButton && !canStopAutoModeling && (
<VSCodeButton appearance="icon" onClick={handleModelWithAI}>
<Codicon name="lightbulb-autofix" label="Model with AI" />
&nbsp;Model with AI
</VSCodeButton>
)}
{viewState.showLlmButton && canStopAutoModeling && (
<VSCodeButton appearance="icon" onClick={handleStopModelWithAI}>
<Codicon name="debug-stop" label="Stop model with AI" />
&nbsp;Stop
</VSCodeButton>
)}
{viewState.mode === Mode.Application && (
<VSCodeButton appearance="icon" onClick={handleModelFromSource}>
<Codicon name="code" label="Model from source" />

View File

@@ -31,6 +31,7 @@ type Props = {
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
) => void;
onStopGenerateFromLlmClick: (dependencyName: string) => void;
onGenerateFromSourceClick: () => void;
onModelDependencyClick: () => void;
};
@@ -49,6 +50,7 @@ export const ModeledMethodsList = ({
onChange,
onSaveModelClick,
onGenerateFromLlmClick,
onStopGenerateFromLlmClick,
onGenerateFromSourceClick,
onModelDependencyClick,
}: Props) => {
@@ -93,6 +95,7 @@ export const ModeledMethodsList = ({
onChange={onChange}
onSaveModelClick={onSaveModelClick}
onGenerateFromLlmClick={onGenerateFromLlmClick}
onStopGenerateFromLlmClick={onStopGenerateFromLlmClick}
onGenerateFromSourceClick={onGenerateFromSourceClick}
onModelDependencyClick={onModelDependencyClick}
/>

View File

@@ -165,28 +165,4 @@ describe("getCandidates", () => {
);
expect(candidates.length).toEqual(1);
});
it("respects the limit", () => {
const externalApiUsages: ExternalApiUsage[] = [];
for (let i = 0; i < 30; i++) {
externalApiUsages.push({
library: "my.jar",
signature: `org.my.A#x${i}()`,
packageName: "org.my",
typeName: "A",
methodName: `x${i}`,
methodParameters: "()",
supported: false,
supportedType: "none",
usages: [],
});
}
const modeledMethods = {};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(20);
});
});

View File

@@ -19,6 +19,7 @@ import { MethodSignature } from "../../../../src/data-extensions-editor/external
import { join } from "path";
import { exists, readFile } from "fs-extra";
import { load as loadYaml } from "js-yaml";
import { CancellationTokenSource } from "vscode-jsonrpc";
describe("runAutoModelQueries", () => {
const qlpack = {
@@ -142,6 +143,7 @@ describe("runAutoModelQueries", () => {
}),
queryStorageDir: "/tmp/queries",
progress: jest.fn(),
cancellationTokenSource: new CancellationTokenSource(),
};
const result = await runAutoModelQueries(options);