Merge pull request #2448 from github/koesie10/auto-model

Add initial implementation of auto-modeling
This commit is contained in:
Koen Vlaswinkel
2023-05-30 12:19:22 +02:00
committed by GitHub
9 changed files with 772 additions and 1 deletions

View File

@@ -711,3 +711,10 @@ const QUERIES_PANEL = new Setting("queriesPanel", ROOT_SETTING);
export function showQueriesPanel(): boolean {
return !!QUERIES_PANEL.getValue<boolean>();
}
const DATA_EXTENSIONS = new Setting("dataExtensions", ROOT_SETTING);
const LLM_GENERATION = new Setting("llmGeneration", DATA_EXTENSIONS);
export function showLlmGeneration(): boolean {
return !!LLM_GENERATION.getValue<boolean>();
}

View File

@@ -0,0 +1,54 @@
import { Credentials } from "../common/authentication";
import { OctokitResponse } from "@octokit/types";
export enum ClassificationType {
Unknown = "CLASSIFICATION_TYPE_UNKNOWN",
Neutral = "CLASSIFICATION_TYPE_NEUTRAL",
Source = "CLASSIFICATION_TYPE_SOURCE",
Sink = "CLASSIFICATION_TYPE_SINK",
Summary = "CLASSIFICATION_TYPE_SUMMARY",
}
export interface Classification {
type: ClassificationType;
kind: string;
explanation: string;
}
export interface Method {
package: string;
type: string;
name: string;
signature: string;
usages: string[];
classification?: Classification;
input?: string;
output?: string;
}
export interface ModelRequest {
language: string;
candidates: Method[];
samples: Method[];
}
export interface ModelResponse {
language: string;
predicted: Method[];
}
export async function autoModel(
credentials: Credentials,
request: ModelRequest,
): Promise<ModelResponse> {
const octokit = await credentials.getOctokit();
const response: OctokitResponse<ModelResponse> = await octokit.request(
"POST /repos/github/codeql/code-scanning/codeql/auto-model",
{
data: request,
},
);
return response.data;
}

View File

