feat: close #864 improve long term history

This commit is contained in:
Yifei Zhang 2023-04-18 11:42:08 +08:00 committed by GitHub
parent 146ef1bf49
commit d75b7d49b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 23 deletions

View File

@ -38,12 +38,12 @@ const cn = {
MessageFromChatGPT: "来自 ChatGPT 的消息", MessageFromChatGPT: "来自 ChatGPT 的消息",
}, },
Memory: { Memory: {
Title: "历史记忆", Title: "历史摘要",
EmptyContent: "尚未记忆", EmptyContent: "尚未总结",
Send: "发送记忆", Send: "启用总结并发送摘要",
Copy: "复制记忆", Copy: "复制摘要",
Reset: "重置对话", Reset: "重置对话",
ResetConfirm: "重置后将清空当前对话记录以及历史记忆,确认重置?", ResetConfirm: "重置后将清空当前对话记录以及历史摘要,确认重置?",
}, },
Home: { Home: {
NewChat: "新的聊天", NewChat: "新的聊天",

View File

@ -102,7 +102,7 @@ export function limitNumber(
x: number, x: number,
min: number, min: number,
max: number, max: number,
defaultValue: number, defaultValue: number
) { ) {
if (typeof x !== "number" || isNaN(x)) { if (typeof x !== "number" || isNaN(x)) {
return defaultValue; return defaultValue;
@ -217,7 +217,7 @@ interface ChatStore {
updateMessage: ( updateMessage: (
sessionIndex: number, sessionIndex: number,
messageIndex: number, messageIndex: number,
updater: (message?: Message) => void, updater: (message?: Message) => void
) => void; ) => void;
resetSession: () => void; resetSession: () => void;
getMessagesWithMemory: () => Message[]; getMessagesWithMemory: () => Message[];
@ -345,12 +345,12 @@ export const useChatStore = create<ChatStore>()(
.slice(0, index) .slice(0, index)
.concat([deletedSession]) .concat([deletedSession])
.concat( .concat(
state.sessions.slice(index + Number(isLastSession)), state.sessions.slice(index + Number(isLastSession))
), ),
})); }));
}, },
}, },
5000, 5000
); );
} }
}, },
@ -412,7 +412,7 @@ export const useChatStore = create<ChatStore>()(
get().onNewMessage(botMessage); get().onNewMessage(botMessage);
ControllerPool.remove( ControllerPool.remove(
sessionIndex, sessionIndex,
botMessage.id ?? messageIndex, botMessage.id ?? messageIndex
); );
} else { } else {
botMessage.content = content; botMessage.content = content;
@ -436,7 +436,7 @@ export const useChatStore = create<ChatStore>()(
ControllerPool.addController( ControllerPool.addController(
sessionIndex, sessionIndex,
botMessage.id ?? messageIndex, botMessage.id ?? messageIndex,
controller, controller
); );
}, },
filterBot: !get().config.sendBotMessages, filterBot: !get().config.sendBotMessages,
@ -462,6 +462,7 @@ export const useChatStore = create<ChatStore>()(
const context = session.context.slice(); const context = session.context.slice();
// long term memory
if ( if (
session.sendMemory && session.sendMemory &&
session.memoryPrompt && session.memoryPrompt &&
@ -471,9 +472,33 @@ export const useChatStore = create<ChatStore>()(
context.push(memoryPrompt); context.push(memoryPrompt);
} }
const recentMessages = context.concat( // get short term and unmemoried long term memory
messages.slice(Math.max(0, n - config.historyMessageCount)), const shortTermMemoryMessageIndex = Math.max(
0,
n - config.historyMessageCount
); );
const longTermMemoryMessageIndex = config.lastSummarizeIndex;
const oldestIndex = Math.min(
shortTermMemoryMessageIndex,
longTermMemoryMessageIndex
);
const threshold = config.compressMessageLengthThreshold;
// get recent messages as many as possible
const reversedRecentMessages = [];
for (
let i = n - 1, count = 0;
i >= oldestIndex && count < threshold;
i -= 1
) {
const msg = messages[i];
if (!msg || msg.isError) continue;
count += msg.content.length;
reversedRecentMessages.push(msg);
}
// concat
const recentMessages = context.concat(reversedRecentMessages.reverse());
return recentMessages; return recentMessages;
}, },
@ -481,7 +506,7 @@ export const useChatStore = create<ChatStore>()(
updateMessage( updateMessage(
sessionIndex: number, sessionIndex: number,
messageIndex: number, messageIndex: number,
updater: (message?: Message) => void, updater: (message?: Message) => void
) { ) {
const sessions = get().sessions; const sessions = get().sessions;
const session = sessions.at(sessionIndex); const session = sessions.at(sessionIndex);
@ -510,15 +535,15 @@ export const useChatStore = create<ChatStore>()(
(res) => { (res) => {
get().updateCurrentSession( get().updateCurrentSession(
(session) => (session) =>
(session.topic = res ? trimTopic(res) : DEFAULT_TOPIC), (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC)
); );
}, }
); );
} }
const config = get().config; const config = get().config;
let toBeSummarizedMsgs = session.messages.slice( let toBeSummarizedMsgs = session.messages.slice(
session.lastSummarizeIndex, session.lastSummarizeIndex
); );
const historyMsgLength = countMessages(toBeSummarizedMsgs); const historyMsgLength = countMessages(toBeSummarizedMsgs);
@ -526,7 +551,7 @@ export const useChatStore = create<ChatStore>()(
if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) { if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
const n = toBeSummarizedMsgs.length; const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice( toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
Math.max(0, n - config.historyMessageCount), Math.max(0, n - config.historyMessageCount)
); );
} }
@ -539,10 +564,13 @@ export const useChatStore = create<ChatStore>()(
"[Chat History] ", "[Chat History] ",
toBeSummarizedMsgs, toBeSummarizedMsgs,
historyMsgLength, historyMsgLength,
config.compressMessageLengthThreshold, config.compressMessageLengthThreshold
); );
if (historyMsgLength > config.compressMessageLengthThreshold) { if (
historyMsgLength > config.compressMessageLengthThreshold &&
session.sendMemory
) {
requestChatStream( requestChatStream(
toBeSummarizedMsgs.concat({ toBeSummarizedMsgs.concat({
role: "system", role: "system",
@ -561,7 +589,7 @@ export const useChatStore = create<ChatStore>()(
onError(error) { onError(error) {
console.error("[Summarize] ", error); console.error("[Summarize] ", error);
}, },
}, }
); );
} }
}, },
@ -603,6 +631,6 @@ export const useChatStore = create<ChatStore>()(
return state; return state;
}, },
}, }
), )
); );