feat: close #2192 use /list/models to get model ids

This commit is contained in:
Yidadaa 2023-07-04 23:16:24 +08:00
parent f2d748cfe4
commit 4131fccbe0
12 changed files with 214 additions and 121 deletions

View File

@ -9,7 +9,7 @@ const serverConfig = getServerSideConfig();
const DANGER_CONFIG = { const DANGER_CONFIG = {
needCode: serverConfig.needCode, needCode: serverConfig.needCode,
hideUserApiKey: serverConfig.hideUserApiKey, hideUserApiKey: serverConfig.hideUserApiKey,
enableGPT4: serverConfig.enableGPT4, disableGPT4: serverConfig.disableGPT4,
hideBalanceQuery: serverConfig.hideBalanceQuery, hideBalanceQuery: serverConfig.hideBalanceQuery,
}; };

View File

@ -1,3 +1,5 @@
import { type OpenAIListModelResponse } from "@/app/client/platforms/openai";
import { getServerSideConfig } from "@/app/config/server";
import { OpenaiPath } from "@/app/constant"; import { OpenaiPath } from "@/app/constant";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
@ -6,6 +8,18 @@ import { requestOpenai } from "../../common";
const ALLOWD_PATH = new Set(Object.values(OpenaiPath)); const ALLOWD_PATH = new Set(Object.values(OpenaiPath));
function getModels(remoteModelRes: OpenAIListModelResponse) {
const config = getServerSideConfig();
if (config.disableGPT4) {
remoteModelRes.data = remoteModelRes.data.filter(
(m) => !m.id.startsWith("gpt-4"),
);
}
return remoteModelRes;
}
async function handle( async function handle(
req: NextRequest, req: NextRequest,
{ params }: { params: { path: string[] } }, { params }: { params: { path: string[] } },
@ -39,7 +53,18 @@ async function handle(
} }
try { try {
return await requestOpenai(req); const response = await requestOpenai(req);
// list models
if (subpath === OpenaiPath.ListModelPath && response.status === 200) {
const resJson = (await response.json()) as OpenAIListModelResponse;
const availableModels = getModels(resJson);
return NextResponse.json(availableModels, {
status: response.status,
});
}
return response;
} catch (e) { } catch (e) {
console.error("[OpenAI] ", e); console.error("[OpenAI] ", e);
return NextResponse.json(prettyObject(e)); return NextResponse.json(prettyObject(e));

View File

@ -38,9 +38,15 @@ export interface LLMUsage {
total: number; total: number;
} }
export interface LLMModel {
name: string;
available: boolean;
}
export abstract class LLMApi { export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>; abstract chat(options: ChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>; abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>;
} }
type ProviderName = "openai" | "azure" | "claude" | "palm"; type ProviderName = "openai" | "azure" | "claude" | "palm";

View File

@ -5,7 +5,7 @@ import {
} from "@/app/constant"; } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { ChatOptions, getHeaders, LLMApi, LLMUsage } from "../api"; import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
import Locale from "../../locales"; import Locale from "../../locales";
import { import {
EventStreamContentType, EventStreamContentType,
@ -13,6 +13,15 @@ import {
} from "@fortaine/fetch-event-source"; } from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
export interface OpenAIListModelResponse {
object: string;
data: Array<{
id: string;
object: string;
root: string;
}>;
}
export class ChatGPTApi implements LLMApi { export class ChatGPTApi implements LLMApi {
path(path: string): string { path(path: string): string {
let openaiUrl = useAccessStore.getState().openaiUrl; let openaiUrl = useAccessStore.getState().openaiUrl;
@ -22,6 +31,9 @@ export class ChatGPTApi implements LLMApi {
if (openaiUrl.endsWith("/")) { if (openaiUrl.endsWith("/")) {
openaiUrl = openaiUrl.slice(0, openaiUrl.length - 1); openaiUrl = openaiUrl.slice(0, openaiUrl.length - 1);
} }
if (!openaiUrl.startsWith("http") && !openaiUrl.startsWith("/api/openai")) {
openaiUrl = "https://" + openaiUrl;
}
return [openaiUrl, path].join("/"); return [openaiUrl, path].join("/");
} }
@ -232,5 +244,23 @@ export class ChatGPTApi implements LLMApi {
total: total.hard_limit_usd, total: total.hard_limit_usd,
} as LLMUsage; } as LLMUsage;
} }
async models(): Promise<LLMModel[]> {
const res = await fetch(this.path(OpenaiPath.ListModelPath), {
method: "GET",
headers: {
...getHeaders(),
},
});
const resJson = (await res.json()) as OpenAIListModelResponse;
const chatModels = resJson.data.filter((m) => m.id.startsWith("gpt-"));
console.log("[Models]", chatModels);
return chatModels.map((m) => ({
name: m.id,
available: true,
}));
}
} }
export { OpenaiPath }; export { OpenaiPath };

