Merge pull request #2904 from github/koesie10/convert-yaml-modeled-methods

Convert `yaml.ts` to handle multiple models per method
This commit is contained in:
Koen Vlaswinkel
2023-10-04 14:01:02 +02:00
committed by GitHub
5 changed files with 300 additions and 213 deletions

View File

@@ -16,6 +16,7 @@ import { QueryRunner } from "../query-server";
import { DatabaseItem } from "../databases/local-databases"; import { DatabaseItem } from "../databases/local-databases";
import { Mode } from "./shared/mode"; import { Mode } from "./shared/mode";
import { CancellationTokenSource } from "vscode"; import { CancellationTokenSource } from "vscode";
import { convertToLegacyModeledMethods } from "./modeled-methods-legacy";
// Limit the number of candidates we send to the model in each request // Limit the number of candidates we send to the model in each request
// to avoid long requests. // to avoid long requests.
@@ -192,11 +193,13 @@ export class AutoModeler {
filename: "auto-model.yml", filename: "auto-model.yml",
}); });
const loadedMethods = loadDataExtensionYaml(models); const rawLoadedMethods = loadDataExtensionYaml(models);
if (!loadedMethods) { if (!rawLoadedMethods) {
return; return;
} }
const loadedMethods = convertToLegacyModeledMethods(rawLoadedMethods);
// Any candidate that was part of the response is a negative result // Any candidate that was part of the response is a negative result
// meaning that the canidate is not a sink for the kinds that the LLM is checking for. // meaning that the canidate is not a sink for the kinds that the LLM is checking for.
// For now we model this as a sink neutral method, however this is subject // For now we model this as a sink neutral method, however this is subject

View File

@@ -10,6 +10,11 @@ import { getOnDiskWorkspaceFolders } from "../common/vscode/workspace-folders";
import { load as loadYaml } from "js-yaml"; import { load as loadYaml } from "js-yaml";
import { CodeQLCliServer } from "../codeql-cli/cli"; import { CodeQLCliServer } from "../codeql-cli/cli";
import { pathsEqual } from "../common/files"; import { pathsEqual } from "../common/files";
import {
convertFromLegacyModeledMethods,
convertFromLegacyModeledMethodsFiles,
convertToLegacyModeledMethods,
} from "./modeled-methods-legacy";
export async function saveModeledMethods( export async function saveModeledMethods(
extensionPack: ExtensionPack, extensionPack: ExtensionPack,
@@ -29,8 +34,8 @@ export async function saveModeledMethods(
const yamls = createDataExtensionYamls( const yamls = createDataExtensionYamls(
language, language,
methods, methods,
modeledMethods, convertFromLegacyModeledMethods(modeledMethods),
existingModeledMethods, convertFromLegacyModeledMethodsFiles(existingModeledMethods),
mode, mode,
); );
@@ -68,7 +73,8 @@ async function loadModeledMethodFiles(
); );
continue; continue;
} }
modeledMethodsByFile[modelFile] = modeledMethods; modeledMethodsByFile[modelFile] =
convertToLegacyModeledMethods(modeledMethods);
} }
return modeledMethodsByFile; return modeledMethodsByFile;

View File

@@ -0,0 +1,33 @@
import { ModeledMethod } from "./modeled-method";
export function convertFromLegacyModeledMethods(
modeledMethods: Record<string, ModeledMethod>,
): Record<string, ModeledMethod[]> {
// Convert a single ModeledMethod to an array of ModeledMethods
return Object.fromEntries(
Object.entries(modeledMethods).map(([signature, modeledMethod]) => {
return [signature, [modeledMethod]];
}),
);
}
export function convertToLegacyModeledMethods(
modeledMethods: Record<string, ModeledMethod[]>,
): Record<string, ModeledMethod> {
// Always take the first modeled method in the array
return Object.fromEntries(
Object.entries(modeledMethods).map(([signature, modeledMethods]) => {
return [signature, modeledMethods[0]];
}),
);
}
export function convertFromLegacyModeledMethodsFiles(
modeledMethods: Record<string, Record<string, ModeledMethod>>,
): Record<string, Record<string, ModeledMethod[]>> {
return Object.fromEntries(
Object.entries(modeledMethods).map(([filename, modeledMethods]) => {
return [filename, convertFromLegacyModeledMethods(modeledMethods)];
}),
);
}

View File

