forked from XiaoMo/ChatGPT-Next-Web
feat: #close 1789 add user input template
This commit is contained in:
parent
fa9ceb5875
commit
be597a551d
@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
|
|||||||
|
|
||||||
import Locale from "../locales";
|
import Locale from "../locales";
|
||||||
import { InputRange } from "./input-range";
|
import { InputRange } from "./input-range";
|
||||||
import { List, ListItem, Select } from "./ui-lib";
|
import { ListItem, Select } from "./ui-lib";
|
||||||
|
|
||||||
export function ModelConfigList(props: {
|
export function ModelConfigList(props: {
|
||||||
modelConfig: ModelConfig;
|
modelConfig: ModelConfig;
|
||||||
@ -109,6 +109,21 @@ export function ModelConfigList(props: {
|
|||||||
></InputRange>
|
></InputRange>
|
||||||
</ListItem>
|
</ListItem>
|
||||||
|
|
||||||
|
<ListItem
|
||||||
|
title={Locale.Settings.InputTemplate.Title}
|
||||||
|
subTitle={Locale.Settings.InputTemplate.SubTitle}
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={props.modelConfig.template}
|
||||||
|
onChange={(e) =>
|
||||||
|
props.updateConfig(
|
||||||
|
(config) => (config.template = e.currentTarget.value),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
></input>
|
||||||
|
</ListItem>
|
||||||
|
|
||||||
<ListItem
|
<ListItem
|
||||||
title={Locale.Settings.HistoryCount.Title}
|
title={Locale.Settings.HistoryCount.Title}
|
||||||
subTitle={Locale.Settings.HistoryCount.SubTitle}
|
subTitle={Locale.Settings.HistoryCount.SubTitle}
|
||||||
|
@ -52,3 +52,10 @@ export const OpenaiPath = {
|
|||||||
UsagePath: "dashboard/billing/usage",
|
UsagePath: "dashboard/billing/usage",
|
||||||
SubsPath: "dashboard/billing/subscription",
|
SubsPath: "dashboard/billing/subscription",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const DEFAULT_INPUT_TEMPLATE = `
|
||||||
|
Act as a virtual assistant powered by model: '{{model}}', my input is:
|
||||||
|
'''
|
||||||
|
{{input}}
|
||||||
|
'''
|
||||||
|
`;
|
||||||
|
@ -115,6 +115,11 @@ const cn = {
|
|||||||
SubTitle: "聊天内容的字体大小",
|
SubTitle: "聊天内容的字体大小",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
InputTemplate: {
|
||||||
|
Title: "用户输入预处理",
|
||||||
|
SubTitle: "用户最新的一条消息会填充到此模板",
|
||||||
|
},
|
||||||
|
|
||||||
Update: {
|
Update: {
|
||||||
Version: (x: string) => `当前版本:${x}`,
|
Version: (x: string) => `当前版本:${x}`,
|
||||||
IsLatest: "已是最新版本",
|
IsLatest: "已是最新版本",
|
||||||
|
@ -116,6 +116,12 @@ const en: LocaleType = {
|
|||||||
Title: "Font Size",
|
Title: "Font Size",
|
||||||
SubTitle: "Adjust font size of chat content",
|
SubTitle: "Adjust font size of chat content",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
InputTemplate: {
|
||||||
|
Title: "Input Template",
|
||||||
|
SubTitle: "Newest message will be filled to this template",
|
||||||
|
},
|
||||||
|
|
||||||
Update: {
|
Update: {
|
||||||
Version: (x: string) => `Version: ${x}`,
|
Version: (x: string) => `Version: ${x}`,
|
||||||
IsLatest: "Latest version",
|
IsLatest: "Latest version",
|
||||||
|
@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
|
|||||||
|
|
||||||
import { trimTopic } from "../utils";
|
import { trimTopic } from "../utils";
|
||||||
|
|
||||||
import Locale from "../locales";
|
import Locale, { getLang } from "../locales";
|
||||||
import { showToast } from "../components/ui-lib";
|
import { showToast } from "../components/ui-lib";
|
||||||
import { ModelType } from "./config";
|
import { ModelConfig, ModelType } from "./config";
|
||||||
import { createEmptyMask, Mask } from "./mask";
|
import { createEmptyMask, Mask } from "./mask";
|
||||||
import { StoreKey } from "../constant";
|
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
|
||||||
import { api, RequestMessage } from "../client/api";
|
import { api, RequestMessage } from "../client/api";
|
||||||
import { ChatControllerPool } from "../client/controller";
|
import { ChatControllerPool } from "../client/controller";
|
||||||
import { prettyObject } from "../utils/format";
|
import { prettyObject } from "../utils/format";
|
||||||
@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
|
|||||||
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
|
||||||
|
const vars = {
|
||||||
|
model: modelConfig.model,
|
||||||
|
time: new Date().toLocaleString(),
|
||||||
|
lang: getLang(),
|
||||||
|
input: input,
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
|
||||||
|
|
||||||
|
// must contains {{input}}
|
||||||
|
const inputVar = "{{input}}";
|
||||||
|
if (!output.includes(inputVar)) {
|
||||||
|
output += "\n" + inputVar;
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.entries(vars).forEach(([name, value]) => {
|
||||||
|
output = output.replaceAll(`{{${name}}}`, value);
|
||||||
|
});
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
export const useChatStore = create<ChatStore>()(
|
export const useChatStore = create<ChatStore>()(
|
||||||
persist(
|
persist(
|
||||||
(set, get) => ({
|
(set, get) => ({
|
||||||
@ -238,9 +261,12 @@ export const useChatStore = create<ChatStore>()(
|
|||||||
const session = get().currentSession();
|
const session = get().currentSession();
|
||||||
const modelConfig = session.mask.modelConfig;
|
const modelConfig = session.mask.modelConfig;
|
||||||
|
|
||||||
|
const userContent = fillTemplateWith(content, modelConfig);
|
||||||
|
console.log("[User Input] fill with template: ", userContent);
|
||||||
|
|
||||||
const userMessage: ChatMessage = createMessage({
|
const userMessage: ChatMessage = createMessage({
|
||||||
role: "user",
|
role: "user",
|
||||||
content,
|
content: userContent,
|
||||||
});
|
});
|
||||||
|
|
||||||
const botMessage: ChatMessage = createMessage({
|
const botMessage: ChatMessage = createMessage({
|
||||||
@ -250,31 +276,22 @@ export const useChatStore = create<ChatStore>()(
|
|||||||
model: modelConfig.model,
|
model: modelConfig.model,
|
||||||
});
|
});
|
||||||
|
|
||||||
const systemInfo = createMessage({
|
|
||||||
role: "system",
|
|
||||||
content: `IMPORTANT: You are a virtual assistant powered by the ${
|
|
||||||
modelConfig.model
|
|
||||||
} model, now time is ${new Date().toLocaleString()}}`,
|
|
||||||
id: botMessage.id! + 1,
|
|
||||||
});
|
|
||||||
|
|
||||||
// get recent messages
|
// get recent messages
|
||||||
const systemMessages = [];
|
|
||||||
// if user define a mask with context prompts, wont send system info
|
|
||||||
if (session.mask.context.length === 0) {
|
|
||||||
systemMessages.push(systemInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
const recentMessages = get().getMessagesWithMemory();
|
const recentMessages = get().getMessagesWithMemory();
|
||||||
const sendMessages = systemMessages.concat(
|
const sendMessages = recentMessages.concat(userMessage);
|
||||||
recentMessages.concat(userMessage),
|
|
||||||
);
|
|
||||||
const sessionIndex = get().currentSessionIndex;
|
const sessionIndex = get().currentSessionIndex;
|
||||||
const messageIndex = get().currentSession().messages.length + 1;
|
const messageIndex = get().currentSession().messages.length + 1;
|
||||||
|
|
||||||
// save user's and bot's message
|
// save user's and bot's message
|
||||||
get().updateCurrentSession((session) => {
|
get().updateCurrentSession((session) => {
|
||||||
session.messages = session.messages.concat([userMessage, botMessage]);
|
const savedUserMessage = {
|
||||||
|
...userMessage,
|
||||||
|
content,
|
||||||
|
};
|
||||||
|
session.messages = session.messages.concat([
|
||||||
|
savedUserMessage,
|
||||||
|
botMessage,
|
||||||
|
]);
|
||||||
});
|
});
|
||||||
|
|
||||||
// make request
|
// make request
|
||||||
@ -350,55 +367,62 @@ export const useChatStore = create<ChatStore>()(
|
|||||||
getMessagesWithMemory() {
|
getMessagesWithMemory() {
|
||||||
const session = get().currentSession();
|
const session = get().currentSession();
|
||||||
const modelConfig = session.mask.modelConfig;
|
const modelConfig = session.mask.modelConfig;
|
||||||
|
const clearContextIndex = session.clearContextIndex ?? 0;
|
||||||
|
const messages = session.messages.slice();
|
||||||
|
const totalMessageCount = session.messages.length;
|
||||||
|
|
||||||
// wont send cleared context messages
|
// in-context prompts
|
||||||
const clearedContextMessages = session.messages.slice(
|
const contextPrompts = session.mask.context.slice();
|
||||||
session.clearContextIndex ?? 0,
|
|
||||||
);
|
|
||||||
const messages = clearedContextMessages.filter((msg) => !msg.isError);
|
|
||||||
const n = messages.length;
|
|
||||||
|
|
||||||
const context = session.mask.context.slice();
|
|
||||||
|
|
||||||
// long term memory
|
// long term memory
|
||||||
if (
|
const shouldSendLongTermMemory =
|
||||||
modelConfig.sendMemory &&
|
modelConfig.sendMemory &&
|
||||||
session.memoryPrompt &&
|
session.memoryPrompt &&
|
||||||
session.memoryPrompt.length > 0
|
session.memoryPrompt.length > 0 &&
|
||||||
) {
|
session.lastSummarizeIndex <= clearContextIndex;
|
||||||
const memoryPrompt = get().getMemoryPrompt();
|
const longTermMemoryPrompts = shouldSendLongTermMemory
|
||||||
context.push(memoryPrompt);
|
? [get().getMemoryPrompt()]
|
||||||
}
|
: [];
|
||||||
|
const longTermMemoryStartIndex = session.lastSummarizeIndex;
|
||||||
|
|
||||||
// get short term and unmemorized long term memory
|
// short term memory
|
||||||
const shortTermMemoryMessageIndex = Math.max(
|
const shortTermMemoryStartIndex = Math.max(
|
||||||
0,
|
0,
|
||||||
n - modelConfig.historyMessageCount,
|
totalMessageCount - modelConfig.historyMessageCount,
|
||||||
);
|
);
|
||||||
const longTermMemoryMessageIndex = session.lastSummarizeIndex;
|
|
||||||
|
|
||||||
// try to concat history messages
|
// lets concat send messages, including 4 parts:
|
||||||
|
// 1. long term memory: summarized memory messages
|
||||||
|
// 2. pre-defined in-context prompts
|
||||||
|
// 3. short term memory: latest n messages
|
||||||
|
// 4. newest input message
|
||||||
const memoryStartIndex = Math.min(
|
const memoryStartIndex = Math.min(
|
||||||
shortTermMemoryMessageIndex,
|
longTermMemoryStartIndex,
|
||||||
longTermMemoryMessageIndex,
|
shortTermMemoryStartIndex,
|
||||||
);
|
);
|
||||||
const threshold = modelConfig.max_tokens;
|
// and if user has cleared history messages, we should exclude the memory too.
|
||||||
|
const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
|
||||||
|
const maxTokenThreshold = modelConfig.max_tokens;
|
||||||
|
|
||||||
// get recent messages as many as possible
|
// get recent messages as much as possible
|
||||||
const reversedRecentMessages = [];
|
const reversedRecentMessages = [];
|
||||||
for (
|
for (
|
||||||
let i = n - 1, count = 0;
|
let i = totalMessageCount - 1, tokenCount = 0;
|
||||||
i >= memoryStartIndex && count < threshold;
|
i >= contextStartIndex && tokenCount < maxTokenThreshold;
|
||||||
i -= 1
|
i -= 1
|
||||||
) {
|
) {
|
||||||
const msg = messages[i];
|
const msg = messages[i];
|
||||||
if (!msg || msg.isError) continue;
|
if (!msg || msg.isError) continue;
|
||||||
count += estimateTokenLength(msg.content);
|
tokenCount += estimateTokenLength(msg.content);
|
||||||
reversedRecentMessages.push(msg);
|
reversedRecentMessages.push(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// concat
|
// concat all messages
|
||||||
const recentMessages = context.concat(reversedRecentMessages.reverse());
|
const recentMessages = [
|
||||||
|
...longTermMemoryPrompts,
|
||||||
|
...contextPrompts,
|
||||||
|
...reversedRecentMessages.reverse(),
|
||||||
|
];
|
||||||
|
|
||||||
return recentMessages;
|
return recentMessages;
|
||||||
},
|
},
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { create } from "zustand";
|
import { create } from "zustand";
|
||||||
import { persist } from "zustand/middleware";
|
import { persist } from "zustand/middleware";
|
||||||
import { getClientConfig } from "../config/client";
|
import { getClientConfig } from "../config/client";
|
||||||
import { StoreKey } from "../constant";
|
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
|
||||||
|
|
||||||
export enum SubmitKey {
|
export enum SubmitKey {
|
||||||
Enter = "Enter",
|
Enter = "Enter",
|
||||||
@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = {
|
|||||||
sendMemory: true,
|
sendMemory: true,
|
||||||
historyMessageCount: 4,
|
historyMessageCount: 4,
|
||||||
compressMessageLengthThreshold: 1000,
|
compressMessageLengthThreshold: 1000,
|
||||||
|
template: DEFAULT_INPUT_TEMPLATE,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -176,15 +177,16 @@ export const useAppConfig = create<ChatConfigStore>()(
|
|||||||
}),
|
}),
|
||||||
{
|
{
|
||||||
name: StoreKey.Config,
|
name: StoreKey.Config,
|
||||||
version: 3,
|
version: 3.1,
|
||||||
migrate(persistedState, version) {
|
migrate(persistedState, version) {
|
||||||
if (version === 3) return persistedState as any;
|
if (version === 3.1) return persistedState as any;
|
||||||
|
|
||||||
const state = persistedState as ChatConfig;
|
const state = persistedState as ChatConfig;
|
||||||
state.modelConfig.sendMemory = true;
|
state.modelConfig.sendMemory = true;
|
||||||
state.modelConfig.historyMessageCount = 4;
|
state.modelConfig.historyMessageCount = 4;
|
||||||
state.modelConfig.compressMessageLengthThreshold = 1000;
|
state.modelConfig.compressMessageLengthThreshold = 1000;
|
||||||
state.modelConfig.frequency_penalty = 0;
|
state.modelConfig.frequency_penalty = 0;
|
||||||
|
state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
|
||||||
state.dontShowMaskSplashScreen = false;
|
state.dontShowMaskSplashScreen = false;
|
||||||
|
|
||||||
return state;
|
return state;
|
||||||
|
Loading…
Reference in New Issue
Block a user