@@ -0,0 +1,218 @@
import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod, ModeledMethodType } from "./modeled-method";
import {
Classification,
ClassificationType,
Method,
ModelRequest,
} from "./auto-model-api";
export function createAutoModelRequest(
language: string,
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): ModelRequest {
const request: ModelRequest = {
language,
samples: [],
candidates: [],
};
// Sort by number of usages so we always send the most used methods first
externalApiUsages = [...externalApiUsages];
externalApiUsages.sort((a, b) => b.usages.length - a.usages.length);
for (const externalApiUsage of externalApiUsages) {
const modeledMethod: ModeledMethod = modeledMethods[
externalApiUsage.signature
] ?? {
type: "none",
};
const numberOfArguments =
externalApiUsage.methodParameters === "()"
? 0
: externalApiUsage.methodParameters.split(",").length;
for (
let argumentIndex = 0;
argumentIndex < numberOfArguments;
argumentIndex++
) {
const method: Method = {
package: externalApiUsage.packageName,
type: externalApiUsage.typeName,
name: externalApiUsage.methodName,
signature: externalApiUsage.methodParameters,
classification:
modeledMethod.type === "none"
? undefined
: toMethodClassification(modeledMethod),
usages: externalApiUsage.usages
.slice(0, 10)
.map((usage) => usage.label),
input: `Argument[${argumentIndex}]`,
};
if (modeledMethod.type === "none") {
request.candidates.push(method);
} else {
request.samples.push(method);
}
}
}
request.candidates = request.candidates.slice(0, 20);
request.samples = request.samples.slice(0, 100);
return request;
}
/**
* For now, we have a simplified model that only models methods as sinks. It does not model methods as neutral,
* so we aren't actually able to correctly determine that a method is neutral; it could still be a source or summary.
* However, to keep this method simple and give output to the user, we will model any method for which none of its
* arguments are modeled as sinks as neutral.
*
* If there are multiple arguments which are modeled as sinks, we will only model the first one.
*/
export function parsePredictedClassifications(
predicted: Method[],
): Record<string, ModeledMethod> {
const predictedBySignature: Record<string, Method[]> = {};
for (const method of predicted) {
if (!method.classification) {
continue;
}
const signature = toFullMethodSignature(method);
if (!(signature in predictedBySignature)) {
predictedBySignature[signature] = [];
}
predictedBySignature[signature].push(method);
}
const modeledMethods: Record<string, ModeledMethod> = {};
for (const signature in predictedBySignature) {
const predictedMethods = predictedBySignature[signature];
const sinks = predictedMethods.filter(
(method) => method.classification?.type === ClassificationType.Sink,
);
if (sinks.length === 0) {
// For now, model any method for which none of its arguments are modeled as sinks as neutral
modeledMethods[signature] = {
type: "neutral",
kind: "",
input: "",
output: "",
};
continue;
}
// Order the sinks by the input alphabetically. This will ensure that the first argument is always
// first in the list of sinks, the second argument is always second, etc.
// If we get back "Argument[1]" and "Argument[3]", "Argument[1]" should always be first
sinks.sort((a, b) => compareInputOutput(a.input ?? "", b.input ?? ""));
const sink = sinks[0];
modeledMethods[signature] = {
type: "sink",
kind: sink.classification?.kind ?? "",
input: sink.input ?? "",
output: sink.output ?? "",
};
}
return modeledMethods;
}
function toMethodClassificationType(
type: ModeledMethodType,
): ClassificationType {
switch (type) {
case "source":
return ClassificationType.Source;
case "sink":
return ClassificationType.Sink;
case "summary":
return ClassificationType.Summary;
case "neutral":
return ClassificationType.Neutral;
default:
return ClassificationType.Unknown;
}
}
function toMethodClassification(modeledMethod: ModeledMethod): Classification {
return {
type: toMethodClassificationType(modeledMethod.type),
kind: modeledMethod.kind,
explanation: "",
};
}
function toFullMethodSignature(method: Method): string {
return `${method.package}.${method.type}#${method.name}${method.signature}`;
}
const argumentRegex = /^Argument\[(\d+)]$/;
// Argument[this] is before ReturnValue
const nonNumericArgumentOrder = ["Argument[this]", "ReturnValue"];
/**
* Compare two inputs or outputs matching `Argument[<number>]`, `Argument[this]`, or `ReturnValue`.
* If they are the same, return 0. If a is less than b, returns a negative number.
* If a is greater than b, returns a positive number.
*/
export function compareInputOutput(a: string, b: string): number {
if (a === b) {
return 0;
}
const aMatch = a.match(argumentRegex);
const bMatch = b.match(argumentRegex);
// Numeric arguments are always first
if (aMatch && !bMatch) {
return -1;
}
if (!aMatch && bMatch) {
return 1;
}
// Neither is an argument
if (!aMatch && !bMatch) {
const aIndex = nonNumericArgumentOrder.indexOf(a);
const bIndex = nonNumericArgumentOrder.indexOf(b);
// If either one is unknown, it is sorted last
if (aIndex === -1 && bIndex === -1) {
return a.localeCompare(b);
}
if (aIndex === -1) {
return 1;
}
if (bIndex === -1) {
return -1;
}
return aIndex - bIndex;
}
// This case shouldn't happen, but makes TypeScript happy
if (!aMatch || !bMatch) {
return 0;
}
// Both are arguments
const aIndex = parseInt(aMatch[1]);
const bIndex = parseInt(bMatch[1]);
return aIndex - bIndex;
}

View File

