Sort methods sent to LLM the same way as the UI
This changes the sorting of the methods sent to LLM to match the order shown in the data extensions editor. This will ensure that the methods which are shown first in the data extensions editor are also modeled first.
This commit is contained in:
@@ -7,6 +7,8 @@ import {
|
||||
ModelRequest,
|
||||
} from "./auto-model-api";
|
||||
import type { UsageSnippetsBySignature } from "./auto-model-usages-query";
|
||||
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
|
||||
import { Mode } from "./shared/mode";
|
||||
|
||||
// Soft limit on the number of candidates to send to the model.
|
||||
// Note that the model may return fewer than this number of candidates.
|
||||
@@ -19,6 +21,7 @@ export function createAutoModelRequest(
|
||||
externalApiUsages: ExternalApiUsage[],
|
||||
modeledMethods: Record<string, ModeledMethod>,
|
||||
usages: UsageSnippetsBySignature,
|
||||
mode: Mode,
|
||||
): ModelRequest {
|
||||
const request: ModelRequest = {
|
||||
language,
|
||||
@@ -26,11 +29,14 @@ export function createAutoModelRequest(
|
||||
candidates: [],
|
||||
};
|
||||
|
||||
// Sort by number of usages so we always send the most used methods first
|
||||
externalApiUsages = [...externalApiUsages];
|
||||
externalApiUsages.sort((a, b) => b.usages.length - a.usages.length);
|
||||
// Sort the same way as the UI so we send the first ones listed in the UI first
|
||||
const grouped = groupMethods(externalApiUsages, mode);
|
||||
const sortedGroupNames = sortGroupNames(grouped);
|
||||
const sortedExternalApiUsages = sortedGroupNames.flatMap((name) =>
|
||||
sortMethods(grouped[name]),
|
||||
);
|
||||
|
||||
for (const externalApiUsage of externalApiUsages) {
|
||||
for (const externalApiUsage of sortedExternalApiUsages) {
|
||||
const modeledMethod: ModeledMethod = modeledMethods[
|
||||
externalApiUsage.signature
|
||||
] ?? {
|
||||
|
||||
@@ -457,6 +457,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
|
||||
externalApiUsages,
|
||||
modeledMethods,
|
||||
usages,
|
||||
this.mode,
|
||||
);
|
||||
|
||||
await this.showProgress({
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
ClassificationType,
|
||||
Method,
|
||||
} from "../../../src/data-extensions-editor/auto-model-api";
|
||||
import { Mode } from "../../../src/data-extensions-editor/shared/mode";
|
||||
|
||||
describe("createAutoModelRequest", () => {
|
||||
const externalApiUsages: ExternalApiUsage[] = [
|
||||
@@ -259,7 +260,13 @@ describe("createAutoModelRequest", () => {
|
||||
|
||||
it("creates a matching request", () => {
|
||||
expect(
|
||||
createAutoModelRequest("java", externalApiUsages, modeledMethods, usages),
|
||||
createAutoModelRequest(
|
||||
"java",
|
||||
externalApiUsages,
|
||||
modeledMethods,
|
||||
usages,
|
||||
Mode.Application,
|
||||
),
|
||||
).toEqual({
|
||||
language: "java",
|
||||
samples: [
|
||||
@@ -340,60 +347,6 @@ describe("createAutoModelRequest", () => {
|
||||
input: "Argument[0]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.springframework.boot",
|
||||
type: "SpringApplication",
|
||||
name: "run",
|
||||
signature: "(Class,String[])",
|
||||
usages:
|
||||
usages[
|
||||
"org.springframework.boot.SpringApplication#run(Class,String[])"
|
||||
],
|
||||
input: "Argument[this]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.springframework.boot",
|
||||
type: "SpringApplication",
|
||||
name: "run",
|
||||
signature: "(Class,String[])",
|
||||
usages:
|
||||
usages[
|
||||
"org.springframework.boot.SpringApplication#run(Class,String[])"
|
||||
],
|
||||
input: "Argument[0]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.springframework.boot",
|
||||
type: "SpringApplication",
|
||||
name: "run",
|
||||
signature: "(Class,String[])",
|
||||
usages:
|
||||
usages[
|
||||
"org.springframework.boot.SpringApplication#run(Class,String[])"
|
||||
],
|
||||
input: "Argument[1]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "java.io",
|
||||
type: "PrintStream",
|
||||
name: "println",
|
||||
signature: "(String)",
|
||||
usages: usages["java.io.PrintStream#println(String)"],
|
||||
input: "Argument[this]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "java.io",
|
||||
type: "PrintStream",
|
||||
name: "println",
|
||||
signature: "(String)",
|
||||
usages: usages["java.io.PrintStream#println(String)"],
|
||||
input: "Argument[0]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.sql2o",
|
||||
type: "Sql2o",
|
||||
@@ -430,6 +383,60 @@ describe("createAutoModelRequest", () => {
|
||||
input: "Argument[2]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "java.io",
|
||||
type: "PrintStream",
|
||||
name: "println",
|
||||
signature: "(String)",
|
||||
usages: usages["java.io.PrintStream#println(String)"],
|
||||
input: "Argument[this]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "java.io",
|
||||
type: "PrintStream",
|
||||
name: "println",
|
||||
signature: "(String)",
|
||||
usages: usages["java.io.PrintStream#println(String)"],
|
||||
input: "Argument[0]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.springframework.boot",
|
||||
type: "SpringApplication",
|
||||
name: "run",
|
||||
signature: "(Class,String[])",
|
||||
usages:
|
||||
usages[
|
||||
"org.springframework.boot.SpringApplication#run(Class,String[])"
|
||||
],
|
||||
input: "Argument[this]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.springframework.boot",
|
||||
type: "SpringApplication",
|
||||
name: "run",
|
||||
signature: "(Class,String[])",
|
||||
usages:
|
||||
usages[
|
||||
"org.springframework.boot.SpringApplication#run(Class,String[])"
|
||||
],
|
||||
input: "Argument[0]",
|
||||
classification: undefined,
|
||||
},
|
||||
{
|
||||
package: "org.springframework.boot",
|
||||
type: "SpringApplication",
|
||||
name: "run",
|
||||
signature: "(Class,String[])",
|
||||
usages:
|
||||
usages[
|
||||
"org.springframework.boot.SpringApplication#run(Class,String[])"
|
||||
],
|
||||
input: "Argument[1]",
|
||||
classification: undefined,
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user