Sort methods sent to LLM the same way as the UI

This changes the sorting of the methods sent to LLM to match the order
shown in the data extensions editor. This will ensure that the methods
which are shown first in the data extensions editor are also modeled
first.
This commit is contained in:
Koen Vlaswinkel
2023-06-26 14:26:26 +02:00
parent a8aee6a8e1
commit 90c8391fea
3 changed files with 73 additions and 59 deletions

View File

@@ -7,6 +7,8 @@ import {
ModelRequest,
} from "./auto-model-api";
import type { UsageSnippetsBySignature } from "./auto-model-usages-query";
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
import { Mode } from "./shared/mode";
// Soft limit on the number of candidates to send to the model.
// Note that the model may return fewer than this number of candidates.
@@ -19,6 +21,7 @@ export function createAutoModelRequest(
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
usages: UsageSnippetsBySignature,
mode: Mode,
): ModelRequest {
const request: ModelRequest = {
language,
@@ -26,11 +29,14 @@ export function createAutoModelRequest(
candidates: [],
};
// Sort by number of usages so we always send the most used methods first
externalApiUsages = [...externalApiUsages];
externalApiUsages.sort((a, b) => b.usages.length - a.usages.length);
// Sort the same way as the UI so we send the first ones listed in the UI first
const grouped = groupMethods(externalApiUsages, mode);
const sortedGroupNames = sortGroupNames(grouped);
const sortedExternalApiUsages = sortedGroupNames.flatMap((name) =>
sortMethods(grouped[name]),
);
for (const externalApiUsage of externalApiUsages) {
for (const externalApiUsage of sortedExternalApiUsages) {
const modeledMethod: ModeledMethod = modeledMethods[
externalApiUsage.signature
] ?? {

View File

@@ -457,6 +457,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
externalApiUsages,
modeledMethods,
usages,
this.mode,
);
await this.showProgress({

View File

@@ -9,6 +9,7 @@ import {
ClassificationType,
Method,
} from "../../../src/data-extensions-editor/auto-model-api";
import { Mode } from "../../../src/data-extensions-editor/shared/mode";
describe("createAutoModelRequest", () => {
const externalApiUsages: ExternalApiUsage[] = [
@@ -259,7 +260,13 @@ describe("createAutoModelRequest", () => {
it("creates a matching request", () => {
expect(
createAutoModelRequest("java", externalApiUsages, modeledMethods, usages),
createAutoModelRequest(
"java",
externalApiUsages,
modeledMethods,
usages,
Mode.Application,
),
).toEqual({
language: "java",
samples: [
@@ -340,60 +347,6 @@ describe("createAutoModelRequest", () => {
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[1]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Sql2o",
@@ -430,6 +383,60 @@ describe("createAutoModelRequest", () => {
input: "Argument[2]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[1]",
classification: undefined,
},
],
});
});