From 778e88cb5677dcc0658ea3ef85ed2707ff9d398a Mon Sep 17 00:00:00 2001 From: Fred Liang Date: Sun, 24 Dec 2023 02:15:30 +0800 Subject: [PATCH] chore: resolve conflict --- .env.template | 11 ++ app/api/common.ts | 2 +- app/api/google/[...path]/route.ts | 104 ++++++++++++++++ app/client/api.ts | 34 +++-- app/client/platforms/google.ts | 199 ++++++++++++++++++++++++++++++ app/components/exporter.tsx | 14 ++- app/components/home.tsx | 13 +- app/components/model-config.tsx | 2 +- app/components/settings.tsx | 67 +++++++++- app/config/server.ts | 3 + app/constant.ts | 95 ++++++++++++++ app/locales/cn.ts | 25 +++- app/store/access.ts | 11 ++ app/store/chat.ts | 54 +++++--- app/store/update.ts | 64 +++++----- app/utils/model.ts | 6 +- 16 files changed, 630 insertions(+), 74 deletions(-) create mode 100644 app/api/google/[...path]/route.ts create mode 100644 app/client/platforms/google.ts diff --git a/.env.template b/.env.template index 3e329036..89bab2cb 100644 --- a/.env.template +++ b/.env.template @@ -8,6 +8,16 @@ CODE=your-password # You can start service behind a proxy PROXY_URL=http://localhost:7890 +# (optional) +# Default: Empty +# Googel Gemini Pro API key, set if you want to use Google Gemini Pro API. +GOOGLE_API_KEY= + +# (optional) +# Default: https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent +# Googel Gemini Pro API url, set if you want to customize Google Gemini Pro API url. +GOOGLE_URL= + # Override openai api request base url. (optional) # Default: https://api.openai.com # Examples: http://your-openai-proxy.com @@ -36,3 +46,4 @@ ENABLE_BALANCE_QUERY= # Default: Empty # If you want to disable parse settings from url, set this value to 1. DISABLE_FAST_LINK= + diff --git a/app/api/common.ts b/app/api/common.ts index 6b0d619d..13cfab03 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -1,6 +1,6 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "../config/server"; -import { DEFAULT_MODELS, OPENAI_BASE_URL } from "../constant"; +import { DEFAULT_MODELS, OPENAI_BASE_URL, GEMINI_BASE_URL } from "../constant"; import { collectModelTable } from "../utils/model"; import { makeAzurePath } from "../azure"; diff --git a/app/api/google/[...path]/route.ts b/app/api/google/[...path]/route.ts new file mode 100644 index 00000000..5b19740a --- /dev/null +++ b/app/api/google/[...path]/route.ts @@ -0,0 +1,104 @@ +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "../../auth"; +import { getServerSideConfig } from "@/app/config/server"; +import { GEMINI_BASE_URL, Google } from "@/app/constant"; + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Google Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const controller = new AbortController(); + + const serverConfig = getServerSideConfig(); + + let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL; + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + let path = `${req.nextUrl.pathname}`.replaceAll("/api/google/", ""); + + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + // this fix [Org ID] undefined in server side if not using custom point + if (serverConfig.openaiOrgId !== undefined) { + console.log("[Org ID]", serverConfig.openaiOrgId); + } + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const fetchUrl = `${baseUrl}/${path}?key=${req.nextUrl.searchParams.get( + "key", + )}`; + + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-store", + }, + method: req.method, + body: req.body, + // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + try { + const res = await fetch(fetchUrl, fetchOptions); + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new Response(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; diff --git a/app/client/api.ts b/app/client/api.ts index c7e33c71..50865d4b 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -1,8 +1,13 @@ import { getClientConfig } from "../config/client"; -import { ACCESS_CODE_PREFIX, Azure, ServiceProvider } from "../constant"; -import { ChatMessage, ModelType, useAccessStore } from "../store"; +import { + ACCESS_CODE_PREFIX, + Azure, + ModelProvider, + ServiceProvider, +} from "../constant"; +import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; - +import { GeminiApi } from "./platforms/google"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -40,7 +45,15 @@ export interface LLMUsage { export interface LLMModel { name: string; + displayName: string; available: boolean; + provider: LLMModelProvider; +} + +export interface LLMModelProvider { + id: string; + providerName: string; + providerType: string; } export abstract class LLMApi { @@ -73,7 +86,11 @@ interface ChatProvider { export class ClientApi { public llm: LLMApi; - constructor() { + constructor(provider: ModelProvider = ModelProvider.GPT) { + if (provider === ModelProvider.Gemini) { + this.llm = new GeminiApi(); + return; + } this.llm = new ChatGPTApi(); } @@ -123,8 +140,6 @@ export class ClientApi { } } -export const api = new ClientApi(); - export function getHeaders() { const accessStore = useAccessStore.getState(); const headers: Record = { @@ -132,9 +147,14 @@ export function getHeaders() { "x-requested-with": "XMLHttpRequest", }; + const isGoogle = accessStore.provider === ServiceProvider.Google; const isAzure = accessStore.provider === ServiceProvider.Azure; const authHeader = isAzure ? "api-key" : "Authorization"; - const apiKey = isAzure ? accessStore.azureApiKey : accessStore.openaiApiKey; + const apiKey = isGoogle + ? accessStore.googleApiKey + : isAzure + ? accessStore.azureApiKey + : accessStore.openaiApiKey; const makeBearer = (s: string) => `${isAzure ? "" : "Bearer "}${s.trim()}`; const validString = (x: string) => x && x.length > 0; diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts new file mode 100644 index 00000000..90584571 --- /dev/null +++ b/app/client/platforms/google.ts @@ -0,0 +1,199 @@ +import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; +import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "@/app/utils/format"; +import { getClientConfig } from "@/app/config/client"; +import Locale from "../../locales"; +export class GeminiApi implements LLMApi { + extractMessage(res: any) { + console.log("[Response] gemini response: ", res); + return ( + res?.candidates?.at(0)?.content?.parts.at(0)?.text || + res?.error?.message || + "" + ); + } + async chat(options: ChatOptions): Promise { + const messages = options.messages.map((v) => ({ + role: v.role.replace("assistant", "model").replace("system", "model"), + parts: [{ text: v.content }], + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + const accessStore = useAccessStore.getState(); + + const requestPayload = { + contents: messages, + // stream: options.config.stream, + // model: modelConfig.model, + // temperature: modelConfig.temperature, + // presence_penalty: modelConfig.presence_penalty, + // frequency_penalty: modelConfig.frequency_penalty, + // top_p: modelConfig.top_p, + // max_tokens: Math.max(modelConfig.max_tokens, 1024), + // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. + }; + + console.log("[Request] openai payload: ", requestPayload); + + // todo: support stream later + const shouldStream = false; + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = + this.path(Google.ChatPath) + `?key=${accessStore.googleApiKey}`; + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + if (shouldStream) { + let responseText = ""; + let remainText = ""; + let finished = false; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log( + "[OpenAI] request response content type: ", + contentType, + ); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text) as { + choices: Array<{ + delta: { + content: string; + }; + }>; + }; + const delta = json.choices[0]?.delta?.content; + if (delta) { + remainText += delta; + } + } catch (e) { + console.error("[Request] parse error", text); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } else { + const res = await fetch(chatPath, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + } catch (e) { + console.log("[Request] failed to make a chat request", e); + options.onError?.(e as Error); + } + } + usage(): Promise { + throw new Error("Method not implemented."); + } + async models(): Promise { + return []; + } + path(path: string): string { + return "/api/google/" + path; + } +} diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 8cae7ac9..70b4ab91 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -29,10 +29,11 @@ import NextImage from "next/image"; import { toBlob, toPng } from "html-to-image"; import { DEFAULT_MASK_AVATAR } from "../store/mask"; -import { api } from "../client/api"; + import { prettyObject } from "../utils/format"; -import { EXPORT_MESSAGE_CLASS_NAME } from "../constant"; +import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant"; import { getClientConfig } from "../config/client"; +import { ClientApi } from "../client/api"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -301,10 +302,17 @@ export function PreviewActions(props: { }) { const [loading, setLoading] = useState(false); const [shouldExport, setShouldExport] = useState(false); - + const config = useAppConfig(); const onRenderMsgs = (msgs: ChatMessage[]) => { setShouldExport(false); + var api: ClientApi; + if (config.modelConfig.model === "gemini") { + api = new ClientApi(ModelProvider.Gemini); + } else { + api = new ClientApi(ModelProvider.GPT); + } + api .share(msgs) .then((res) => { diff --git a/app/components/home.tsx b/app/components/home.tsx index 811cbdf5..928c2d90 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg"; import { getCSSVar, useMobileScreen } from "../utils"; import dynamic from "next/dynamic"; -import { Path, SlotID } from "../constant"; +import { ModelProvider, Path, SlotID } from "../constant"; import { ErrorBoundary } from "./error"; import { getISOLang, getLang } from "../locales"; @@ -27,7 +27,7 @@ import { SideBar } from "./sidebar"; import { useAppConfig } from "../store/config"; import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; -import { api } from "../client/api"; +import { ClientApi } from "../client/api"; import { useAccessStore } from "../store"; export function Loading(props: { noLogo?: boolean }) { @@ -128,7 +128,8 @@ function Screen() { const isHome = location.pathname === Path.Home; const isAuth = location.pathname === Path.Auth; const isMobileScreen = useMobileScreen(); - const shouldTightBorder = getClientConfig()?.isApp || (config.tightBorder && !isMobileScreen); + const shouldTightBorder = + getClientConfig()?.isApp || (config.tightBorder && !isMobileScreen); useEffect(() => { loadAsyncGoogleFont(); @@ -169,6 +170,12 @@ function Screen() { export function useLoadData() { const config = useAppConfig(); + var api: ClientApi; + if (config.modelConfig.model === "gemini") { + api = new ClientApi(ModelProvider.Gemini); + } else { + api = new ClientApi(ModelProvider.GPT); + } useEffect(() => { (async () => { const models = await api.llm.models(); diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 214a18c7..a077b82c 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -29,7 +29,7 @@ export function ModelConfigList(props: { .filter((v) => v.available) .map((v, i) => ( ))} diff --git a/app/components/settings.tsx b/app/components/settings.tsx index f53024d6..9a622af3 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -52,6 +52,7 @@ import { copyToClipboard } from "../utils"; import Link from "next/link"; import { Azure, + Google, OPENAI_BASE_URL, Path, RELEASE_URL, @@ -635,7 +636,8 @@ export function Settings() { navigate(Path.Home); } }; - if (clientConfig?.isApp) { // Force to set custom endpoint to true if it's app + if (clientConfig?.isApp) { + // Force to set custom endpoint to true if it's app accessStore.update((state) => { state.useCustomConfig = true; }); @@ -997,7 +999,7 @@ export function Settings() { /> - ) : ( + ) : accessStore.provider === "Azure" ? ( <> - )} + ) : accessStore.provider === "Google" ? ( + <> + + + accessStore.update( + (access) => + (access.googleUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.googleApiKey = e.currentTarget.value), + ); + }} + /> + + + + accessStore.update( + (access) => + (access.googleApiVersion = + e.currentTarget.value), + ) + } + > + + + ) : null} )} diff --git a/app/config/server.ts b/app/config/server.ts index 2398805a..becad842 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -80,6 +80,9 @@ export const getServerSideConfig = () => { azureApiKey: process.env.AZURE_API_KEY, azureApiVersion: process.env.AZURE_API_VERSION, + googleApiKey: process.env.GOOGLE_API_KEY, + googleUrl: process.env.GOOGLE_URL, + needCode: ACCESS_CODES.size > 0, code: process.env.CODE, codes: ACCESS_CODES, diff --git a/app/constant.ts b/app/constant.ts index 69d5c511..1f6a647d 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -12,6 +12,8 @@ export const DEFAULT_CORS_HOST = "https://a.nextweb.fun"; export const DEFAULT_API_HOST = `${DEFAULT_CORS_HOST}/api/proxy`; export const OPENAI_BASE_URL = "https://api.openai.com"; +export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; + export enum Path { Home = "/", Chat = "/chat", @@ -65,6 +67,12 @@ export const EXPORT_MESSAGE_CLASS_NAME = "export-markdown"; export enum ServiceProvider { OpenAI = "OpenAI", Azure = "Azure", + Google = "Google", +} + +export enum ModelProvider { + GPT = "GPT", + Gemini = "Gemini", } export const OpenaiPath = { @@ -78,6 +86,14 @@ export const Azure = { ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", }; +export const Google = { + ExampleEndpoint: + "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent", + ChatPath: "v1beta/models/gemini-pro:generateContent", + + // /api/openai/v1/chat/completions +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_SYSTEM_TEMPLATE = ` You are ChatGPT, a large language model trained by OpenAI. @@ -100,58 +116,137 @@ export const DEFAULT_MODELS = [ { name: "gpt-4", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-0314", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-0613", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-32k", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-32k-0314", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-32k-0613", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-1106-preview", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-4-vision-preview", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-3.5-turbo", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-3.5-turbo-0301", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-3.5-turbo-0613", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-3.5-turbo-1106", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-3.5-turbo-16k", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, }, { name: "gpt-3.5-turbo-16k-0613", available: true, + provider: { + id: "openai", + providerName: "OpenAI", + providerType: "openai", + }, + }, + { + name: "gemini", + available: true, + provider: { + id: "google", + providerName: "Google", + providerType: "google", + }, }, ] as const; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 50dd4428..42270b2f 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -312,6 +312,23 @@ const cn = { SubTitle: "选择指定的部分版本", }, }, + Google: { + ApiKey: { + Title: "接口密钥", + SubTitle: "使用自定义 Google AI Studio API Key 绕过密码访问限制", + Placeholder: "Google AI Studio API Key", + }, + + Endpoint: { + Title: "接口地址", + SubTitle: "样例:", + }, + + ApiVerion: { + Title: "接口版本 (gemini api version)", + SubTitle: "选择指定的部分版本", + }, + }, CustomModel: { Title: "自定义模型名", SubTitle: "增加自定义模型可选项,使用英文逗号隔开", @@ -347,7 +364,7 @@ const cn = { Prompt: { History: (content: string) => "这是历史聊天总结作为前情提要:" + content, Topic: - "使用四到五个字直接返回这句话的简要主题,不要解释、不要标点、不要语气词、不要多余文本,如果没有主题,请直接返回“闲聊”", + "使用四到五个字直接返回这句话的简要主题,不要解释、不要标点、不要语气词、不要多余文本,不要加粗,如果没有主题,请直接返回“闲聊”", Summarize: "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内", }, @@ -441,9 +458,9 @@ const cn = { Config: "配置", }, Exporter: { - Description : { - Title: "只有清除上下文之后的消息会被展示" - }, + Description: { + Title: "只有清除上下文之后的消息会被展示", + }, Model: "模型", Messages: "消息", Topic: "主题", diff --git a/app/store/access.ts b/app/store/access.ts index 3b9008ba..9e8024a6 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -29,6 +29,11 @@ const DEFAULT_ACCESS_STATE = { azureApiKey: "", azureApiVersion: "2023-08-01-preview", + // google ai studio + googleUrl: "", + googleApiKey: "", + googleApiVersion: "v1", + // server config needCode: true, hideUserApiKey: false, @@ -56,6 +61,10 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["azureUrl", "azureApiKey", "azureApiVersion"]); }, + isValidGoogle() { + return ensure(get(), ["googleApiKey"]); + }, + isAuthorized() { this.fetch(); @@ -63,6 +72,7 @@ export const useAccessStore = createPersistStore( return ( this.isValidOpenAI() || this.isValidAzure() || + this.isValidGoogle() || !this.enabledAccessControl() || (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); @@ -99,6 +109,7 @@ export const useAccessStore = createPersistStore( token: string; openaiApiKey: string; azureApiVersion: string; + googleApiKey: string; }; state.openaiApiKey = state.token; state.azureApiVersion = "2023-08-01-preview"; diff --git a/app/store/chat.ts b/app/store/chat.ts index 66a39d2b..f53f6115 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -8,10 +8,11 @@ import { DEFAULT_INPUT_TEMPLATE, DEFAULT_SYSTEM_TEMPLATE, KnowledgeCutOffDate, + ModelProvider, StoreKey, SUMMARIZE_MODEL, } from "../constant"; -import { api, RequestMessage } from "../client/api"; +import { ClientApi, RequestMessage } from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; @@ -301,6 +302,13 @@ export const useChatStore = createPersistStore( ]); }); + var api: ClientApi; + if (modelConfig.model === "gemini") { + api = new ClientApi(ModelProvider.Gemini); + } else { + api = new ClientApi(ModelProvider.GPT); + } + // make request api.llm.chat({ messages: sendMessages, @@ -379,22 +387,26 @@ export const useChatStore = createPersistStore( // system prompts, to get close to OpenAI Web ChatGPT const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts; - const systemPrompts = shouldInjectSystemPrompts - ? [ - createMessage({ - role: "system", - content: fillTemplateWith("", { - ...modelConfig, - template: DEFAULT_SYSTEM_TEMPLATE, + + var systemPrompts: ChatMessage[] = []; + if (modelConfig.model !== "gemini") { + systemPrompts = shouldInjectSystemPrompts + ? [ + createMessage({ + role: "system", + content: fillTemplateWith("", { + ...modelConfig, + template: DEFAULT_SYSTEM_TEMPLATE, + }), }), - }), - ] - : []; - if (shouldInjectSystemPrompts) { - console.log( - "[Global System Prompt] ", - systemPrompts.at(0)?.content ?? "empty", - ); + ] + : []; + if (shouldInjectSystemPrompts) { + console.log( + "[Global System Prompt] ", + systemPrompts.at(0)?.content ?? "empty", + ); + } } // long term memory @@ -473,6 +485,14 @@ export const useChatStore = createPersistStore( summarizeSession() { const config = useAppConfig.getState(); const session = get().currentSession(); + const modelConfig = session.mask.modelConfig; + + var api: ClientApi; + if (modelConfig.model === "gemini") { + api = new ClientApi(ModelProvider.Gemini); + } else { + api = new ClientApi(ModelProvider.GPT); + } // remove error messages if any const messages = session.messages; @@ -504,8 +524,6 @@ export const useChatStore = createPersistStore( }, }); } - - const modelConfig = session.mask.modelConfig; const summarizeIndex = Math.max( session.lastSummarizeIndex, session.clearContextIndex ?? 0, diff --git a/app/store/update.ts b/app/store/update.ts index 2ab7ec19..3c88866d 100644 --- a/app/store/update.ts +++ b/app/store/update.ts @@ -1,5 +1,4 @@ import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant"; -import { api } from "../client/api"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import ChatGptIcon from "../icons/chatgpt.png"; @@ -85,35 +84,40 @@ export const useUpdateStore = createPersistStore( })); if (window.__TAURI__?.notification && isApp) { // Check if notification permission is granted - await window.__TAURI__?.notification.isPermissionGranted().then((granted) => { - if (!granted) { - return; - } else { - // Request permission to show notifications - window.__TAURI__?.notification.requestPermission().then((permission) => { - if (permission === 'granted') { - if (version === remoteId) { - // Show a notification using Tauri - window.__TAURI__?.notification.sendNotification({ - title: "NextChat", - body: `${Locale.Settings.Update.IsLatest}`, - icon: `${ChatGptIcon.src}`, - sound: "Default" - }); - } else { - const updateMessage = Locale.Settings.Update.FoundUpdate(`${remoteId}`); - // Show a notification for the new version using Tauri - window.__TAURI__?.notification.sendNotification({ - title: "NextChat", - body: updateMessage, - icon: `${ChatGptIcon.src}`, - sound: "Default" - }); - } - } - }); - } - }); + await window.__TAURI__?.notification + .isPermissionGranted() + .then((granted) => { + if (!granted) { + return; + } else { + // Request permission to show notifications + window.__TAURI__?.notification + .requestPermission() + .then((permission) => { + if (permission === "granted") { + if (version === remoteId) { + // Show a notification using Tauri + window.__TAURI__?.notification.sendNotification({ + title: "NextChat", + body: `${Locale.Settings.Update.IsLatest}`, + icon: `${ChatGptIcon.src}`, + sound: "Default", + }); + } else { + const updateMessage = + Locale.Settings.Update.FoundUpdate(`${remoteId}`); + // Show a notification for the new version using Tauri + window.__TAURI__?.notification.sendNotification({ + title: "NextChat", + body: updateMessage, + icon: `${ChatGptIcon.src}`, + sound: "Default", + }); + } + } + }); + } + }); } console.log("[Got Upstream] ", remoteId); } catch (error) { diff --git a/app/utils/model.ts b/app/utils/model.ts index 74b28a66..16bcc19c 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -4,10 +4,7 @@ export function collectModelTable( models: readonly LLMModel[], customModels: string, ) { - const modelTable: Record< - string, - { available: boolean; name: string; displayName: string } - > = {}; + const modelTable: { [key: string]: LLMModel } = {}; // default models models.forEach( @@ -37,6 +34,7 @@ export function collectModelTable( name, displayName: displayName || name, available, + provider: modelTable[name].provider, }; }); return modelTable;