import { create } from "zustand"; import { persist } from "zustand/middleware"; import { trimTopic } from "../utils"; import Locale, { getLang } from "../locales"; import { showToast } from "../components/ui-lib"; import { ModelConfig, ModelType, useAppConfig } from "./config"; import { createEmptyMask, Mask } from "./mask"; import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; import { api, RequestMessage } from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; export type ChatMessage = RequestMessage & { date: string; streaming?: boolean; isError?: boolean; id?: number; model?: ModelType; }; export function createMessage(override: Partial): ChatMessage { return { id: Date.now(), date: new Date().toLocaleString(), role: "user", content: "", ...override, }; } export interface ChatStat { tokenCount: number; wordCount: number; charCount: number; } export interface ChatSession { id: number; topic: string; memoryPrompt: string; messages: ChatMessage[]; stat: ChatStat; lastUpdate: number; lastSummarizeIndex: number; clearContextIndex?: number; mask: Mask; } export const DEFAULT_TOPIC = Locale.Store.DefaultTopic; export const BOT_HELLO: ChatMessage = createMessage({ role: "assistant", content: Locale.Store.BotHello, }); function createEmptySession(): ChatSession { return { id: Date.now() + Math.random(), topic: DEFAULT_TOPIC, memoryPrompt: "", messages: [], stat: { tokenCount: 0, wordCount: 0, charCount: 0, }, lastUpdate: Date.now(), lastSummarizeIndex: 0, mask: createEmptyMask(), }; } interface ChatStore { sessions: ChatSession[]; currentSessionIndex: number; globalId: number; clearSessions: () => void; moveSession: (from: number, to: number) => void; selectSession: (index: number) => void; newSession: (mask?: Mask) => void; deleteSession: (index: number) => void; currentSession: () => ChatSession; nextSession: (delta: number) => void; onNewMessage: (message: ChatMessage) => void; onUserInput: (content: string) => Promise; summarizeSession: () => void; updateStat: (message: ChatMessage) => void; updateCurrentSession: (updater: (session: ChatSession) => void) => void; updateMessage: ( sessionIndex: number, messageIndex: number, updater: (message?: ChatMessage) => void, ) => void; resetSession: () => void; getMessagesWithMemory: () => ChatMessage[]; getMemoryPrompt: () => ChatMessage; clearAllData: () => void; } function countMessages(msgs: ChatMessage[]) { return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); } function fillTemplateWith(input: string, modelConfig: ModelConfig) { const vars = { model: modelConfig.model, time: new Date().toLocaleString(), lang: getLang(), input: input, }; let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE; // must contains {{input}} const inputVar = "{{input}}"; if (!output.includes(inputVar)) { output += "\n" + inputVar; } Object.entries(vars).forEach(([name, value]) => { output = output.replaceAll(`{{${name}}}`, value); }); return output; } export const useChatStore = create()( persist( (set, get) => ({ sessions: [createEmptySession()], currentSessionIndex: 0, globalId: 0, clearSessions() { set(() => ({ sessions: [createEmptySession()], currentSessionIndex: 0, })); }, selectSession(index: number) { set({ currentSessionIndex: index, }); }, moveSession(from: number, to: number) { set((state) => { const { sessions, currentSessionIndex: oldIndex } = state; // move the session const newSessions = [...sessions]; const session = newSessions[from]; newSessions.splice(from, 1); newSessions.splice(to, 0, session); // modify current session id let newIndex = oldIndex === from ? to : oldIndex; if (oldIndex > from && oldIndex <= to) { newIndex -= 1; } else if (oldIndex < from && oldIndex >= to) { newIndex += 1; } return { currentSessionIndex: newIndex, sessions: newSessions, }; }); }, newSession(mask) { const session = createEmptySession(); set(() => ({ globalId: get().globalId + 1 })); session.id = get().globalId; if (mask) { const config = useAppConfig.getState(); const globalModelConfig = config.modelConfig; session.mask = { ...mask, modelConfig: { ...globalModelConfig, ...mask.modelConfig, }, }; session.topic = mask.name; } set((state) => ({ currentSessionIndex: 0, sessions: [session].concat(state.sessions), })); }, nextSession(delta) { const n = get().sessions.length; const limit = (x: number) => (x + n) % n; const i = get().currentSessionIndex; get().selectSession(limit(i + delta)); }, deleteSession(index) { const deletingLastSession = get().sessions.length === 1; const deletedSession = get().sessions.at(index); if (!deletedSession) return; const sessions = get().sessions.slice(); sessions.splice(index, 1); const currentIndex = get().currentSessionIndex; let nextIndex = Math.min( currentIndex - Number(index < currentIndex), sessions.length - 1, ); if (deletingLastSession) { nextIndex = 0; sessions.push(createEmptySession()); } // for undo delete action const restoreState = { currentSessionIndex: get().currentSessionIndex, sessions: get().sessions.slice(), }; set(() => ({ currentSessionIndex: nextIndex, sessions, })); showToast( Locale.Home.DeleteToast, { text: Locale.Home.Revert, onClick() { set(() => restoreState); }, }, 5000, ); }, currentSession() { let index = get().currentSessionIndex; const sessions = get().sessions; if (index < 0 || index >= sessions.length) { index = Math.min(sessions.length - 1, Math.max(0, index)); set(() => ({ currentSessionIndex: index })); } const session = sessions[index]; return session; }, onNewMessage(message) { get().updateCurrentSession((session) => { session.messages = session.messages.concat(); session.lastUpdate = Date.now(); }); get().updateStat(message); get().summarizeSession(); }, async onUserInput(content) { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; const userContent = fillTemplateWith(content, modelConfig); console.log("[User Input] fill with template: ", userContent); const userMessage: ChatMessage = createMessage({ role: "user", content: userContent, }); const botMessage: ChatMessage = createMessage({ role: "assistant", streaming: true, id: userMessage.id! + 1, model: modelConfig.model, }); // get recent messages const recentMessages = get().getMessagesWithMemory(); const sendMessages = recentMessages.concat(userMessage); const sessionIndex = get().currentSessionIndex; const messageIndex = get().currentSession().messages.length + 1; // save user's and bot's message get().updateCurrentSession((session) => { const savedUserMessage = { ...userMessage, content, }; session.messages = session.messages.concat([ savedUserMessage, botMessage, ]); }); // make request console.log("[User Input] ", sendMessages); api.llm.chat({ messages: sendMessages, config: { ...modelConfig, stream: true }, onUpdate(message) { botMessage.streaming = true; if (message) { botMessage.content = message; } get().updateCurrentSession((session) => { session.messages = session.messages.concat(); }); }, onFinish(message) { botMessage.streaming = false; if (message) { botMessage.content = message; get().onNewMessage(botMessage); } ChatControllerPool.remove( sessionIndex, botMessage.id ?? messageIndex, ); }, onError(error) { const isAborted = error.message.includes("aborted"); botMessage.content = "\n\n" + prettyObject({ error: true, message: error.message, }); botMessage.streaming = false; userMessage.isError = !isAborted; botMessage.isError = !isAborted; get().updateCurrentSession((session) => { session.messages = session.messages.concat(); }); ChatControllerPool.remove( sessionIndex, botMessage.id ?? messageIndex, ); console.error("[Chat] failed ", error); }, onController(controller) { // collect controller for stop/retry ChatControllerPool.addController( sessionIndex, botMessage.id ?? messageIndex, controller, ); }, }); }, getMemoryPrompt() { const session = get().currentSession(); return { role: "system", content: session.memoryPrompt.length > 0 ? Locale.Store.Prompt.History(session.memoryPrompt) : "", date: "", } as ChatMessage; }, getMessagesWithMemory() { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; const clearContextIndex = session.clearContextIndex ?? 0; const messages = session.messages.slice(); const totalMessageCount = session.messages.length; // in-context prompts const contextPrompts = session.mask.context.slice(); // long term memory const shouldSendLongTermMemory = modelConfig.sendMemory && session.memoryPrompt && session.memoryPrompt.length > 0 && session.lastSummarizeIndex <= clearContextIndex; const longTermMemoryPrompts = shouldSendLongTermMemory ? [get().getMemoryPrompt()] : []; const longTermMemoryStartIndex = session.lastSummarizeIndex; // short term memory const shortTermMemoryStartIndex = Math.max( 0, totalMessageCount - modelConfig.historyMessageCount, ); // lets concat send messages, including 4 parts: // 1. long term memory: summarized memory messages // 2. pre-defined in-context prompts // 3. short term memory: latest n messages // 4. newest input message const memoryStartIndex = shouldSendLongTermMemory ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex) : shortTermMemoryStartIndex; // and if user has cleared history messages, we should exclude the memory too. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex); const maxTokenThreshold = modelConfig.max_tokens; // get recent messages as much as possible const reversedRecentMessages = []; for ( let i = totalMessageCount - 1, tokenCount = 0; i >= contextStartIndex && tokenCount < maxTokenThreshold; i -= 1 ) { const msg = messages[i]; if (!msg || msg.isError) continue; tokenCount += estimateTokenLength(msg.content); reversedRecentMessages.push(msg); } // concat all messages const recentMessages = [ ...longTermMemoryPrompts, ...contextPrompts, ...reversedRecentMessages.reverse(), ]; return recentMessages; }, updateMessage( sessionIndex: number, messageIndex: number, updater: (message?: ChatMessage) => void, ) { const sessions = get().sessions; const session = sessions.at(sessionIndex); const messages = session?.messages; updater(messages?.at(messageIndex)); set(() => ({ sessions })); }, resetSession() { get().updateCurrentSession((session) => { session.messages = []; session.memoryPrompt = ""; }); }, summarizeSession() { const session = get().currentSession(); // remove error messages if any const messages = session.messages; // should summarize topic after chating more than 50 words const SUMMARIZE_MIN_LEN = 50; if ( session.topic === DEFAULT_TOPIC && countMessages(messages) >= SUMMARIZE_MIN_LEN ) { const topicMessages = messages.concat( createMessage({ role: "user", content: Locale.Store.Prompt.Topic, }), ); api.llm.chat({ messages: topicMessages, config: { model: "gpt-3.5-turbo", }, onFinish(message) { get().updateCurrentSession( (session) => (session.topic = message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC), ); }, }); } const modelConfig = session.mask.modelConfig; const summarizeIndex = Math.max( session.lastSummarizeIndex, session.clearContextIndex ?? 0, ); let toBeSummarizedMsgs = messages .filter((msg) => !msg.isError) .slice(summarizeIndex); const historyMsgLength = countMessages(toBeSummarizedMsgs); if (historyMsgLength > modelConfig?.max_tokens ?? 4000) { const n = toBeSummarizedMsgs.length; toBeSummarizedMsgs = toBeSummarizedMsgs.slice( Math.max(0, n - modelConfig.historyMessageCount), ); } // add memory prompt toBeSummarizedMsgs.unshift(get().getMemoryPrompt()); const lastSummarizeIndex = session.messages.length; console.log( "[Chat History] ", toBeSummarizedMsgs, historyMsgLength, modelConfig.compressMessageLengthThreshold, ); if ( historyMsgLength > modelConfig.compressMessageLengthThreshold && modelConfig.sendMemory ) { api.llm.chat({ messages: toBeSummarizedMsgs.concat({ role: "system", content: Locale.Store.Prompt.Summarize, date: "", }), config: { ...modelConfig, stream: true }, onUpdate(message) { session.memoryPrompt = message; }, onFinish(message) { console.log("[Memory] ", message); session.lastSummarizeIndex = lastSummarizeIndex; }, onError(err) { console.error("[Summarize] ", err); }, }); } }, updateStat(message) { get().updateCurrentSession((session) => { session.stat.charCount += message.content.length; // TODO: should update chat count and word count }); }, updateCurrentSession(updater) { const sessions = get().sessions; const index = get().currentSessionIndex; updater(sessions[index]); set(() => ({ sessions })); }, clearAllData() { localStorage.clear(); location.reload(); }, }), { name: StoreKey.Chat, version: 2, migrate(persistedState, version) { const state = persistedState as any; const newState = JSON.parse(JSON.stringify(state)) as ChatStore; if (version < 2) { newState.globalId = 0; newState.sessions = []; const oldSessions = state.sessions; for (const oldSession of oldSessions) { const newSession = createEmptySession(); newSession.topic = oldSession.topic; newSession.messages = [...oldSession.messages]; newSession.mask.modelConfig.sendMemory = true; newSession.mask.modelConfig.historyMessageCount = 4; newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; newState.sessions.push(newSession); } } return newState; }, }, ), );