chroe: update model name

This commit is contained in:
Fred Liang 2023-12-24 04:22:12 +08:00
parent 7d9a2132cb
commit ae19a0dc5f
No known key found for this signature in database
GPG Key ID: 4DABDA85EF70EC71
8 changed files with 19 additions and 19 deletions

View File

@ -7,7 +7,7 @@ import {
} from "../constant"; } from "../constant";
import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
import { ChatGPTApi } from "./platforms/openai"; import { ChatGPTApi } from "./platforms/openai";
import { GeminiApi } from "./platforms/google"; import { GeminiProApi } from "./platforms/google";
export const ROLES = ["system", "user", "assistant"] as const; export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number]; export type MessageRole = (typeof ROLES)[number];
@ -86,8 +86,8 @@ export class ClientApi {
public llm: LLMApi; public llm: LLMApi;
constructor(provider: ModelProvider = ModelProvider.GPT) { constructor(provider: ModelProvider = ModelProvider.GPT) {
if (provider === ModelProvider.Gemini) { if (provider === ModelProvider.GeminiPro) {
this.llm = new GeminiApi(); this.llm = new GeminiProApi();
return; return;
} }
this.llm = new ChatGPTApi(); this.llm = new ChatGPTApi();
@ -146,7 +146,7 @@ export function getHeaders() {
"x-requested-with": "XMLHttpRequest", "x-requested-with": "XMLHttpRequest",
}; };
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig; const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.model === "gemini"; const isGoogle = modelConfig.model === "gemini-pro";
const isAzure = accessStore.provider === ServiceProvider.Azure; const isAzure = accessStore.provider === ServiceProvider.Azure;
const authHeader = isAzure ? "api-key" : "Authorization"; const authHeader = isAzure ? "api-key" : "Authorization";
const apiKey = isGoogle const apiKey = isGoogle

View File

@ -9,9 +9,9 @@ import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client"; import { getClientConfig } from "@/app/config/client";
import Locale from "../../locales"; import Locale from "../../locales";
import { getServerSideConfig } from "@/app/config/server"; import { getServerSideConfig } from "@/app/config/server";
export class GeminiApi implements LLMApi { export class GeminiProApi implements LLMApi {
extractMessage(res: any) { extractMessage(res: any) {
console.log("[Response] gemini response: ", res); console.log("[Response] gemini-pro response: ", res);
return ( return (
res?.candidates?.at(0)?.content?.parts.at(0)?.text || res?.candidates?.at(0)?.content?.parts.at(0)?.text ||
res?.error?.message || res?.error?.message ||

View File

@ -307,8 +307,8 @@ export function PreviewActions(props: {
setShouldExport(false); setShouldExport(false);
var api: ClientApi; var api: ClientApi;
if (config.modelConfig.model === "gemini") { if (config.modelConfig.model === "gemini-pro") {
api = new ClientApi(ModelProvider.Gemini); api = new ClientApi(ModelProvider.GeminiPro);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }

View File

@ -171,8 +171,8 @@ export function useLoadData() {
const config = useAppConfig(); const config = useAppConfig();
var api: ClientApi; var api: ClientApi;
if (config.modelConfig.model === "gemini") { if (config.modelConfig.model === "gemini-pro") {
api = new ClientApi(ModelProvider.Gemini); api = new ClientApi(ModelProvider.GeminiPro);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }

View File

@ -72,7 +72,7 @@ export enum ServiceProvider {
export enum ModelProvider { export enum ModelProvider {
GPT = "GPT", GPT = "GPT",
Gemini = "Gemini", GeminiPro = "GeminiPro",
} }
export const OpenaiPath = { export const OpenaiPath = {
@ -240,7 +240,7 @@ export const DEFAULT_MODELS = [
}, },
}, },
{ {
name: "gemini", name: "gemini-pro",
available: true, available: true,
provider: { provider: {
id: "google", id: "google",

View File

@ -325,7 +325,7 @@ const cn = {
}, },
ApiVerion: { ApiVerion: {
Title: "接口版本 (gemini api version)", Title: "接口版本 (gemini-pro api version)",
SubTitle: "选择指定的部分版本", SubTitle: "选择指定的部分版本",
}, },
}, },

View File

@ -333,7 +333,7 @@ const en: LocaleType = {
}, },
ApiVerion: { ApiVerion: {
Title: "API Version (gemini api version)", Title: "API Version (gemini-pro api version)",
SubTitle: "Select a specific part version", SubTitle: "Select a specific part version",
}, },
}, },

View File

@ -303,8 +303,8 @@ export const useChatStore = createPersistStore(
}); });
var api: ClientApi; var api: ClientApi;
if (modelConfig.model === "gemini") { if (modelConfig.model === "gemini-pro") {
api = new ClientApi(ModelProvider.Gemini); api = new ClientApi(ModelProvider.GeminiPro);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }
@ -389,7 +389,7 @@ export const useChatStore = createPersistStore(
const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts; const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
var systemPrompts: ChatMessage[] = []; var systemPrompts: ChatMessage[] = [];
if (modelConfig.model !== "gemini") { if (modelConfig.model !== "gemini-pro") {
systemPrompts = shouldInjectSystemPrompts systemPrompts = shouldInjectSystemPrompts
? [ ? [
createMessage({ createMessage({
@ -488,8 +488,8 @@ export const useChatStore = createPersistStore(
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
var api: ClientApi; var api: ClientApi;
if (modelConfig.model === "gemini") { if (modelConfig.model === "gemini-pro") {
api = new ClientApi(ModelProvider.Gemini); api = new ClientApi(ModelProvider.GeminiPro);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }