Add initial implementation of auto-modeling

This commit is contained in:
Koen Vlaswinkel
2023-05-25 13:21:20 +02:00
parent b8557d337c
commit a88e683ebf
9 changed files with 535 additions and 1 deletions

View File

@@ -711,3 +711,10 @@ const QUERIES_PANEL = new Setting("queriesPanel", ROOT_SETTING);
export function showQueriesPanel(): boolean {
return !!QUERIES_PANEL.getValue<boolean>();
}
const DATA_EXTENSIONS = new Setting("dataExtensions", ROOT_SETTING);
const LLM_GENERATION = new Setting("llmGeneration", DATA_EXTENSIONS);
export function showLlmGeneration(): boolean {
return !!LLM_GENERATION.getValue<boolean>();
}

View File

@@ -0,0 +1,54 @@
import { Credentials } from "../common/authentication";
import { OctokitResponse } from "@octokit/types";
export enum ClassificationType {
Unknown = "CLASSIFICATION_TYPE_UNKNOWN",
Neutral = "CLASSIFICATION_TYPE_NEUTRAL",
Source = "CLASSIFICATION_TYPE_SOURCE",
Sink = "CLASSIFICATION_TYPE_SINK",
Summary = "CLASSIFICATION_TYPE_SUMMARY",
}
export interface Classification {
type: ClassificationType;
kind: string;
explanation: string;
}
export interface Method {
package: string;
type: string;
name: string;
signature: string;
usages: string[];
classification?: Classification;
input?: string;
output?: string;
}
export interface ModelRequest {
language: string;
candidates: Method[];
samples: Method[];
}
export interface ModelResponse {
language: string;
predicted: Method[];
}
export async function autoModel(
credentials: Credentials,
request: ModelRequest,
): Promise<ModelResponse> {
const octokit = await credentials.getOctokit();
const response: OctokitResponse<ModelResponse> = await octokit.request(
"POST /repos/github/codeql/code-scanning/codeql/auto-model",
{
data: request,
},
);
return response.data;
}

View File

@@ -0,0 +1,117 @@
import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod, ModeledMethodType } from "./modeled-method";
import {
Classification,
ClassificationType,
Method,
ModelRequest,
} from "./auto-model-api";
export function createAutoModelRequest(
language: string,
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): ModelRequest {
const request: ModelRequest = {
language,
samples: [],
candidates: [],
};
// Sort by number of usages so we always send the most used methods first
externalApiUsages = [...externalApiUsages];
externalApiUsages.sort((a, b) => b.usages.length - a.usages.length);
for (const externalApiUsage of externalApiUsages) {
const modeledMethod: ModeledMethod = modeledMethods[
externalApiUsage.signature
] ?? {
type: "none",
};
const numberOfArguments =
externalApiUsage.methodParameters === "()"
? 0
: externalApiUsage.methodParameters.split(",").length;
for (
let argumentIndex = 0;
argumentIndex < numberOfArguments;
argumentIndex++
) {
const method: Method = {
package: externalApiUsage.packageName,
type: externalApiUsage.typeName,
name: externalApiUsage.methodName,
signature: externalApiUsage.methodParameters,
classification:
modeledMethod.type === "none"
? undefined
: toMethodClassification(modeledMethod),
usages: externalApiUsage.usages.map((usage) => usage.label),
input: `Argument[${argumentIndex}]`,
};
if (method.usages.length > 10) {
method.usages = method.usages.slice(0, 10);
}
if (modeledMethod.type === "none") {
request.candidates.push(method);
} else {
request.samples.push(method);
}
}
}
if (request.candidates.length > 100) {
request.candidates = request.candidates.slice(0, 100);
}
if (request.samples.length > 20) {
request.samples = request.samples.slice(0, 20);
}
return request;
}
function toMethodClassificationType(
type: ModeledMethodType,
): ClassificationType {
switch (type) {
case "source":
return ClassificationType.Source;
case "sink":
return ClassificationType.Sink;
case "summary":
return ClassificationType.Summary;
case "neutral":
return ClassificationType.Neutral;
default:
return ClassificationType.Unknown;
}
}
function toMethodClassification(modeledMethod: ModeledMethod): Classification {
return {
type: toMethodClassificationType(modeledMethod.type),
kind: modeledMethod.kind,
explanation: "",
};
}
export function classificationTypeToModeledMethodType(
type: ClassificationType,
): ModeledMethodType {
switch (type) {
case ClassificationType.Source:
return "source";
case ClassificationType.Sink:
return "sink";
case ClassificationType.Summary:
return "summary";
case ClassificationType.Neutral:
return "neutral";
default:
return "none";
}
}

