Merge pull request #2698 from github/charisk/remove-automodel-v1

Remove automodel v1 code
This commit is contained in:
Charis Kyriakou
2023-08-11 13:22:16 +01:00
committed by GitHub
10 changed files with 217 additions and 1414 deletions

View File

@@ -704,7 +704,6 @@ export function showQueriesPanel(): boolean {
const DATA_EXTENSIONS = new Setting("dataExtensions", ROOT_SETTING);
const LLM_GENERATION = new Setting("llmGeneration", DATA_EXTENSIONS);
const LLM_GENERATION_V2 = new Setting("llmGenerationV2", DATA_EXTENSIONS);
const FRAMEWORK_MODE = new Setting("frameworkMode", DATA_EXTENSIONS);
const DISABLE_AUTO_NAME_EXTENSION_PACK = new Setting(
"disableAutoNameExtensionPack",
@@ -724,10 +723,6 @@ export function showLlmGeneration(): boolean {
return !!LLM_GENERATION.getValue<boolean>();
}
export function useLlmGenerationV2(): boolean {
return !!LLM_GENERATION_V2.getValue<boolean>();
}
export function enableFrameworkMode(): boolean {
return !!FRAMEWORK_MODE.getValue<boolean>();
}

View File

@@ -1,34 +0,0 @@
import { Credentials } from "../common/authentication";
import { OctokitResponse } from "@octokit/types";
export enum AutomodelMode {
Unspecified = "AUTOMODEL_MODE_UNSPECIFIED",
Framework = "AUTOMODEL_MODE_FRAMEWORK",
Application = "AUTOMODEL_MODE_APPLICATION",
}
export interface ModelRequest {
mode: AutomodelMode;
// Base64-encoded GZIP-compressed SARIF log
candidates: string;
}
export interface ModelResponse {
models: string;
}
export async function autoModelV2(
credentials: Credentials,
request: ModelRequest,
): Promise<ModelResponse> {
const octokit = await credentials.getOctokit();
const response: OctokitResponse<ModelResponse> = await octokit.request(
"POST /repos/github/codeql/code-scanning/codeql/auto-model",
{
data: request,
},
);
return response.data;
}

View File

@@ -1,40 +1,20 @@
import { Credentials } from "../common/authentication";
import { OctokitResponse } from "@octokit/types";
export enum ClassificationType {
Unknown = "CLASSIFICATION_TYPE_UNKNOWN",
Neutral = "CLASSIFICATION_TYPE_NEUTRAL",
Source = "CLASSIFICATION_TYPE_SOURCE",
Sink = "CLASSIFICATION_TYPE_SINK",
Summary = "CLASSIFICATION_TYPE_SUMMARY",
}
export interface Classification {
type: ClassificationType;
kind: string;
explanation: string;
}
export interface Method {
package: string;
type: string;
name: string;
signature: string;
usages: string[];
classification?: Classification;
input?: string;
output?: string;
export enum AutomodelMode {
Unspecified = "AUTOMODEL_MODE_UNSPECIFIED",
Framework = "AUTOMODEL_MODE_FRAMEWORK",
Application = "AUTOMODEL_MODE_APPLICATION",
}
export interface ModelRequest {
language: string;
candidates: Method[];
samples: Method[];
mode: AutomodelMode;
// Base64-encoded GZIP-compressed SARIF log
candidates: string;
}
export interface ModelResponse {
language: string;
predicted?: Method[];
models: string;
}
export async function autoModel(

View File

@@ -1,140 +0,0 @@
import { CancellationTokenSource } from "vscode";
import { join } from "path";
import { runQuery } from "./external-api-usage-query";
import { CodeQLCliServer } from "../codeql-cli/cli";
import { QueryRunner } from "../query-server";
import { DatabaseItem } from "../databases/local-databases";
import { interpretResultsSarif } from "../query-results";
import { ProgressCallback } from "../common/vscode/progress";
import { Mode } from "./shared/mode";
type Options = {
cliServer: CodeQLCliServer;
queryRunner: QueryRunner;
databaseItem: DatabaseItem;
queryStorageDir: string;
queryDir: string;
progress: ProgressCallback;
};
export type UsageSnippetsBySignature = Record<string, string[]>;
export async function getAutoModelUsages({
cliServer,
queryRunner,
databaseItem,
queryStorageDir,
queryDir,
progress,
}: Options): Promise<UsageSnippetsBySignature> {
const maxStep = 1500;
const cancellationTokenSource = new CancellationTokenSource();
// This will re-run the query that was already run when opening the data extensions editor. This
// might be unnecessary, but this makes it really easy to get the path to the BQRS file which we
// need to interpret the results.
const queryResult = await runQuery(Mode.Application, {
cliServer,
queryRunner,
queryStorageDir,
databaseItem,
queryDir,
progress: (update) =>
progress({
maxStep,
step: update.step,
message: update.message,
}),
token: cancellationTokenSource.token,
});
if (!queryResult) {
throw new Error("Query failed");
}
progress({
maxStep,
step: 1100,
message: "Retrieving source location prefix",
});
// CodeQL needs to have access to the database to be able to retrieve the
// snippets from it. The source location prefix is used to determine the
// base path of the database.
const sourceLocationPrefix = await databaseItem.getSourceLocationPrefix(
cliServer,
);
const sourceArchiveUri = databaseItem.sourceArchive;
const sourceInfo =
sourceArchiveUri === undefined
? undefined
: {
sourceArchive: sourceArchiveUri.fsPath,
sourceLocationPrefix,
};
progress({
maxStep,
step: 1200,
message: "Interpreting results",
});
// Convert the results to SARIF so that Codeql will retrieve the snippets
// from the datababe. This means we don't need to do that in the extension
// and everything is handled by the CodeQL CLI.
const sarif = await interpretResultsSarif(
cliServer,
{
// To interpret the results we need to provide metadata about the query. We could do this using
// `resolveMetadata` but that would be an extra call to the CodeQL CLI server and would require
// us to know the path to the query on the filesystem. Since we know what the metadata should
// look like and the only metadata that the CodeQL CLI requires is an ID and the kind, we can
// simply use constants here.
kind: "problem",
id: "usage",
},
{
resultsPath: queryResult.outputDir.bqrsPath,
interpretedResultsPath: join(
queryStorageDir,
"interpreted-results.sarif",
),
},
sourceInfo,
["--sarif-add-snippets"],
);
progress({
maxStep,
step: 1400,
message: "Parsing results",
});
const snippets: UsageSnippetsBySignature = {};
const results = sarif.runs[0]?.results;
if (!results) {
throw new Error("No results");
}
// This will group the snippets by the method signature.
for (const result of results) {
const signature = result.message.text;
const snippet =
result.locations?.[0]?.physicalLocation?.contextRegion?.snippet?.text;
if (!signature || !snippet) {
continue;
}
if (!(signature in snippets)) {
snippets[signature] = [];
}
snippets[signature].push(snippet);
}
return snippets;
}

View File

@@ -1,89 +0,0 @@
import { AutomodelMode, ModelRequest } from "./auto-model-api-v2";
import { Mode } from "./shared/mode";
import { AutoModelQueriesResult } from "./auto-model-codeml-queries";
import { assertNever } from "../common/helpers-pure";
import * as Sarif from "sarif";
import { gzipEncode } from "../common/zlib";
import { ExternalApiUsage, MethodSignature } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
/**
* Return the candidates that the model should be run on. This includes limiting the number of
* candidates to the candidate limit and filtering out anything that is already modeled and respecting
* the order in the UI.
* @param mode Whether it is application or framework mode.
* @param externalApiUsages all external API usages.
* @param modeledMethods the currently modeled methods.
* @returns list of modeled methods that are candidates for modeling.
*/
export function getCandidates(
mode: Mode,
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): MethodSignature[] {
// Sort the same way as the UI so we send the first ones listed in the UI first
const grouped = groupMethods(externalApiUsages, mode);
const sortedGroupNames = sortGroupNames(grouped);
const sortedExternalApiUsages = sortedGroupNames.flatMap((name) =>
sortMethods(grouped[name]),
);
const candidates: MethodSignature[] = [];
for (const externalApiUsage of sortedExternalApiUsages) {
const modeledMethod: ModeledMethod = modeledMethods[
externalApiUsage.signature
] ?? {
type: "none",
};
// Anything that is modeled is not a candidate
if (modeledMethod.type !== "none") {
continue;
}
// A method that is supported is modeled outside of the model file, so it is not a candidate.
if (externalApiUsage.supported) {
continue;
}
// The rest are candidates
candidates.push(externalApiUsage);
}
return candidates;
}
/**
* Encode a SARIF log to the format expected by the server: JSON, GZIP-compressed, base64-encoded
* @param log SARIF log to encode
* @returns base64-encoded GZIP-compressed SARIF log
*/
export async function encodeSarif(log: Sarif.Log): Promise<string> {
const json = JSON.stringify(log);
const buffer = Buffer.from(json, "utf-8");
const compressed = await gzipEncode(buffer);
return compressed.toString("base64");
}
export async function createAutoModelV2Request(
mode: Mode,
result: AutoModelQueriesResult,
): Promise<ModelRequest> {
let requestMode: AutomodelMode;
switch (mode) {
case Mode.Application:
requestMode = AutomodelMode.Application;
break;
case Mode.Framework:
requestMode = AutomodelMode.Framework;
break;
default:
assertNever(mode);
}
return {
mode: requestMode,
candidates: await encodeSarif(result.candidates),
};
}

View File

@@ -1,34 +1,27 @@
import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod, ModeledMethodType } from "./modeled-method";
import {
Classification,
ClassificationType,
Method,
ModelRequest,
} from "./auto-model-api";
import type { UsageSnippetsBySignature } from "./auto-model-usages-query";
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
import { AutomodelMode, ModelRequest } from "./auto-model-api";
import { Mode } from "./shared/mode";
import { AutoModelQueriesResult } from "./auto-model-codeml-queries";
import { assertNever } from "../common/helpers-pure";
import * as Sarif from "sarif";
import { gzipEncode } from "../common/zlib";
import { ExternalApiUsage, MethodSignature } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
// Soft limit on the number of candidates to send to the model.
// Note that the model may return fewer than this number of candidates.
const candidateLimit = 20;
// Soft limit on the number of samples to send to the model.
const sampleLimit = 100;
export function createAutoModelRequest(
language: string,
/**
* Return the candidates that the model should be run on. This includes limiting the number of
* candidates to the candidate limit and filtering out anything that is already modeled and respecting
* the order in the UI.
* @param mode Whether it is application or framework mode.
* @param externalApiUsages all external API usages.
* @param modeledMethods the currently modeled methods.
* @returns list of modeled methods that are candidates for modeling.
*/
export function getCandidates(
mode: Mode,
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
usages: UsageSnippetsBySignature,
mode: Mode,
): ModelRequest {
const request: ModelRequest = {
language,
samples: [],
candidates: [],
};
): MethodSignature[] {
// Sort the same way as the UI so we send the first ones listed in the UI first
const grouped = groupMethods(externalApiUsages, mode);
const sortedGroupNames = sortGroupNames(grouped);
@@ -36,6 +29,8 @@ export function createAutoModelRequest(
sortMethods(grouped[name]),
);
const candidates: MethodSignature[] = [];
for (const externalApiUsage of sortedExternalApiUsages) {
const modeledMethod: ModeledMethod = modeledMethods[
externalApiUsage.signature
@@ -43,220 +38,52 @@ export function createAutoModelRequest(
type: "none",
};
const usagesForMethod =
usages[externalApiUsage.signature] ??
externalApiUsage.usages.map((usage) => usage.label);
const numberOfArguments =
externalApiUsage.methodParameters === "()"
? 0
: externalApiUsage.methodParameters.split(",").length;
const candidates: Method[] = [];
const samples: Method[] = [];
for (
let argumentIndex = -1; // Start at -1 which means `this` as in `this.method()`
argumentIndex < numberOfArguments;
argumentIndex++
) {
const argumentInput: string =
argumentIndex === -1 ? "Argument[this]" : `Argument[${argumentIndex}]`;
const method: Method = {
package: externalApiUsage.packageName,
type: externalApiUsage.typeName,
name: externalApiUsage.methodName,
signature: externalApiUsage.methodParameters,
classification:
modeledMethod.type === "none"
? undefined
: toMethodClassification(modeledMethod),
usages: usagesForMethod.slice(0, 6), // At most 6 usages per argument
input: argumentInput,
};
// A method that is supported is modeled outside of the model file, so it is not a candidate.
// We also do not want it as a sample because we do not know the classification.
if (modeledMethod.type === "none" && externalApiUsage.supported) {
continue;
}
// Candidates are methods that are not currently modeled
if (modeledMethod.type === "none") {
candidates.push(method);
} else {
samples.push(method);
}
}
// If there is room for at least one candidate, add all candidates.
// This ensures that we send all arguments for a method together.
// NOTE: this might go above the candidate limit, but that's okay.
if (request.candidates.length < candidateLimit) {
request.candidates.push(...candidates);
}
// Same for samples
if (request.samples.length < sampleLimit) {
request.samples.push(...samples);
}
}
return request;
}
/**
* For now, we have a simplified model that only models methods as sinks. It does not model methods as neutral,
* so we aren't actually able to correctly determine that a method is neutral; it could still be a source or summary.
* However, to keep this method simple and give output to the user, we will model any method for which none of its
* arguments are modeled as sinks as neutral.
*
* If there are multiple arguments which are modeled as sinks, we will only model the first one.
*/
export function parsePredictedClassifications(
predicted: Method[],
): Record<string, ModeledMethod> {
const predictedBySignature: Record<string, Method[]> = {};
for (const method of predicted) {
const signature = toFullMethodSignature(method);
if (!(signature in predictedBySignature)) {
predictedBySignature[signature] = [];
}
predictedBySignature[signature].push(method);
}
const modeledMethods: Record<string, ModeledMethod> = {};
for (const signature in predictedBySignature) {
const predictedMethods = predictedBySignature[signature];
const sinks = predictedMethods.filter(
(method) => method.classification?.type === ClassificationType.Sink,
);
if (sinks.length === 0) {
// For now, model any method for which none of its arguments are modeled as sinks as neutral
modeledMethods[signature] = {
type: "neutral",
kind: "summary",
input: "",
output: "",
provenance: "ai-generated",
signature,
// predictedBySignature[signature] always has at least element
packageName: predictedMethods[0].package,
typeName: predictedMethods[0].type,
methodName: predictedMethods[0].name,
methodParameters: predictedMethods[0].signature,
};
// Anything that is modeled is not a candidate
if (modeledMethod.type !== "none") {
continue;
}
// Order the sinks by the input alphabetically. This will ensure that the first argument is always
// first in the list of sinks, the second argument is always second, etc.
// If we get back "Argument[1]" and "Argument[3]", "Argument[1]" should always be first
sinks.sort((a, b) => compareInputOutput(a.input ?? "", b.input ?? ""));
// A method that is supported is modeled outside of the model file, so it is not a candidate.
if (externalApiUsage.supported) {
continue;
}
const sink = sinks[0];
modeledMethods[signature] = {
type: "sink",
kind: sink.classification?.kind ?? "",
input: sink.input ?? "",
output: sink.output ?? "",
provenance: "ai-generated",
signature,
packageName: sink.package,
typeName: sink.type,
methodName: sink.name,
methodParameters: sink.signature,
};
// The rest are candidates
candidates.push(externalApiUsage);
}
return modeledMethods;
return candidates;
}
function toMethodClassificationType(
type: ModeledMethodType,
): ClassificationType {
switch (type) {
case "source":
return ClassificationType.Source;
case "sink":
return ClassificationType.Sink;
case "summary":
return ClassificationType.Summary;
case "neutral":
return ClassificationType.Neutral;
default:
return ClassificationType.Unknown;
}
}
function toMethodClassification(modeledMethod: ModeledMethod): Classification {
return {
type: toMethodClassificationType(modeledMethod.type),
kind: modeledMethod.kind,
explanation: "",
};
}
function toFullMethodSignature(method: Method): string {
return `${method.package}.${method.type}#${method.name}${method.signature}`;
}
const argumentRegex = /^Argument\[(\d+)]$/;
// Argument[this] is before ReturnValue
const nonNumericArgumentOrder = ["Argument[this]", "ReturnValue"];
/**
* Compare two inputs or outputs matching `Argument[<number>]`, `Argument[this]`, or `ReturnValue`.
* If they are the same, return 0. If a is less than b, returns a negative number.
* If a is greater than b, returns a positive number.
* Encode a SARIF log to the format expected by the server: JSON, GZIP-compressed, base64-encoded
* @param log SARIF log to encode
* @returns base64-encoded GZIP-compressed SARIF log
*/
export function compareInputOutput(a: string, b: string): number {
if (a === b) {
return 0;
}
const aMatch = a.match(argumentRegex);
const bMatch = b.match(argumentRegex);
// Numeric arguments are always first
if (aMatch && !bMatch) {
return -1;
}
if (!aMatch && bMatch) {
return 1;
}
// Neither is an argument
if (!aMatch && !bMatch) {
const aIndex = nonNumericArgumentOrder.indexOf(a);
const bIndex = nonNumericArgumentOrder.indexOf(b);
// If either one is unknown, it is sorted last
if (aIndex === -1 && bIndex === -1) {
// Use en-US because these are well-known strings that are not localized
return a.localeCompare(b, "en-US");
}
if (aIndex === -1) {
return 1;
}
if (bIndex === -1) {
return -1;
}
return aIndex - bIndex;
}
// This case shouldn't happen, but makes TypeScript happy
if (!aMatch || !bMatch) {
return 0;
}
// Both are arguments
const aIndex = parseInt(aMatch[1]);
const bIndex = parseInt(bMatch[1]);
return aIndex - bIndex;
export async function encodeSarif(log: Sarif.Log): Promise<string> {
const json = JSON.stringify(log);
const buffer = Buffer.from(json, "utf-8");
const compressed = await gzipEncode(buffer);
return compressed.toString("base64");
}
export async function createAutoModelRequest(
mode: Mode,
result: AutoModelQueriesResult,
): Promise<ModelRequest> {
let requestMode: AutomodelMode;
switch (mode) {
case Mode.Application:
requestMode = AutomodelMode.Application;
break;
case Mode.Framework:
requestMode = AutomodelMode.Framework;
break;
default:
assertNever(mode);
}
return {
mode: requestMode,
candidates: await encodeSarif(result.candidates),
};
}

View File

@@ -3,10 +3,10 @@ import { ModeledMethod } from "./modeled-method";
import { extLogger } from "../common/logging/vscode";
import { load as loadYaml } from "js-yaml";
import { ProgressCallback, withProgress } from "../common/vscode/progress";
import { createAutoModelV2Request, getCandidates } from "./auto-model-v2";
import { createAutoModelRequest, getCandidates } from "./auto-model";
import { runAutoModelQueries } from "./auto-model-codeml-queries";
import { loadDataExtensionYaml } from "./yaml";
import { ModelRequest, ModelResponse, autoModelV2 } from "./auto-model-api-v2";
import { ModelRequest, ModelResponse, autoModel } from "./auto-model-api";
import { RequestError } from "@octokit/request-error";
import { showAndLogExceptionWithTelemetry } from "../common/logging";
import { redactableError } from "../common/errors";
@@ -183,7 +183,7 @@ export class AutoModeler {
return;
}
const request = await createAutoModelV2Request(mode, usages);
const request = await createAutoModelRequest(mode, usages);
void extLogger.log("Calling auto-model API");
@@ -229,7 +229,7 @@ export class AutoModeler {
request: ModelRequest,
): Promise<ModelResponse | null> {
try {
return await autoModelV2(this.app.credentials, request);
return await autoModel(this.app.credentials, request);
} catch (e) {
if (e instanceof RequestError && e.status === 429) {
void showAndLogExceptionWithTelemetry(

View File

@@ -5,7 +5,6 @@ import {
ViewColumn,
window,
} from "vscode";
import { RequestError } from "@octokit/request-error";
import {
AbstractWebview,
WebviewPanelConfig,
@@ -34,18 +33,11 @@ import { readQueryResults, runQuery } from "./external-api-usage-query";
import { ExternalApiUsage } from "./external-api-usage";
import { ModeledMethod } from "./modeled-method";
import { ExtensionPack } from "./shared/extension-pack";
import { autoModel, ModelRequest, ModelResponse } from "./auto-model-api";
import {
createAutoModelRequest,
parsePredictedClassifications,
} from "./auto-model";
import {
enableFrameworkMode,
showLlmGeneration,
showModelDetailsView,
useLlmGenerationV2,
} from "../config";
import { getAutoModelUsages } from "./auto-model-usages-query";
import { Mode } from "./shared/mode";
import { loadModeledMethods, saveModeledMethods } from "./modeled-method-fs";
import { join } from "path";
@@ -176,18 +168,11 @@ export class DataExtensionsEditorView extends AbstractWebview<
break;
case "generateExternalApiFromLlm":
if (useLlmGenerationV2()) {
await this.generateModeledMethodsFromLlmV2(
msg.packageName,
msg.externalApiUsages,
msg.modeledMethods,
);
} else {
await this.generateModeledMethodsFromLlmV1(
msg.externalApiUsages,
msg.modeledMethods,
);
}
await this.generateModeledMethodsFromLlm(
msg.packageName,
msg.externalApiUsages,
msg.modeledMethods,
);
break;
case "stopGeneratingExternalApiFromLlm":
await this.autoModeler.stopModeling(msg.packageName);
@@ -389,77 +374,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
);
}
private async generateModeledMethodsFromLlmV1(
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
): Promise<void> {
await withProgress(async (progress) => {
const maxStep = 3000;
progress({
step: 0,
maxStep,
message: "Retrieving usages",
});
const usages = await getAutoModelUsages({
cliServer: this.cliServer,
queryRunner: this.queryRunner,
queryStorageDir: this.queryStorageDir,
queryDir: this.queryDir,
databaseItem: this.databaseItem,
progress: (update) => progress({ ...update, maxStep }),
});
progress({
step: 1800,
maxStep,
message: "Creating request",
});
const request = createAutoModelRequest(
this.databaseItem.language,
externalApiUsages,
modeledMethods,
usages,
this.mode,
);
progress({
step: 2000,
maxStep,
message: "Sending request",
});
const response = await this.callAutoModelApi(request);
if (!response) {
return;
}
progress({
step: 2500,
maxStep,
message: "Parsing response",
});
const predictedModeledMethods = parsePredictedClassifications(
response.predicted || [],
);
progress({
step: 2800,
maxStep,
message: "Applying results",
});
await this.postMessage({
t: "addModeledMethods",
modeledMethods: predictedModeledMethods,
});
});
}
private async generateModeledMethodsFromLlmV2(
private async generateModeledMethodsFromLlm(
packageName: string,
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
@@ -580,23 +495,4 @@ export class DataExtensionsEditorView extends AbstractWebview<
return addedDatabase;
}
private async callAutoModelApi(
request: ModelRequest,
): Promise<ModelResponse | null> {
try {
return await autoModel(this.app.credentials, request);
} catch (e) {
if (e instanceof RequestError && e.status === 429) {
void showAndLogExceptionWithTelemetry(
this.app.logger,
this.app.telemetry,
redactableError(e)`Rate limit hit, please try again soon.`,
);
return null;
} else {
throw e;
}
}
}
}

View File

@@ -1,168 +0,0 @@
import {
createAutoModelV2Request,
encodeSarif,
getCandidates,
} from "../../../src/data-extensions-editor/auto-model-v2";
import { Mode } from "../../../src/data-extensions-editor/shared/mode";
import { AutomodelMode } from "../../../src/data-extensions-editor/auto-model-api-v2";
import { AutoModelQueriesResult } from "../../../src/data-extensions-editor/auto-model-codeml-queries";
import * as sarif from "sarif";
import { gzipDecode } from "../../../src/common/zlib";
import { ExternalApiUsage } from "../../../src/data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
describe("createAutoModelV2Request", () => {
const createSarifLog = (queryId: string): sarif.Log => {
return {
version: "2.1.0",
$schema: "http://json.schemastore.org/sarif-2.1.0-rtm.4",
runs: [
{
tool: {
driver: {
name: "CodeQL",
rules: [
{
id: queryId,
},
],
},
},
results: [
{
message: {
text: "msg",
},
locations: [
{
physicalLocation: {
contextRegion: {
startLine: 10,
endLine: 12,
snippet: {
text: "Foo",
},
},
region: {
startLine: 10,
startColumn: 1,
endColumn: 3,
},
artifactLocation: {
uri: "foo.js",
},
},
},
],
},
],
},
],
};
};
const result: AutoModelQueriesResult = {
candidates: createSarifLog(
"java/ml/extract-automodel-application-candidates",
),
};
it("creates a matching request", async () => {
expect(await createAutoModelV2Request(Mode.Application, result)).toEqual({
mode: AutomodelMode.Application,
candidates: await encodeSarif(result.candidates),
});
});
it("can decode the SARIF", async () => {
const request = await createAutoModelV2Request(Mode.Application, result);
const decoded = Buffer.from(request.candidates, "base64");
const decompressed = await gzipDecode(decoded);
const json = decompressed.toString("utf-8");
const parsed = JSON.parse(json);
expect(parsed).toEqual(result.candidates);
});
});
describe("getCandidates", () => {
it("doesn't return methods that are already modelled", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
library: "my.jar",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
supported: false,
supportedType: "none",
usages: [],
},
];
const modeledMethods: Record<string, ModeledMethod> = {
"org.my.A#x()": {
type: "neutral",
kind: "",
input: "",
output: "",
provenance: "manual",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
},
};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(0);
});
it("doesn't return methods that are supported from other sources", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
library: "my.jar",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
supported: true,
supportedType: "none",
usages: [],
},
];
const modeledMethods = {};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(0);
});
it("returns methods that are neither modeled nor supported from other sources", () => {
const externalApiUsages: ExternalApiUsage[] = [];
externalApiUsages.push({
library: "my.jar",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
supported: false,
supportedType: "none",
usages: [],
});
const modeledMethods = {};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(1);
});
});

View File

@@ -1,632 +1,168 @@
import {
compareInputOutput,
createAutoModelRequest,
parsePredictedClassifications,
encodeSarif,
getCandidates,
} from "../../../src/data-extensions-editor/auto-model";
import {
CallClassification,
ExternalApiUsage,
} from "../../../src/data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
import {
ClassificationType,
Method,
} from "../../../src/data-extensions-editor/auto-model-api";
import { Mode } from "../../../src/data-extensions-editor/shared/mode";
import { AutomodelMode } from "../../../src/data-extensions-editor/auto-model-api";
import { AutoModelQueriesResult } from "../../../src/data-extensions-editor/auto-model-codeml-queries";
import * as sarif from "sarif";
import { gzipDecode } from "../../../src/common/zlib";
import { ExternalApiUsage } from "../../../src/data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
describe("createAutoModelRequest", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
library: "spring-boot-3.0.2.jar",
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
supported: false,
supportedType: "none",
usages: [
const createSarifLog = (queryId: string): sarif.Log => {
return {
version: "2.1.0",
$schema: "http://json.schemastore.org/sarif-2.1.0-rtm.4",
runs: [
{
label: "run(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/Sql2oExampleApplication.java",
startLine: 9,
startColumn: 9,
endLine: 9,
endColumn: 66,
tool: {
driver: {
name: "CodeQL",
rules: [
{
id: queryId,
},
],
},
},
classification: CallClassification.Source,
results: [
{
message: {
text: "msg",
},
locations: [
{
physicalLocation: {
contextRegion: {
startLine: 10,
endLine: 12,
snippet: {
text: "Foo",
},
},
region: {
startLine: 10,
startColumn: 1,
endColumn: 3,
},
artifactLocation: {
uri: "foo.js",
},
},
},
],
},
],
},
],
},
{
library: "sql2o-1.6.0.jar",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
supported: false,
supportedType: "none",
usages: [
{
label: "createQuery(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 15,
startColumn: 13,
endLine: 15,
endColumn: 56,
},
classification: CallClassification.Source,
},
{
label: "createQuery(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 26,
startColumn: 13,
endLine: 26,
endColumn: 39,
},
classification: CallClassification.Source,
},
],
},
{
library: "sql2o-1.6.0.jar",
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
supported: false,
supportedType: "none",
usages: [
{
label: "executeScalar(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 15,
startColumn: 13,
endLine: 15,
endColumn: 85,
},
classification: CallClassification.Source,
},
{
label: "executeScalar(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 26,
startColumn: 13,
endLine: 26,
endColumn: 68,
},
classification: CallClassification.Source,
},
],
},
{
library: "sql2o-1.6.0.jar",
signature: "org.sql2o.Sql2o#open()",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "open",
methodParameters: "()",
supported: false,
supportedType: "none",
usages: [
{
label: "open(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 14,
startColumn: 24,
endLine: 14,
endColumn: 35,
},
classification: CallClassification.Source,
},
{
label: "open(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 25,
startColumn: 24,
endLine: 25,
endColumn: 35,
},
classification: CallClassification.Source,
},
],
},
{
library: "rt.jar",
signature: "java.io.PrintStream#println(String)",
packageName: "java.io",
typeName: "PrintStream",
methodName: "println",
methodParameters: "(String)",
supported: false,
supportedType: "none",
usages: [
{
label: "println(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 29,
startColumn: 9,
endLine: 29,
endColumn: 49,
},
classification: CallClassification.Source,
},
],
},
{
library: "sql2o-1.6.0.jar",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
supported: false,
supportedType: "none",
usages: [
{
label: "new Sql2o(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 10,
startColumn: 33,
endLine: 10,
endColumn: 88,
},
classification: CallClassification.Source,
},
],
},
{
library: "sql2o-1.6.0.jar",
signature: "org.sql2o.Sql2o#Sql2o(String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String)",
supported: false,
supportedType: "none",
usages: [
{
label: "new Sql2o(...)",
url: {
uri: "file:/home/runner/work/sql2o-example/sql2o-example/src/main/java/org/example/HelloController.java",
startLine: 23,
startColumn: 23,
endLine: 23,
endColumn: 36,
},
classification: CallClassification.Source,
},
],
},
{
library: "test.jar",
signature: "org.test.MyClass#test()",
packageName: "org.test",
typeName: "MyClass",
methodName: "test",
methodParameters: "()",
supported: true,
supportedType: "neutral",
usages: [
{
label: "abc.test(...)",
url: {
uri: "file:/home/runner/work/test/Test.java",
startLine: 23,
startColumn: 23,
endLine: 23,
endColumn: 36,
},
classification: CallClassification.Source,
},
],
},
];
const modeledMethods: Record<string, ModeledMethod> = {
"org.sql2o.Sql2o#open()": {
type: "neutral",
kind: "",
input: "",
output: "",
provenance: "manual",
signature: "org.sql2o.Sql2o#open()",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "open",
methodParameters: "()",
},
"org.sql2o.Sql2o#Sql2o(String)": {
type: "sink",
kind: "jndi-injection",
input: "Argument[0]",
output: "",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String)",
},
};
};
const usages: Record<string, string[]> = {
"org.springframework.boot.SpringApplication#run(Class,String[])": [
"public class Sql2oExampleApplication {\n public static void main(String[] args) {\n SpringApplication.run(Sql2oExampleApplication.class, args);\n }\n}",
],
"org.sql2o.Connection#createQuery(String)": [
' public String index(@RequestParam("id") String id) {\n try (var con = sql2o.open()) {\n con.createQuery("select 1 where id = " + id).executeScalar(Integer.class);\n }\n\n',
'\n try (var con = sql2o.open()) {\n con.createQuery("select 1").executeScalar(Integer.class);\n }\n\n',
],
"org.sql2o.Query#executeScalar(Class)": [
' public String index(@RequestParam("id") String id) {\n try (var con = sql2o.open()) {\n con.createQuery("select 1 where id = " + id).executeScalar(Integer.class);\n }\n\n',
'\n try (var con = sql2o.open()) {\n con.createQuery("select 1").executeScalar(Integer.class);\n }\n\n',
],
"org.sql2o.Sql2o#open()": [
' @GetMapping("/")\n public String index(@RequestParam("id") String id) {\n try (var con = sql2o.open()) {\n con.createQuery("select 1 where id = " + id).executeScalar(Integer.class);\n }\n',
' Sql2o sql2o = new Sql2o(url);\n\n try (var con = sql2o.open()) {\n con.createQuery("select 1").executeScalar(Integer.class);\n }\n',
],
"java.io.PrintStream#println(String)": [
' }\n\n System.out.println("Connected to " + url);\n\n return "Greetings from Spring Boot!";\n',
],
"org.sql2o.Sql2o#Sql2o(String,String,String)": [
'@RestController\npublic class HelloController {\n private final Sql2o sql2o = new Sql2o("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1","sa", "");\n\n @GetMapping("/")\n',
],
"org.sql2o.Sql2o#Sql2o(String)": [
' @GetMapping("/connect")\n public String connect(@RequestParam("url") String url) {\n Sql2o sql2o = new Sql2o(url);\n\n try (var con = sql2o.open()) {\n',
],
const result: AutoModelQueriesResult = {
candidates: createSarifLog(
"java/ml/extract-automodel-application-candidates",
),
};
it("creates a matching request", () => {
expect(
createAutoModelRequest(
"java",
externalApiUsages,
modeledMethods,
usages,
Mode.Application,
),
).toEqual({
language: "java",
samples: [
{
package: "org.sql2o",
type: "Sql2o",
name: "open",
signature: "()",
classification: {
type: "CLASSIFICATION_TYPE_NEUTRAL",
kind: "",
explanation: "",
},
usages: usages["org.sql2o.Sql2o#open()"],
input: "Argument[this]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String)",
classification: {
type: "CLASSIFICATION_TYPE_SINK",
kind: "jndi-injection",
explanation: "",
},
usages: usages["org.sql2o.Sql2o#Sql2o(String)"],
input: "Argument[this]",
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String)",
classification: {
type: "CLASSIFICATION_TYPE_SINK",
kind: "jndi-injection",
explanation: "",
},
usages: usages["org.sql2o.Sql2o#Sql2o(String)"],
input: "Argument[0]",
},
],
candidates: [
{
package: "org.sql2o",
type: "Connection",
name: "createQuery",
signature: "(String)",
usages: usages["org.sql2o.Connection#createQuery(String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Connection",
name: "createQuery",
signature: "(String)",
usages: usages["org.sql2o.Connection#createQuery(String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Query",
name: "executeScalar",
signature: "(Class)",
usages: usages["org.sql2o.Query#executeScalar(Class)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Query",
name: "executeScalar",
signature: "(Class)",
usages: usages["org.sql2o.Query#executeScalar(Class)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: usages["org.sql2o.Sql2o#Sql2o(String,String,String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: usages["org.sql2o.Sql2o#Sql2o(String,String,String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: usages["org.sql2o.Sql2o#Sql2o(String,String,String)"],
input: "Argument[1]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: usages["org.sql2o.Sql2o#Sql2o(String,String,String)"],
input: "Argument[2]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[1]",
classification: undefined,
},
],
it("creates a matching request", async () => {
expect(await createAutoModelRequest(Mode.Application, result)).toEqual({
mode: AutomodelMode.Application,
candidates: await encodeSarif(result.candidates),
});
});
it("can decode the SARIF", async () => {
const request = await createAutoModelRequest(Mode.Application, result);
const decoded = Buffer.from(request.candidates, "base64");
const decompressed = await gzipDecode(decoded);
const json = decompressed.toString("utf-8");
const parsed = JSON.parse(json);
expect(parsed).toEqual(result.candidates);
});
});
describe("parsePredictedClassifications", () => {
const predictions: Method[] = [
{
package: "org.sql2o",
type: "Sql2o",
name: "createQuery",
signature: "(String)",
usages: ["createQuery(...)", "createQuery(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
describe("getCandidates", () => {
it("doesn't return methods that are already modelled", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
library: "my.jar",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
supported: false,
supportedType: "none",
usages: [],
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "executeScalar",
signature: "(Class)",
usages: ["executeScalar(...)", "executeScalar(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Neutral,
kind: "",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Neutral,
kind: "",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[1]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[2]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
},
},
];
it("correctly parses the output", () => {
expect(parsePredictedClassifications(predictions)).toEqual({
"org.sql2o.Sql2o#createQuery(String)": {
type: "sink",
kind: "sql injection sink",
input: "Argument[0]",
output: "",
provenance: "ai-generated",
signature: "org.sql2o.Sql2o#createQuery(String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Sql2o#executeScalar(Class)": {
];
const modeledMethods: Record<string, ModeledMethod> = {
"org.my.A#x()": {
type: "neutral",
kind: "summary",
kind: "",
input: "",
output: "",
provenance: "ai-generated",
signature: "org.sql2o.Sql2o#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "executeScalar",
methodParameters: "(Class)",
provenance: "manual",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
kind: "sql injection sink",
input: "Argument[1]",
output: "",
provenance: "ai-generated",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
});
});
});
describe("compareInputOutput", () => {
it("with two small numeric arguments", () => {
expect(
compareInputOutput("Argument[0]", "Argument[1]"),
).toBeLessThanOrEqual(-1);
});
it("with one larger non-alphabetic argument", () => {
expect(
compareInputOutput("Argument[10]", "Argument[2]"),
).toBeGreaterThanOrEqual(1);
});
it("with one non-numeric arguments", () => {
expect(
compareInputOutput("Argument[5]", "Argument[this]"),
).toBeLessThanOrEqual(-1);
});
it("with two non-numeric arguments", () => {
expect(
compareInputOutput("ReturnValue", "Argument[this]"),
).toBeGreaterThanOrEqual(1);
});
it("with one unknown argument in the a position", () => {
expect(
compareInputOutput("FooBar", "Argument[this]"),
).toBeGreaterThanOrEqual(1);
});
it("with one unknown argument in the b position", () => {
expect(compareInputOutput("Argument[this]", "FooBar")).toBeLessThanOrEqual(
-1,
};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(0);
});
it("with one empty string arguments", () => {
expect(compareInputOutput("Argument[5]", "")).toBeLessThanOrEqual(-1);
it("doesn't return methods that are supported from other sources", () => {
const externalApiUsages: ExternalApiUsage[] = [
{
library: "my.jar",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
supported: true,
supportedType: "none",
usages: [],
},
];
const modeledMethods = {};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(0);
});
it("with two unknown arguments", () => {
expect(compareInputOutput("FooBar", "BarFoo")).toBeGreaterThanOrEqual(1);
it("returns methods that are neither modeled nor supported from other sources", () => {
const externalApiUsages: ExternalApiUsage[] = [];
externalApiUsages.push({
library: "my.jar",
signature: "org.my.A#x()",
packageName: "org.my",
typeName: "A",
methodName: "x",
methodParameters: "()",
supported: false,
supportedType: "none",
usages: [],
});
const modeledMethods = {};
const candidates = getCandidates(
Mode.Application,
externalApiUsages,
modeledMethods,
);
expect(candidates.length).toEqual(1);
});
});