diff --git a/app/components/error.tsx b/app/components/error.tsx index b38341e2..914740f9 100644 --- a/app/components/error.tsx +++ b/app/components/error.tsx @@ -4,8 +4,8 @@ import GithubIcon from "../icons/github.svg"; import ResetIcon from "../icons/reload.svg"; import { ISSUE_URL } from "../constant"; import Locale from "../locales"; -import { downloadAs } from "../utils"; import { showConfirm } from "./ui-lib"; +import { useSyncStore } from "../store/sync"; interface IErrorBoundaryState { hasError: boolean; @@ -26,10 +26,7 @@ export class ErrorBoundary extends React.Component { clearAndSaveData() { try { - downloadAs( - JSON.stringify(localStorage), - "chatgpt-next-web-snapshot.json", - ); + useSyncStore.getState().export(); } finally { localStorage.clear(); location.reload(); diff --git a/app/components/mask.tsx b/app/components/mask.tsx index 3d8ce3a2..1ee1c239 100644 --- a/app/components/mask.tsx +++ b/app/components/mask.tsx @@ -410,7 +410,7 @@ export function MaskPage() { const closeMaskModal = () => setEditingMaskId(undefined); const downloadAll = () => { - downloadAs(JSON.stringify(masks), FileName.Masks); + downloadAs(JSON.stringify(masks.filter((v) => !v.builtin)), FileName.Masks); }; const importFromFile = () => { @@ -452,11 +452,13 @@ export function MaskPage() { icon={} bordered onClick={downloadAll} + text={Locale.UI.Export} />
} + text={Locale.UI.Import} bordered onClick={() => importFromFile()} /> @@ -604,7 +606,7 @@ export function MaskPage() { - maskStore.update(editingMaskId!, updater) + maskStore.updateMask(editingMaskId!, updater) } readonly={editingMask.builtin} /> diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 1e6ef713..19c54515 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -10,6 +10,9 @@ import ClearIcon from "../icons/clear.svg"; import LoadingIcon from "../icons/three-dots.svg"; import EditIcon from "../icons/edit.svg"; import EyeIcon from "../icons/eye.svg"; +import DownloadIcon from "../icons/download.svg"; +import UploadIcon from "../icons/upload.svg"; + import { Input, List, @@ -49,6 +52,7 @@ import { Avatar, AvatarPicker } from "./emoji"; import { getClientConfig } from "../config/client"; import { useSyncStore } from "../store/sync"; import { nanoid } from "nanoid"; +import { useMaskStore } from "../store/mask"; function EditPromptModal(props: { id: string; onClose: () => void }) { const promptStore = usePromptStore(); @@ -75,7 +79,7 @@ function EditPromptModal(props: { id: string; onClose: () => void }) { readOnly={!prompt.isUser} className={styles["edit-prompt-title"]} onInput={(e) => - promptStore.update( + promptStore.updatePrompt( props.id, (prompt) => (prompt.title = e.currentTarget.value), ) @@ -87,7 +91,7 @@ function EditPromptModal(props: { id: string; onClose: () => void }) { className={styles["edit-prompt-content"]} rows={10} onInput={(e) => - promptStore.update( + promptStore.updatePrompt( props.id, (prompt) => (prompt.content = e.currentTarget.value), ) @@ -127,14 +131,15 @@ function UserPromptModal(props: { onClose?: () => void }) { actions={[ - promptStore.add({ + onClick={() => { + const promptId = promptStore.add({ id: nanoid(), createdAt: Date.now(), title: "Empty Prompt", content: "Empty Prompt Content", - }) - } + }); + setEditingPromptId(promptId); + }} icon={} bordered text={Locale.Settings.Prompt.Modal.Add} @@ -244,19 +249,31 @@ function DangerItems() { function SyncItems() { const syncStore = useSyncStore(); const webdav = syncStore.webDavConfig; + const chatStore = useChatStore(); + const promptStore = usePromptStore(); + const maskStore = useMaskStore(); - // not ready: https://github.com/Yidadaa/ChatGPT-Next-Web/issues/920#issuecomment-1609866332 - return null; + const stateOverview = useMemo(() => { + const sessions = chatStore.sessions; + const messageCount = sessions.reduce((p, c) => p + c.messages.length, 0); + + return { + chat: sessions.length, + message: messageCount, + prompt: Object.keys(promptStore.prompts).length, + mask: Object.keys(maskStore.masks).length, + }; + }, [chatStore.sessions, maskStore.masks, promptStore.prompts]); return ( } - text="同步" + text={Locale.UI.Sync} onClick={() => { syncStore.check().then(console.log); }} @@ -264,50 +281,25 @@ function SyncItems() { - - - { - syncStore.update( - (config) => (config.server = e.currentTarget.value), - ); - }} - /> - - - - { - syncStore.update( - (config) => (config.username = e.currentTarget.value), - ); - }} - /> - - - - { - syncStore.update( - (config) => (config.password = e.currentTarget.value), - ); - }} - /> +
+ } + text={Locale.UI.Export} + onClick={() => { + syncStore.export(); + }} + /> + } + text={Locale.UI.Import} + onClick={() => { + syncStore.import(); + }} + /> +
); @@ -562,6 +554,8 @@ export function Settings() { + + - - { + return `${overview.chat} 次对话,${overview.message} 条消息,${overview.prompt} 条提示词,${overview.mask} 个面具`; + }, + ImportFailed: "导入失败", + }, Mask: { Splash: { Title: "面具启动页", @@ -355,6 +363,9 @@ const cn = { Close: "关闭", Create: "新建", Edit: "编辑", + Export: "导出", + Import: "导入", + Sync: "同步", }, Exporter: { Model: "模型", diff --git a/app/locales/en.ts b/app/locales/en.ts index 98135727..e3129578 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -180,6 +180,14 @@ const en: LocaleType = { Title: "Auto Generate Title", SubTitle: "Generate a suitable title based on the conversation content", }, + Sync: { + LastUpdate: "Last Update", + LocalState: "Local Data", + Overview: (overview: any) => { + return `${overview.chat} chats,${overview.message} messages,${overview.prompt} prompts,${overview.mask} masks`; + }, + ImportFailed: "Failed to import from file", + }, Mask: { Splash: { Title: "Mask Splash Screen", @@ -355,6 +363,9 @@ const en: LocaleType = { Close: "Close", Create: "Create", Edit: "Edit", + Export: "Export", + Import: "Import", + Sync: "Sync", }, Exporter: { Model: "Model", diff --git a/app/store/access.ts b/app/store/access.ts index b6021163..9eaa81e5 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -1,28 +1,7 @@ -import { create } from "zustand"; -import { persist } from "zustand/middleware"; import { DEFAULT_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant"; import { getHeaders } from "../client/api"; -import { BOT_HELLO } from "./chat"; import { getClientConfig } from "../config/client"; - -export interface AccessControlStore { - accessCode: string; - token: string; - - needCode: boolean; - hideUserApiKey: boolean; - hideBalanceQuery: boolean; - disableGPT4: boolean; - - openaiUrl: string; - - updateToken: (_: string) => void; - updateCode: (_: string) => void; - updateOpenAiUrl: (_: string) => void; - enabledAccessControl: () => boolean; - isAuthorized: () => boolean; - fetch: () => void; -} +import { createPersistStore } from "../utils/store"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done @@ -30,72 +9,74 @@ const DEFAULT_OPENAI_URL = getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : "/api/openai/"; console.log("[API] default openai url", DEFAULT_OPENAI_URL); -export const useAccessStore = create()( - persist( - (set, get) => ({ - token: "", - accessCode: "", - needCode: true, - hideUserApiKey: false, - hideBalanceQuery: false, - disableGPT4: false, +const DEFAULT_ACCESS_STATE = { + token: "", + accessCode: "", + needCode: true, + hideUserApiKey: false, + hideBalanceQuery: false, + disableGPT4: false, - openaiUrl: DEFAULT_OPENAI_URL, + openaiUrl: DEFAULT_OPENAI_URL, +}; - enabledAccessControl() { - get().fetch(); +export const useAccessStore = createPersistStore( + { ...DEFAULT_ACCESS_STATE }, - return get().needCode; - }, - updateCode(code: string) { - set(() => ({ accessCode: code?.trim() })); - }, - updateToken(token: string) { - set(() => ({ token: token?.trim() })); - }, - updateOpenAiUrl(url: string) { - set(() => ({ openaiUrl: url?.trim() })); - }, - isAuthorized() { - get().fetch(); + (set, get) => ({ + enabledAccessControl() { + this.fetch(); - // has token or has code or disabled access control - return ( - !!get().token || !!get().accessCode || !get().enabledAccessControl() - ); - }, - fetch() { - if (fetchState > 0 || getClientConfig()?.buildMode === "export") return; - fetchState = 1; - fetch("/api/config", { - method: "post", - body: null, - headers: { - ...getHeaders(), - }, - }) - .then((res) => res.json()) - .then((res: DangerConfig) => { - console.log("[Config] got config from server", res); - set(() => ({ ...res })); - - if (res.disableGPT4) { - DEFAULT_MODELS.forEach( - (m: any) => (m.available = !m.name.startsWith("gpt-4")), - ); - } - }) - .catch(() => { - console.error("[Config] failed to fetch config"); - }) - .finally(() => { - fetchState = 2; - }); - }, - }), - { - name: StoreKey.Access, - version: 1, + return get().needCode; }, - ), + updateCode(code: string) { + set(() => ({ accessCode: code?.trim() })); + }, + updateToken(token: string) { + set(() => ({ token: token?.trim() })); + }, + updateOpenAiUrl(url: string) { + set(() => ({ openaiUrl: url?.trim() })); + }, + isAuthorized() { + this.fetch(); + + // has token or has code or disabled access control + return ( + !!get().token || !!get().accessCode || !this.enabledAccessControl() + ); + }, + fetch() { + if (fetchState > 0 || getClientConfig()?.buildMode === "export") return; + fetchState = 1; + fetch("/api/config", { + method: "post", + body: null, + headers: { + ...getHeaders(), + }, + }) + .then((res) => res.json()) + .then((res: DangerConfig) => { + console.log("[Config] got config from server", res); + set(() => ({ ...res })); + + if (res.disableGPT4) { + DEFAULT_MODELS.forEach( + (m: any) => (m.available = !m.name.startsWith("gpt-4")), + ); + } + }) + .catch(() => { + console.error("[Config] failed to fetch config"); + }) + .finally(() => { + fetchState = 2; + }); + }, + }), + { + name: StoreKey.Access, + version: 1, + }, ); diff --git a/app/store/chat.ts b/app/store/chat.ts index 20603fe4..269cc4a3 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -18,6 +18,7 @@ import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; import { nanoid } from "nanoid"; +import { createPersistStore } from "../utils/store"; export type ChatMessage = RequestMessage & { date: string; @@ -140,12 +141,22 @@ function fillTemplateWith(input: string, modelConfig: ModelConfig) { return output; } -export const useChatStore = create()( - persist( - (set, get) => ({ - sessions: [createEmptySession()], - currentSessionIndex: 0, +const DEFAULT_CHAT_STATE = { + sessions: [createEmptySession()], + currentSessionIndex: 0, +}; +export const useChatStore = createPersistStore( + DEFAULT_CHAT_STATE, + (set, _get) => { + function get() { + return { + ..._get(), + ...methods, + }; + } + + const methods = { clearSessions() { set(() => ({ sessions: [createEmptySession()], @@ -184,7 +195,7 @@ export const useChatStore = create()( }); }, - newSession(mask) { + newSession(mask?: Mask) { const session = createEmptySession(); if (mask) { @@ -207,14 +218,14 @@ export const useChatStore = create()( })); }, - nextSession(delta) { + nextSession(delta: number) { const n = get().sessions.length; const limit = (x: number) => (x + n) % n; const i = get().currentSessionIndex; get().selectSession(limit(i + delta)); }, - deleteSession(index) { + deleteSession(index: number) { const deletingLastSession = get().sessions.length === 1; const deletedSession = get().sessions.at(index); @@ -271,7 +282,7 @@ export const useChatStore = create()( return session; }, - onNewMessage(message) { + onNewMessage(message: ChatMessage) { get().updateCurrentSession((session) => { session.messages = session.messages.concat(); session.lastUpdate = Date.now(); @@ -280,7 +291,7 @@ export const useChatStore = create()( get().summarizeSession(); }, - async onUserInput(content) { + async onUserInput(content: string) { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; @@ -580,14 +591,14 @@ export const useChatStore = create()( } }, - updateStat(message) { + updateStat(message: ChatMessage) { get().updateCurrentSession((session) => { session.stat.charCount += message.content.length; // TODO: should update chat count and word count }); }, - updateCurrentSession(updater) { + updateCurrentSession(updater: (session: ChatSession) => void) { const sessions = get().sessions; const index = get().currentSessionIndex; updater(sessions[index]); @@ -598,56 +609,60 @@ export const useChatStore = create()( localStorage.clear(); location.reload(); }, - }), - { - name: StoreKey.Chat, - version: 3.1, - migrate(persistedState, version) { - const state = persistedState as any; - const newState = JSON.parse(JSON.stringify(state)) as ChatStore; + }; - if (version < 2) { - newState.sessions = []; + return methods; + }, + { + name: StoreKey.Chat, + version: 3.1, + migrate(persistedState, version) { + const state = persistedState as any; + const newState = JSON.parse( + JSON.stringify(state), + ) as typeof DEFAULT_CHAT_STATE; - const oldSessions = state.sessions; - for (const oldSession of oldSessions) { - const newSession = createEmptySession(); - newSession.topic = oldSession.topic; - newSession.messages = [...oldSession.messages]; - newSession.mask.modelConfig.sendMemory = true; - newSession.mask.modelConfig.historyMessageCount = 4; - newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; - newState.sessions.push(newSession); + if (version < 2) { + newState.sessions = []; + + const oldSessions = state.sessions; + for (const oldSession of oldSessions) { + const newSession = createEmptySession(); + newSession.topic = oldSession.topic; + newSession.messages = [...oldSession.messages]; + newSession.mask.modelConfig.sendMemory = true; + newSession.mask.modelConfig.historyMessageCount = 4; + newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; + newState.sessions.push(newSession); + } + } + + if (version < 3) { + // migrate id to nanoid + newState.sessions.forEach((s) => { + s.id = nanoid(); + s.messages.forEach((m) => (m.id = nanoid())); + }); + } + + // Enable `enableInjectSystemPrompts` attribute for old sessions. + // Resolve issue of old sessions not automatically enabling. + if (version < 3.1) { + newState.sessions.forEach((s) => { + if ( + // Exclude those already set by user + !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts") + ) { + // Because users may have changed this configuration, + // the user's current configuration is used instead of the default + const config = useAppConfig.getState(); + s.mask.modelConfig.enableInjectSystemPrompts = + config.modelConfig.enableInjectSystemPrompts; } - } + }); + } - if (version < 3) { - // migrate id to nanoid - newState.sessions.forEach((s) => { - s.id = nanoid(); - s.messages.forEach((m) => (m.id = nanoid())); - }); - } - - // Enable `enableInjectSystemPrompts` attribute for old sessions. - // Resolve issue of old sessions not automatically enabling. - if (version < 3.1) { - newState.sessions.forEach((s) => { - if ( - // Exclude those already set by user - !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts") - ) { - // Because users may have changed this configuration, - // the user's current configuration is used instead of the default - const config = useAppConfig.getState(); - s.mask.modelConfig.enableInjectSystemPrompts = - config.modelConfig.enableInjectSystemPrompts; - } - }); - } - - return newState; - }, + return newState as any; }, - ), + }, ); diff --git a/app/store/config.ts b/app/store/config.ts index 7070ea05..5fa136a0 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -3,6 +3,7 @@ import { persist } from "zustand/middleware"; import { LLMModel } from "../client/api"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, StoreKey } from "../constant"; +import { createPersistStore } from "../utils/store"; export type ModelType = (typeof DEFAULT_MODELS)[number]["name"]; @@ -21,6 +22,8 @@ export enum Theme { } export const DEFAULT_CONFIG = { + lastUpdate: Date.now(), // timestamp, to merge state + submitKey: SubmitKey.CtrlEnter as SubmitKey, avatar: "1f603", fontSize: 14, @@ -55,13 +58,6 @@ export const DEFAULT_CONFIG = { export type ChatConfig = typeof DEFAULT_CONFIG; -export type ChatConfigStore = ChatConfig & { - reset: () => void; - update: (updater: (config: ChatConfig) => void) => void; - mergeModels: (newModels: LLMModel[]) => void; - allModels: () => LLMModel[]; -}; - export type ModelConfig = ChatConfig["modelConfig"]; export function limitNumber( @@ -98,85 +94,80 @@ export const ModalConfigValidator = { }, }; -export const useAppConfig = create()( - persist( - (set, get) => ({ - ...DEFAULT_CONFIG, - - reset() { - set(() => ({ ...DEFAULT_CONFIG })); - }, - - update(updater) { - const config = { ...get() }; - updater(config); - set(() => config); - }, - - mergeModels(newModels) { - if (!newModels || newModels.length === 0) { - return; - } - - const oldModels = get().models; - const modelMap: Record = {}; - - for (const model of oldModels) { - model.available = false; - modelMap[model.name] = model; - } - - for (const model of newModels) { - model.available = true; - modelMap[model.name] = model; - } - - set(() => ({ - models: Object.values(modelMap), - })); - }, - - allModels() { - const customModels = get() - .customModels.split(",") - .filter((v) => !!v && v.length > 0) - .map((m) => ({ name: m, available: true })); - - const models = get().models.concat(customModels); - return models; - }, - }), - { - name: StoreKey.Config, - version: 3.7, - migrate(persistedState, version) { - const state = persistedState as ChatConfig; - - if (version < 3.4) { - state.modelConfig.sendMemory = true; - state.modelConfig.historyMessageCount = 4; - state.modelConfig.compressMessageLengthThreshold = 1000; - state.modelConfig.frequency_penalty = 0; - state.modelConfig.top_p = 1; - state.modelConfig.template = DEFAULT_INPUT_TEMPLATE; - state.dontShowMaskSplashScreen = false; - state.hideBuiltinMasks = false; - } - - if (version < 3.5) { - state.customModels = "claude,claude-100k"; - } - - if (version < 3.6) { - state.modelConfig.enableInjectSystemPrompts = true; - } - - if (version < 3.7) { - state.enableAutoGenerateTitle = true; - } - - return state as any; - }, +export const useAppConfig = createPersistStore( + { ...DEFAULT_CONFIG }, + (set, get) => ({ + reset() { + set(() => ({ ...DEFAULT_CONFIG })); }, - ), + + mergeModels(newModels: LLMModel[]) { + if (!newModels || newModels.length === 0) { + return; + } + + const oldModels = get().models; + const modelMap: Record = {}; + + for (const model of oldModels) { + model.available = false; + modelMap[model.name] = model; + } + + for (const model of newModels) { + model.available = true; + modelMap[model.name] = model; + } + + set(() => ({ + models: Object.values(modelMap), + })); + }, + + allModels() { + const customModels = get() + .customModels.split(",") + .filter((v) => !!v && v.length > 0) + .map((m) => ({ name: m, available: true })); + + const models = get().models.concat(customModels); + return models; + }, + }), + { + name: StoreKey.Config, + version: 3.8, + migrate(persistedState, version) { + const state = persistedState as ChatConfig; + + if (version < 3.4) { + state.modelConfig.sendMemory = true; + state.modelConfig.historyMessageCount = 4; + state.modelConfig.compressMessageLengthThreshold = 1000; + state.modelConfig.frequency_penalty = 0; + state.modelConfig.top_p = 1; + state.modelConfig.template = DEFAULT_INPUT_TEMPLATE; + state.dontShowMaskSplashScreen = false; + state.hideBuiltinMasks = false; + } + + if (version < 3.5) { + state.customModels = "claude,claude-100k"; + } + + if (version < 3.6) { + state.modelConfig.enableInjectSystemPrompts = true; + } + + if (version < 3.7) { + state.enableAutoGenerateTitle = true; + } + + if (version < 3.8) { + state.lastUpdate = Date.now(); + } + + return state as any; + }, + }, ); diff --git a/app/store/mask.ts b/app/store/mask.ts index 02132b77..82c41fec 100644 --- a/app/store/mask.ts +++ b/app/store/mask.ts @@ -1,11 +1,10 @@ -import { create } from "zustand"; -import { persist } from "zustand/middleware"; import { BUILTIN_MASKS } from "../masks"; import { getLang, Lang } from "../locales"; import { DEFAULT_TOPIC, ChatMessage } from "./chat"; import { ModelConfig, useAppConfig } from "./config"; import { StoreKey } from "../constant"; import { nanoid } from "nanoid"; +import { createPersistStore } from "../utils/store"; export type Mask = { id: string; @@ -25,14 +24,6 @@ export const DEFAULT_MASK_STATE = { }; export type MaskState = typeof DEFAULT_MASK_STATE; -type MaskStore = MaskState & { - create: (mask?: Partial) => Mask; - update: (id: string, updater: (mask: Mask) => void) => void; - delete: (id: string) => void; - search: (text: string) => Mask[]; - get: (id?: string) => Mask | null; - getAll: () => Mask[]; -}; export const DEFAULT_MASK_AVATAR = "gpt-bot"; export const createEmptyMask = () => @@ -46,89 +37,92 @@ export const createEmptyMask = () => lang: getLang(), builtin: false, createdAt: Date.now(), - } as Mask); + }) as Mask; -export const useMaskStore = create()( - persist( - (set, get) => ({ - ...DEFAULT_MASK_STATE, +export const useMaskStore = createPersistStore( + { ...DEFAULT_MASK_STATE }, - create(mask) { - const masks = get().masks; - const id = nanoid(); - masks[id] = { - ...createEmptyMask(), - ...mask, - id, - builtin: false, - }; + (set, get) => ({ + ...DEFAULT_MASK_STATE, - set(() => ({ masks })); + create(mask?: Partial) { + const masks = get().masks; + const id = nanoid(); + masks[id] = { + ...createEmptyMask(), + ...mask, + id, + builtin: false, + }; - return masks[id]; - }, - update(id, updater) { - const masks = get().masks; - const mask = masks[id]; - if (!mask) return; - const updateMask = { ...mask }; - updater(updateMask); - masks[id] = updateMask; - set(() => ({ masks })); - }, - delete(id) { - const masks = get().masks; - delete masks[id]; - set(() => ({ masks })); - }, + set(() => ({ masks })); + get().markUpdate(); - get(id) { - return get().masks[id ?? 1145141919810]; - }, - getAll() { - const userMasks = Object.values(get().masks).sort( - (a, b) => b.createdAt - a.createdAt, - ); - const config = useAppConfig.getState(); - if (config.hideBuiltinMasks) return userMasks; - const buildinMasks = BUILTIN_MASKS.map( - (m) => - ({ - ...m, - modelConfig: { - ...config.modelConfig, - ...m.modelConfig, - }, - } as Mask), - ); - return userMasks.concat(buildinMasks); - }, - search(text) { - return Object.values(get().masks); - }, - }), - { - name: StoreKey.Mask, - version: 3.1, - - migrate(state, version) { - const newState = JSON.parse(JSON.stringify(state)) as MaskState; - - // migrate mask id to nanoid - if (version < 3) { - Object.values(newState.masks).forEach((m) => (m.id = nanoid())); - } - - if (version < 3.1) { - const updatedMasks: Record = {}; - Object.values(newState.masks).forEach((m) => { - updatedMasks[m.id] = m; - }); - newState.masks = updatedMasks; - } - - return newState as any; - }, + return masks[id]; }, - ), + updateMask(id: string, updater: (mask: Mask) => void) { + const masks = get().masks; + const mask = masks[id]; + if (!mask) return; + const updateMask = { ...mask }; + updater(updateMask); + masks[id] = updateMask; + set(() => ({ masks })); + get().markUpdate(); + }, + delete(id: string) { + const masks = get().masks; + delete masks[id]; + set(() => ({ masks })); + get().markUpdate(); + }, + + get(id?: string) { + return get().masks[id ?? 1145141919810]; + }, + getAll() { + const userMasks = Object.values(get().masks).sort( + (a, b) => b.createdAt - a.createdAt, + ); + const config = useAppConfig.getState(); + if (config.hideBuiltinMasks) return userMasks; + const buildinMasks = BUILTIN_MASKS.map( + (m) => + ({ + ...m, + modelConfig: { + ...config.modelConfig, + ...m.modelConfig, + }, + }) as Mask, + ); + return userMasks.concat(buildinMasks); + }, + search(text: string) { + return Object.values(get().masks); + }, + }), + { + name: StoreKey.Mask, + version: 3.1, + + migrate(state, version) { + const newState = JSON.parse(JSON.stringify(state)) as MaskState; + + // migrate mask id to nanoid + if (version < 3) { + Object.values(newState.masks).forEach((m) => (m.id = nanoid())); + } + + if (version < 3.1) { + const updatedMasks: Record = {}; + Object.values(newState.masks).forEach((m) => { + updatedMasks[m.id] = m; + }); + newState.masks = updatedMasks; + } + + return newState as any; + }, + }, ); diff --git a/app/store/prompt.ts b/app/store/prompt.ts index e743f914..c6cff1a6 100644 --- a/app/store/prompt.ts +++ b/app/store/prompt.ts @@ -1,9 +1,8 @@ -import { create } from "zustand"; -import { persist } from "zustand/middleware"; import Fuse from "fuse.js"; import { getLang } from "../locales"; import { StoreKey } from "../constant"; import { nanoid } from "nanoid"; +import { createPersistStore } from "../utils/store"; export interface Prompt { id: string; @@ -13,19 +12,6 @@ export interface Prompt { createdAt: number; } -export interface PromptStore { - counter: number; - prompts: Record; - - add: (prompt: Prompt) => string; - get: (id: string) => Prompt | undefined; - remove: (id: string) => void; - search: (text: string) => Prompt[]; - update: (id: string, updater: (prompt: Prompt) => void) => void; - - getUserPrompts: () => Prompt[]; -} - export const SearchService = { ready: false, builtinEngine: new Fuse([], { keys: ["title"] }), @@ -62,130 +48,136 @@ export const SearchService = { }, }; -export const usePromptStore = create()( - persist( - (set, get) => ({ - counter: 0, - latestId: 0, - prompts: {}, +export const usePromptStore = createPersistStore( + { + counter: 0, + prompts: {} as Record, + }, - add(prompt) { - const prompts = get().prompts; - prompt.id = nanoid(); - prompt.isUser = true; - prompt.createdAt = Date.now(); - prompts[prompt.id] = prompt; + (set, get) => ({ + add(prompt: Prompt) { + const prompts = get().prompts; + prompt.id = nanoid(); + prompt.isUser = true; + prompt.createdAt = Date.now(); + prompts[prompt.id] = prompt; - set(() => ({ - latestId: prompt.id!, - prompts: prompts, - })); + set(() => ({ + prompts: prompts, + })); - return prompt.id!; - }, - - get(id) { - const targetPrompt = get().prompts[id]; - - if (!targetPrompt) { - return SearchService.builtinPrompts.find((v) => v.id === id); - } - - return targetPrompt; - }, - - remove(id) { - const prompts = get().prompts; - delete prompts[id]; - SearchService.remove(id); - - set(() => ({ - prompts, - counter: get().counter + 1, - })); - }, - - getUserPrompts() { - const userPrompts = Object.values(get().prompts ?? {}); - userPrompts.sort((a, b) => - b.id && a.id ? b.createdAt - a.createdAt : 0, - ); - return userPrompts; - }, - - update(id, updater) { - const prompt = get().prompts[id] ?? { - title: "", - content: "", - id, - }; - - SearchService.remove(id); - updater(prompt); - const prompts = get().prompts; - prompts[id] = prompt; - set(() => ({ prompts })); - SearchService.add(prompt); - }, - - search(text) { - if (text.length === 0) { - // return all rompts - return get().getUserPrompts().concat(SearchService.builtinPrompts); - } - return SearchService.search(text) as Prompt[]; - }, - }), - { - name: StoreKey.Prompt, - version: 3, - - migrate(state, version) { - const newState = JSON.parse(JSON.stringify(state)) as PromptStore; - - if (version < 3) { - Object.values(newState.prompts).forEach((p) => (p.id = nanoid())); - } - - return newState; - }, - - onRehydrateStorage(state) { - const PROMPT_URL = "./prompts.json"; - - type PromptList = Array<[string, string]>; - - fetch(PROMPT_URL) - .then((res) => res.json()) - .then((res) => { - let fetchPrompts = [res.en, res.cn]; - if (getLang() === "cn") { - fetchPrompts = fetchPrompts.reverse(); - } - const builtinPrompts = fetchPrompts.map( - (promptList: PromptList) => { - return promptList.map( - ([title, content]) => - ({ - id: nanoid(), - title, - content, - createdAt: Date.now(), - } as Prompt), - ); - }, - ); - - const userPrompts = - usePromptStore.getState().getUserPrompts() ?? []; - - const allPromptsForSearch = builtinPrompts - .reduce((pre, cur) => pre.concat(cur), []) - .filter((v) => !!v.title && !!v.content); - SearchService.count.builtin = res.en.length + res.cn.length; - SearchService.init(allPromptsForSearch, userPrompts); - }); - }, + return prompt.id!; }, - ), + + get(id: string) { + const targetPrompt = get().prompts[id]; + + if (!targetPrompt) { + return SearchService.builtinPrompts.find((v) => v.id === id); + } + + return targetPrompt; + }, + + remove(id: string) { + const prompts = get().prompts; + delete prompts[id]; + + Object.entries(prompts).some(([key, prompt]) => { + if (prompt.id === id) { + delete prompts[key]; + return true; + } + return false; + }); + + SearchService.remove(id); + + set(() => ({ + prompts, + counter: get().counter + 1, + })); + }, + + getUserPrompts() { + const userPrompts = Object.values(get().prompts ?? {}); + userPrompts.sort((a, b) => + b.id && a.id ? b.createdAt - a.createdAt : 0, + ); + return userPrompts; + }, + + updatePrompt(id: string, updater: (prompt: Prompt) => void) { + const prompt = get().prompts[id] ?? { + title: "", + content: "", + id, + }; + + SearchService.remove(id); + updater(prompt); + const prompts = get().prompts; + prompts[id] = prompt; + set(() => ({ prompts })); + SearchService.add(prompt); + }, + + search(text: string) { + if (text.length === 0) { + // return all rompts + return this.getUserPrompts().concat(SearchService.builtinPrompts); + } + return SearchService.search(text) as Prompt[]; + }, + }), + { + name: StoreKey.Prompt, + version: 3, + + migrate(state, version) { + const newState = JSON.parse(JSON.stringify(state)) as { + prompts: Record; + }; + + if (version < 3) { + Object.values(newState.prompts).forEach((p) => (p.id = nanoid())); + } + + return newState as any; + }, + + onRehydrateStorage(state) { + const PROMPT_URL = "./prompts.json"; + + type PromptList = Array<[string, string]>; + + fetch(PROMPT_URL) + .then((res) => res.json()) + .then((res) => { + let fetchPrompts = [res.en, res.cn]; + if (getLang() === "cn") { + fetchPrompts = fetchPrompts.reverse(); + } + const builtinPrompts = fetchPrompts.map((promptList: PromptList) => { + return promptList.map( + ([title, content]) => + ({ + id: nanoid(), + title, + content, + createdAt: Date.now(), + }) as Prompt, + ); + }); + + const userPrompts = usePromptStore.getState().getUserPrompts() ?? []; + + const allPromptsForSearch = builtinPrompts + .reduce((pre, cur) => pre.concat(cur), []) + .filter((v) => !!v.title && !!v.content); + SearchService.count.builtin = res.en.length + res.cn.length; + SearchService.init(allPromptsForSearch, userPrompts); + }); + }, + }, ); diff --git a/app/store/sync.ts b/app/store/sync.ts index 1a111f75..fc602809 100644 --- a/app/store/sync.ts +++ b/app/store/sync.ts @@ -1,7 +1,15 @@ import { Updater } from "../typing"; -import { create } from "zustand"; -import { persist } from "zustand/middleware"; import { StoreKey } from "../constant"; +import { createPersistStore } from "../utils/store"; +import { + AppState, + getLocalAppState, + mergeAppState, + setLocalAppState, +} from "../utils/sync"; +import { downloadAs, readFromFile } from "../utils"; +import { showToast } from "../components/ui-lib"; +import Locale from "../locales"; export interface WebDavConfig { server: string; @@ -20,68 +28,86 @@ export interface SyncStore { headers: () => { Authorization: string }; } -const FILE = { - root: "/chatgpt-next-web/", -}; - -export const useSyncStore = create()( - persist( - (set, get) => ({ - webDavConfig: { - server: "", - username: "", - password: "", - }, - - lastSyncTime: 0, - - update(updater) { - const config = { ...get().webDavConfig }; - updater(config); - set({ webDavConfig: config }); - }, - - async check() { - try { - const res = await fetch(this.path(""), { - method: "PROFIND", - headers: this.headers(), - }); - console.log(res); - return res.status === 207; - } catch (e) { - console.error("[Sync] ", e); - return false; - } - }, - - path(path: string) { - let url = get().webDavConfig.server; - - if (!url.endsWith("/")) { - url += "/"; - } - - if (path.startsWith("/")) { - path = path.slice(1); - } - - return url + path; - }, - - headers() { - const auth = btoa( - [get().webDavConfig.username, get().webDavConfig.password].join(":"), - ); - - return { - Authorization: `Basic ${auth}`, - }; - }, - }), - { - name: StoreKey.Sync, - version: 1, +export const useSyncStore = createPersistStore( + { + webDavConfig: { + server: "", + username: "", + password: "", }, - ), + + lastSyncTime: 0, + }, + (set, get) => ({ + webDavConfig: { + server: "", + username: "", + password: "", + }, + + lastSyncTime: 0, + + export() { + const state = getLocalAppState(); + const fileName = `Backup-${new Date().toLocaleString()}.json`; + downloadAs(JSON.stringify(state), fileName); + }, + + async import() { + const rawContent = await readFromFile(); + + try { + const remoteState = JSON.parse(rawContent) as AppState; + const localState = getLocalAppState(); + mergeAppState(localState, remoteState); + setLocalAppState(localState); + location.reload(); + } catch (e) { + console.error("[Import]", e); + showToast(Locale.Settings.Sync.ImportFailed); + } + }, + + async check() { + try { + const res = await fetch(this.path(""), { + method: "PROFIND", + headers: this.headers(), + }); + console.log(res); + return res.status === 207; + } catch (e) { + console.error("[Sync] ", e); + return false; + } + }, + + path(path: string) { + let url = get().webDavConfig.server; + + if (!url.endsWith("/")) { + url += "/"; + } + + if (path.startsWith("/")) { + path = path.slice(1); + } + + return url + path; + }, + + headers() { + const auth = btoa( + [get().webDavConfig.username, get().webDavConfig.password].join(":"), + ); + + return { + Authorization: `Basic ${auth}`, + }; + }, + }), + { + name: StoreKey.Sync, + version: 1, + }, ); diff --git a/app/store/update.ts b/app/store/update.ts index dd4d3c72..42b86586 100644 --- a/app/store/update.ts +++ b/app/store/update.ts @@ -1,24 +1,7 @@ -import { create } from "zustand"; -import { persist } from "zustand/middleware"; import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant"; import { api } from "../client/api"; import { getClientConfig } from "../config/client"; - -export interface UpdateStore { - versionType: "date" | "tag"; - lastUpdate: number; - version: string; - remoteVersion: string; - - used?: number; - subscription?: number; - lastUpdateUsage: number; - - getLatestVersion: (force?: boolean) => Promise; - updateUsage: (force?: boolean) => Promise; - - formatVersion: (version: string) => string; -} +import { createPersistStore } from "../utils/store"; const ONE_MINUTE = 60 * 1000; @@ -35,7 +18,9 @@ function formatVersionDate(t: string) { ].join(""); } -async function getVersion(type: "date" | "tag") { +type VersionType = "date" | "tag"; + +async function getVersion(type: VersionType) { if (type === "date") { const data = (await (await fetch(FETCH_COMMIT_URL)).json()) as { commit: { @@ -55,75 +40,76 @@ async function getVersion(type: "date" | "tag") { } } -export const useUpdateStore = create()( - persist( - (set, get) => ({ - versionType: "tag", - lastUpdate: 0, - version: "unknown", - remoteVersion: "", +export const useUpdateStore = createPersistStore( + { + versionType: "tag" as VersionType, + lastUpdate: 0, + version: "unknown", + remoteVersion: "", + used: 0, + subscription: 0, - lastUpdateUsage: 0, - - formatVersion(version: string) { - if (get().versionType === "date") { - version = formatVersionDate(version); - } - return version; - }, - - async getLatestVersion(force = false) { - const versionType = get().versionType; - let version = - versionType === "date" - ? getClientConfig()?.commitDate - : getClientConfig()?.version; - - set(() => ({ version })); - - const shouldCheck = Date.now() - get().lastUpdate > 2 * 60 * ONE_MINUTE; - if (!force && !shouldCheck) return; - - set(() => ({ - lastUpdate: Date.now(), - })); - - try { - const remoteId = await getVersion(versionType); - set(() => ({ - remoteVersion: remoteId, - })); - console.log("[Got Upstream] ", remoteId); - } catch (error) { - console.error("[Fetch Upstream Commit Id]", error); - } - }, - - async updateUsage(force = false) { - const overOneMinute = Date.now() - get().lastUpdateUsage >= ONE_MINUTE; - if (!overOneMinute && !force) return; - - set(() => ({ - lastUpdateUsage: Date.now(), - })); - - try { - const usage = await api.llm.usage(); - - if (usage) { - set(() => ({ - used: usage.used, - subscription: usage.total, - })); - } - } catch (e) { - console.error((e as Error).message); - } - }, - }), - { - name: StoreKey.Update, - version: 1, + lastUpdateUsage: 0, + }, + (set, get) => ({ + formatVersion(version: string) { + if (get().versionType === "date") { + version = formatVersionDate(version); + } + return version; }, - ), + + async getLatestVersion(force = false) { + const versionType = get().versionType; + let version = + versionType === "date" + ? getClientConfig()?.commitDate + : getClientConfig()?.version; + + set(() => ({ version })); + + const shouldCheck = Date.now() - get().lastUpdate > 2 * 60 * ONE_MINUTE; + if (!force && !shouldCheck) return; + + set(() => ({ + lastUpdate: Date.now(), + })); + + try { + const remoteId = await getVersion(versionType); + set(() => ({ + remoteVersion: remoteId, + })); + console.log("[Got Upstream] ", remoteId); + } catch (error) { + console.error("[Fetch Upstream Commit Id]", error); + } + }, + + async updateUsage(force = false) { + const overOneMinute = Date.now() - get().lastUpdateUsage >= ONE_MINUTE; + if (!overOneMinute && !force) return; + + set(() => ({ + lastUpdateUsage: Date.now(), + })); + + try { + const usage = await api.llm.usage(); + + if (usage) { + set(() => ({ + used: usage.used, + subscription: usage.total, + })); + } + } catch (e) { + console.error((e as Error).message); + } + }, + }), + { + name: StoreKey.Update, + version: 1, + }, ); diff --git a/app/utils/clone.ts b/app/utils/clone.ts new file mode 100644 index 00000000..2958b6b9 --- /dev/null +++ b/app/utils/clone.ts @@ -0,0 +1,3 @@ +export function deepClone(obj: T) { + return JSON.parse(JSON.stringify(obj)); +} diff --git a/app/utils/store.ts b/app/utils/store.ts new file mode 100644 index 00000000..cd151dc4 --- /dev/null +++ b/app/utils/store.ts @@ -0,0 +1,55 @@ +import { create } from "zustand"; +import { persist } from "zustand/middleware"; +import { Updater } from "../typing"; +import { deepClone } from "./clone"; + +type SecondParam = T extends ( + _f: infer _F, + _s: infer S, + ...args: infer _U +) => any + ? S + : never; + +type MakeUpdater = { + lastUpdateTime: number; + + markUpdate: () => void; + update: Updater; +}; + +type SetStoreState = ( + partial: T | Partial | ((state: T) => T | Partial), + replace?: boolean | undefined, +) => void; + +export function createPersistStore( + defaultState: T, + methods: ( + set: SetStoreState>, + get: () => T & MakeUpdater, + ) => M, + persistOptions: SecondParam>>, +) { + return create>()( + persist((set, get) => { + return { + ...defaultState, + ...methods(set as any, get), + + lastUpdateTime: 0, + markUpdate() { + set({ lastUpdateTime: Date.now() } as Partial< + T & M & MakeUpdater + >); + }, + update(updater) { + const state = deepClone(get()); + updater(state); + get().markUpdate(); + set(state); + }, + }; + }, persistOptions), + ); +} diff --git a/app/utils/sync.ts b/app/utils/sync.ts new file mode 100644 index 00000000..ab1f1f44 --- /dev/null +++ b/app/utils/sync.ts @@ -0,0 +1,162 @@ +import { + ChatSession, + useAccessStore, + useAppConfig, + useChatStore, +} from "../store"; +import { useMaskStore } from "../store/mask"; +import { usePromptStore } from "../store/prompt"; +import { StoreKey } from "../constant"; +import { merge } from "./merge"; + +type NonFunctionKeys = { + [K in keyof T]: T[K] extends (...args: any[]) => any ? never : K; +}[keyof T]; +type NonFunctionFields = Pick>; + +export function getNonFunctionFileds(obj: T) { + const ret: any = {}; + + Object.entries(obj).map(([k, v]) => { + if (typeof v !== "function") { + ret[k] = v; + } + }); + + return ret as NonFunctionFields; +} + +export type GetStoreState = T extends { getState: () => infer U } + ? NonFunctionFields + : never; + +const LocalStateSetters = { + [StoreKey.Chat]: useChatStore.setState, + [StoreKey.Access]: useAccessStore.setState, + [StoreKey.Config]: useAppConfig.setState, + [StoreKey.Mask]: useMaskStore.setState, + [StoreKey.Prompt]: usePromptStore.setState, +} as const; + +const LocalStateGetters = { + [StoreKey.Chat]: () => getNonFunctionFileds(useChatStore.getState()), + [StoreKey.Access]: () => getNonFunctionFileds(useAccessStore.getState()), + [StoreKey.Config]: () => getNonFunctionFileds(useAppConfig.getState()), + [StoreKey.Mask]: () => getNonFunctionFileds(useMaskStore.getState()), + [StoreKey.Prompt]: () => getNonFunctionFileds(usePromptStore.getState()), +} as const; + +export type AppState = { + [k in keyof typeof LocalStateGetters]: ReturnType< + (typeof LocalStateGetters)[k] + >; +}; + +type Merger = ( + localState: U, + remoteState: U, +) => U; + +type StateMerger = { + [K in keyof AppState]: Merger; +}; + +// we merge remote state to local state +const MergeStates: StateMerger = { + [StoreKey.Chat]: (localState, remoteState) => { + // merge sessions + const localSessions: Record = {}; + localState.sessions.forEach((s) => (localSessions[s.id] = s)); + + remoteState.sessions.forEach((remoteSession) => { + const localSession = localSessions[remoteSession.id]; + if (!localSession) { + // if remote session is new, just merge it + localState.sessions.push(remoteSession); + } else { + // if both have the same session id, merge the messages + const localMessageIds = new Set(localSession.messages.map((v) => v.id)); + remoteSession.messages.forEach((m) => { + if (!localMessageIds.has(m.id)) { + localSession.messages.push(m); + } + }); + + // sort local messages with date field in asc order + localSession.messages.sort( + (a, b) => new Date(a.date).getTime() - new Date(b.date).getTime(), + ); + } + }); + + // sort local sessions with date field in desc order + localState.sessions.sort( + (a, b) => + new Date(b.lastUpdate).getTime() - new Date(a.lastUpdate).getTime(), + ); + + return localState; + }, + [StoreKey.Prompt]: (localState, remoteState) => { + localState.prompts = { + ...remoteState.prompts, + ...localState.prompts, + }; + return localState; + }, + [StoreKey.Mask]: (localState, remoteState) => { + localState.masks = { + ...remoteState.masks, + ...localState.masks, + }; + return localState; + }, + [StoreKey.Config]: mergeWithUpdate, + [StoreKey.Access]: mergeWithUpdate, +}; + +export function getLocalAppState() { + const appState = Object.fromEntries( + Object.entries(LocalStateGetters).map(([key, getter]) => { + return [key, getter()]; + }), + ) as AppState; + + return appState; +} + +export function setLocalAppState(appState: AppState) { + Object.entries(LocalStateSetters).forEach(([key, setter]) => { + setter(appState[key as keyof AppState]); + }); +} + +export function mergeAppState(localState: AppState, remoteState: AppState) { + Object.keys(localState).forEach((k: string) => { + const key = k as T; + const localStoreState = localState[key]; + const remoteStoreState = remoteState[key]; + MergeStates[key](localStoreState, remoteStoreState); + }); + + return localState; +} + +/** + * Merge state with `lastUpdateTime`, older state will be override + */ +export function mergeWithUpdate( + localState: T, + remoteState: T, +) { + const localUpdateTime = localState.lastUpdateTime ?? 0; + const remoteUpdateTime = localState.lastUpdateTime ?? 1; + + if (localUpdateTime < remoteUpdateTime) { + merge(remoteState, localState); + return { ...remoteState }; + } else { + merge(localState, remoteState); + return { ...localState }; + } +}