Merge pull request #2109 from Yidadaa/bugfix-0623

feat: close #1789 add user input template
This commit is contained in:
Yifei Zhang 2023-06-24 00:20:05 +08:00 committed by GitHub
commit 15ce114440
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 143 additions and 62 deletions

View File

@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
import Locale from "../locales"; import Locale from "../locales";
import { InputRange } from "./input-range"; import { InputRange } from "./input-range";
import { List, ListItem, Select } from "./ui-lib"; import { ListItem, Select } from "./ui-lib";
export function ModelConfigList(props: { export function ModelConfigList(props: {
modelConfig: ModelConfig; modelConfig: ModelConfig;
@ -109,6 +109,21 @@ export function ModelConfigList(props: {
></InputRange> ></InputRange>
</ListItem> </ListItem>
<ListItem
title={Locale.Settings.InputTemplate.Title}
subTitle={Locale.Settings.InputTemplate.SubTitle}
>
<input
type="text"
value={props.modelConfig.template}
onChange={(e) =>
props.updateConfig(
(config) => (config.template = e.currentTarget.value),
)
}
></input>
</ListItem>
<ListItem <ListItem
title={Locale.Settings.HistoryCount.Title} title={Locale.Settings.HistoryCount.Title}
subTitle={Locale.Settings.HistoryCount.SubTitle} subTitle={Locale.Settings.HistoryCount.SubTitle}

View File

@ -52,3 +52,10 @@ export const OpenaiPath = {
UsagePath: "dashboard/billing/usage", UsagePath: "dashboard/billing/usage",
SubsPath: "dashboard/billing/subscription", SubsPath: "dashboard/billing/subscription",
}; };
export const DEFAULT_INPUT_TEMPLATE = `
Act as a virtual assistant powered by model: '{{model}}', my input is:
'''
{{input}}
'''
`;

View File

@ -115,6 +115,11 @@ const cn = {
SubTitle: "聊天内容的字体大小", SubTitle: "聊天内容的字体大小",
}, },
InputTemplate: {
Title: "用户输入预处理",
SubTitle: "用户最新的一条消息会填充到此模板",
},
Update: { Update: {
Version: (x: string) => `当前版本:${x}`, Version: (x: string) => `当前版本:${x}`,
IsLatest: "已是最新版本", IsLatest: "已是最新版本",

View File

@ -116,6 +116,12 @@ const en: LocaleType = {
Title: "Font Size", Title: "Font Size",
SubTitle: "Adjust font size of chat content", SubTitle: "Adjust font size of chat content",
}, },
InputTemplate: {
Title: "Input Template",
SubTitle: "Newest message will be filled to this template",
},
Update: { Update: {
Version: (x: string) => `Version: ${x}`, Version: (x: string) => `Version: ${x}`,
IsLatest: "Latest version", IsLatest: "Latest version",

View File

@ -9,7 +9,7 @@ export const BUILTIN_MASK_ID = 100000;
export const BUILTIN_MASK_STORE = { export const BUILTIN_MASK_STORE = {
buildinId: BUILTIN_MASK_ID, buildinId: BUILTIN_MASK_ID,
masks: {} as Record<number, Mask>, masks: {} as Record<number, BuiltinMask>,
get(id?: number) { get(id?: number) {
if (!id) return undefined; if (!id) return undefined;
return this.masks[id] as Mask | undefined; return this.masks[id] as Mask | undefined;
@ -21,6 +21,6 @@ export const BUILTIN_MASK_STORE = {
}, },
}; };
export const BUILTIN_MASKS: Mask[] = [...CN_MASKS, ...EN_MASKS].map((m) => export const BUILTIN_MASKS: BuiltinMask[] = [...CN_MASKS, ...EN_MASKS].map(
BUILTIN_MASK_STORE.add(m), (m) => BUILTIN_MASK_STORE.add(m),
); );

View File

@ -1,5 +1,7 @@
import { ModelConfig } from "../store";
import { type Mask } from "../store/mask"; import { type Mask } from "../store/mask";
export type BuiltinMask = Omit<Mask, "id"> & { export type BuiltinMask = Omit<Mask, "id" | "modelConfig"> & {
builtin: true; builtin: Boolean;
modelConfig: Partial<ModelConfig>;
}; };

View File

@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
import { trimTopic } from "../utils"; import { trimTopic } from "../utils";
import Locale from "../locales"; import Locale, { getLang } from "../locales";
import { showToast } from "../components/ui-lib"; import { showToast } from "../components/ui-lib";
import { ModelType } from "./config"; import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask"; import { createEmptyMask, Mask } from "./mask";
import { StoreKey } from "../constant"; import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
import { api, RequestMessage } from "../client/api"; import { api, RequestMessage } from "../client/api";
import { ChatControllerPool } from "../client/controller"; import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format"; import { prettyObject } from "../utils/format";
@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
} }
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
const vars = {
model: modelConfig.model,
time: new Date().toLocaleString(),
lang: getLang(),
input: input,
};
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
// must contains {{input}}
const inputVar = "{{input}}";
if (!output.includes(inputVar)) {
output += "\n" + inputVar;
}
Object.entries(vars).forEach(([name, value]) => {
output = output.replaceAll(`{{${name}}}`, value);
});
return output;
}
export const useChatStore = create<ChatStore>()( export const useChatStore = create<ChatStore>()(
persist( persist(
(set, get) => ({ (set, get) => ({
@ -158,7 +181,16 @@ export const useChatStore = create<ChatStore>()(
session.id = get().globalId; session.id = get().globalId;
if (mask) { if (mask) {
session.mask = { ...mask }; const config = useAppConfig.getState();
const globalModelConfig = config.modelConfig;
session.mask = {
...mask,
modelConfig: {
...globalModelConfig,
...mask.modelConfig,
},
};
session.topic = mask.name; session.topic = mask.name;
} }
@ -238,9 +270,12 @@ export const useChatStore = create<ChatStore>()(
const session = get().currentSession(); const session = get().currentSession();
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
const userContent = fillTemplateWith(content, modelConfig);
console.log("[User Input] fill with template: ", userContent);
const userMessage: ChatMessage = createMessage({ const userMessage: ChatMessage = createMessage({
role: "user", role: "user",
content, content: userContent,
}); });
const botMessage: ChatMessage = createMessage({ const botMessage: ChatMessage = createMessage({
@ -250,31 +285,22 @@ export const useChatStore = create<ChatStore>()(
model: modelConfig.model, model: modelConfig.model,
}); });
const systemInfo = createMessage({
role: "system",
content: `IMPORTANT: You are a virtual assistant powered by the ${
modelConfig.model
} model, now time is ${new Date().toLocaleString()}}`,
id: botMessage.id! + 1,
});
// get recent messages // get recent messages
const systemMessages = [];
// if user define a mask with context prompts, wont send system info
if (session.mask.context.length === 0) {
systemMessages.push(systemInfo);
}
const recentMessages = get().getMessagesWithMemory(); const recentMessages = get().getMessagesWithMemory();
const sendMessages = systemMessages.concat( const sendMessages = recentMessages.concat(userMessage);
recentMessages.concat(userMessage),
);
const sessionIndex = get().currentSessionIndex; const sessionIndex = get().currentSessionIndex;
const messageIndex = get().currentSession().messages.length + 1; const messageIndex = get().currentSession().messages.length + 1;
// save user's and bot's message // save user's and bot's message
get().updateCurrentSession((session) => { get().updateCurrentSession((session) => {
session.messages = session.messages.concat([userMessage, botMessage]); const savedUserMessage = {
...userMessage,
content,
};
session.messages = session.messages.concat([
savedUserMessage,
botMessage,
]);
}); });
// make request // make request
@ -350,55 +376,62 @@ export const useChatStore = create<ChatStore>()(
getMessagesWithMemory() { getMessagesWithMemory() {
const session = get().currentSession(); const session = get().currentSession();
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
const clearContextIndex = session.clearContextIndex ?? 0;
const messages = session.messages.slice();
const totalMessageCount = session.messages.length;
// wont send cleared context messages // in-context prompts
const clearedContextMessages = session.messages.slice( const contextPrompts = session.mask.context.slice();
session.clearContextIndex ?? 0,
);
const messages = clearedContextMessages.filter((msg) => !msg.isError);
const n = messages.length;
const context = session.mask.context.slice();
// long term memory // long term memory
if ( const shouldSendLongTermMemory =
modelConfig.sendMemory && modelConfig.sendMemory &&
session.memoryPrompt && session.memoryPrompt &&
session.memoryPrompt.length > 0 session.memoryPrompt.length > 0 &&
) { session.lastSummarizeIndex <= clearContextIndex;
const memoryPrompt = get().getMemoryPrompt(); const longTermMemoryPrompts = shouldSendLongTermMemory
context.push(memoryPrompt); ? [get().getMemoryPrompt()]
} : [];
const longTermMemoryStartIndex = session.lastSummarizeIndex;
// get short term and unmemorized long term memory // short term memory
const shortTermMemoryMessageIndex = Math.max( const shortTermMemoryStartIndex = Math.max(
0, 0,
n - modelConfig.historyMessageCount, totalMessageCount - modelConfig.historyMessageCount,
); );
const longTermMemoryMessageIndex = session.lastSummarizeIndex;
// try to concat history messages // lets concat send messages, including 4 parts:
// 1. long term memory: summarized memory messages
// 2. pre-defined in-context prompts
// 3. short term memory: latest n messages
// 4. newest input message
const memoryStartIndex = Math.min( const memoryStartIndex = Math.min(
shortTermMemoryMessageIndex, longTermMemoryStartIndex,
longTermMemoryMessageIndex, shortTermMemoryStartIndex,
); );
const threshold = modelConfig.max_tokens; // and if user has cleared history messages, we should exclude the memory too.
const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
const maxTokenThreshold = modelConfig.max_tokens;
// get recent messages as many as possible // get recent messages as much as possible
const reversedRecentMessages = []; const reversedRecentMessages = [];
for ( for (
let i = n - 1, count = 0; let i = totalMessageCount - 1, tokenCount = 0;
i >= memoryStartIndex && count < threshold; i >= contextStartIndex && tokenCount < maxTokenThreshold;
i -= 1 i -= 1
) { ) {
const msg = messages[i]; const msg = messages[i];
if (!msg || msg.isError) continue; if (!msg || msg.isError) continue;
count += estimateTokenLength(msg.content); tokenCount += estimateTokenLength(msg.content);
reversedRecentMessages.push(msg); reversedRecentMessages.push(msg);
} }
// concat // concat all messages
const recentMessages = context.concat(reversedRecentMessages.reverse()); const recentMessages = [
...longTermMemoryPrompts,
...contextPrompts,
...reversedRecentMessages.reverse(),
];
return recentMessages; return recentMessages;
}, },

View File

@ -1,7 +1,7 @@
import { create } from "zustand"; import { create } from "zustand";
import { persist } from "zustand/middleware"; import { persist } from "zustand/middleware";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { StoreKey } from "../constant"; import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
export enum SubmitKey { export enum SubmitKey {
Enter = "Enter", Enter = "Enter",
@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = {
sendMemory: true, sendMemory: true,
historyMessageCount: 4, historyMessageCount: 4,
compressMessageLengthThreshold: 1000, compressMessageLengthThreshold: 1000,
template: DEFAULT_INPUT_TEMPLATE,
}, },
}; };
@ -176,15 +177,16 @@ export const useAppConfig = create<ChatConfigStore>()(
}), }),
{ {
name: StoreKey.Config, name: StoreKey.Config,
version: 3, version: 3.1,
migrate(persistedState, version) { migrate(persistedState, version) {
if (version === 3) return persistedState as any; if (version === 3.1) return persistedState as any;
const state = persistedState as ChatConfig; const state = persistedState as ChatConfig;
state.modelConfig.sendMemory = true; state.modelConfig.sendMemory = true;
state.modelConfig.historyMessageCount = 4; state.modelConfig.historyMessageCount = 4;
state.modelConfig.compressMessageLengthThreshold = 1000; state.modelConfig.compressMessageLengthThreshold = 1000;
state.modelConfig.frequency_penalty = 0; state.modelConfig.frequency_penalty = 0;
state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
state.dontShowMaskSplashScreen = false; state.dontShowMaskSplashScreen = false;
return state; return state;

View File

@ -3,7 +3,7 @@ import { persist } from "zustand/middleware";
import { BUILTIN_MASKS } from "../masks"; import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales"; import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, ChatMessage } from "./chat"; import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, ModelType, useAppConfig } from "./config"; import { ModelConfig, useAppConfig } from "./config";
import { StoreKey } from "../constant"; import { StoreKey } from "../constant";
export type Mask = { export type Mask = {
@ -89,7 +89,18 @@ export const useMaskStore = create<MaskStore>()(
const userMasks = Object.values(get().masks).sort( const userMasks = Object.values(get().masks).sort(
(a, b) => b.id - a.id, (a, b) => b.id - a.id,
); );
return userMasks.concat(BUILTIN_MASKS); const config = useAppConfig.getState();
const buildinMasks = BUILTIN_MASKS.map(
(m) =>
({
...m,
modelConfig: {
...config.modelConfig,
...m.modelConfig,
},
} as Mask),
);
return userMasks.concat(buildinMasks);
}, },
search(text) { search(text) {
return Object.values(get().masks); return Object.values(get().masks);