View File

@@ -39,6 +39,12 @@ import { createDataExtensionYaml, loadDataExtensionYaml } from "./yaml";
import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { ExtensionPackModelFile } from "./shared/extension-pack";
import { autoModel } from "./auto-model-api";
import {
classificationTypeToModeledMethodType,
createAutoModelRequest,
} from "./auto-model";
import { showLlmGeneration } from "../config";
function getQlSubmoduleFolder(): WorkspaceFolder | undefined {
const workspaceFolder = workspace.workspaceFolders?.find(
@@ -127,6 +133,13 @@ export class DataExtensionsEditorView extends AbstractWebview<
case "generateExternalApi":
await this.generateModeledMethods();
break;
case "generateExternalApiFromLlm":
await this.generateModeledMethodsFromLlm(
msg.externalApiUsages,
msg.modeledMethods,
);
break;
default:
assertNever(msg);
@@ -149,6 +162,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
viewState: {
extensionPackModelFile: this.modelFile,
modelFileExists: await pathExists(this.modelFile.filename),
showLlmButton: showLlmGeneration(),
},
});
}
@@ -367,6 +381,40 @@ export class DataExtensionsEditorView extends AbstractWebview<
await this.clearProgress();
}
private async generateModeledMethodsFromLlm(
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): Promise<void> {
const request = createAutoModelRequest(
this.databaseItem.language,
externalApiUsages,
modeledMethods,
);
const response = await autoModel(this.app.credentials, request);
const modeledMethodsByName: Record<string, ModeledMethod> = {};
for (const method of response.predicted) {
if (method.classification === undefined) {
continue;
}
modeledMethodsByName[method.signature] = {
type: classificationTypeToModeledMethodType(method.classification.type),
kind: method.classification.kind,
input: method.input ?? "",
output: method.output ?? "",
};
}
await this.postMessage({
t: "addModeledMethods",
modeledMethods: modeledMethodsByName,
overrideNone: true,
});
}
/*
* Progress in this class is a bit weird. Most of the progress is based on running the query.
* Query progress is always between 0 and 1000. However, we still have some steps that need

View File

@@ -3,4 +3,5 @@ import { ExtensionPackModelFile } from "./extension-pack";
export interface DataExtensionEditorViewState {
extensionPackModelFile: ExtensionPackModelFile;
modelFileExists: boolean;
showLlmButton: boolean;
}

View File

@@ -544,6 +544,12 @@ export interface GenerateExternalApiMessage {
t: "generateExternalApi";
}
export interface GenerateExternalApiFromLlmMessage {
t: "generateExternalApiFromLlm";
externalApiUsages: ExternalApiUsage[];
modeledMethods: Record<string, ModeledMethod>;
}
export type ToDataExtensionsEditorMessage =
| SetExtensionPackStateMessage
| SetExternalApiUsagesMessage
@@ -556,4 +562,5 @@ export type FromDataExtensionsEditorMessage =
| OpenExtensionPackMessage
| JumpToUsageMessage
| SaveModeledMethods
| GenerateExternalApiMessage;
| GenerateExternalApiMessage
| GenerateExternalApiFromLlmMessage;

View File

@@ -30,6 +30,7 @@ DataExtensionsEditor.args = {
"/home/user/vscode-codeql-starter/codeql-custom-queries-java/sql2o/models/sql2o.yml",
},
modelFileExists: true,
showLlmButton: true,
},
initialExternalApiUsages: [
{

View File

@@ -157,6 +157,14 @@ export function DataExtensionsEditor({
});
}, []);
const onGenerateFromLlmClick = useCallback(() => {
vscode.postMessage({
t: "generateExternalApiFromLlm",
externalApiUsages,
modeledMethods,
});
}, [externalApiUsages, modeledMethods]);
const onOpenExtensionPackClick = useCallback(() => {
vscode.postMessage({
t: "openExtensionPack",
@@ -214,6 +222,14 @@ export function DataExtensionsEditor({
<VSCodeButton onClick={onGenerateClick}>
Download and generate
</VSCodeButton>
{viewState?.showLlmButton && (
<>
&nbsp;
<VSCodeButton onClick={onGenerateFromLlmClick}>
Generate using LLM
</VSCodeButton>
</>
)}
<br />
<br />
<VSCodeDataGrid>

View File

@@ -0,0 +1,283 @@
import { createAutoModelRequest } from "../../../src/data-extensions-editor/auto-model";
import { ExternalApiUsage } from "../../../src/data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
describe("createAutoModelRequest", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
supported: false,
usages: [
{
label: "run(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/Sql2oExampleApplication.java",
startLine: 9,
startColumn: 9,
endLine: 9,
endColumn: 66,
},
},
],
},
{
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
supported: true,
usages: [
{
label: "createQuery(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 15,
startColumn: 13,
endLine: 15,
endColumn: 56,
},
},
{
label: "createQuery(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 26,
startColumn: 13,
endLine: 26,
endColumn: 39,
},
},
],
},
{
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
supported: true,
usages: [
{
label: "executeScalar(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 15,
startColumn: 13,
endLine: 15,
endColumn: 85,
},
},
{
label: "executeScalar(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 26,
startColumn: 13,
endLine: 26,
endColumn: 68,
},
},
],
},
{
signature: "org.sql2o.Sql2o#open()",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "open",
methodParameters: "()",
supported: true,
usages: [
{
label: "open(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 14,
startColumn: 24,
endLine: 14,
endColumn: 35,
},
},
{
label: "open(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 25,
startColumn: 24,
endLine: 25,
endColumn: 35,
},
},
],
},
{
signature: "java.io.PrintStream#println(String)",
packageName: "java.io",
typeName: "PrintStream",
methodName: "println",
methodParameters: "(String)",
supported: true,
usages: [
{
label: "println(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 29,
startColumn: 9,
endLine: 29,
endColumn: 49,
},
},
],
},
{
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
supported: true,
usages: [
{
label: "new Sql2o(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 10,
startColumn: 33,
endLine: 10,
endColumn: 88,
},
},
],
},
{
signature: "org.sql2o.Sql2o#Sql2o(String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String)",
supported: true,
usages: [
{
label: "new Sql2o(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 23,
startColumn: 23,
endLine: 23,
endColumn: 36,
},
},
],
},
];
const modeledMethods: Record<string, ModeledMethod> = {
"org.sql2o.Sql2o#open()": {
type: "neutral",
kind: "",
input: "",
output: "",
},
"org.sql2o.Sql2o#Sql2o(String)": {
type: "sink",
kind: "jndi-injection",
input: "Argument[0]",
output: "",
},
};
it("creates a matching request", () => {
expect(
createAutoModelRequest("java", externalApiUsages, modeledMethods),
).toEqual({
language: "java",
samples: [
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String)",
classification: {
type: "CLASSIFICATION_TYPE_SINK",
kind: "jndi-injection",
explanation: "",
},
usages: ["new Sql2o(...)"],
input: "Argument[0]",
},
],
candidates: [
{
package: "org.sql2o",
type: "Connection",
name: "createQuery",
signature: "(String)",
usages: ["createQuery(...)", "createQuery(...)"],
input: "Argument[0]",
},
{
package: "org.sql2o",
type: "Query",
name: "executeScalar",
signature: "(Class)",
usages: ["executeScalar(...)", "executeScalar(...)"],
input: "Argument[0]",
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages: ["run(...)"],
input: "Argument[0]",
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages: ["run(...)"],
input: "Argument[1]",
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: ["println(...)"],
input: "Argument[0]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[0]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[1]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[2]",
},
],
});
});
});