@@ -39,6 +39,12 @@ import { createDataExtensionYaml, loadDataExtensionYaml } from "./yaml";
import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { ExtensionPackModelFile } from "./shared/extension-pack";
import { autoModel } from "./auto-model-api";
import {
createAutoModelRequest,
parsePredictedClassifications,
} from "./auto-model";
import { showLlmGeneration } from "../config";
function getQlSubmoduleFolder(): WorkspaceFolder | undefined {
const workspaceFolder = workspace.workspaceFolders?.find(
@@ -127,6 +133,13 @@ export class DataExtensionsEditorView extends AbstractWebview<
case "generateExternalApi":
await this.generateModeledMethods();
break;
case "generateExternalApiFromLlm":
await this.generateModeledMethodsFromLlm(
msg.externalApiUsages,
msg.modeledMethods,
);
break;
default:
assertNever(msg);
@@ -149,6 +162,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
viewState: {
extensionPackModelFile: this.modelFile,
modelFileExists: await pathExists(this.modelFile.filename),
showLlmButton: showLlmGeneration(),
},
});
}
@@ -367,6 +381,29 @@ export class DataExtensionsEditorView extends AbstractWebview<
await this.clearProgress();
}
private async generateModeledMethodsFromLlm(
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): Promise<void> {
const request = createAutoModelRequest(
this.databaseItem.language,
externalApiUsages,
modeledMethods,
);
const response = await autoModel(this.app.credentials, request);
const predictedModeledMethods = parsePredictedClassifications(
response.predicted,
);
await this.postMessage({
t: "addModeledMethods",
modeledMethods: predictedModeledMethods,
overrideNone: true,
});
}
/*
* Progress in this class is a bit weird. Most of the progress is based on running the query.
* Query progress is always between 0 and 1000. However, we still have some steps that need

View File

@@ -3,4 +3,5 @@ import { ExtensionPackModelFile } from "./extension-pack";
export interface DataExtensionEditorViewState {
extensionPackModelFile: ExtensionPackModelFile;
modelFileExists: boolean;
showLlmButton: boolean;
}

View File

@@ -544,6 +544,12 @@ export interface GenerateExternalApiMessage {
t: "generateExternalApi";
}
export interface GenerateExternalApiFromLlmMessage {
t: "generateExternalApiFromLlm";
externalApiUsages: ExternalApiUsage[];
modeledMethods: Record<string, ModeledMethod>;
}
export type ToDataExtensionsEditorMessage =
| SetExtensionPackStateMessage
| SetExternalApiUsagesMessage
@@ -556,4 +562,5 @@ export type FromDataExtensionsEditorMessage =
| OpenExtensionPackMessage
| JumpToUsageMessage
| SaveModeledMethods
| GenerateExternalApiMessage;
| GenerateExternalApiMessage
| GenerateExternalApiFromLlmMessage;

View File

@@ -30,6 +30,7 @@ DataExtensionsEditor.args = {
"/home/user/vscode-codeql-starter/codeql-custom-queries-java/sql2o/models/sql2o.yml",
},
modelFileExists: true,
showLlmButton: true,
},
initialExternalApiUsages: [
{

View File

@@ -157,6 +157,14 @@ export function DataExtensionsEditor({
});
}, []);
const onGenerateFromLlmClick = useCallback(() => {
vscode.postMessage({
t: "generateExternalApiFromLlm",
externalApiUsages,
modeledMethods,
});
}, [externalApiUsages, modeledMethods]);
const onOpenExtensionPackClick = useCallback(() => {
vscode.postMessage({
t: "openExtensionPack",
@@ -214,6 +222,14 @@ export function DataExtensionsEditor({
<VSCodeButton onClick={onGenerateClick}>
Download and generate
</VSCodeButton>
{viewState?.showLlmButton && (
<>
&nbsp;
<VSCodeButton onClick={onGenerateFromLlmClick}>
Generate using LLM
</VSCodeButton>
</>
)}
<br />
<br />
<VSCodeDataGrid>

View File

@@ -0,0 +1,430 @@
import {
compareInputOutput,
createAutoModelRequest,
parsePredictedClassifications,
} from "../../../src/data-extensions-editor/auto-model";
import { ExternalApiUsage } from "../../../src/data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
import {
ClassificationType,
Method,
} from "../../../src/data-extensions-editor/auto-model-api";
describe("createAutoModelRequest", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
supported: false,
usages: [
{
label: "run(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/Sql2oExampleApplication.java",
startLine: 9,
startColumn: 9,
endLine: 9,
endColumn: 66,
},
},
],
},
{
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
supported: true,
usages: [
{
label: "createQuery(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 15,
startColumn: 13,
endLine: 15,
endColumn: 56,
},
},
{
label: "createQuery(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 26,
startColumn: 13,
endLine: 26,
endColumn: 39,
},
},
],
},
{
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
supported: true,
usages: [
{
label: "executeScalar(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 15,
startColumn: 13,
endLine: 15,
endColumn: 85,
},
},
{
label: "executeScalar(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 26,
startColumn: 13,
endLine: 26,
endColumn: 68,
},
},
],
},
{
signature: "org.sql2o.Sql2o#open()",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "open",
methodParameters: "()",
supported: true,
usages: [
{
label: "open(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 14,
startColumn: 24,
endLine: 14,
endColumn: 35,
},
},
{
label: "open(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 25,
startColumn: 24,
endLine: 25,
endColumn: 35,
},
},
],
},
{
signature: "java.io.PrintStream#println(String)",
packageName: "java.io",
typeName: "PrintStream",
methodName: "println",
methodParameters: "(String)",
supported: true,
usages: [
{
label: "println(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 29,
startColumn: 9,
endLine: 29,
endColumn: 49,
},
},
],
},
{
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
supported: true,
usages: [
{
label: "new Sql2o(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 10,
startColumn: 33,
endLine: 10,
endColumn: 88,
},
},
],
},
{
signature: "org.sql2o.Sql2o#Sql2o(String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String)",
supported: true,
usages: [
{
label: "new Sql2o(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 23,
startColumn: 23,
endLine: 23,
endColumn: 36,
},
},
],
},
];
const modeledMethods: Record<string, ModeledMethod> = {
"org.sql2o.Sql2o#open()": {
type: "neutral",
kind: "",
input: "",
output: "",
},
"org.sql2o.Sql2o#Sql2o(String)": {
type: "sink",
kind: "jndi-injection",
input: "Argument[0]",
output: "",
},
};
it("creates a matching request", () => {
expect(
createAutoModelRequest("java", externalApiUsages, modeledMethods),
).toEqual({
language: "java",
samples: [
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String)",
classification: {
type: "CLASSIFICATION_TYPE_SINK",
kind: "jndi-injection",
explanation: "",
},
usages: ["new Sql2o(...)"],
input: "Argument[0]",
},
],
candidates: [
{
package: "org.sql2o",
type: "Connection",
name: "createQuery",
signature: "(String)",
usages: ["createQuery(...)", "createQuery(...)"],
input: "Argument[0]",
},
{
package: "org.sql2o",
type: "Query",
name: "executeScalar",
signature: "(Class)",
usages: ["executeScalar(...)", "executeScalar(...)"],
input: "Argument[0]",
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages: ["run(...)"],
input: "Argument[0]",
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages: ["run(...)"],
input: "Argument[1]",
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: ["println(...)"],
input: "Argument[0]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[0]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[1]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[2]",
},
],
});
});
});
describe("parsePredictedClassifications", () => {
const predictions: Method[] = [
{
package: "org.sql2o",
type: "Sql2o",
name: "createQuery",
signature: "(String)",
usages: ["createQuery(...)", "createQuery(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "executeScalar",
signature: "(Class)",
usages: ["executeScalar(...)", "executeScalar(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Neutral,
kind: "",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Neutral,
kind: "",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[1]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[2]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
},
},
];
it("correctly parses the output", () => {
expect(parsePredictedClassifications(predictions)).toEqual({
"org.sql2o.Sql2o#createQuery(String)": {
type: "sink",
kind: "sql injection sink",
input: "Argument[0]",
output: "",
},
"org.sql2o.Sql2o#executeScalar(Class)": {
type: "neutral",
kind: "",
input: "",
output: "",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
kind: "sql injection sink",
input: "Argument[1]",
output: "",
},
});
});
});
describe("compareInputOutput", () => {
it("with two small numeric arguments", () => {
expect(
compareInputOutput("Argument[0]", "Argument[1]"),
).toBeLessThanOrEqual(-1);
});
it("with one larger non-alphabetic argument", () => {
expect(
compareInputOutput("Argument[10]", "Argument[2]"),
).toBeGreaterThanOrEqual(1);
});
it("with one non-numeric arguments", () => {
expect(
compareInputOutput("Argument[5]", "Argument[this]"),
).toBeLessThanOrEqual(-1);
});
it("with two non-numeric arguments", () => {
expect(
compareInputOutput("ReturnValue", "Argument[this]"),
).toBeGreaterThanOrEqual(1);
});
it("with one unknown argument in the a position", () => {
expect(
compareInputOutput("FooBar", "Argument[this]"),
).toBeGreaterThanOrEqual(1);
});
it("with one unknown argument in the b position", () => {
expect(compareInputOutput("Argument[this]", "FooBar")).toBeLessThanOrEqual(
-1,
);
});
it("with one empty string arguments", () => {
expect(compareInputOutput("Argument[5]", "")).toBeLessThanOrEqual(-1);
});
it("with two unknown arguments", () => {
expect(compareInputOutput("FooBar", "BarFoo")).toBeGreaterThanOrEqual(1);
});
});