Merge pull request #2904 from github/koesie10/convert-yaml-modeled-methods

Convert `yaml.ts` to handle multiple models per method
This commit is contained in:
Koen Vlaswinkel
2023-10-04 14:01:02 +02:00
committed by GitHub
5 changed files with 300 additions and 213 deletions

View File

@@ -16,6 +16,7 @@ import { QueryRunner } from "../query-server";
import { DatabaseItem } from "../databases/local-databases";
import { Mode } from "./shared/mode";
import { CancellationTokenSource } from "vscode";
import { convertToLegacyModeledMethods } from "./modeled-methods-legacy";
// Limit the number of candidates we send to the model in each request
// to avoid long requests.
@@ -192,11 +193,13 @@ export class AutoModeler {
filename: "auto-model.yml",
});
const loadedMethods = loadDataExtensionYaml(models);
if (!loadedMethods) {
const rawLoadedMethods = loadDataExtensionYaml(models);
if (!rawLoadedMethods) {
return;
}
const loadedMethods = convertToLegacyModeledMethods(rawLoadedMethods);
// Any candidate that was part of the response is a negative result
// meaning that the canidate is not a sink for the kinds that the LLM is checking for.
// For now we model this as a sink neutral method, however this is subject

View File

@@ -10,6 +10,11 @@ import { getOnDiskWorkspaceFolders } from "../common/vscode/workspace-folders";
import { load as loadYaml } from "js-yaml";
import { CodeQLCliServer } from "../codeql-cli/cli";
import { pathsEqual } from "../common/files";
import {
convertFromLegacyModeledMethods,
convertFromLegacyModeledMethodsFiles,
convertToLegacyModeledMethods,
} from "./modeled-methods-legacy";
export async function saveModeledMethods(
extensionPack: ExtensionPack,
@@ -29,8 +34,8 @@ export async function saveModeledMethods(
const yamls = createDataExtensionYamls(
language,
methods,
modeledMethods,
existingModeledMethods,
convertFromLegacyModeledMethods(modeledMethods),
convertFromLegacyModeledMethodsFiles(existingModeledMethods),
mode,
);
@@ -68,7 +73,8 @@ async function loadModeledMethodFiles(
);
continue;
}
modeledMethodsByFile[modelFile] = modeledMethods;
modeledMethodsByFile[modelFile] =
convertToLegacyModeledMethods(modeledMethods);
}
return modeledMethodsByFile;

View File

@@ -0,0 +1,33 @@
import { ModeledMethod } from "./modeled-method";
export function convertFromLegacyModeledMethods(
modeledMethods: Record<string, ModeledMethod>,
): Record<string, ModeledMethod[]> {
// Convert a single ModeledMethod to an array of ModeledMethods
return Object.fromEntries(
Object.entries(modeledMethods).map(([signature, modeledMethod]) => {
return [signature, [modeledMethod]];
}),
);
}
export function convertToLegacyModeledMethods(
modeledMethods: Record<string, ModeledMethod[]>,
): Record<string, ModeledMethod> {
// Always take the first modeled method in the array
return Object.fromEntries(
Object.entries(modeledMethods).map(([signature, modeledMethods]) => {
return [signature, modeledMethods[0]];
}),
);
}
export function convertFromLegacyModeledMethodsFiles(
modeledMethods: Record<string, Record<string, ModeledMethod>>,
): Record<string, Record<string, ModeledMethod[]>> {
return Object.fromEntries(
Object.entries(modeledMethods).map(([filename, modeledMethods]) => {
return [filename, convertFromLegacyModeledMethods(modeledMethods)];
}),
);
}

View File

@@ -71,8 +71,8 @@ ${extensions.join("\n")}`;
export function createDataExtensionYamls(
language: string,
methods: Method[],
newModeledMethods: Record<string, ModeledMethod>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>,
newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
mode: Mode,
) {
switch (mode) {
@@ -98,11 +98,11 @@ export function createDataExtensionYamls(
function createDataExtensionYamlsByGrouping(
language: string,
methods: Method[],
newModeledMethods: Record<string, ModeledMethod>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>,
newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
createFilename: (method: Method) => string,
): Record<string, string> {
const methodsByFilename: Record<string, Record<string, ModeledMethod>> = {};
const methodsByFilename: Record<string, Record<string, ModeledMethod[]>> = {};
// We only want to generate a yaml file when it's a known external API usage
// and there are new modeled methods for it. This avoids us overwriting other
@@ -114,10 +114,12 @@ function createDataExtensionYamlsByGrouping(
}
// First populate methodsByFilename with any existing modeled methods.
for (const [filename, methods] of Object.entries(existingModeledMethods)) {
for (const [filename, methodsBySignature] of Object.entries(
existingModeledMethods,
)) {
if (filename in methodsByFilename) {
for (const [signature, method] of Object.entries(methods)) {
methodsByFilename[filename][signature] = method;
for (const [signature, methods] of Object.entries(methodsBySignature)) {
methodsByFilename[filename][signature] = methods;
}
}
}
@@ -125,10 +127,12 @@ function createDataExtensionYamlsByGrouping(
// Add the new modeled methods, potentially overwriting existing modeled methods
// but not removing existing modeled methods that are not in the new set.
for (const method of methods) {
const newMethod = newModeledMethods[method.signature];
if (newMethod) {
const newMethods = newModeledMethods[method.signature];
if (newMethods) {
const filename = createFilename(method);
methodsByFilename[filename][newMethod.signature] = newMethod;
// Override any existing modeled methods with the new ones.
methodsByFilename[filename][method.signature] = newMethods;
}
}
@@ -137,7 +141,7 @@ function createDataExtensionYamlsByGrouping(
for (const [filename, methods] of Object.entries(methodsByFilename)) {
result[filename] = createDataExtensionYaml(
language,
Object.values(methods),
Object.values(methods).flatMap((methods) => methods),
);
}
@@ -147,8 +151,8 @@ function createDataExtensionYamlsByGrouping(
export function createDataExtensionYamlsForApplicationMode(
language: string,
methods: Method[],
newModeledMethods: Record<string, ModeledMethod>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>,
newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
): Record<string, string> {
return createDataExtensionYamlsByGrouping(
language,
@@ -162,8 +166,8 @@ export function createDataExtensionYamlsForApplicationMode(
export function createDataExtensionYamlsForFrameworkMode(
language: string,
methods: Method[],
newModeledMethods: Record<string, ModeledMethod>,
existingModeledMethods: Record<string, Record<string, ModeledMethod>>,
newModeledMethods: Record<string, ModeledMethod[]>,
existingModeledMethods: Record<string, Record<string, ModeledMethod[]>>,
): Record<string, string> {
return createDataExtensionYamlsByGrouping(
language,
@@ -228,14 +232,14 @@ function validateModelExtensionFile(data: unknown): data is ModelExtensionFile {
export function loadDataExtensionYaml(
data: unknown,
): Record<string, ModeledMethod> | undefined {
): Record<string, ModeledMethod[]> | undefined {
if (!validateModelExtensionFile(data)) {
return undefined;
}
const extensions = data.extensions;
const modeledMethods: Record<string, ModeledMethod> = {};
const modeledMethods: Record<string, ModeledMethod[]> = {};
for (const extension of extensions) {
const addsTo = extension.addsTo;
@@ -250,11 +254,16 @@ export function loadDataExtensionYaml(
}
for (const row of data) {
const modeledMethod = definition.readModeledMethod(row);
const modeledMethod: ModeledMethod = definition.readModeledMethod(row);
if (!modeledMethod) {
continue;
}
modeledMethods[modeledMethod.signature] = modeledMethod;
if (!(modeledMethod.signature in modeledMethods)) {
modeledMethods[modeledMethod.signature] = [];
}
modeledMethods[modeledMethod.signature].push(modeledMethod);
}
}

View File

@@ -225,43 +225,49 @@ describe("createDataExtensionYamlsForApplicationMode", () => {
},
],
{
"org.sql2o.Connection#createQuery(String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.springframework.boot.SpringApplication#run(Class,String[])": {
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
"org.sql2o.Connection#createQuery(String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
"org.springframework.boot.SpringApplication#run(Class,String[])": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
},
],
"org.sql2o.Sql2o#Sql2o(String,String,String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
},
{},
);
@@ -463,84 +469,97 @@ describe("createDataExtensionYamlsForApplicationMode", () => {
},
],
{
"org.sql2o.Connection#createQuery(String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.springframework.boot.SpringApplication#run(Class,String[])": {
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
},
{
"models/sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": {
type: "neutral",
input: "",
"org.sql2o.Connection#createQuery(String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "summary",
provenance: "manual",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Query#executeScalar(Class)": {
],
"org.springframework.boot.SpringApplication#run(Class,String[])": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
signature:
"org.springframework.boot.SpringApplication#run(Class,String[])",
packageName: "org.springframework.boot",
typeName: "SpringApplication",
methodName: "run",
methodParameters: "(Class,String[])",
},
],
"org.sql2o.Sql2o#Sql2o(String,String,String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
},
{
"models/sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
"org.sql2o.Query#executeScalar(Class)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
},
],
},
"models/gson.model.yml": {
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": {
type: "summary",
input: "Argument[this]",
output: "ReturnValue",
kind: "taint",
provenance: "df-generated",
signature: "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)",
packageName: "com.google.gson",
typeName: "TypeAdapter",
methodName: "fromJsonTree",
methodParameters: "(JsonElement)",
},
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": [
{
type: "summary",
input: "Argument[this]",
output: "ReturnValue",
kind: "taint",
provenance: "df-generated",
signature:
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)",
packageName: "com.google.gson",
typeName: "TypeAdapter",
methodName: "fromJsonTree",
methodParameters: "(JsonElement)",
},
],
},
},
);
@@ -694,30 +713,34 @@ describe("createDataExtensionYamlsForFrameworkMode", () => {
},
],
{
"org.sql2o.Connection#createQuery(String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
"org.sql2o.Connection#createQuery(String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
"org.sql2o.Sql2o#Sql2o(String,String,String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
},
{},
);
@@ -846,71 +869,82 @@ describe("createDataExtensionYamlsForFrameworkMode", () => {
},
],
{
"org.sql2o.Connection#createQuery(String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Sql2o#Sql2o(String,String,String)": {
type: "sink",
input: "Argument[0]",
output: "",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
},
{
"models/org.sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": {
type: "neutral",
input: "",
"org.sql2o.Connection#createQuery(String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "summary",
provenance: "manual",
kind: "sql",
provenance: "df-generated",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Query#executeScalar(Class)": {
type: "neutral",
input: "",
],
"org.sql2o.Sql2o#Sql2o(String,String,String)": [
{
type: "sink",
input: "Argument[0]",
output: "",
kind: "summary",
kind: "jndi",
provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)",
signature: "org.sql2o.Sql2o#Sql2o(String,String,String)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
typeName: "Sql2o",
methodName: "Sql2o",
methodParameters: "(String,String,String)",
},
],
},
{
"models/org.sql2o.model.yml": {
"org.sql2o.Connection#createQuery(String)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
"org.sql2o.Query#executeScalar(Class)": [
{
type: "neutral",
input: "",
output: "",
kind: "summary",
provenance: "manual",
signature: "org.sql2o.Query#executeScalar(Class)",
packageName: "org.sql2o",
typeName: "Query",
methodName: "executeScalar",
methodParameters: "(Class)",
},
],
},
"models/gson.model.yml": {
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": {
type: "summary",
input: "Argument[this]",
output: "ReturnValue",
kind: "taint",
provenance: "df-generated",
signature: "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)",
packageName: "com.google.gson",
typeName: "TypeAdapter",
methodName: "fromJsonTree",
methodParameters: "(JsonElement)",
},
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": [
{
type: "summary",
input: "Argument[this]",
output: "ReturnValue",
kind: "taint",
provenance: "df-generated",
signature:
"com.google.gson.TypeAdapter#fromJsonTree(JsonElement)",
packageName: "com.google.gson",
typeName: "TypeAdapter",
methodName: "fromJsonTree",
methodParameters: "(JsonElement)",
},
],
},
},
);
@@ -980,18 +1014,20 @@ describe("loadDataExtensionYaml", () => {
});
expect(data).toEqual({
"org.sql2o.Connection#createQuery(String)": {
input: "Argument[0]",
kind: "sql",
output: "",
type: "sink",
provenance: "manual",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
"org.sql2o.Connection#createQuery(String)": [
{
input: "Argument[0]",
kind: "sql",
output: "",
type: "sink",
provenance: "manual",
signature: "org.sql2o.Connection#createQuery(String)",
packageName: "org.sql2o",
typeName: "Connection",
methodName: "createQuery",
methodParameters: "(String)",
},
],
});
});