Improve parsing of predicted classifications

This commit is contained in:
Koen Vlaswinkel
2023-05-26 16:18:59 +02:00
parent 4b54e4f31f
commit f52ad04afe
3 changed files with 163 additions and 32 deletions

View File

@@ -68,6 +68,61 @@ export function createAutoModelRequest(
return request;
}
export function parsePredictedClassifications(
predicted: Method[],
): Record<string, ModeledMethod> {
const predictedBySignature: Record<string, Method[]> = {};
for (const method of predicted) {
if (!method.classification) {
continue;
}
const signature = toFullMethodSignature(method);
if (!(signature in predictedBySignature)) {
predictedBySignature[signature] = [];
}
predictedBySignature[signature].push(method);
}
const modeledMethods: Record<string, ModeledMethod> = {};
for (const signature in predictedBySignature) {
const predictedMethods = predictedBySignature[signature];
const sinks = predictedMethods.filter(
(method) => method.classification?.type === ClassificationType.Sink,
);
if (sinks.length === 0) {
// For now, model any method for which none of its arguments are modeled as sinks as neutral
modeledMethods[signature] = {
type: "neutral",
kind: "",
input: "",
output: "",
};
continue;
}
// Order the sinks by the input alphabetically. This will ensure that the first argument is always
// first in the list of sinks, the second argument is always second, etc.
// If we get back "Argument[1]" and "Argument[3]", "Argument[1]" should always be first
sinks.sort((a, b) => (a.input ?? "").localeCompare(b.input ?? ""));
const sink = sinks[0];
modeledMethods[signature] = {
type: "sink",
kind: sink.classification?.kind ?? "",
input: sink.input ?? "",
output: sink.output ?? "",
};
}
return modeledMethods;
}
function toMethodClassificationType(
type: ModeledMethodType,
): ClassificationType {
@@ -93,19 +148,6 @@ function toMethodClassification(modeledMethod: ModeledMethod): Classification {
};
}
export function classificationTypeToModeledMethodType(
type: ClassificationType,
): ModeledMethodType {
switch (type) {
case ClassificationType.Source:
return "source";
case ClassificationType.Sink:
return "sink";
case ClassificationType.Summary:
return "summary";
case ClassificationType.Neutral:
return "neutral";
default:
return "none";
}
function toFullMethodSignature(method: Method): string {
return `${method.package}.${method.type}.${method.name}${method.signature}`;
}

View File

@@ -41,8 +41,8 @@ import { ModeledMethod } from "./modeled-method";
import { ExtensionPackModelFile } from "./shared/extension-pack";
import { autoModel } from "./auto-model-api";
import {
classificationTypeToModeledMethodType,
createAutoModelRequest,
parsePredictedClassifications,
} from "./auto-model";
import { showLlmGeneration } from "../config";
@@ -393,24 +393,13 @@ export class DataExtensionsEditorView extends AbstractWebview<
const response = await autoModel(this.app.credentials, request);
const modeledMethodsByName: Record<string, ModeledMethod> = {};
for (const method of response.predicted) {
if (method.classification === undefined) {
continue;
}
modeledMethodsByName[method.signature] = {
type: classificationTypeToModeledMethodType(method.classification.type),
kind: method.classification.kind,
input: method.input ?? "",
output: method.output ?? "",
};
}
const predictedModeledMethods = parsePredictedClassifications(
response.predicted,
);
await this.postMessage({
t: "addModeledMethods",
modeledMethods: modeledMethodsByName,
modeledMethods: predictedModeledMethods,
overrideNone: true,
});
}

View File

@@ -1,6 +1,13 @@
import { createAutoModelRequest } from "../../../src/data-extensions-editor/auto-model";
import {
createAutoModelRequest,
parsePredictedClassifications,
} from "../../../src/data-extensions-editor/auto-model";
import { ExternalApiUsage } from "../../../src/data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
import {
ClassificationType,
Method,
} from "../../../src/data-extensions-editor/auto-model-api";
describe("createAutoModelRequest", () => {
const externalApiUsages: ExternalApiUsage[] = [
@@ -281,3 +288,96 @@ describe("createAutoModelRequest", () => {
});
});
});
describe("parsePredictedClassifications", () => {
const predictions: Method[] = [
{
package: "org.sql2o",
type: "Sql2o",
name: "createQuery",
signature: "(String)",
usages: ["createQuery(...)", "createQuery(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "executeScalar",
signature: "(Class)",
usages: ["executeScalar(...)", "executeScalar(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Neutral,
kind: "",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[0]",
classification: {
type: ClassificationType.Neutral,
kind: "",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[1]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "not a sink",
},
},
{
package: "org.sql2o",
type: "Sql2o",
name: "Sql2o",
signature: "(String,String,String)",
usages: ["new Sql2o(...)"],
input: "Argument[2]",
classification: {
type: ClassificationType.Sink,
kind: "sql injection sink",
explanation: "not a sink",
},
},
];
it("correctly parses the output", () => {
expect(parsePredictedClassifications(predictions)).toEqual({
"org.sql2o.Sql2o.createQuery(String)": {
type: "sink",
kind: "sql injection sink",
input: "Argument[0]",
output: "",
},
"org.sql2o.Sql2o.executeScalar(Class)": {
type: "neutral",
kind: "",
input: "",
output: "",
},
"org.sql2o.Sql2o.Sql2o(String,String,String)": {
type: "sink",
kind: "sql injection sink",
input: "Argument[1]",
output: "",
},
});
});
});