@@ -71,8 +71,8 @@ ${extensions.join("\n")}`;
export function createDataExtensionYamls( export function createDataExtensionYamls(
language: string, language: string,
methods: Method[], methods: Method[],
newModeledMethods: Record<string, ModeledMethod>, newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>, existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
mode: Mode, mode: Mode,
) { ) {
switch (mode) { switch (mode) {
@@ -98,11 +98,11 @@ export function createDataExtensionYamls(
function createDataExtensionYamlsByGrouping( function createDataExtensionYamlsByGrouping(
language: string, language: string,
methods: Method[], methods: Method[],
newModeledMethods: Record<string, ModeledMethod>, newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>, existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
createFilename: (method: Method) => string, createFilename: (method: Method) => string,
): Record<string, string> { ): Record<string, string> {
const methodsByFilename: Record<string, Record<string, ModeledMethod>> = {}; const methodsByFilename: Record<string, Record<string, ModeledMethod[]>> = {};
// We only want to generate a yaml file when it's a known external API usage // We only want to generate a yaml file when it's a known external API usage
// and there are new modeled methods for it. This avoids us overwriting other // and there are new modeled methods for it. This avoids us overwriting other
@@ -114,10 +114,12 @@ function createDataExtensionYamlsByGrouping(
} }
// First populate methodsByFilename with any existing modeled methods. // First populate methodsByFilename with any existing modeled methods.
for (const [filename, methods] of Object.entries(existingModeledMethods)) { for (const [filename, methodsBySignature] of Object.entries(
existingModeledMethods,
)) {
if (filename in methodsByFilename) { if (filename in methodsByFilename) {
for (const [signature, method] of Object.entries(methods)) { for (const [signature, methods] of Object.entries(methodsBySignature)) {
methodsByFilename[filename][signature] = method; methodsByFilename[filename][signature] = methods;
} }
} }
} }
@@ -125,10 +127,12 @@ function createDataExtensionYamlsByGrouping(
// Add the new modeled methods, potentially overwriting existing modeled methods // Add the new modeled methods, potentially overwriting existing modeled methods
// but not removing existing modeled methods that are not in the new set. // but not removing existing modeled methods that are not in the new set.
for (const method of methods) { for (const method of methods) {
const newMethod = newModeledMethods[method.signature]; const newMethods = newModeledMethods[method.signature];
if (newMethod) { if (newMethods) {
const filename = createFilename(method); const filename = createFilename(method);
methodsByFilename[filename][newMethod.signature] = newMethod;
// Override any existing modeled methods with the new ones.
methodsByFilename[filename][method.signature] = newMethods;
} }
} }
@@ -137,7 +141,7 @@ function createDataExtensionYamlsByGrouping(
for (const [filename, methods] of Object.entries(methodsByFilename)) { for (const [filename, methods] of Object.entries(methodsByFilename)) {
result[filename] = createDataExtensionYaml( result[filename] = createDataExtensionYaml(
language, language,
Object.values(methods), Object.values(methods).flatMap((methods) => methods),
); );
} }
@@ -147,8 +151,8 @@ function createDataExtensionYamlsByGrouping(
export function createDataExtensionYamlsForApplicationMode( export function createDataExtensionYamlsForApplicationMode(
language: string, language: string,
methods: Method[], methods: Method[],
newModeledMethods: Record<string, ModeledMethod>, newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>, existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
): Record<string, string> { ): Record<string, string> {
return createDataExtensionYamlsByGrouping( return createDataExtensionYamlsByGrouping(
language, language,
@@ -162,8 +166,8 @@ export function createDataExtensionYamlsForApplicationMode(
export function createDataExtensionYamlsForFrameworkMode( export function createDataExtensionYamlsForFrameworkMode(
language: string, language: string,
methods: Method[], methods: Method[],
newModeledMethods: Record<string, ModeledMethod>, newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>, existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
): Record<string, string> { ): Record<string, string> {
return createDataExtensionYamlsByGrouping( return createDataExtensionYamlsByGrouping(
language, language,
@@ -228,14 +232,14 @@ function validateModelExtensionFile(data: unknown): data is ModelExtensionFile {
export function loadDataExtensionYaml( export function loadDataExtensionYaml(
data: unknown, data: unknown,
): Record<string, ModeledMethod> | undefined { ): Record<string, ModeledMethod[]> | undefined {
if (!validateModelExtensionFile(data)) { if (!validateModelExtensionFile(data)) {
return undefined; return undefined;
} }
const extensions = data.extensions; const extensions = data.extensions;
const modeledMethods: Record<string, ModeledMethod> = {}; const modeledMethods: Record<string, ModeledMethod[]> = {};
for (const extension of extensions) { for (const extension of extensions) {
const addsTo = extension.addsTo; const addsTo = extension.addsTo;
@@ -250,11 +254,16 @@ export function loadDataExtensionYaml(
} }
for (const row of data) { for (const row of data) {
const modeledMethod = definition.readModeledMethod(row); const modeledMethod: ModeledMethod = definition.readModeledMethod(row);
if (!modeledMethod) { if (!modeledMethod) {
continue; continue;
} }
modeledMethods[modeledMethod.signature] = modeledMethod;
if (!(modeledMethod.signature in modeledMethods)) {
modeledMethods[modeledMethod.signature] = [];
}
modeledMethods[modeledMethod.signature].push(modeledMethod);
} }
} }

View File

@@ -225,43 +225,49 @@ describe("createDataExtensionYamlsForApplicationMode", () => {
}, },
], ],
{ {
"org.sql2o.Connection#createQuery(String)": { "org.sql2o.Connection#createQuery(String)": [
type: "sink", {
input: "Argument[0]", type: "sink",
output: "", input: "Argument[0]",
kind: "sql", output: "",
provenance: "df-generated", kind: "sql",
signature: "org.sql2o.Connection#createQuery(String)", provenance: "df-generated",
packageName: "org.sql2o", signature: "org.sql2o.Connection#createQuery(String)",
typeName: "Connection", packageName: "org.sql2o",
methodName: "createQuery", typeName: "Connection",
methodParameters: "(String)", methodName: "createQuery",
}, methodParameters: "(String)",
"org.springframework.boot.SpringApplication#run(Class,String[])": { },
type: "neutral", ],
input: "", "org.springframework.boot.SpringApplication#run(Class,String[])": [
output: "", {
kind: "summary", type: "neutral",
provenance: "manual", input: "",
signature: output: "",
"org.springframework.boot.SpringApplication#run(Class,String[])", kind: "summary",
packageName: "org.springframework.boot", provenance: "manual",
typeName: "SpringApplication", signature:
methodName: "run", "org.springframework.boot.SpringApplication#run(Class,String[])",
methodParameters: "(Class,String[])", packageName: "org.springframework.boot",
}, typeName: "SpringApplication",
"org.sql2o.Sql2o#Sql2o(String,String,String)": { methodName: "run",
type: "sink", methodParameters: "(Class,String[])",
input: "Argument[0]", },
output: "", ],
kind: "jndi", "org.sql2o.Sql2o#Sql2o(String,String,String)": [
provenance: "manual", {
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", type: "sink",
packageName: "org.sql2o", input: "Argument[0]",
typeName: "Sql2o", output: "",
methodName: "Sql2o", kind: "jndi",
methodParameters: "(String,String,String)", provenance: "manual",
}, signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
}, },
{}, {},
); );
@@ -463,84 +469,97 @@ describe("createDataExtensionYamlsForApplicationMode", () => {
}, },
], ],
{ {
"org.sql2o.Connection#createQuery(String)": { "org.sql2o.Connection#createQuery(String)": [
type: "sink", {
input: "Argument[0]", type: "sink",
output: "", input: "Argument[0]",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.springframework.boot.SpringApplication#run(Class,String[])": {
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
},
{
"models/sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": {
type: "neutral",
input: "",
output: "", output: "",
kind: "summary", kind: "sql",
provenance: "manual", provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)", signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o", packageName: "org.sql2o",
typeName: "Connection", typeName: "Connection",
methodName: "createQuery", methodName: "createQuery",
methodParameters: "(String)", methodParameters: "(String)",
}, },
"org.sql2o.Query#executeScalar(Class)": { ],
"org.springframework.boot.SpringApplication#run(Class,String[])": [
{
type: "neutral", type: "neutral",
input: "", input: "",
output: "", output: "",
kind: "summary", kind: "summary",
provenance: "manual", provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)", signature:
packageName: "org.sql2o", "org.springframework.boot.SpringApplication#run(Class,String[])",
typeName: "Query", packageName: "org.springframework.boot",
methodName: "executeScalar", typeName: "SpringApplication",
methodParameters: "(Class)", methodName: "run",
methodParameters: "(Class,String[])",
}, },
],
"org.sql2o.Sql2o#Sql2o(String,String,String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
},
{
"models/sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
"org.sql2o.Query#executeScalar(Class)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
},
],
}, },
"models/gson.model.yml": { "models/gson.model.yml": {
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": { "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": [
type: "summary", {
input: "Argument[this]", type: "summary",
output: "ReturnValue", input: "Argument[this]",
kind: "taint", output: "ReturnValue",
provenance: "df-generated", kind: "taint",
signature: "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)", provenance: "df-generated",
packageName: "com.google.gson", signature:
typeName: "TypeAdapter", "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)",
methodName: "fromJsonTree", packageName: "com.google.gson",
methodParameters: "(JsonElement)", typeName: "TypeAdapter",
}, methodName: "fromJsonTree",
methodParameters: "(JsonElement)",
},
],
}, },
}, },
); );
@@ -694,30 +713,34 @@ describe("createDataExtensionYamlsForFrameworkMode", () => {
}, },
], ],
{ {
"org.sql2o.Connection#createQuery(String)": { "org.sql2o.Connection#createQuery(String)": [
type: "sink", {
input: "Argument[0]", type: "sink",
output: "", input: "Argument[0]",
kind: "sql", output: "",
provenance: "df-generated", kind: "sql",
signature: "org.sql2o.Connection#createQuery(String)", provenance: "df-generated",
packageName: "org.sql2o", signature: "org.sql2o.Connection#createQuery(String)",
typeName: "Connection", packageName: "org.sql2o",
methodName: "createQuery", typeName: "Connection",
methodParameters: "(String)", methodName: "createQuery",
}, methodParameters: "(String)",
"org.sql2o.Sql2o#Sql2o(String,String,String)": { },
type: "sink", ],
input: "Argument[0]", "org.sql2o.Sql2o#Sql2o(String,String,String)": [
output: "", {
kind: "jndi", type: "sink",
provenance: "manual", input: "Argument[0]",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", output: "",
packageName: "org.sql2o", kind: "jndi",
typeName: "Sql2o", provenance: "manual",
methodName: "Sql2o", signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
methodParameters: "(String,String,String)", packageName: "org.sql2o",
}, typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
}, },
{}, {},
); );
@@ -846,71 +869,82 @@ describe("createDataExtensionYamlsForFrameworkMode", () => {
}, },
], ],
{ {
"org.sql2o.Connection#createQuery(String)": { "org.sql2o.Connection#createQuery(String)": [
type: "sink", {
input: "Argument[0]", type: "sink",
output: "", input: "Argument[0]",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
},
{
"models/org.sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": {
type: "neutral",
input: "",
output: "", output: "",
kind: "summary", kind: "sql",
provenance: "manual", provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)", signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o", packageName: "org.sql2o",
typeName: "Connection", typeName: "Connection",
methodName: "createQuery", methodName: "createQuery",
methodParameters: "(String)", methodParameters: "(String)",
}, },
"org.sql2o.Query#executeScalar(Class)": { ],
type: "neutral", "org.sql2o.Sql2o#Sql2o(String,String,String)": [
input: "", {
type: "sink",
input: "Argument[0]",
output: "", output: "",
kind: "summary", kind: "jndi",
provenance: "manual", provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)", signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o", packageName: "org.sql2o",
typeName: "Query", typeName: "Sql2o",
methodName: "executeScalar", methodName: "Sql2o",
methodParameters: "(Class)", methodParameters: "(String,String,String)",
}, },
],
},
{
"models/org.sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
"org.sql2o.Query#executeScalar(Class)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
},
],
}, },
"models/gson.model.yml": { "models/gson.model.yml": {
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": { "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": [
type: "summary", {
input: "Argument[this]", type: "summary",
output: "ReturnValue", input: "Argument[this]",
kind: "taint", output: "ReturnValue",
provenance: "df-generated", kind: "taint",
signature: "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)", provenance: "df-generated",
packageName: "com.google.gson", signature:
typeName: "TypeAdapter", "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)",
methodName: "fromJsonTree", packageName: "com.google.gson",
methodParameters: "(JsonElement)", typeName: "TypeAdapter",
}, methodName: "fromJsonTree",
methodParameters: "(JsonElement)",
},
],
}, },
}, },
); );
@@ -980,18 +1014,20 @@ describe("loadDataExtensionYaml", () => {
}); });
expect(data).toEqual({ expect(data).toEqual({
"org.sql2o.Connection#createQuery(String)": { "org.sql2o.Connection#createQuery(String)": [
input: "Argument[0]", {
kind: "sql", input: "Argument[0]",
output: "", kind: "sql",
type: "sink", output: "",
provenance: "manual", type: "sink",
signature: "org.sql2o.Connection#createQuery(String)", provenance: "manual",
packageName: "org.sql2o", signature: "org.sql2o.Connection#createQuery(String)",
typeName: "Connection", packageName: "org.sql2o",
methodName: "createQuery", typeName: "Connection",
methodParameters: "(String)", methodName: "createQuery",
}, methodParameters: "(String)",
},
],
}); });
}); });