From a5a1f2e8ad781e0c82a6f775746286477d806545 Mon Sep 17 00:00:00 2001 From: Yidadaa Date: Sun, 12 Nov 2023 00:46:21 +0800 Subject: [PATCH] feat: CUSTOM_MODELS support mapper --- app/api/common.ts | 2 +- app/components/chat.tsx | 10 +++++----- app/utils/model.ts | 32 +++++++++++++++++++++----------- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/app/api/common.ts b/app/api/common.ts index adec611b..dd1cc0bb 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -81,7 +81,7 @@ export async function requestOpenai(req: NextRequest) { const jsonBody = JSON.parse(clonedBody) as { model?: string }; // not undefined and is false - if (modelTable[jsonBody?.model ?? ""] === false) { + if (modelTable[jsonBody?.model ?? ""].available === false) { return NextResponse.json( { error: true, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 48f76e8a..a088483e 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -433,7 +433,7 @@ export function ChatActions(props: { const currentModel = chatStore.currentSession().mask.modelConfig.model; const allModels = useAllModels(); const models = useMemo( - () => allModels.filter((m) => m.available).map((m) => m.name), + () => allModels.filter((m) => m.available), [allModels], ); const [showModelSelector, setShowModelSelector] = useState(false); @@ -441,9 +441,9 @@ export function ChatActions(props: { useEffect(() => { // if current model is not available // switch to first available model - const isUnavaliableModel = !models.includes(currentModel); + const isUnavaliableModel = !models.some((m) => m.name === currentModel); if (isUnavaliableModel && models.length > 0) { - const nextModel = models[0] as ModelType; + const nextModel = models[0].name as ModelType; chatStore.updateCurrentSession( (session) => (session.mask.modelConfig.model = nextModel), ); @@ -531,8 +531,8 @@ export function ChatActions(props: { ({ - title: m, - value: m, + title: m.displayName, + value: m.name, }))} onClose={() => setShowModelSelector(false)} onSelection={(s) => { diff --git a/app/utils/model.ts b/app/utils/model.ts index 23090f9d..d5c009c0 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -4,21 +4,34 @@ export function collectModelTable( models: readonly LLMModel[], customModels: string, ) { - const modelTable: Record = {}; + const modelTable: Record< + string, + { available: boolean; name: string; displayName: string } + > = {}; // default models - models.forEach((m) => (modelTable[m.name] = m.available)); + models.forEach( + (m) => + (modelTable[m.name] = { + ...m, + displayName: m.name, + }), + ); // server custom models customModels .split(",") .filter((v) => !!v && v.length > 0) .map((m) => { - if (m.startsWith("+")) { - modelTable[m.slice(1)] = true; - } else if (m.startsWith("-")) { - modelTable[m.slice(1)] = false; - } else modelTable[m] = true; + const available = !m.startsWith("-"); + const nameConfig = + m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; + const [name, displayName] = nameConfig.split(":"); + modelTable[name] = { + name, + displayName: displayName || name, + available, + }; }); return modelTable; } @@ -31,10 +44,7 @@ export function collectModels( customModels: string, ) { const modelTable = collectModelTable(models, customModels); - const allModels = Object.keys(modelTable).map((m) => ({ - name: m, - available: modelTable[m], - })); + const allModels = Object.values(modelTable); return allModels; }