import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { getClientConfig } from "@/app/config/client"; import { DEFAULT_API_HOST } from "@/app/constant"; import { getMessageTextContent, getMessageImages, isVisionModel, } from "@/app/utils"; export class GeminiProApi implements LLMApi { extractMessage(res: any) { console.log("[Response] gemini-pro response: ", res); return ( res?.candidates?.at(0)?.content?.parts.at(0)?.text || res?.error?.message || "" ); } async chat(options: ChatOptions): Promise { // const apiClient = this; const visionModel = isVisionModel(options.config.model); let multimodal = false; const messages = options.messages.map((v) => { let parts: any[] = [{ text: getMessageTextContent(v) }]; if (visionModel) { const images = getMessageImages(v); if (images.length > 0) { multimodal = true; parts = parts.concat( images.map((image) => { const imageType = image.split(";")[0].split(":")[1]; const imageData = image.split(",")[1]; return { inline_data: { mime_type: imageType, data: imageData, }, }; }), ); } } return { role: v.role.replace("assistant", "model").replace("system", "user"), parts: parts, }; }); // google requires that role in neighboring messages must not be the same for (let i = 0; i < messages.length - 1; ) { // Check if current and next item both have the role "model" if (messages[i].role === messages[i + 1].role) { // Concatenate the 'parts' of the current and next item messages[i].parts = messages[i].parts.concat(messages[i + 1].parts); // Remove the next item messages.splice(i + 1, 1); } else { // Move to the next item i++; } } // if (visionModel && messages.length > 1) { // options.onError?.(new Error("Multiturn chat is not enabled for models/gemini-pro-vision")); // } const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, ...{ model: options.config.model, }, }; const requestPayload = { contents: messages, generationConfig: { // stopSequences: [ // "Title" // ], temperature: modelConfig.temperature, maxOutputTokens: modelConfig.max_tokens, topP: modelConfig.top_p, // "topK": modelConfig.top_k, }, safetySettings: [ { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_ONLY_HIGH", }, { category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_ONLY_HIGH", }, { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_ONLY_HIGH", }, { category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_ONLY_HIGH", }, ], }; const accessStore = useAccessStore.getState(); let baseUrl = accessStore.googleUrl; const isApp = !!getClientConfig()?.isApp; let shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); try { let googleChatPath = visionModel ? Google.VisionChatPath : Google.ChatPath; let chatPath = this.path(googleChatPath); // let baseUrl = accessStore.googleUrl; if (!baseUrl) { baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/google/" + googleChatPath : chatPath; } if (isApp) { baseUrl += `?key=${accessStore.googleApiKey}`; } const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), signal: controller.signal, headers: getHeaders(), }; // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), REQUEST_TIMEOUT_MS, ); if (shouldStream) { let responseText = ""; let remainText = ""; let finished = false; let existingTexts: string[] = []; const finish = () => { finished = true; options.onFinish(existingTexts.join("")); }; // animate response to make it looks smooth function animateResponseText() { if (finished || controller.signal.aborted) { responseText += remainText; finish(); return; } if (remainText.length > 0) { const fetchCount = Math.max(1, Math.round(remainText.length / 60)); const fetchText = remainText.slice(0, fetchCount); responseText += fetchText; remainText = remainText.slice(fetchCount); options.onUpdate?.(responseText, fetchText); } requestAnimationFrame(animateResponseText); } // start animaion animateResponseText(); fetch( baseUrl.replace("generateContent", "streamGenerateContent"), chatPayload, ) .then((response) => { const reader = response?.body?.getReader(); const decoder = new TextDecoder(); let partialData = ""; return reader?.read().then(function processText({ done, value, }): Promise { if (done) { if (response.status !== 200) { try { let data = JSON.parse(ensureProperEnding(partialData)); if (data && data[0].error) { options.onError?.(new Error(data[0].error.message)); } else { options.onError?.(new Error("Request failed")); } } catch (_) { options.onError?.(new Error("Request failed")); } } console.log("Stream complete"); // options.onFinish(responseText + remainText); finished = true; return Promise.resolve(); } partialData += decoder.decode(value, { stream: true }); try { let data = JSON.parse(ensureProperEnding(partialData)); const textArray = data.reduce( (acc: string[], item: { candidates: any[] }) => { const texts = item.candidates.map((candidate) => candidate.content.parts .map((part: { text: any }) => part.text) .join(""), ); return acc.concat(texts); }, [], ); if (textArray.length > existingTexts.length) { const deltaArray = textArray.slice(existingTexts.length); existingTexts = textArray; remainText += deltaArray.join(""); } } catch (error) { // console.log("[Response Animation] error: ", error,partialData); // skip error message when parsing json } return reader.read().then(processText); }); }) .catch((error) => { console.error("Error:", error); }); } else { const res = await fetch(baseUrl, chatPayload); clearTimeout(requestTimeoutId); const resJson = await res.json(); if (resJson?.promptFeedback?.blockReason) { // being blocked options.onError?.( new Error( "Message is being blocked for reason: " + resJson.promptFeedback.blockReason, ), ); } const message = this.extractMessage(resJson); options.onFinish(message); } } catch (e) { console.log("[Request] failed to make a chat request", e); options.onError?.(e as Error); } } usage(): Promise { throw new Error("Method not implemented."); } async models(): Promise { return []; } path(path: string): string { return "/api/google/" + path; } } function ensureProperEnding(str: string) { if (str.startsWith("[") && !str.endsWith("]")) { return str + "]"; } return str; }