Merge pull request #2549 from github/koesie10/consistent-sorting

Sort methods sent to LLM the same way as the UI
This commit is contained in:
Koen Vlaswinkel
2023-06-26 16:11:38 +02:00
committed by GitHub
10 changed files with 179 additions and 134 deletions

View File

@@ -7,6 +7,8 @@ import {
ModelRequest,
} from "./auto-model-api";
import type { UsageSnippetsBySignature } from "./auto-model-usages-query";
import { groupMethods, sortGroupNames, sortMethods } from "./shared/sorting";
import { Mode } from "./shared/mode";
// Soft limit on the number of candidates to send to the model.
// Note that the model may return fewer than this number of candidates.
@@ -19,6 +21,7 @@ export function createAutoModelRequest(
externalApiUsages: ExternalApiUsage[],
modeledMethods: Record<string, ModeledMethod>,
usages: UsageSnippetsBySignature,
mode: Mode,
): ModelRequest {
const request: ModelRequest = {
language,
@@ -26,11 +29,14 @@ export function createAutoModelRequest(
candidates: [],
};
// Sort by number of usages so we always send the most used methods first
externalApiUsages = [...externalApiUsages];
externalApiUsages.sort((a, b) => b.usages.length - a.usages.length);
// Sort the same way as the UI so we send the first ones listed in the UI first
const grouped = groupMethods(externalApiUsages, mode);
const sortedGroupNames = sortGroupNames(grouped);
const sortedExternalApiUsages = sortedGroupNames.flatMap((name) =>
sortMethods(grouped[name]),
);
for (const externalApiUsage of externalApiUsages) {
for (const externalApiUsage of sortedExternalApiUsages) {
const modeledMethod: ModeledMethod = modeledMethods[
externalApiUsage.signature
] ?? {

View File

@@ -457,6 +457,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
externalApiUsages,
modeledMethods,
usages,
this.mode,
);
await this.showProgress({

View File

@@ -1,4 +1,4 @@
import { ExternalApiUsage } from "../../data-extensions-editor/external-api-usage";
import { ExternalApiUsage } from "../external-api-usage";
export function calculateModeledPercentage(
externalApiUsages: Array<Pick<ExternalApiUsage, "supported">>,

View File

@@ -0,0 +1,88 @@
import { ExternalApiUsage } from "../external-api-usage";
import { Mode } from "./mode";
import { calculateModeledPercentage } from "./modeled-percentage";
export function groupMethods(
externalApiUsages: ExternalApiUsage[],
mode: Mode,
): Record<string, ExternalApiUsage[]> {
const groupedByLibrary: Record<string, ExternalApiUsage[]> = {};
for (const externalApiUsage of externalApiUsages) {
// Group by package if using framework mode
const key =
mode === Mode.Framework
? externalApiUsage.packageName
: externalApiUsage.library;
groupedByLibrary[key] ??= [];
groupedByLibrary[key].push(externalApiUsage);
}
return groupedByLibrary;
}
export function sortGroupNames(
methods: Record<string, ExternalApiUsage[]>,
): string[] {
return Object.keys(methods).sort((a, b) =>
compareGroups(methods[a], a, methods[b], b),
);
}
export function sortMethods(
externalApiUsages: ExternalApiUsage[],
): ExternalApiUsage[] {
const sortedExternalApiUsages = [...externalApiUsages];
sortedExternalApiUsages.sort((a, b) => compareMethod(a, b));
return sortedExternalApiUsages;
}
function compareGroups(
a: ExternalApiUsage[],
aName: string,
b: ExternalApiUsage[],
bName: string,
): number {
const supportedPercentageA = calculateModeledPercentage(a);
const supportedPercentageB = calculateModeledPercentage(b);
// Sort first by supported percentage ascending
if (supportedPercentageA > supportedPercentageB) {
return 1;
}
if (supportedPercentageA < supportedPercentageB) {
return -1;
}
const numberOfUsagesA = a.reduce((acc, curr) => acc + curr.usages.length, 0);
const numberOfUsagesB = b.reduce((acc, curr) => acc + curr.usages.length, 0);
// If the number of usages is equal, sort by number of methods descending
if (numberOfUsagesA === numberOfUsagesB) {
const numberOfMethodsA = a.length;
const numberOfMethodsB = b.length;
// If the number of methods is equal, sort by library name ascending
if (numberOfMethodsA === numberOfMethodsB) {
return aName.localeCompare(bName);
}
return numberOfMethodsB - numberOfMethodsA;
}
// Then sort by number of usages descending
return numberOfUsagesB - numberOfUsagesA;
}
function compareMethod(a: ExternalApiUsage, b: ExternalApiUsage): number {
// Sort first by supported, putting unmodeled methods first.
if (a.supported && !b.supported) {
return 1;
}
if (!a.supported && b.supported) {
return -1;
}
// Then sort by number of usages descending
return b.usages.length - a.usages.length;
}

View File

@@ -10,7 +10,7 @@ import { ExternalApiUsage } from "../../data-extensions-editor/external-api-usag
import { ModeledMethod } from "../../data-extensions-editor/modeled-method";
import { assertNever } from "../../common/helpers-pure";
import { vscode } from "../vscode-api";
import { calculateModeledPercentage } from "./modeled";
import { calculateModeledPercentage } from "../../data-extensions-editor/shared/modeled-percentage";
import { LinkIconButton } from "../variant-analysis/LinkIconButton";
import { ViewTitle } from "../common";
import { DataExtensionEditorViewState } from "../../data-extensions-editor/shared/view-state";

View File

@@ -5,7 +5,7 @@ import { ExternalApiUsage } from "../../data-extensions-editor/external-api-usag
import { ModeledMethod } from "../../data-extensions-editor/modeled-method";
import { pluralize } from "../../common/word";
import { ModeledMethodDataGrid } from "./ModeledMethodDataGrid";
import { calculateModeledPercentage } from "./modeled";
import { calculateModeledPercentage } from "../../data-extensions-editor/shared/modeled-percentage";
import { decimalFormatter, percentFormatter } from "./formatters";
import { Codicon } from "../common";
import { Mode } from "../../data-extensions-editor/shared/mode";

View File

@@ -9,6 +9,7 @@ import { ExternalApiUsage } from "../../data-extensions-editor/external-api-usag
import { ModeledMethod } from "../../data-extensions-editor/modeled-method";
import { useMemo } from "react";
import { Mode } from "../../data-extensions-editor/shared/mode";
import { sortMethods } from "../../data-extensions-editor/shared/sorting";
type Props = {
externalApiUsages: ExternalApiUsage[];
@@ -26,21 +27,10 @@ export const ModeledMethodDataGrid = ({
mode,
onChange,
}: Props) => {
const sortedExternalApiUsages = useMemo(() => {
const sortedExternalApiUsages = [...externalApiUsages];
sortedExternalApiUsages.sort((a, b) => {
// Sort first by supported, putting unmodeled methods first.
if (a.supported && !b.supported) {
return 1;
}
if (!a.supported && b.supported) {
return -1;
}
// Then sort by number of usages descending
return b.usages.length - a.usages.length;
});
return sortedExternalApiUsages;
}, [externalApiUsages]);
const sortedExternalApiUsages = useMemo(
() => sortMethods(externalApiUsages),
[externalApiUsages],
);
return (
<VSCodeDataGrid>

View File

@@ -2,9 +2,12 @@ import * as React from "react";
import { useMemo } from "react";
import { ExternalApiUsage } from "../../data-extensions-editor/external-api-usage";
import { ModeledMethod } from "../../data-extensions-editor/modeled-method";
import { calculateModeledPercentage } from "./modeled";
import { LibraryRow } from "./LibraryRow";
import { Mode } from "../../data-extensions-editor/shared/mode";
import {
groupMethods,
sortGroupNames,
} from "../../data-extensions-editor/shared/sorting";
type Props = {
externalApiUsages: ExternalApiUsage[];
@@ -22,62 +25,12 @@ export const ModeledMethodsList = ({
mode,
onChange,
}: Props) => {
const grouped = useMemo(() => {
const groupedByLibrary: Record<string, ExternalApiUsage[]> = {};
const grouped = useMemo(
() => groupMethods(externalApiUsages, mode),
[externalApiUsages, mode],
);
for (const externalApiUsage of externalApiUsages) {
// Group by package if using framework mode
const key =
mode === Mode.Framework
? externalApiUsage.packageName
: externalApiUsage.library;
groupedByLibrary[key] ??= [];
groupedByLibrary[key].push(externalApiUsage);
}
return groupedByLibrary;
}, [externalApiUsages, mode]);
const sortedGroupNames = useMemo(() => {
return Object.keys(grouped).sort((a, b) => {
const supportedPercentageA = calculateModeledPercentage(grouped[a]);
const supportedPercentageB = calculateModeledPercentage(grouped[b]);
// Sort first by supported percentage ascending
if (supportedPercentageA > supportedPercentageB) {
return 1;
}
if (supportedPercentageA < supportedPercentageB) {
return -1;
}
const numberOfUsagesA = grouped[a].reduce(
(acc, curr) => acc + curr.usages.length,
0,
);
const numberOfUsagesB = grouped[b].reduce(
(acc, curr) => acc + curr.usages.length,
0,
);
// If the number of usages is equal, sort by number of methods descending
if (numberOfUsagesA === numberOfUsagesB) {
const numberOfMethodsA = grouped[a].length;
const numberOfMethodsB = grouped[b].length;
// If the number of methods is equal, sort by library name ascending
if (numberOfMethodsA === numberOfMethodsB) {
return a.localeCompare(b);
}
return numberOfMethodsB - numberOfMethodsA;
}
// Then sort by number of usages descending
return numberOfUsagesB - numberOfUsagesA;
});
}, [grouped]);
const sortedGroupNames = useMemo(() => sortGroupNames(grouped), [grouped]);
return (
<>

View File

@@ -9,6 +9,7 @@ import {
ClassificationType,
Method,
} from "../../../src/data-extensions-editor/auto-model-api";
import { Mode } from "../../../src/data-extensions-editor/shared/mode";
describe("createAutoModelRequest", () => {
const externalApiUsages: ExternalApiUsage[] = [
@@ -259,7 +260,13 @@ describe("createAutoModelRequest", () => {
it("creates a matching request", () => {
expect(
createAutoModelRequest("java", externalApiUsages, modeledMethods, usages),
createAutoModelRequest(
"java",
externalApiUsages,
modeledMethods,
usages,
Mode.Application,
),
).toEqual({
language: "java",
samples: [
@@ -340,60 +347,6 @@ describe("createAutoModelRequest", () => {
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[1]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.sql2o",
type: "Sql2o",
@@ -430,6 +383,60 @@ describe("createAutoModelRequest", () => {
input: "Argument[2]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[this]",
classification: undefined,
},
{
package: "java.io",
type: "PrintStream",
name: "println",
signature: "(String)",
usages: usages["java.io.PrintStream#println(String)"],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[this]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[0]",
classification: undefined,
},
{
package: "org.springframework.boot",
type: "SpringApplication",
name: "run",
signature: "(Class,String[])",
usages:
usages[
"org.springframework.boot.SpringApplication#run(Class,String[])"
],
input: "Argument[1]",
classification: undefined,
},
],
});
});

View File

@@ -1,4 +1,4 @@
import { calculateModeledPercentage } from "../modeled";
import { calculateModeledPercentage } from "../../../../src/data-extensions-editor/shared/modeled-percentage";
describe("calculateModeledPercentage", () => {
it("when there are no external API usages", () => {