Merge pull request #2796 from Yidadaa/backup

This commit is contained in:
Yifei Zhang 2023-09-11 00:27:51 +08:00 committed by GitHub
commit 1487762925
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 891 additions and 673 deletions

View File

@ -4,8 +4,8 @@ import GithubIcon from "../icons/github.svg";
import ResetIcon from "../icons/reload.svg"; import ResetIcon from "../icons/reload.svg";
import { ISSUE_URL } from "../constant"; import { ISSUE_URL } from "../constant";
import Locale from "../locales"; import Locale from "../locales";
import { downloadAs } from "../utils";
import { showConfirm } from "./ui-lib"; import { showConfirm } from "./ui-lib";
import { useSyncStore } from "../store/sync";
interface IErrorBoundaryState { interface IErrorBoundaryState {
hasError: boolean; hasError: boolean;
@ -26,10 +26,7 @@ export class ErrorBoundary extends React.Component<any, IErrorBoundaryState> {
clearAndSaveData() { clearAndSaveData() {
try { try {
downloadAs( useSyncStore.getState().export();
JSON.stringify(localStorage),
"chatgpt-next-web-snapshot.json",
);
} finally { } finally {
localStorage.clear(); localStorage.clear();
location.reload(); location.reload();

View File

@ -410,7 +410,7 @@ export function MaskPage() {
const closeMaskModal = () => setEditingMaskId(undefined); const closeMaskModal = () => setEditingMaskId(undefined);
const downloadAll = () => { const downloadAll = () => {
downloadAs(JSON.stringify(masks), FileName.Masks); downloadAs(JSON.stringify(masks.filter((v) => !v.builtin)), FileName.Masks);
}; };
const importFromFile = () => { const importFromFile = () => {
@ -452,11 +452,13 @@ export function MaskPage() {
icon={<DownloadIcon />} icon={<DownloadIcon />}
bordered bordered
onClick={downloadAll} onClick={downloadAll}
text={Locale.UI.Export}
/> />
</div> </div>
<div className="window-action-button"> <div className="window-action-button">
<IconButton <IconButton
icon={<UploadIcon />} icon={<UploadIcon />}
text={Locale.UI.Import}
bordered bordered
onClick={() => importFromFile()} onClick={() => importFromFile()}
/> />
@ -604,7 +606,7 @@ export function MaskPage() {
<MaskConfig <MaskConfig
mask={editingMask} mask={editingMask}
updateMask={(updater) => updateMask={(updater) =>
maskStore.update(editingMaskId!, updater) maskStore.updateMask(editingMaskId!, updater)
} }
readonly={editingMask.builtin} readonly={editingMask.builtin}
/> />

View File

@ -10,6 +10,9 @@ import ClearIcon from "../icons/clear.svg";
import LoadingIcon from "../icons/three-dots.svg"; import LoadingIcon from "../icons/three-dots.svg";
import EditIcon from "../icons/edit.svg"; import EditIcon from "../icons/edit.svg";
import EyeIcon from "../icons/eye.svg"; import EyeIcon from "../icons/eye.svg";
import DownloadIcon from "../icons/download.svg";
import UploadIcon from "../icons/upload.svg";
import { import {
Input, Input,
List, List,
@ -49,6 +52,7 @@ import { Avatar, AvatarPicker } from "./emoji";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { useSyncStore } from "../store/sync"; import { useSyncStore } from "../store/sync";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { useMaskStore } from "../store/mask";
function EditPromptModal(props: { id: string; onClose: () => void }) { function EditPromptModal(props: { id: string; onClose: () => void }) {
const promptStore = usePromptStore(); const promptStore = usePromptStore();
@ -75,7 +79,7 @@ function EditPromptModal(props: { id: string; onClose: () => void }) {
readOnly={!prompt.isUser} readOnly={!prompt.isUser}
className={styles["edit-prompt-title"]} className={styles["edit-prompt-title"]}
onInput={(e) => onInput={(e) =>
promptStore.update( promptStore.updatePrompt(
props.id, props.id,
(prompt) => (prompt.title = e.currentTarget.value), (prompt) => (prompt.title = e.currentTarget.value),
) )
@ -87,7 +91,7 @@ function EditPromptModal(props: { id: string; onClose: () => void }) {
className={styles["edit-prompt-content"]} className={styles["edit-prompt-content"]}
rows={10} rows={10}
onInput={(e) => onInput={(e) =>
promptStore.update( promptStore.updatePrompt(
props.id, props.id,
(prompt) => (prompt.content = e.currentTarget.value), (prompt) => (prompt.content = e.currentTarget.value),
) )
@ -127,14 +131,15 @@ function UserPromptModal(props: { onClose?: () => void }) {
actions={[ actions={[
<IconButton <IconButton
key="add" key="add"
onClick={() => onClick={() => {
promptStore.add({ const promptId = promptStore.add({
id: nanoid(), id: nanoid(),
createdAt: Date.now(), createdAt: Date.now(),
title: "Empty Prompt", title: "Empty Prompt",
content: "Empty Prompt Content", content: "Empty Prompt Content",
}) });
} setEditingPromptId(promptId);
}}
icon={<AddIcon />} icon={<AddIcon />}
bordered bordered
text={Locale.Settings.Prompt.Modal.Add} text={Locale.Settings.Prompt.Modal.Add}
@ -244,19 +249,31 @@ function DangerItems() {
function SyncItems() { function SyncItems() {
const syncStore = useSyncStore(); const syncStore = useSyncStore();
const webdav = syncStore.webDavConfig; const webdav = syncStore.webDavConfig;
const chatStore = useChatStore();
const promptStore = usePromptStore();
const maskStore = useMaskStore();
// not ready: https://github.com/Yidadaa/ChatGPT-Next-Web/issues/920#issuecomment-1609866332 const stateOverview = useMemo(() => {
return null; const sessions = chatStore.sessions;
const messageCount = sessions.reduce((p, c) => p + c.messages.length, 0);
return {
chat: sessions.length,
message: messageCount,
prompt: Object.keys(promptStore.prompts).length,
mask: Object.keys(maskStore.masks).length,
};
}, [chatStore.sessions, maskStore.masks, promptStore.prompts]);
return ( return (
<List> <List>
<ListItem <ListItem
title={"上次同步:" + new Date().toLocaleString()} title={Locale.Settings.Sync.LastUpdate}
subTitle={"20 次对话100 条消息200 提示词20 面具"} subTitle={new Date().toLocaleString()}
> >
<IconButton <IconButton
icon={<ResetIcon />} icon={<ResetIcon />}
text="同步" text={Locale.UI.Sync}
onClick={() => { onClick={() => {
syncStore.check().then(console.log); syncStore.check().then(console.log);
}} }}
@ -264,50 +281,25 @@ function SyncItems() {
</ListItem> </ListItem>
<ListItem <ListItem
title={"本地备份"} title={Locale.Settings.Sync.LocalState}
subTitle={"20 次对话100 条消息200 提示词20 面具"} subTitle={Locale.Settings.Sync.Overview(stateOverview)}
></ListItem>
<ListItem
title={"Web Dav Server"}
subTitle={Locale.Settings.AccessCode.SubTitle}
> >
<input <div style={{ display: "flex" }}>
value={webdav.server} <IconButton
type="text" icon={<UploadIcon />}
placeholder={"https://example.com"} text={Locale.UI.Export}
onChange={(e) => { onClick={() => {
syncStore.update( syncStore.export();
(config) => (config.server = e.currentTarget.value), }}
); />
}} <IconButton
/> icon={<DownloadIcon />}
</ListItem> text={Locale.UI.Import}
onClick={() => {
<ListItem title="Web Dav User Name" subTitle="user name here"> syncStore.import();
<input }}
value={webdav.username} />
type="text" </div>
placeholder={"username"}
onChange={(e) => {
syncStore.update(
(config) => (config.username = e.currentTarget.value),
);
}}
/>
</ListItem>
<ListItem title="Web Dav Password" subTitle="password here">
<input
value={webdav.password}
type="text"
placeholder={"password"}
onChange={(e) => {
syncStore.update(
(config) => (config.password = e.currentTarget.value),
);
}}
/>
</ListItem> </ListItem>
</List> </List>
); );
@ -562,6 +554,8 @@ export function Settings() {
</ListItem> </ListItem>
</List> </List>
<SyncItems />
<List> <List>
<ListItem <ListItem
title={Locale.Settings.Mask.Splash.Title} title={Locale.Settings.Mask.Splash.Title}
@ -722,8 +716,6 @@ export function Settings() {
</ListItem> </ListItem>
</List> </List>
<SyncItems />
<List> <List>
<ModelConfigList <ModelConfigList
modelConfig={config.modelConfig} modelConfig={config.modelConfig}

View File

@ -178,6 +178,14 @@ const cn = {
Title: "自动生成标题", Title: "自动生成标题",
SubTitle: "根据对话内容生成合适的标题", SubTitle: "根据对话内容生成合适的标题",
}, },
Sync: {
LastUpdate: "上次同步",
LocalState: "本地数据",
Overview: (overview: any) => {
return `${overview.chat} 次对话,${overview.message} 条消息,${overview.prompt} 条提示词,${overview.mask} 个面具`;
},
ImportFailed: "导入失败",
},
Mask: { Mask: {
Splash: { Splash: {
Title: "面具启动页", Title: "面具启动页",
@ -355,6 +363,9 @@ const cn = {
Close: "关闭", Close: "关闭",
Create: "新建", Create: "新建",
Edit: "编辑", Edit: "编辑",
Export: "导出",
Import: "导入",
Sync: "同步",
}, },
Exporter: { Exporter: {
Model: "模型", Model: "模型",

View File

@ -180,6 +180,14 @@ const en: LocaleType = {
Title: "Auto Generate Title", Title: "Auto Generate Title",
SubTitle: "Generate a suitable title based on the conversation content", SubTitle: "Generate a suitable title based on the conversation content",
}, },
Sync: {
LastUpdate: "Last Update",
LocalState: "Local Data",
Overview: (overview: any) => {
return `${overview.chat} chats${overview.message} messages${overview.prompt} prompts${overview.mask} masks`;
},
ImportFailed: "Failed to import from file",
},
Mask: { Mask: {
Splash: { Splash: {
Title: "Mask Splash Screen", Title: "Mask Splash Screen",
@ -355,6 +363,9 @@ const en: LocaleType = {
Close: "Close", Close: "Close",
Create: "Create", Create: "Create",
Edit: "Edit", Edit: "Edit",
Export: "Export",
Import: "Import",
Sync: "Sync",
}, },
Exporter: { Exporter: {
Model: "Model", Model: "Model",

View File

@ -1,28 +1,7 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { DEFAULT_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant"; import { DEFAULT_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant";
import { getHeaders } from "../client/api"; import { getHeaders } from "../client/api";
import { BOT_HELLO } from "./chat";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
export interface AccessControlStore {
accessCode: string;
token: string;
needCode: boolean;
hideUserApiKey: boolean;
hideBalanceQuery: boolean;
disableGPT4: boolean;
openaiUrl: string;
updateToken: (_: string) => void;
updateCode: (_: string) => void;
updateOpenAiUrl: (_: string) => void;
enabledAccessControl: () => boolean;
isAuthorized: () => boolean;
fetch: () => void;
}
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
@ -30,72 +9,74 @@ const DEFAULT_OPENAI_URL =
getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : "/api/openai/"; getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : "/api/openai/";
console.log("[API] default openai url", DEFAULT_OPENAI_URL); console.log("[API] default openai url", DEFAULT_OPENAI_URL);
export const useAccessStore = create<AccessControlStore>()( const DEFAULT_ACCESS_STATE = {
persist( token: "",
(set, get) => ({ accessCode: "",
token: "", needCode: true,
accessCode: "", hideUserApiKey: false,
needCode: true, hideBalanceQuery: false,
hideUserApiKey: false, disableGPT4: false,
hideBalanceQuery: false,
disableGPT4: false,
openaiUrl: DEFAULT_OPENAI_URL, openaiUrl: DEFAULT_OPENAI_URL,
};
enabledAccessControl() { export const useAccessStore = createPersistStore(
get().fetch(); { ...DEFAULT_ACCESS_STATE },
return get().needCode; (set, get) => ({
}, enabledAccessControl() {
updateCode(code: string) { this.fetch();
set(() => ({ accessCode: code?.trim() }));
},
updateToken(token: string) {
set(() => ({ token: token?.trim() }));
},
updateOpenAiUrl(url: string) {
set(() => ({ openaiUrl: url?.trim() }));
},
isAuthorized() {
get().fetch();
// has token or has code or disabled access control return get().needCode;
return (
!!get().token || !!get().accessCode || !get().enabledAccessControl()
);
},
fetch() {
if (fetchState > 0 || getClientConfig()?.buildMode === "export") return;
fetchState = 1;
fetch("/api/config", {
method: "post",
body: null,
headers: {
...getHeaders(),
},
})
.then((res) => res.json())
.then((res: DangerConfig) => {
console.log("[Config] got config from server", res);
set(() => ({ ...res }));
if (res.disableGPT4) {
DEFAULT_MODELS.forEach(
(m: any) => (m.available = !m.name.startsWith("gpt-4")),
);
}
})
.catch(() => {
console.error("[Config] failed to fetch config");
})
.finally(() => {
fetchState = 2;
});
},
}),
{
name: StoreKey.Access,
version: 1,
}, },
), updateCode(code: string) {
set(() => ({ accessCode: code?.trim() }));
},
updateToken(token: string) {
set(() => ({ token: token?.trim() }));
},
updateOpenAiUrl(url: string) {
set(() => ({ openaiUrl: url?.trim() }));
},
isAuthorized() {
this.fetch();
// has token or has code or disabled access control
return (
!!get().token || !!get().accessCode || !this.enabledAccessControl()
);
},
fetch() {
if (fetchState > 0 || getClientConfig()?.buildMode === "export") return;
fetchState = 1;
fetch("/api/config", {
method: "post",
body: null,
headers: {
...getHeaders(),
},
})
.then((res) => res.json())
.then((res: DangerConfig) => {
console.log("[Config] got config from server", res);
set(() => ({ ...res }));
if (res.disableGPT4) {
DEFAULT_MODELS.forEach(
(m: any) => (m.available = !m.name.startsWith("gpt-4")),
);
}
})
.catch(() => {
console.error("[Config] failed to fetch config");
})
.finally(() => {
fetchState = 2;
});
},
}),
{
name: StoreKey.Access,
version: 1,
},
); );

View File

@ -18,6 +18,7 @@ import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format"; import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token"; import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
export type ChatMessage = RequestMessage & { export type ChatMessage = RequestMessage & {
date: string; date: string;
@ -140,12 +141,22 @@ function fillTemplateWith(input: string, modelConfig: ModelConfig) {
return output; return output;
} }
export const useChatStore = create<ChatStore>()( const DEFAULT_CHAT_STATE = {
persist( sessions: [createEmptySession()],
(set, get) => ({ currentSessionIndex: 0,
sessions: [createEmptySession()], };
currentSessionIndex: 0,
export const useChatStore = createPersistStore(
DEFAULT_CHAT_STATE,
(set, _get) => {
function get() {
return {
..._get(),
...methods,
};
}
const methods = {
clearSessions() { clearSessions() {
set(() => ({ set(() => ({
sessions: [createEmptySession()], sessions: [createEmptySession()],
@ -184,7 +195,7 @@ export const useChatStore = create<ChatStore>()(
}); });
}, },
newSession(mask) { newSession(mask?: Mask) {
const session = createEmptySession(); const session = createEmptySession();
if (mask) { if (mask) {
@ -207,14 +218,14 @@ export const useChatStore = create<ChatStore>()(
})); }));
}, },
nextSession(delta) { nextSession(delta: number) {
const n = get().sessions.length; const n = get().sessions.length;
const limit = (x: number) => (x + n) % n; const limit = (x: number) => (x + n) % n;
const i = get().currentSessionIndex; const i = get().currentSessionIndex;
get().selectSession(limit(i + delta)); get().selectSession(limit(i + delta));
}, },
deleteSession(index) { deleteSession(index: number) {
const deletingLastSession = get().sessions.length === 1; const deletingLastSession = get().sessions.length === 1;
const deletedSession = get().sessions.at(index); const deletedSession = get().sessions.at(index);
@ -271,7 +282,7 @@ export const useChatStore = create<ChatStore>()(
return session; return session;
}, },
onNewMessage(message) { onNewMessage(message: ChatMessage) {
get().updateCurrentSession((session) => { get().updateCurrentSession((session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
session.lastUpdate = Date.now(); session.lastUpdate = Date.now();
@ -280,7 +291,7 @@ export const useChatStore = create<ChatStore>()(
get().summarizeSession(); get().summarizeSession();
}, },
async onUserInput(content) { async onUserInput(content: string) {
const session = get().currentSession(); const session = get().currentSession();
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
@ -580,14 +591,14 @@ export const useChatStore = create<ChatStore>()(
} }
}, },
updateStat(message) { updateStat(message: ChatMessage) {
get().updateCurrentSession((session) => { get().updateCurrentSession((session) => {
session.stat.charCount += message.content.length; session.stat.charCount += message.content.length;
// TODO: should update chat count and word count // TODO: should update chat count and word count
}); });
}, },
updateCurrentSession(updater) { updateCurrentSession(updater: (session: ChatSession) => void) {
const sessions = get().sessions; const sessions = get().sessions;
const index = get().currentSessionIndex; const index = get().currentSessionIndex;
updater(sessions[index]); updater(sessions[index]);
@ -598,56 +609,60 @@ export const useChatStore = create<ChatStore>()(
localStorage.clear(); localStorage.clear();
location.reload(); location.reload();
}, },
}), };
{
name: StoreKey.Chat,
version: 3.1,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(JSON.stringify(state)) as ChatStore;
if (version < 2) { return methods;
newState.sessions = []; },
{
name: StoreKey.Chat,
version: 3.1,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
JSON.stringify(state),
) as typeof DEFAULT_CHAT_STATE;
const oldSessions = state.sessions; if (version < 2) {
for (const oldSession of oldSessions) { newState.sessions = [];
const newSession = createEmptySession();
newSession.topic = oldSession.topic; const oldSessions = state.sessions;
newSession.messages = [...oldSession.messages]; for (const oldSession of oldSessions) {
newSession.mask.modelConfig.sendMemory = true; const newSession = createEmptySession();
newSession.mask.modelConfig.historyMessageCount = 4; newSession.topic = oldSession.topic;
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; newSession.messages = [...oldSession.messages];
newState.sessions.push(newSession); newSession.mask.modelConfig.sendMemory = true;
newSession.mask.modelConfig.historyMessageCount = 4;
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
newState.sessions.push(newSession);
}
}
if (version < 3) {
// migrate id to nanoid
newState.sessions.forEach((s) => {
s.id = nanoid();
s.messages.forEach((m) => (m.id = nanoid()));
});
}
// Enable `enableInjectSystemPrompts` attribute for old sessions.
// Resolve issue of old sessions not automatically enabling.
if (version < 3.1) {
newState.sessions.forEach((s) => {
if (
// Exclude those already set by user
!s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
) {
// Because users may have changed this configuration,
// the user's current configuration is used instead of the default
const config = useAppConfig.getState();
s.mask.modelConfig.enableInjectSystemPrompts =
config.modelConfig.enableInjectSystemPrompts;
} }
} });
}
if (version < 3) { return newState as any;
// migrate id to nanoid
newState.sessions.forEach((s) => {
s.id = nanoid();
s.messages.forEach((m) => (m.id = nanoid()));
});
}
// Enable `enableInjectSystemPrompts` attribute for old sessions.
// Resolve issue of old sessions not automatically enabling.
if (version < 3.1) {
newState.sessions.forEach((s) => {
if (
// Exclude those already set by user
!s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
) {
// Because users may have changed this configuration,
// the user's current configuration is used instead of the default
const config = useAppConfig.getState();
s.mask.modelConfig.enableInjectSystemPrompts =
config.modelConfig.enableInjectSystemPrompts;
}
});
}
return newState;
},
}, },
), },
); );

View File

@ -3,6 +3,7 @@ import { persist } from "zustand/middleware";
import { LLMModel } from "../client/api"; import { LLMModel } from "../client/api";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, StoreKey } from "../constant"; import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, StoreKey } from "../constant";
import { createPersistStore } from "../utils/store";
export type ModelType = (typeof DEFAULT_MODELS)[number]["name"]; export type ModelType = (typeof DEFAULT_MODELS)[number]["name"];
@ -21,6 +22,8 @@ export enum Theme {
} }
export const DEFAULT_CONFIG = { export const DEFAULT_CONFIG = {
lastUpdate: Date.now(), // timestamp, to merge state
submitKey: SubmitKey.CtrlEnter as SubmitKey, submitKey: SubmitKey.CtrlEnter as SubmitKey,
avatar: "1f603", avatar: "1f603",
fontSize: 14, fontSize: 14,
@ -55,13 +58,6 @@ export const DEFAULT_CONFIG = {
export type ChatConfig = typeof DEFAULT_CONFIG; export type ChatConfig = typeof DEFAULT_CONFIG;
export type ChatConfigStore = ChatConfig & {
reset: () => void;
update: (updater: (config: ChatConfig) => void) => void;
mergeModels: (newModels: LLMModel[]) => void;
allModels: () => LLMModel[];
};
export type ModelConfig = ChatConfig["modelConfig"]; export type ModelConfig = ChatConfig["modelConfig"];
export function limitNumber( export function limitNumber(
@ -98,85 +94,80 @@ export const ModalConfigValidator = {
}, },
}; };
export const useAppConfig = create<ChatConfigStore>()( export const useAppConfig = createPersistStore(
persist( { ...DEFAULT_CONFIG },
(set, get) => ({ (set, get) => ({
...DEFAULT_CONFIG, reset() {
set(() => ({ ...DEFAULT_CONFIG }));
reset() {
set(() => ({ ...DEFAULT_CONFIG }));
},
update(updater) {
const config = { ...get() };
updater(config);
set(() => config);
},
mergeModels(newModels) {
if (!newModels || newModels.length === 0) {
return;
}
const oldModels = get().models;
const modelMap: Record<string, LLMModel> = {};
for (const model of oldModels) {
model.available = false;
modelMap[model.name] = model;
}
for (const model of newModels) {
model.available = true;
modelMap[model.name] = model;
}
set(() => ({
models: Object.values(modelMap),
}));
},
allModels() {
const customModels = get()
.customModels.split(",")
.filter((v) => !!v && v.length > 0)
.map((m) => ({ name: m, available: true }));
const models = get().models.concat(customModels);
return models;
},
}),
{
name: StoreKey.Config,
version: 3.7,
migrate(persistedState, version) {
const state = persistedState as ChatConfig;
if (version < 3.4) {
state.modelConfig.sendMemory = true;
state.modelConfig.historyMessageCount = 4;
state.modelConfig.compressMessageLengthThreshold = 1000;
state.modelConfig.frequency_penalty = 0;
state.modelConfig.top_p = 1;
state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
state.dontShowMaskSplashScreen = false;
state.hideBuiltinMasks = false;
}
if (version < 3.5) {
state.customModels = "claude,claude-100k";
}
if (version < 3.6) {
state.modelConfig.enableInjectSystemPrompts = true;
}
if (version < 3.7) {
state.enableAutoGenerateTitle = true;
}
return state as any;
},
}, },
),
mergeModels(newModels: LLMModel[]) {
if (!newModels || newModels.length === 0) {
return;
}
const oldModels = get().models;
const modelMap: Record<string, LLMModel> = {};
for (const model of oldModels) {
model.available = false;
modelMap[model.name] = model;
}
for (const model of newModels) {
model.available = true;
modelMap[model.name] = model;
}
set(() => ({
models: Object.values(modelMap),
}));
},
allModels() {
const customModels = get()
.customModels.split(",")
.filter((v) => !!v && v.length > 0)
.map((m) => ({ name: m, available: true }));
const models = get().models.concat(customModels);
return models;
},
}),
{
name: StoreKey.Config,
version: 3.8,
migrate(persistedState, version) {
const state = persistedState as ChatConfig;
if (version < 3.4) {
state.modelConfig.sendMemory = true;
state.modelConfig.historyMessageCount = 4;
state.modelConfig.compressMessageLengthThreshold = 1000;
state.modelConfig.frequency_penalty = 0;
state.modelConfig.top_p = 1;
state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
state.dontShowMaskSplashScreen = false;
state.hideBuiltinMasks = false;
}
if (version < 3.5) {
state.customModels = "claude,claude-100k";
}
if (version < 3.6) {
state.modelConfig.enableInjectSystemPrompts = true;
}
if (version < 3.7) {
state.enableAutoGenerateTitle = true;
}
if (version < 3.8) {
state.lastUpdate = Date.now();
}
return state as any;
},
},
); );

View File

@ -1,11 +1,10 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { BUILTIN_MASKS } from "../masks"; import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales"; import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, ChatMessage } from "./chat"; import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, useAppConfig } from "./config"; import { ModelConfig, useAppConfig } from "./config";
import { StoreKey } from "../constant"; import { StoreKey } from "../constant";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
export type Mask = { export type Mask = {
id: string; id: string;
@ -25,14 +24,6 @@ export const DEFAULT_MASK_STATE = {
}; };
export type MaskState = typeof DEFAULT_MASK_STATE; export type MaskState = typeof DEFAULT_MASK_STATE;
type MaskStore = MaskState & {
create: (mask?: Partial<Mask>) => Mask;
update: (id: string, updater: (mask: Mask) => void) => void;
delete: (id: string) => void;
search: (text: string) => Mask[];
get: (id?: string) => Mask | null;
getAll: () => Mask[];
};
export const DEFAULT_MASK_AVATAR = "gpt-bot"; export const DEFAULT_MASK_AVATAR = "gpt-bot";
export const createEmptyMask = () => export const createEmptyMask = () =>
@ -46,89 +37,92 @@ export const createEmptyMask = () =>
lang: getLang(), lang: getLang(),
builtin: false, builtin: false,
createdAt: Date.now(), createdAt: Date.now(),
} as Mask); }) as Mask;
export const useMaskStore = create<MaskStore>()( export const useMaskStore = createPersistStore(
persist( { ...DEFAULT_MASK_STATE },
(set, get) => ({
...DEFAULT_MASK_STATE,
create(mask) { (set, get) => ({
const masks = get().masks; ...DEFAULT_MASK_STATE,
const id = nanoid();
masks[id] = {
...createEmptyMask(),
...mask,
id,
builtin: false,
};
set(() => ({ masks })); create(mask?: Partial<Mask>) {
const masks = get().masks;
const id = nanoid();
masks[id] = {
...createEmptyMask(),
...mask,
id,
builtin: false,
};
return masks[id]; set(() => ({ masks }));
}, get().markUpdate();
update(id, updater) {
const masks = get().masks;
const mask = masks[id];
if (!mask) return;
const updateMask = { ...mask };
updater(updateMask);
masks[id] = updateMask;
set(() => ({ masks }));
},
delete(id) {
const masks = get().masks;
delete masks[id];
set(() => ({ masks }));
},
get(id) { return masks[id];
return get().masks[id ?? 1145141919810];
},
getAll() {
const userMasks = Object.values(get().masks).sort(
(a, b) => b.createdAt - a.createdAt,
);
const config = useAppConfig.getState();
if (config.hideBuiltinMasks) return userMasks;
const buildinMasks = BUILTIN_MASKS.map(
(m) =>
({
...m,
modelConfig: {
...config.modelConfig,
...m.modelConfig,
},
} as Mask),
);
return userMasks.concat(buildinMasks);
},
search(text) {
return Object.values(get().masks);
},
}),
{
name: StoreKey.Mask,
version: 3.1,
migrate(state, version) {
const newState = JSON.parse(JSON.stringify(state)) as MaskState;
// migrate mask id to nanoid
if (version < 3) {
Object.values(newState.masks).forEach((m) => (m.id = nanoid()));
}
if (version < 3.1) {
const updatedMasks: Record<string, Mask> = {};
Object.values(newState.masks).forEach((m) => {
updatedMasks[m.id] = m;
});
newState.masks = updatedMasks;
}
return newState as any;
},
}, },
), updateMask(id: string, updater: (mask: Mask) => void) {
const masks = get().masks;
const mask = masks[id];
if (!mask) return;
const updateMask = { ...mask };
updater(updateMask);
masks[id] = updateMask;
set(() => ({ masks }));
get().markUpdate();
},
delete(id: string) {
const masks = get().masks;
delete masks[id];
set(() => ({ masks }));
get().markUpdate();
},
get(id?: string) {
return get().masks[id ?? 1145141919810];
},
getAll() {
const userMasks = Object.values(get().masks).sort(
(a, b) => b.createdAt - a.createdAt,
);
const config = useAppConfig.getState();
if (config.hideBuiltinMasks) return userMasks;
const buildinMasks = BUILTIN_MASKS.map(
(m) =>
({
...m,
modelConfig: {
...config.modelConfig,
...m.modelConfig,
},
}) as Mask,
);
return userMasks.concat(buildinMasks);
},
search(text: string) {
return Object.values(get().masks);
},
}),
{
name: StoreKey.Mask,
version: 3.1,
migrate(state, version) {
const newState = JSON.parse(JSON.stringify(state)) as MaskState;
// migrate mask id to nanoid
if (version < 3) {
Object.values(newState.masks).forEach((m) => (m.id = nanoid()));
}
if (version < 3.1) {
const updatedMasks: Record<string, Mask> = {};
Object.values(newState.masks).forEach((m) => {
updatedMasks[m.id] = m;
});
newState.masks = updatedMasks;
}
return newState as any;
},
},
); );

View File

@ -1,9 +1,8 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import Fuse from "fuse.js"; import Fuse from "fuse.js";
import { getLang } from "../locales"; import { getLang } from "../locales";
import { StoreKey } from "../constant"; import { StoreKey } from "../constant";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
export interface Prompt { export interface Prompt {
id: string; id: string;
@ -13,19 +12,6 @@ export interface Prompt {
createdAt: number; createdAt: number;
} }
export interface PromptStore {
counter: number;
prompts: Record<string, Prompt>;
add: (prompt: Prompt) => string;
get: (id: string) => Prompt | undefined;
remove: (id: string) => void;
search: (text: string) => Prompt[];
update: (id: string, updater: (prompt: Prompt) => void) => void;
getUserPrompts: () => Prompt[];
}
export const SearchService = { export const SearchService = {
ready: false, ready: false,
builtinEngine: new Fuse<Prompt>([], { keys: ["title"] }), builtinEngine: new Fuse<Prompt>([], { keys: ["title"] }),
@ -62,130 +48,136 @@ export const SearchService = {
}, },
}; };
export const usePromptStore = create<PromptStore>()( export const usePromptStore = createPersistStore(
persist( {
(set, get) => ({ counter: 0,
counter: 0, prompts: {} as Record<string, Prompt>,
latestId: 0, },
prompts: {},
add(prompt) { (set, get) => ({
const prompts = get().prompts; add(prompt: Prompt) {
prompt.id = nanoid(); const prompts = get().prompts;
prompt.isUser = true; prompt.id = nanoid();
prompt.createdAt = Date.now(); prompt.isUser = true;
prompts[prompt.id] = prompt; prompt.createdAt = Date.now();
prompts[prompt.id] = prompt;
set(() => ({ set(() => ({
latestId: prompt.id!, prompts: prompts,
prompts: prompts, }));
}));
return prompt.id!; return prompt.id!;
},
get(id) {
const targetPrompt = get().prompts[id];
if (!targetPrompt) {
return SearchService.builtinPrompts.find((v) => v.id === id);
}
return targetPrompt;
},
remove(id) {
const prompts = get().prompts;
delete prompts[id];
SearchService.remove(id);
set(() => ({
prompts,
counter: get().counter + 1,
}));
},
getUserPrompts() {
const userPrompts = Object.values(get().prompts ?? {});
userPrompts.sort((a, b) =>
b.id && a.id ? b.createdAt - a.createdAt : 0,
);
return userPrompts;
},
update(id, updater) {
const prompt = get().prompts[id] ?? {
title: "",
content: "",
id,
};
SearchService.remove(id);
updater(prompt);
const prompts = get().prompts;
prompts[id] = prompt;
set(() => ({ prompts }));
SearchService.add(prompt);
},
search(text) {
if (text.length === 0) {
// return all rompts
return get().getUserPrompts().concat(SearchService.builtinPrompts);
}
return SearchService.search(text) as Prompt[];
},
}),
{
name: StoreKey.Prompt,
version: 3,
migrate(state, version) {
const newState = JSON.parse(JSON.stringify(state)) as PromptStore;
if (version < 3) {
Object.values(newState.prompts).forEach((p) => (p.id = nanoid()));
}
return newState;
},
onRehydrateStorage(state) {
const PROMPT_URL = "./prompts.json";
type PromptList = Array<[string, string]>;
fetch(PROMPT_URL)
.then((res) => res.json())
.then((res) => {
let fetchPrompts = [res.en, res.cn];
if (getLang() === "cn") {
fetchPrompts = fetchPrompts.reverse();
}
const builtinPrompts = fetchPrompts.map(
(promptList: PromptList) => {
return promptList.map(
([title, content]) =>
({
id: nanoid(),
title,
content,
createdAt: Date.now(),
} as Prompt),
);
},
);
const userPrompts =
usePromptStore.getState().getUserPrompts() ?? [];
const allPromptsForSearch = builtinPrompts
.reduce((pre, cur) => pre.concat(cur), [])
.filter((v) => !!v.title && !!v.content);
SearchService.count.builtin = res.en.length + res.cn.length;
SearchService.init(allPromptsForSearch, userPrompts);
});
},
}, },
),
get(id: string) {
const targetPrompt = get().prompts[id];
if (!targetPrompt) {
return SearchService.builtinPrompts.find((v) => v.id === id);
}
return targetPrompt;
},
remove(id: string) {
const prompts = get().prompts;
delete prompts[id];
Object.entries(prompts).some(([key, prompt]) => {
if (prompt.id === id) {
delete prompts[key];
return true;
}
return false;
});
SearchService.remove(id);
set(() => ({
prompts,
counter: get().counter + 1,
}));
},
getUserPrompts() {
const userPrompts = Object.values(get().prompts ?? {});
userPrompts.sort((a, b) =>
b.id && a.id ? b.createdAt - a.createdAt : 0,
);
return userPrompts;
},
updatePrompt(id: string, updater: (prompt: Prompt) => void) {
const prompt = get().prompts[id] ?? {
title: "",
content: "",
id,
};
SearchService.remove(id);
updater(prompt);
const prompts = get().prompts;
prompts[id] = prompt;
set(() => ({ prompts }));
SearchService.add(prompt);
},
search(text: string) {
if (text.length === 0) {
// return all rompts
return this.getUserPrompts().concat(SearchService.builtinPrompts);
}
return SearchService.search(text) as Prompt[];
},
}),
{
name: StoreKey.Prompt,
version: 3,
migrate(state, version) {
const newState = JSON.parse(JSON.stringify(state)) as {
prompts: Record<string, Prompt>;
};
if (version < 3) {
Object.values(newState.prompts).forEach((p) => (p.id = nanoid()));
}
return newState as any;
},
onRehydrateStorage(state) {
const PROMPT_URL = "./prompts.json";
type PromptList = Array<[string, string]>;
fetch(PROMPT_URL)
.then((res) => res.json())
.then((res) => {
let fetchPrompts = [res.en, res.cn];
if (getLang() === "cn") {
fetchPrompts = fetchPrompts.reverse();
}
const builtinPrompts = fetchPrompts.map((promptList: PromptList) => {
return promptList.map(
([title, content]) =>
({
id: nanoid(),
title,
content,
createdAt: Date.now(),
}) as Prompt,
);
});
const userPrompts = usePromptStore.getState().getUserPrompts() ?? [];
const allPromptsForSearch = builtinPrompts
.reduce((pre, cur) => pre.concat(cur), [])
.filter((v) => !!v.title && !!v.content);
SearchService.count.builtin = res.en.length + res.cn.length;
SearchService.init(allPromptsForSearch, userPrompts);
});
},
},
); );

View File

@ -1,7 +1,15 @@
import { Updater } from "../typing"; import { Updater } from "../typing";
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { StoreKey } from "../constant"; import { StoreKey } from "../constant";
import { createPersistStore } from "../utils/store";
import {
AppState,
getLocalAppState,
mergeAppState,
setLocalAppState,
} from "../utils/sync";
import { downloadAs, readFromFile } from "../utils";
import { showToast } from "../components/ui-lib";
import Locale from "../locales";
export interface WebDavConfig { export interface WebDavConfig {
server: string; server: string;
@ -20,68 +28,86 @@ export interface SyncStore {
headers: () => { Authorization: string }; headers: () => { Authorization: string };
} }
const FILE = { export const useSyncStore = createPersistStore(
root: "/chatgpt-next-web/", {
}; webDavConfig: {
server: "",
export const useSyncStore = create<SyncStore>()( username: "",
persist( password: "",
(set, get) => ({
webDavConfig: {
server: "",
username: "",
password: "",
},
lastSyncTime: 0,
update(updater) {
const config = { ...get().webDavConfig };
updater(config);
set({ webDavConfig: config });
},
async check() {
try {
const res = await fetch(this.path(""), {
method: "PROFIND",
headers: this.headers(),
});
console.log(res);
return res.status === 207;
} catch (e) {
console.error("[Sync] ", e);
return false;
}
},
path(path: string) {
let url = get().webDavConfig.server;
if (!url.endsWith("/")) {
url += "/";
}
if (path.startsWith("/")) {
path = path.slice(1);
}
return url + path;
},
headers() {
const auth = btoa(
[get().webDavConfig.username, get().webDavConfig.password].join(":"),
);
return {
Authorization: `Basic ${auth}`,
};
},
}),
{
name: StoreKey.Sync,
version: 1,
}, },
),
lastSyncTime: 0,
},
(set, get) => ({
webDavConfig: {
server: "",
username: "",
password: "",
},
lastSyncTime: 0,
export() {
const state = getLocalAppState();
const fileName = `Backup-${new Date().toLocaleString()}.json`;
downloadAs(JSON.stringify(state), fileName);
},
async import() {
const rawContent = await readFromFile();
try {
const remoteState = JSON.parse(rawContent) as AppState;
const localState = getLocalAppState();
mergeAppState(localState, remoteState);
setLocalAppState(localState);
location.reload();
} catch (e) {
console.error("[Import]", e);
showToast(Locale.Settings.Sync.ImportFailed);
}
},
async check() {
try {
const res = await fetch(this.path(""), {
method: "PROFIND",
headers: this.headers(),
});
console.log(res);
return res.status === 207;
} catch (e) {
console.error("[Sync] ", e);
return false;
}
},
path(path: string) {
let url = get().webDavConfig.server;
if (!url.endsWith("/")) {
url += "/";
}
if (path.startsWith("/")) {
path = path.slice(1);
}
return url + path;
},
headers() {
const auth = btoa(
[get().webDavConfig.username, get().webDavConfig.password].join(":"),
);
return {
Authorization: `Basic ${auth}`,
};
},
}),
{
name: StoreKey.Sync,
version: 1,
},
); );

View File

@ -1,24 +1,7 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant"; import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant";
import { api } from "../client/api"; import { api } from "../client/api";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
export interface UpdateStore {
versionType: "date" | "tag";
lastUpdate: number;
version: string;
remoteVersion: string;
used?: number;
subscription?: number;
lastUpdateUsage: number;
getLatestVersion: (force?: boolean) => Promise<void>;
updateUsage: (force?: boolean) => Promise<void>;
formatVersion: (version: string) => string;
}
const ONE_MINUTE = 60 * 1000; const ONE_MINUTE = 60 * 1000;
@ -35,7 +18,9 @@ function formatVersionDate(t: string) {
].join(""); ].join("");
} }
async function getVersion(type: "date" | "tag") { type VersionType = "date" | "tag";
async function getVersion(type: VersionType) {
if (type === "date") { if (type === "date") {
const data = (await (await fetch(FETCH_COMMIT_URL)).json()) as { const data = (await (await fetch(FETCH_COMMIT_URL)).json()) as {
commit: { commit: {
@ -55,75 +40,76 @@ async function getVersion(type: "date" | "tag") {
} }
} }
export const useUpdateStore = create<UpdateStore>()( export const useUpdateStore = createPersistStore(
persist( {
(set, get) => ({ versionType: "tag" as VersionType,
versionType: "tag", lastUpdate: 0,
lastUpdate: 0, version: "unknown",
version: "unknown", remoteVersion: "",
remoteVersion: "", used: 0,
subscription: 0,
lastUpdateUsage: 0, lastUpdateUsage: 0,
},
formatVersion(version: string) { (set, get) => ({
if (get().versionType === "date") { formatVersion(version: string) {
version = formatVersionDate(version); if (get().versionType === "date") {
} version = formatVersionDate(version);
return version; }
}, return version;
async getLatestVersion(force = false) {
const versionType = get().versionType;
let version =
versionType === "date"
? getClientConfig()?.commitDate
: getClientConfig()?.version;
set(() => ({ version }));
const shouldCheck = Date.now() - get().lastUpdate > 2 * 60 * ONE_MINUTE;
if (!force && !shouldCheck) return;
set(() => ({
lastUpdate: Date.now(),
}));
try {
const remoteId = await getVersion(versionType);
set(() => ({
remoteVersion: remoteId,
}));
console.log("[Got Upstream] ", remoteId);
} catch (error) {
console.error("[Fetch Upstream Commit Id]", error);
}
},
async updateUsage(force = false) {
const overOneMinute = Date.now() - get().lastUpdateUsage >= ONE_MINUTE;
if (!overOneMinute && !force) return;
set(() => ({
lastUpdateUsage: Date.now(),
}));
try {
const usage = await api.llm.usage();
if (usage) {
set(() => ({
used: usage.used,
subscription: usage.total,
}));
}
} catch (e) {
console.error((e as Error).message);
}
},
}),
{
name: StoreKey.Update,
version: 1,
}, },
),
async getLatestVersion(force = false) {
const versionType = get().versionType;
let version =
versionType === "date"
? getClientConfig()?.commitDate
: getClientConfig()?.version;
set(() => ({ version }));
const shouldCheck = Date.now() - get().lastUpdate > 2 * 60 * ONE_MINUTE;
if (!force && !shouldCheck) return;
set(() => ({
lastUpdate: Date.now(),
}));
try {
const remoteId = await getVersion(versionType);
set(() => ({
remoteVersion: remoteId,
}));
console.log("[Got Upstream] ", remoteId);
} catch (error) {
console.error("[Fetch Upstream Commit Id]", error);
}
},
async updateUsage(force = false) {
const overOneMinute = Date.now() - get().lastUpdateUsage >= ONE_MINUTE;
if (!overOneMinute && !force) return;
set(() => ({
lastUpdateUsage: Date.now(),
}));
try {
const usage = await api.llm.usage();
if (usage) {
set(() => ({
used: usage.used,
subscription: usage.total,
}));
}
} catch (e) {
console.error((e as Error).message);
}
},
}),
{
name: StoreKey.Update,
version: 1,
},
); );

3
app/utils/clone.ts Normal file
View File

@ -0,0 +1,3 @@
export function deepClone<T>(obj: T) {
return JSON.parse(JSON.stringify(obj));
}

55
app/utils/store.ts Normal file
View File

@ -0,0 +1,55 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { Updater } from "../typing";
import { deepClone } from "./clone";
type SecondParam<T> = T extends (
_f: infer _F,
_s: infer S,
...args: infer _U
) => any
? S
: never;
type MakeUpdater<T> = {
lastUpdateTime: number;
markUpdate: () => void;
update: Updater<T>;
};
type SetStoreState<T> = (
partial: T | Partial<T> | ((state: T) => T | Partial<T>),
replace?: boolean | undefined,
) => void;
export function createPersistStore<T, M>(
defaultState: T,
methods: (
set: SetStoreState<T & MakeUpdater<T>>,
get: () => T & MakeUpdater<T>,
) => M,
persistOptions: SecondParam<typeof persist<T & M & MakeUpdater<T>>>,
) {
return create<T & M & MakeUpdater<T>>()(
persist((set, get) => {
return {
...defaultState,
...methods(set as any, get),
lastUpdateTime: 0,
markUpdate() {
set({ lastUpdateTime: Date.now() } as Partial<
T & M & MakeUpdater<T>
>);
},
update(updater) {
const state = deepClone(get());
updater(state);
get().markUpdate();
set(state);
},
};
}, persistOptions),
);
}

162
app/utils/sync.ts Normal file
View File

@ -0,0 +1,162 @@
import {
ChatSession,
useAccessStore,
useAppConfig,
useChatStore,
} from "../store";
import { useMaskStore } from "../store/mask";
import { usePromptStore } from "../store/prompt";
import { StoreKey } from "../constant";
import { merge } from "./merge";
type NonFunctionKeys<T> = {
[K in keyof T]: T[K] extends (...args: any[]) => any ? never : K;
}[keyof T];
type NonFunctionFields<T> = Pick<T, NonFunctionKeys<T>>;
export function getNonFunctionFileds<T extends object>(obj: T) {
const ret: any = {};
Object.entries(obj).map(([k, v]) => {
if (typeof v !== "function") {
ret[k] = v;
}
});
return ret as NonFunctionFields<T>;
}
export type GetStoreState<T> = T extends { getState: () => infer U }
? NonFunctionFields<U>
: never;
const LocalStateSetters = {
[StoreKey.Chat]: useChatStore.setState,
[StoreKey.Access]: useAccessStore.setState,
[StoreKey.Config]: useAppConfig.setState,
[StoreKey.Mask]: useMaskStore.setState,
[StoreKey.Prompt]: usePromptStore.setState,
} as const;
const LocalStateGetters = {
[StoreKey.Chat]: () => getNonFunctionFileds(useChatStore.getState()),
[StoreKey.Access]: () => getNonFunctionFileds(useAccessStore.getState()),
[StoreKey.Config]: () => getNonFunctionFileds(useAppConfig.getState()),
[StoreKey.Mask]: () => getNonFunctionFileds(useMaskStore.getState()),
[StoreKey.Prompt]: () => getNonFunctionFileds(usePromptStore.getState()),
} as const;
export type AppState = {
[k in keyof typeof LocalStateGetters]: ReturnType<
(typeof LocalStateGetters)[k]
>;
};
type Merger<T extends keyof AppState, U = AppState[T]> = (
localState: U,
remoteState: U,
) => U;
type StateMerger = {
[K in keyof AppState]: Merger<K>;
};
// we merge remote state to local state
const MergeStates: StateMerger = {
[StoreKey.Chat]: (localState, remoteState) => {
// merge sessions
const localSessions: Record<string, ChatSession> = {};
localState.sessions.forEach((s) => (localSessions[s.id] = s));
remoteState.sessions.forEach((remoteSession) => {
const localSession = localSessions[remoteSession.id];
if (!localSession) {
// if remote session is new, just merge it
localState.sessions.push(remoteSession);
} else {
// if both have the same session id, merge the messages
const localMessageIds = new Set(localSession.messages.map((v) => v.id));
remoteSession.messages.forEach((m) => {
if (!localMessageIds.has(m.id)) {
localSession.messages.push(m);
}
});
// sort local messages with date field in asc order
localSession.messages.sort(
(a, b) => new Date(a.date).getTime() - new Date(b.date).getTime(),
);
}
});
// sort local sessions with date field in desc order
localState.sessions.sort(
(a, b) =>
new Date(b.lastUpdate).getTime() - new Date(a.lastUpdate).getTime(),
);
return localState;
},
[StoreKey.Prompt]: (localState, remoteState) => {
localState.prompts = {
...remoteState.prompts,
...localState.prompts,
};
return localState;
},
[StoreKey.Mask]: (localState, remoteState) => {
localState.masks = {
...remoteState.masks,
...localState.masks,
};
return localState;
},
[StoreKey.Config]: mergeWithUpdate<AppState[StoreKey.Config]>,
[StoreKey.Access]: mergeWithUpdate<AppState[StoreKey.Access]>,
};
export function getLocalAppState() {
const appState = Object.fromEntries(
Object.entries(LocalStateGetters).map(([key, getter]) => {
return [key, getter()];
}),
) as AppState;
return appState;
}
export function setLocalAppState(appState: AppState) {
Object.entries(LocalStateSetters).forEach(([key, setter]) => {
setter(appState[key as keyof AppState]);
});
}
export function mergeAppState(localState: AppState, remoteState: AppState) {
Object.keys(localState).forEach(<T extends keyof AppState>(k: string) => {
const key = k as T;
const localStoreState = localState[key];
const remoteStoreState = remoteState[key];
MergeStates[key](localStoreState, remoteStoreState);
});
return localState;
}
/**
* Merge state with `lastUpdateTime`, older state will be override
*/
export function mergeWithUpdate<T extends { lastUpdateTime?: number }>(
localState: T,
remoteState: T,
) {
const localUpdateTime = localState.lastUpdateTime ?? 0;
const remoteUpdateTime = localState.lastUpdateTime ?? 1;
if (localUpdateTime < remoteUpdateTime) {
merge(remoteState, localState);
return { ...remoteState };
} else {
merge(localState, remoteState);
return { ...localState };
}
}