diff --git a/README.md b/README.md
index 91e03d80..3973c84b 100644
--- a/README.md
+++ b/README.md
@@ -197,6 +197,13 @@ If you do want users to query balance, set this value to 1, or you should set it
If you want to disable parse settings from url, set this to 1.
+### `CUSTOM_MODELS` (optional)
+
+> Default: Empty
+> Example: `+llama,+claude-2,-gpt-3.5-turbo` means add `llama, claude-2` to model list, and remove `gpt-3.5-turbo` from list.
+
+To control custom models, use `+` to add a custom model, use `-` to hide a model, separated by comma.
+
## Requirements
NodeJS >= 18, Docker >= 20
diff --git a/README_CN.md b/README_CN.md
index 13b97417..d8e9553e 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -106,6 +106,12 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填
如果你想禁用从链接解析预制设置,将此环境变量设置为 1 即可。
+### `CUSTOM_MODELS` (可选)
+
+> 示例:`+qwen-7b-chat,+glm-6b,-gpt-3.5-turbo` 表示增加 `qwen-7b-chat` 和 `glm-6b` 到模型列表,而从列表中删除 `gpt-3.5-turbo`。
+
+用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,用英文逗号隔开。
+
## 开发
点击下方按钮,开始二次开发:
diff --git a/app/api/common.ts b/app/api/common.ts
index 0af7761d..a1decd42 100644
--- a/app/api/common.ts
+++ b/app/api/common.ts
@@ -1,10 +1,9 @@
import { NextRequest, NextResponse } from "next/server";
+import { getServerSideConfig } from "../config/server";
+import { DEFAULT_MODELS, OPENAI_BASE_URL } from "../constant";
+import { collectModelTable, collectModels } from "../utils/model";
-export const OPENAI_URL = "api.openai.com";
-const DEFAULT_PROTOCOL = "https";
-const PROTOCOL = process.env.PROTOCOL || DEFAULT_PROTOCOL;
-const BASE_URL = process.env.BASE_URL || OPENAI_URL;
-const DISABLE_GPT4 = !!process.env.DISABLE_GPT4;
+const serverConfig = getServerSideConfig();
export async function requestOpenai(req: NextRequest) {
const controller = new AbortController();
@@ -14,10 +13,10 @@ export async function requestOpenai(req: NextRequest) {
"",
);
- let baseUrl = BASE_URL;
+ let baseUrl = serverConfig.baseUrl ?? OPENAI_BASE_URL;
if (!baseUrl.startsWith("http")) {
- baseUrl = `${PROTOCOL}://${baseUrl}`;
+ baseUrl = `https://${baseUrl}`;
}
if (baseUrl.endsWith("/")) {
@@ -26,10 +25,7 @@ export async function requestOpenai(req: NextRequest) {
console.log("[Proxy] ", openaiPath);
console.log("[Base Url]", baseUrl);
-
- if (process.env.OPENAI_ORG_ID) {
- console.log("[Org ID]", process.env.OPENAI_ORG_ID);
- }
+ console.log("[Org ID]", serverConfig.openaiOrgId);
const timeoutId = setTimeout(
() => {
@@ -58,18 +54,23 @@ export async function requestOpenai(req: NextRequest) {
};
// #1815 try to refuse gpt4 request
- if (DISABLE_GPT4 && req.body) {
+ if (serverConfig.customModels && req.body) {
try {
+ const modelTable = collectModelTable(
+ DEFAULT_MODELS,
+ serverConfig.customModels,
+ );
const clonedBody = await req.text();
fetchOptions.body = clonedBody;
- const jsonBody = JSON.parse(clonedBody);
+ const jsonBody = JSON.parse(clonedBody) as { model?: string };
- if ((jsonBody?.model ?? "").includes("gpt-4")) {
+ // not undefined and is false
+ if (modelTable[jsonBody?.model ?? ""] === false) {
return NextResponse.json(
{
error: true,
- message: "you are not allowed to use gpt-4 model",
+ message: `you are not allowed to use ${jsonBody?.model} model`,
},
{
status: 403,
diff --git a/app/api/config/route.ts b/app/api/config/route.ts
index 44af8d3b..db84fba1 100644
--- a/app/api/config/route.ts
+++ b/app/api/config/route.ts
@@ -12,6 +12,7 @@ const DANGER_CONFIG = {
disableGPT4: serverConfig.disableGPT4,
hideBalanceQuery: serverConfig.hideBalanceQuery,
disableFastLink: serverConfig.disableFastLink,
+ customModels: serverConfig.customModels,
};
declare global {
diff --git a/app/components/chat.tsx b/app/components/chat.tsx
index a0b7307c..9afb49f7 100644
--- a/app/components/chat.tsx
+++ b/app/components/chat.tsx
@@ -88,6 +88,7 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command";
import { prettyObject } from "../utils/format";
import { ExportMessageModal } from "./exporter";
import { getClientConfig } from "../config/client";
+import { useAllModels } from "../utils/hooks";
const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => ,
@@ -430,14 +431,9 @@ export function ChatActions(props: {
// switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model;
- const models = useMemo(
- () =>
- config
- .allModels()
- .filter((m) => m.available)
- .map((m) => m.name),
- [config],
- );
+ const models = useAllModels()
+ .filter((m) => m.available)
+ .map((m) => m.name);
const [showModelSelector, setShowModelSelector] = useState(false);
return (
diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx
index 63950a40..6e4c9bcb 100644
--- a/app/components/model-config.tsx
+++ b/app/components/model-config.tsx
@@ -1,14 +1,15 @@
-import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store";
+import { ModalConfigValidator, ModelConfig } from "../store";
import Locale from "../locales";
import { InputRange } from "./input-range";
import { ListItem, Select } from "./ui-lib";
+import { useAllModels } from "../utils/hooks";
export function ModelConfigList(props: {
modelConfig: ModelConfig;
updateConfig: (updater: (config: ModelConfig) => void) => void;
}) {
- const config = useAppConfig();
+ const allModels = useAllModels();
return (
<>
@@ -24,7 +25,7 @@ export function ModelConfigList(props: {
);
}}
>
- {config.allModels().map((v, i) => (
+ {allModels.map((v, i) => (
diff --git a/app/config/server.ts b/app/config/server.ts
index 2df806fe..007c3973 100644
--- a/app/config/server.ts
+++ b/app/config/server.ts
@@ -1,4 +1,5 @@
import md5 from "spark-md5";
+import { DEFAULT_MODELS } from "../constant";
declare global {
namespace NodeJS {
@@ -7,6 +8,7 @@ declare global {
CODE?: string;
BASE_URL?: string;
PROXY_URL?: string;
+ OPENAI_ORG_ID?: string;
VERCEL?: string;
HIDE_USER_API_KEY?: string; // disable user's api key input
DISABLE_GPT4?: string; // allow user to use gpt-4 or not
@@ -14,6 +16,7 @@ declare global {
BUILD_APP?: string; // is building desktop app
ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
+ CUSTOM_MODELS?: string; // to control custom models
}
}
}
@@ -38,6 +41,16 @@ export const getServerSideConfig = () => {
);
}
+ let disableGPT4 = !!process.env.DISABLE_GPT4;
+ let customModels = process.env.CUSTOM_MODELS ?? "";
+
+ if (disableGPT4) {
+ if (customModels) customModels += ",";
+ customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4"))
+ .map((m) => "-" + m.name)
+ .join(",");
+ }
+
return {
apiKey: process.env.OPENAI_API_KEY,
code: process.env.CODE,
@@ -45,10 +58,12 @@ export const getServerSideConfig = () => {
needCode: ACCESS_CODES.size > 0,
baseUrl: process.env.BASE_URL,
proxyUrl: process.env.PROXY_URL,
+ openaiOrgId: process.env.OPENAI_ORG_ID,
isVercel: !!process.env.VERCEL,
hideUserApiKey: !!process.env.HIDE_USER_API_KEY,
- disableGPT4: !!process.env.DISABLE_GPT4,
+ disableGPT4,
hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
disableFastLink: !!process.env.DISABLE_FAST_LINK,
+ customModels,
};
};
diff --git a/app/store/access.ts b/app/store/access.ts
index 3d889f6e..f87e44a2 100644
--- a/app/store/access.ts
+++ b/app/store/access.ts
@@ -17,6 +17,7 @@ const DEFAULT_ACCESS_STATE = {
hideBalanceQuery: false,
disableGPT4: false,
disableFastLink: false,
+ customModels: "",
openaiUrl: DEFAULT_OPENAI_URL,
};
@@ -52,12 +53,6 @@ export const useAccessStore = createPersistStore(
.then((res: DangerConfig) => {
console.log("[Config] got config from server", res);
set(() => ({ ...res }));
-
- if (res.disableGPT4) {
- DEFAULT_MODELS.forEach(
- (m: any) => (m.available = !m.name.startsWith("gpt-4")),
- );
- }
})
.catch(() => {
console.error("[Config] failed to fetch config");
diff --git a/app/store/config.ts b/app/store/config.ts
index 0fbc26df..5fcd6ff5 100644
--- a/app/store/config.ts
+++ b/app/store/config.ts
@@ -128,15 +128,7 @@ export const useAppConfig = createPersistStore(
}));
},
- allModels() {
- const customModels = get()
- .customModels.split(",")
- .filter((v) => !!v && v.length > 0)
- .map((m) => ({ name: m, available: true }));
- const allModels = get().models.concat(customModels);
- allModels.sort((a, b) => (a.name < b.name ? -1 : 1));
- return allModels;
- },
+ allModels() {},
}),
{
name: StoreKey.Config,
diff --git a/app/utils/hooks.ts b/app/utils/hooks.ts
new file mode 100644
index 00000000..f6bfae67
--- /dev/null
+++ b/app/utils/hooks.ts
@@ -0,0 +1,16 @@
+import { useMemo } from "react";
+import { useAccessStore, useAppConfig } from "../store";
+import { collectModels } from "./model";
+
+export function useAllModels() {
+ const accessStore = useAccessStore();
+ const configStore = useAppConfig();
+ const models = useMemo(() => {
+ return collectModels(
+ configStore.models,
+ [accessStore.customModels, configStore.customModels].join(","),
+ );
+ }, [accessStore.customModels, configStore.customModels, configStore.models]);
+
+ return models;
+}
diff --git a/app/utils/model.ts b/app/utils/model.ts
new file mode 100644
index 00000000..23090f9d
--- /dev/null
+++ b/app/utils/model.ts
@@ -0,0 +1,40 @@
+import { LLMModel } from "../client/api";
+
+export function collectModelTable(
+ models: readonly LLMModel[],
+ customModels: string,
+) {
+ const modelTable: Record = {};
+
+ // default models
+ models.forEach((m) => (modelTable[m.name] = m.available));
+
+ // server custom models
+ customModels
+ .split(",")
+ .filter((v) => !!v && v.length > 0)
+ .map((m) => {
+ if (m.startsWith("+")) {
+ modelTable[m.slice(1)] = true;
+ } else if (m.startsWith("-")) {
+ modelTable[m.slice(1)] = false;
+ } else modelTable[m] = true;
+ });
+ return modelTable;
+}
+
+/**
+ * Generate full model table.
+ */
+export function collectModels(
+ models: readonly LLMModel[],
+ customModels: string,
+) {
+ const modelTable = collectModelTable(models, customModels);
+ const allModels = Object.keys(modelTable).map((m) => ({
+ name: m,
+ available: modelTable[m],
+ }));
+
+ return allModels;
+}