From 4131fccbe0c77832aa496825e9362a78797234ad Mon Sep 17 00:00:00 2001 From: Yidadaa Date: Tue, 4 Jul 2023 23:16:24 +0800 Subject: [PATCH] feat: close #2192 use /list/models to get model ids --- app/api/config/route.ts | 2 +- app/api/openai/[...path]/route.ts | 27 +++++++- app/client/api.ts | 6 ++ app/client/platforms/openai.ts | 32 +++++++++- app/components/chat.tsx | 7 +-- app/components/home.tsx | 14 +++++ app/components/model-config.tsx | 6 +- app/components/settings.tsx | 61 +++++++++--------- app/config/server.ts | 2 +- app/constant.ts | 68 ++++++++++++++++++++ app/store/access.ts | 9 --- app/store/config.ts | 101 +++++++++--------------------- 12 files changed, 214 insertions(+), 121 deletions(-) diff --git a/app/api/config/route.ts b/app/api/config/route.ts index 6b956558..7749e6e9 100644 --- a/app/api/config/route.ts +++ b/app/api/config/route.ts @@ -9,7 +9,7 @@ const serverConfig = getServerSideConfig(); const DANGER_CONFIG = { needCode: serverConfig.needCode, hideUserApiKey: serverConfig.hideUserApiKey, - enableGPT4: serverConfig.enableGPT4, + disableGPT4: serverConfig.disableGPT4, hideBalanceQuery: serverConfig.hideBalanceQuery, }; diff --git a/app/api/openai/[...path]/route.ts b/app/api/openai/[...path]/route.ts index 36f92d0f..9df005a3 100644 --- a/app/api/openai/[...path]/route.ts +++ b/app/api/openai/[...path]/route.ts @@ -1,3 +1,5 @@ +import { type OpenAIListModelResponse } from "@/app/client/platforms/openai"; +import { getServerSideConfig } from "@/app/config/server"; import { OpenaiPath } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; @@ -6,6 +8,18 @@ import { requestOpenai } from "../../common"; const ALLOWD_PATH = new Set(Object.values(OpenaiPath)); +function getModels(remoteModelRes: OpenAIListModelResponse) { + const config = getServerSideConfig(); + + if (config.disableGPT4) { + remoteModelRes.data = remoteModelRes.data.filter( + (m) => !m.id.startsWith("gpt-4"), + ); + } + + return remoteModelRes; +} + async function handle( req: NextRequest, { params }: { params: { path: string[] } }, @@ -39,7 +53,18 @@ async function handle( } try { - return await requestOpenai(req); + const response = await requestOpenai(req); + + // list models + if (subpath === OpenaiPath.ListModelPath && response.status === 200) { + const resJson = (await response.json()) as OpenAIListModelResponse; + const availableModels = getModels(resJson); + return NextResponse.json(availableModels, { + status: response.status, + }); + } + + return response; } catch (e) { console.error("[OpenAI] ", e); return NextResponse.json(prettyObject(e)); diff --git a/app/client/api.ts b/app/client/api.ts index a8960ff5..08c4bb92 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -38,9 +38,15 @@ export interface LLMUsage { total: number; } +export interface LLMModel { + name: string; + available: boolean; +} + export abstract class LLMApi { abstract chat(options: ChatOptions): Promise; abstract usage(): Promise; + abstract models(): Promise; } type ProviderName = "openai" | "azure" | "claude" | "palm"; diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index bbd14d61..3384aeef 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -5,7 +5,7 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; -import { ChatOptions, getHeaders, LLMApi, LLMUsage } from "../api"; +import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; import Locale from "../../locales"; import { EventStreamContentType, @@ -13,6 +13,15 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + export class ChatGPTApi implements LLMApi { path(path: string): string { let openaiUrl = useAccessStore.getState().openaiUrl; @@ -22,6 +31,9 @@ export class ChatGPTApi implements LLMApi { if (openaiUrl.endsWith("/")) { openaiUrl = openaiUrl.slice(0, openaiUrl.length - 1); } + if (!openaiUrl.startsWith("http") && !openaiUrl.startsWith("/api/openai")) { + openaiUrl = "https://" + openaiUrl; + } return [openaiUrl, path].join("/"); } @@ -232,5 +244,23 @@ export class ChatGPTApi implements LLMApi { total: total.hard_limit_usd, } as LLMUsage; } + + async models(): Promise { + const res = await fetch(this.path(OpenaiPath.ListModelPath), { + method: "GET", + headers: { + ...getHeaders(), + }, + }); + + const resJson = (await res.json()) as OpenAIListModelResponse; + const chatModels = resJson.data.filter((m) => m.id.startsWith("gpt-")); + console.log("[Models]", chatModels); + + return chatModels.map((m) => ({ + name: m.id, + available: true, + })); + } } export { OpenaiPath }; diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 48742fcc..74c872de 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -42,12 +42,11 @@ import { Theme, useAppConfig, DEFAULT_TOPIC, - ALL_MODELS, + ModelType, } from "../store"; import { copyToClipboard, - downloadAs, selectOrCopy, autoGrowTextArea, useMobileScreen, @@ -387,12 +386,12 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; function nextModel() { - const models = ALL_MODELS.filter((m) => m.available).map((m) => m.name); + const models = config.models.filter((m) => m.available).map((m) => m.name); const modelIndex = models.indexOf(currentModel); const nextIndex = (modelIndex + 1) % models.length; const nextModel = models[nextIndex]; chatStore.updateCurrentSession((session) => { - session.mask.modelConfig.model = nextModel; + session.mask.modelConfig.model = nextModel as ModelType; session.mask.syncGlobalConfig = false; }); } diff --git a/app/components/home.tsx b/app/components/home.tsx index b4b19028..96c1b838 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -27,6 +27,7 @@ import { SideBar } from "./sidebar"; import { useAppConfig } from "../store/config"; import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; +import { api } from "../client/api"; export function Loading(props: { noLogo?: boolean }) { return ( @@ -152,8 +153,21 @@ function Screen() { ); } +export function useLoadData() { + const config = useAppConfig(); + + useEffect(() => { + (async () => { + const models = await api.llm.models(); + config.mergeModels(models); + })(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); +} + export function Home() { useSwitchTheme(); + useLoadData(); useEffect(() => { console.log("[Config] got config from build time", getClientConfig()); diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 9fd4677e..0b81dd90 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -1,4 +1,4 @@ -import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store"; +import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store"; import Locale from "../locales"; import { InputRange } from "./input-range"; @@ -8,6 +8,8 @@ export function ModelConfigList(props: { modelConfig: ModelConfig; updateConfig: (updater: (config: ModelConfig) => void) => void; }) { + const config = useAppConfig(); + return ( <> @@ -22,7 +24,7 @@ export function ModelConfigList(props: { ); }} > - {ALL_MODELS.map((v) => ( + {config.models.map((v) => ( diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 1ee7316a..ed84825b 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -340,6 +340,10 @@ export function Settings() { }; const [loadingUsage, setLoadingUsage] = useState(false); function checkUsage(force = false) { + if (accessStore.hideBalanceQuery) { + return; + } + setLoadingUsage(true); updateStore.updateUsage(force).finally(() => { setLoadingUsage(false); @@ -577,19 +581,34 @@ export function Settings() { )} {!accessStore.hideUserApiKey ? ( - - { - accessStore.updateToken(e.currentTarget.value); - }} - /> - + <> + + + accessStore.updateOpenAiUrl(e.currentTarget.value) + } + > + + + { + accessStore.updateToken(e.currentTarget.value); + }} + /> + + ) : null} {!accessStore.hideBalanceQuery ? ( @@ -617,22 +636,6 @@ export function Settings() { )} ) : null} - - {!accessStore.hideUserApiKey ? ( - - - accessStore.updateOpenAiUrl(e.currentTarget.value) - } - > - - ) : null} diff --git a/app/config/server.ts b/app/config/server.ts index 5479995e..6eab9ebe 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -46,7 +46,7 @@ export const getServerSideConfig = () => { proxyUrl: process.env.PROXY_URL, isVercel: !!process.env.VERCEL, hideUserApiKey: !!process.env.HIDE_USER_API_KEY, - enableGPT4: !process.env.DISABLE_GPT4, + disableGPT4: !!process.env.DISABLE_GPT4, hideBalanceQuery: !!process.env.HIDE_BALANCE_QUERY, }; }; diff --git a/app/constant.ts b/app/constant.ts index b01fd788..6cf3e645 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -53,6 +53,7 @@ export const OpenaiPath = { ChatPath: "v1/chat/completions", UsagePath: "dashboard/billing/usage", SubsPath: "dashboard/billing/subscription", + ListModelPath: "v1/models", }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -61,3 +62,70 @@ You are ChatGPT, a large language model trained by OpenAI. Knowledge cutoff: 2021-09 Current model: {{model}} Current time: {{time}}`; + +export const DEFAULT_MODELS = [ + { + name: "gpt-4", + available: false, + }, + { + name: "gpt-4-0314", + available: false, + }, + { + name: "gpt-4-0613", + available: false, + }, + { + name: "gpt-4-32k", + available: false, + }, + { + name: "gpt-4-32k-0314", + available: false, + }, + { + name: "gpt-4-32k-0613", + available: false, + }, + { + name: "gpt-3.5-turbo", + available: true, + }, + { + name: "gpt-3.5-turbo-0301", + available: true, + }, + { + name: "gpt-3.5-turbo-0613", + available: true, + }, + { + name: "gpt-3.5-turbo-16k", + available: true, + }, + { + name: "gpt-3.5-turbo-16k-0613", + available: true, + }, + { + name: "qwen-v1", // 通义千问 + available: false, + }, + { + name: "ernie", // 文心一言 + available: false, + }, + { + name: "spark", // 讯飞星火 + available: false, + }, + { + name: "llama", // llama + available: false, + }, + { + name: "chatglm", // chatglm-6b + available: false, + }, +] as const; diff --git a/app/store/access.ts b/app/store/access.ts index e9d09bb8..d2806414 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -3,7 +3,6 @@ import { persist } from "zustand/middleware"; import { DEFAULT_API_HOST, StoreKey } from "../constant"; import { getHeaders } from "../client/api"; import { BOT_HELLO } from "./chat"; -import { ALL_MODELS } from "./config"; import { getClientConfig } from "../config/client"; export interface AccessControlStore { @@ -76,14 +75,6 @@ export const useAccessStore = create()( console.log("[Config] got config from server", res); set(() => ({ ...res })); - if (!res.enableGPT4) { - ALL_MODELS.forEach((model) => { - if (model.name.startsWith("gpt-4")) { - (model as any).available = false; - } - }); - } - if ((res as any).botHello) { BOT_HELLO.content = (res as any).botHello; } diff --git a/app/store/config.ts b/app/store/config.ts index 68e29915..ecf365ab 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,7 +1,10 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; +import { LLMModel } from "../client/api"; import { getClientConfig } from "../config/client"; -import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; +import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, StoreKey } from "../constant"; + +export type ModelType = (typeof DEFAULT_MODELS)[number]["name"]; export enum SubmitKey { Enter = "Enter", @@ -30,6 +33,8 @@ export const DEFAULT_CONFIG = { dontShowMaskSplashScreen: false, // dont show splash screen when create chat + models: DEFAULT_MODELS as any as LLMModel[], + modelConfig: { model: "gpt-3.5-turbo" as ModelType, temperature: 0.5, @@ -49,81 +54,11 @@ export type ChatConfig = typeof DEFAULT_CONFIG; export type ChatConfigStore = ChatConfig & { reset: () => void; update: (updater: (config: ChatConfig) => void) => void; + mergeModels: (newModels: LLMModel[]) => void; }; export type ModelConfig = ChatConfig["modelConfig"]; -const ENABLE_GPT4 = true; - -export const ALL_MODELS = [ - { - name: "gpt-4", - available: ENABLE_GPT4, - }, - { - name: "gpt-4-0314", - available: ENABLE_GPT4, - }, - { - name: "gpt-4-0613", - available: ENABLE_GPT4, - }, - { - name: "gpt-4-32k", - available: ENABLE_GPT4, - }, - { - name: "gpt-4-32k-0314", - available: ENABLE_GPT4, - }, - { - name: "gpt-4-32k-0613", - available: ENABLE_GPT4, - }, - { - name: "gpt-3.5-turbo", - available: true, - }, - { - name: "gpt-3.5-turbo-0301", - available: true, - }, - { - name: "gpt-3.5-turbo-0613", - available: true, - }, - { - name: "gpt-3.5-turbo-16k", - available: true, - }, - { - name: "gpt-3.5-turbo-16k-0613", - available: true, - }, - { - name: "qwen-v1", // 通义千问 - available: false, - }, - { - name: "ernie", // 文心一言 - available: false, - }, - { - name: "spark", // 讯飞星火 - available: false, - }, - { - name: "llama", // llama - available: false, - }, - { - name: "chatglm", // chatglm-6b - available: false, - }, -] as const; - -export type ModelType = (typeof ALL_MODELS)[number]["name"]; - export function limitNumber( x: number, min: number, @@ -138,7 +73,8 @@ export function limitNumber( } export function limitModel(name: string) { - return ALL_MODELS.some((m) => m.name === name && m.available) + const allModels = useAppConfig.getState().models; + return allModels.some((m) => m.name === name && m.available) ? name : "gpt-3.5-turbo"; } @@ -178,6 +114,25 @@ export const useAppConfig = create()( updater(config); set(() => config); }, + + mergeModels(newModels) { + const oldModels = get().models; + const modelMap: Record = {}; + + for (const model of oldModels) { + model.available = false; + modelMap[model.name] = model; + } + + for (const model of newModels) { + model.available = true; + modelMap[model.name] = model; + } + + set(() => ({ + models: Object.values(modelMap), + })); + }, }), { name: StoreKey.Config,