forked from XiaoMo/ChatGPT-Next-Web
feat: #close 1789 add user input template
This commit is contained in:
parent
fa9ceb5875
commit
be597a551d
@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
|
||||
|
||||
import Locale from "../locales";
|
||||
import { InputRange } from "./input-range";
|
||||
import { List, ListItem, Select } from "./ui-lib";
|
||||
import { ListItem, Select } from "./ui-lib";
|
||||
|
||||
export function ModelConfigList(props: {
|
||||
modelConfig: ModelConfig;
|
||||
@ -109,6 +109,21 @@ export function ModelConfigList(props: {
|
||||
></InputRange>
|
||||
</ListItem>
|
||||
|
||||
<ListItem
|
||||
title={Locale.Settings.InputTemplate.Title}
|
||||
subTitle={Locale.Settings.InputTemplate.SubTitle}
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
value={props.modelConfig.template}
|
||||
onChange={(e) =>
|
||||
props.updateConfig(
|
||||
(config) => (config.template = e.currentTarget.value),
|
||||
)
|
||||
}
|
||||
></input>
|
||||
</ListItem>
|
||||
|
||||
<ListItem
|
||||
title={Locale.Settings.HistoryCount.Title}
|
||||
subTitle={Locale.Settings.HistoryCount.SubTitle}
|
||||
|
@ -52,3 +52,10 @@ export const OpenaiPath = {
|
||||
UsagePath: "dashboard/billing/usage",
|
||||
SubsPath: "dashboard/billing/subscription",
|
||||
};
|
||||
|
||||
export const DEFAULT_INPUT_TEMPLATE = `
|
||||
Act as a virtual assistant powered by model: '{{model}}', my input is:
|
||||
'''
|
||||
{{input}}
|
||||
'''
|
||||
`;
|
||||
|
@ -115,6 +115,11 @@ const cn = {
|
||||
SubTitle: "聊天内容的字体大小",
|
||||
},
|
||||
|
||||
InputTemplate: {
|
||||
Title: "用户输入预处理",
|
||||
SubTitle: "用户最新的一条消息会填充到此模板",
|
||||
},
|
||||
|
||||
Update: {
|
||||
Version: (x: string) => `当前版本:${x}`,
|
||||
IsLatest: "已是最新版本",
|
||||
|
@ -116,6 +116,12 @@ const en: LocaleType = {
|
||||
Title: "Font Size",
|
||||
SubTitle: "Adjust font size of chat content",
|
||||
},
|
||||
|
||||
InputTemplate: {
|
||||
Title: "Input Template",
|
||||
SubTitle: "Newest message will be filled to this template",
|
||||
},
|
||||
|
||||
Update: {
|
||||
Version: (x: string) => `Version: ${x}`,
|
||||
IsLatest: "Latest version",
|
||||
|
@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
|
||||
|
||||
import { trimTopic } from "../utils";
|
||||
|
||||
import Locale from "../locales";
|
||||
import Locale, { getLang } from "../locales";
|
||||
import { showToast } from "../components/ui-lib";
|
||||
import { ModelType } from "./config";
|
||||
import { ModelConfig, ModelType } from "./config";
|
||||
import { createEmptyMask, Mask } from "./mask";
|
||||
import { StoreKey } from "../constant";
|
||||
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
|
||||
import { api, RequestMessage } from "../client/api";
|
||||
import { ChatControllerPool } from "../client/controller";
|
||||
import { prettyObject } from "../utils/format";
|
||||
@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
|
||||
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
||||
}
|
||||
|
||||
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
|
||||
const vars = {
|
||||
model: modelConfig.model,
|
||||
time: new Date().toLocaleString(),
|
||||
lang: getLang(),
|
||||
input: input,
|
||||
};
|
||||
|
||||
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
|
||||
|
||||
// must contains {{input}}
|
||||
const inputVar = "{{input}}";
|
||||
if (!output.includes(inputVar)) {
|
||||
output += "\n" + inputVar;
|
||||
}
|
||||
|
||||
Object.entries(vars).forEach(([name, value]) => {
|
||||
output = output.replaceAll(`{{${name}}}`, value);
|
||||
});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
export const useChatStore = create<ChatStore>()(
|
||||
persist(
|
||||
(set, get) => ({
|
||||
@ -238,9 +261,12 @@ export const useChatStore = create<ChatStore>()(
|
||||
const session = get().currentSession();
|
||||
const modelConfig = session.mask.modelConfig;
|
||||
|
||||
const userContent = fillTemplateWith(content, modelConfig);
|
||||
console.log("[User Input] fill with template: ", userContent);
|
||||
|
||||
const userMessage: ChatMessage = createMessage({
|
||||
role: "user",
|
||||
content,
|
||||
content: userContent,
|
||||
});
|
||||
|
||||
const botMessage: ChatMessage = createMessage({
|
||||
@ -250,31 +276,22 @@ export const useChatStore = create<ChatStore>()(
|
||||
model: modelConfig.model,
|
||||
});
|
||||
|
||||
const systemInfo = createMessage({
|
||||
role: "system",
|
||||
content: `IMPORTANT: You are a virtual assistant powered by the ${
|
||||
modelConfig.model
|
||||
} model, now time is ${new Date().toLocaleString()}}`,
|
||||
id: botMessage.id! + 1,
|
||||
});
|
||||
|
||||
// get recent messages
|
||||
const systemMessages = [];
|
||||
// if user define a mask with context prompts, wont send system info
|
||||
if (session.mask.context.length === 0) {
|
||||
systemMessages.push(systemInfo);
|
||||
}
|
||||
|
||||
const recentMessages = get().getMessagesWithMemory();
|
||||
const sendMessages = systemMessages.concat(
|
||||
recentMessages.concat(userMessage),
|
||||
);
|
||||
const sendMessages = recentMessages.concat(userMessage);
|
||||
const sessionIndex = get().currentSessionIndex;
|
||||
const messageIndex = get().currentSession().messages.length + 1;
|
||||
|
||||
// save user's and bot's message
|
||||
get().updateCurrentSession((session) => {
|
||||
session.messages = session.messages.concat([userMessage, botMessage]);
|
||||
const savedUserMessage = {
|
||||
...userMessage,
|
||||
content,
|
||||
};
|
||||
session.messages = session.messages.concat([
|
||||
savedUserMessage,
|
||||
botMessage,
|
||||
]);
|
||||
});
|
||||
|
||||
// make request
|
||||
@ -350,55 +367,62 @@ export const useChatStore = create<ChatStore>()(
|
||||
getMessagesWithMemory() {
|
||||
const session = get().currentSession();
|
||||
const modelConfig = session.mask.modelConfig;
|
||||
const clearContextIndex = session.clearContextIndex ?? 0;
|
||||
const messages = session.messages.slice();
|
||||
const totalMessageCount = session.messages.length;
|
||||
|
||||
// wont send cleared context messages
|
||||
const clearedContextMessages = session.messages.slice(
|
||||
session.clearContextIndex ?? 0,
|
||||
);
|
||||
const messages = clearedContextMessages.filter((msg) => !msg.isError);
|
||||
const n = messages.length;
|
||||
|
||||
const context = session.mask.context.slice();
|
||||
// in-context prompts
|
||||
const contextPrompts = session.mask.context.slice();
|
||||
|
||||
// long term memory
|
||||
if (
|
||||
const shouldSendLongTermMemory =
|
||||
modelConfig.sendMemory &&
|
||||
session.memoryPrompt &&
|
||||
session.memoryPrompt.length > 0
|
||||
) {
|
||||
const memoryPrompt = get().getMemoryPrompt();
|
||||
context.push(memoryPrompt);
|
||||
}
|
||||
session.memoryPrompt.length > 0 &&
|
||||
session.lastSummarizeIndex <= clearContextIndex;
|
||||
const longTermMemoryPrompts = shouldSendLongTermMemory
|
||||
? [get().getMemoryPrompt()]
|
||||
: [];
|
||||
const longTermMemoryStartIndex = session.lastSummarizeIndex;
|
||||
|
||||
// get short term and unmemorized long term memory
|
||||
const shortTermMemoryMessageIndex = Math.max(
|
||||
// short term memory
|
||||
const shortTermMemoryStartIndex = Math.max(
|
||||
0,
|
||||
n - modelConfig.historyMessageCount,
|
||||
totalMessageCount - modelConfig.historyMessageCount,
|
||||
);
|
||||
const longTermMemoryMessageIndex = session.lastSummarizeIndex;
|
||||
|
||||
// try to concat history messages
|
||||
// lets concat send messages, including 4 parts:
|
||||
// 1. long term memory: summarized memory messages
|
||||
// 2. pre-defined in-context prompts
|
||||
// 3. short term memory: latest n messages
|
||||
// 4. newest input message
|
||||
const memoryStartIndex = Math.min(
|
||||
shortTermMemoryMessageIndex,
|
||||
longTermMemoryMessageIndex,
|
||||
longTermMemoryStartIndex,
|
||||
shortTermMemoryStartIndex,
|
||||
);
|
||||
const threshold = modelConfig.max_tokens;
|
||||
// and if user has cleared history messages, we should exclude the memory too.
|
||||
const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
|
||||
const maxTokenThreshold = modelConfig.max_tokens;
|
||||
|
||||
// get recent messages as many as possible
|
||||
// get recent messages as much as possible
|
||||
const reversedRecentMessages = [];
|
||||
for (
|
||||
let i = n - 1, count = 0;
|
||||
i >= memoryStartIndex && count < threshold;
|
||||
let i = totalMessageCount - 1, tokenCount = 0;
|
||||
i >= contextStartIndex && tokenCount < maxTokenThreshold;
|
||||
i -= 1
|
||||
) {
|
||||
const msg = messages[i];
|
||||
if (!msg || msg.isError) continue;
|
||||
count += estimateTokenLength(msg.content);
|
||||
tokenCount += estimateTokenLength(msg.content);
|
||||
reversedRecentMessages.push(msg);
|
||||
}
|
||||
|
||||
// concat
|
||||
const recentMessages = context.concat(reversedRecentMessages.reverse());
|
||||
// concat all messages
|
||||
const recentMessages = [
|
||||
...longTermMemoryPrompts,
|
||||
...contextPrompts,
|
||||
...reversedRecentMessages.reverse(),
|
||||
];
|
||||
|
||||
return recentMessages;
|
||||
},
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { create } from "zustand";
|
||||
import { persist } from "zustand/middleware";
|
||||
import { getClientConfig } from "../config/client";
|
||||
import { StoreKey } from "../constant";
|
||||
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
|
||||
|
||||
export enum SubmitKey {
|
||||
Enter = "Enter",
|
||||
@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = {
|
||||
sendMemory: true,
|
||||
historyMessageCount: 4,
|
||||
compressMessageLengthThreshold: 1000,
|
||||
template: DEFAULT_INPUT_TEMPLATE,
|
||||
},
|
||||
};
|
||||
|
||||
@ -176,15 +177,16 @@ export const useAppConfig = create<ChatConfigStore>()(
|
||||
}),
|
||||
{
|
||||
name: StoreKey.Config,
|
||||
version: 3,
|
||||
version: 3.1,
|
||||
migrate(persistedState, version) {
|
||||
if (version === 3) return persistedState as any;
|
||||
if (version === 3.1) return persistedState as any;
|
||||
|
||||
const state = persistedState as ChatConfig;
|
||||
state.modelConfig.sendMemory = true;
|
||||
state.modelConfig.historyMessageCount = 4;
|
||||
state.modelConfig.compressMessageLengthThreshold = 1000;
|
||||
state.modelConfig.frequency_penalty = 0;
|
||||
state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
|
||||
state.dontShowMaskSplashScreen = false;
|
||||
|
||||
return state;
|
||||
|
Loading…
Reference in New Issue
Block a user