View File

@ -42,12 +42,11 @@ import {
Theme, Theme,
useAppConfig, useAppConfig,
DEFAULT_TOPIC, DEFAULT_TOPIC,
ALL_MODELS, ModelType,
} from "../store"; } from "../store";
import { import {
copyToClipboard, copyToClipboard,
downloadAs,
selectOrCopy, selectOrCopy,
autoGrowTextArea, autoGrowTextArea,
useMobileScreen, useMobileScreen,
@ -387,12 +386,12 @@ export function ChatActions(props: {
// switch model // switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model; const currentModel = chatStore.currentSession().mask.modelConfig.model;
function nextModel() { function nextModel() {
const models = ALL_MODELS.filter((m) => m.available).map((m) => m.name); const models = config.models.filter((m) => m.available).map((m) => m.name);
const modelIndex = models.indexOf(currentModel); const modelIndex = models.indexOf(currentModel);
const nextIndex = (modelIndex + 1) % models.length; const nextIndex = (modelIndex + 1) % models.length;
const nextModel = models[nextIndex]; const nextModel = models[nextIndex];
chatStore.updateCurrentSession((session) => { chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = nextModel; session.mask.modelConfig.model = nextModel as ModelType;
session.mask.syncGlobalConfig = false; session.mask.syncGlobalConfig = false;
}); });
} }

View File

@ -27,6 +27,7 @@ import { SideBar } from "./sidebar";
import { useAppConfig } from "../store/config"; import { useAppConfig } from "../store/config";
import { AuthPage } from "./auth"; import { AuthPage } from "./auth";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { api } from "../client/api";
export function Loading(props: { noLogo?: boolean }) { export function Loading(props: { noLogo?: boolean }) {
return ( return (
@ -152,8 +153,21 @@ function Screen() {
); );
} }
export function useLoadData() {
const config = useAppConfig();
useEffect(() => {
(async () => {
const models = await api.llm.models();
config.mergeModels(models);
})();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
}
export function Home() { export function Home() {
useSwitchTheme(); useSwitchTheme();
useLoadData();
useEffect(() => { useEffect(() => {
console.log("[Config] got config from build time", getClientConfig()); console.log("[Config] got config from build time", getClientConfig());

View File

@ -1,4 +1,4 @@
import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store"; import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store";
import Locale from "../locales"; import Locale from "../locales";
import { InputRange } from "./input-range"; import { InputRange } from "./input-range";
@ -8,6 +8,8 @@ export function ModelConfigList(props: {
modelConfig: ModelConfig; modelConfig: ModelConfig;
updateConfig: (updater: (config: ModelConfig) => void) => void; updateConfig: (updater: (config: ModelConfig) => void) => void;
}) { }) {
const config = useAppConfig();
return ( return (
<> <>
<ListItem title={Locale.Settings.Model}> <ListItem title={Locale.Settings.Model}>
@ -22,7 +24,7 @@ export function ModelConfigList(props: {
); );
}} }}
> >
{ALL_MODELS.map((v) => ( {config.models.map((v) => (
<option value={v.name} key={v.name} disabled={!v.available}> <option value={v.name} key={v.name} disabled={!v.available}>
{v.name} {v.name}
</option> </option>

View File

@ -340,6 +340,10 @@ export function Settings() {
}; };
const [loadingUsage, setLoadingUsage] = useState(false); const [loadingUsage, setLoadingUsage] = useState(false);
function checkUsage(force = false) { function checkUsage(force = false) {
if (accessStore.hideBalanceQuery) {
return;
}
setLoadingUsage(true); setLoadingUsage(true);
updateStore.updateUsage(force).finally(() => { updateStore.updateUsage(force).finally(() => {
setLoadingUsage(false); setLoadingUsage(false);
@ -577,19 +581,34 @@ export function Settings() {
)} )}
{!accessStore.hideUserApiKey ? ( {!accessStore.hideUserApiKey ? (
<ListItem <>
title={Locale.Settings.Token.Title} <ListItem
subTitle={Locale.Settings.Token.SubTitle} title={Locale.Settings.Endpoint.Title}
> subTitle={Locale.Settings.Endpoint.SubTitle}
<PasswordInput >
value={accessStore.token} <input
type="text" type="text"
placeholder={Locale.Settings.Token.Placeholder} value={accessStore.openaiUrl}
onChange={(e) => { placeholder="https://api.openai.com/"
accessStore.updateToken(e.currentTarget.value); onChange={(e) =>
}} accessStore.updateOpenAiUrl(e.currentTarget.value)
/> }
</ListItem> ></input>
</ListItem>
<ListItem
title={Locale.Settings.Token.Title}
subTitle={Locale.Settings.Token.SubTitle}
>
<PasswordInput
value={accessStore.token}
type="text"
placeholder={Locale.Settings.Token.Placeholder}
onChange={(e) => {
accessStore.updateToken(e.currentTarget.value);
}}
/>
</ListItem>
</>
) : null} ) : null}
{!accessStore.hideBalanceQuery ? ( {!accessStore.hideBalanceQuery ? (
@ -617,22 +636,6 @@ export function Settings() {
)} )}
</ListItem> </ListItem>
) : null} ) : null}
{!accessStore.hideUserApiKey ? (
<ListItem
title={Locale.Settings.Endpoint.Title}
subTitle={Locale.Settings.Endpoint.SubTitle}
>
<input
type="text"
value={accessStore.openaiUrl}
placeholder="https://api.openai.com/"
onChange={(e) =>
accessStore.updateOpenAiUrl(e.currentTarget.value)
}
></input>
</ListItem>
) : null}
</List> </List>
<List> <List>

View File

@ -46,7 +46,7 @@ export const getServerSideConfig = () => {
proxyUrl: process.env.PROXY_URL, proxyUrl: process.env.PROXY_URL,
isVercel: !!process.env.VERCEL, isVercel: !!process.env.VERCEL,
hideUserApiKey: !!process.env.HIDE_USER_API_KEY, hideUserApiKey: !!process.env.HIDE_USER_API_KEY,
enableGPT4: !process.env.DISABLE_GPT4, disableGPT4: !!process.env.DISABLE_GPT4,
hideBalanceQuery: !!process.env.HIDE_BALANCE_QUERY, hideBalanceQuery: !!process.env.HIDE_BALANCE_QUERY,
}; };
}; };

View File

@ -53,6 +53,7 @@ export const OpenaiPath = {
ChatPath: "v1/chat/completions", ChatPath: "v1/chat/completions",
UsagePath: "dashboard/billing/usage", UsagePath: "dashboard/billing/usage",
SubsPath: "dashboard/billing/subscription", SubsPath: "dashboard/billing/subscription",
ListModelPath: "v1/models",
}; };
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@ -61,3 +62,70 @@ You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2021-09 Knowledge cutoff: 2021-09
Current model: {{model}} Current model: {{model}}
Current time: {{time}}`; Current time: {{time}}`;
export const DEFAULT_MODELS = [
{
name: "gpt-4",
available: false,
},
{
name: "gpt-4-0314",
available: false,
},
{
name: "gpt-4-0613",
available: false,
},
{
name: "gpt-4-32k",
available: false,
},
{
name: "gpt-4-32k-0314",
available: false,
},
{
name: "gpt-4-32k-0613",
available: false,
},
{
name: "gpt-3.5-turbo",
available: true,
},
{
name: "gpt-3.5-turbo-0301",
available: true,
},
{
name: "gpt-3.5-turbo-0613",
available: true,
},
{
name: "gpt-3.5-turbo-16k",
available: true,
},
{
name: "gpt-3.5-turbo-16k-0613",
available: true,
},
{
name: "qwen-v1", // 通义千问
available: false,
},
{
name: "ernie", // 文心一言
available: false,
},
{
name: "spark", // 讯飞星火
available: false,
},
{
name: "llama", // llama
available: false,
},
{
name: "chatglm", // chatglm-6b
available: false,
},
] as const;

View File

@ -3,7 +3,6 @@ import { persist } from "zustand/middleware";
import { DEFAULT_API_HOST, StoreKey } from "../constant"; import { DEFAULT_API_HOST, StoreKey } from "../constant";
import { getHeaders } from "../client/api"; import { getHeaders } from "../client/api";
import { BOT_HELLO } from "./chat"; import { BOT_HELLO } from "./chat";
import { ALL_MODELS } from "./config";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
export interface AccessControlStore { export interface AccessControlStore {
@ -76,14 +75,6 @@ export const useAccessStore = create<AccessControlStore>()(
console.log("[Config] got config from server", res); console.log("[Config] got config from server", res);
set(() => ({ ...res })); set(() => ({ ...res }));
if (!res.enableGPT4) {
ALL_MODELS.forEach((model) => {
if (model.name.startsWith("gpt-4")) {
(model as any).available = false;
}
});
}
if ((res as any).botHello) { if ((res as any).botHello) {
BOT_HELLO.content = (res as any).botHello; BOT_HELLO.content = (res as any).botHello;
} }

View File

@ -1,7 +1,10 @@
import { create } from "zustand"; import { create } from "zustand";
import { persist } from "zustand/middleware"; import { persist } from "zustand/middleware";
import { LLMModel } from "../client/api";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, StoreKey } from "../constant";
export type ModelType = (typeof DEFAULT_MODELS)[number]["name"];
export enum SubmitKey { export enum SubmitKey {
Enter = "Enter", Enter = "Enter",
@ -30,6 +33,8 @@ export const DEFAULT_CONFIG = {
dontShowMaskSplashScreen: false, // dont show splash screen when create chat dontShowMaskSplashScreen: false, // dont show splash screen when create chat
models: DEFAULT_MODELS as any as LLMModel[],
modelConfig: { modelConfig: {
model: "gpt-3.5-turbo" as ModelType, model: "gpt-3.5-turbo" as ModelType,
temperature: 0.5, temperature: 0.5,
@ -49,81 +54,11 @@ export type ChatConfig = typeof DEFAULT_CONFIG;
export type ChatConfigStore = ChatConfig & { export type ChatConfigStore = ChatConfig & {
reset: () => void; reset: () => void;
update: (updater: (config: ChatConfig) => void) => void; update: (updater: (config: ChatConfig) => void) => void;
mergeModels: (newModels: LLMModel[]) => void;
}; };
export type ModelConfig = ChatConfig["modelConfig"]; export type ModelConfig = ChatConfig["modelConfig"];
const ENABLE_GPT4 = true;
export const ALL_MODELS = [
{
name: "gpt-4",
available: ENABLE_GPT4,
},
{
name: "gpt-4-0314",
available: ENABLE_GPT4,
},
{
name: "gpt-4-0613",
available: ENABLE_GPT4,
},
{
name: "gpt-4-32k",
available: ENABLE_GPT4,
},
{
name: "gpt-4-32k-0314",
available: ENABLE_GPT4,
},
{
name: "gpt-4-32k-0613",
available: ENABLE_GPT4,
},
{
name: "gpt-3.5-turbo",
available: true,
},
{
name: "gpt-3.5-turbo-0301",
available: true,
},
{
name: "gpt-3.5-turbo-0613",
available: true,
},
{
name: "gpt-3.5-turbo-16k",
available: true,
},
{
name: "gpt-3.5-turbo-16k-0613",
available: true,
},
{
name: "qwen-v1", // 通义千问
available: false,
},
{
name: "ernie", // 文心一言
available: false,
},
{
name: "spark", // 讯飞星火
available: false,
},
{
name: "llama", // llama
available: false,
},
{
name: "chatglm", // chatglm-6b
available: false,
},
] as const;
export type ModelType = (typeof ALL_MODELS)[number]["name"];
export function limitNumber( export function limitNumber(
x: number, x: number,
min: number, min: number,
@ -138,7 +73,8 @@ export function limitNumber(
} }
export function limitModel(name: string) { export function limitModel(name: string) {
return ALL_MODELS.some((m) => m.name === name && m.available) const allModels = useAppConfig.getState().models;
return allModels.some((m) => m.name === name && m.available)
? name ? name
: "gpt-3.5-turbo"; : "gpt-3.5-turbo";
} }
@ -178,6 +114,25 @@ export const useAppConfig = create<ChatConfigStore>()(
updater(config); updater(config);
set(() => config); set(() => config);
}, },
mergeModels(newModels) {
const oldModels = get().models;
const modelMap: Record<string, LLMModel> = {};
for (const model of oldModels) {
model.available = false;
modelMap[model.name] = model;
}
for (const model of newModels) {
model.available = true;
modelMap[model.name] = model;
}
set(() => ({
models: Object.values(modelMap),
}));
},
}), }),
{ {
name: StoreKey.Config, name: StoreKey.Config,