fix: #1124 mask model config does not works

This commit is contained in:
Yidadaa 2023-05-01 23:37:02 +08:00
parent b2fc7d476a
commit 9f3188fe45
3 changed files with 26 additions and 33 deletions

View File

@ -14,9 +14,8 @@ const TIME_OUT_MS = 60000;
const makeRequestParam = ( const makeRequestParam = (
messages: Message[], messages: Message[],
options?: { options?: {
filterBot?: boolean;
stream?: boolean; stream?: boolean;
model?: ModelType; overrideModel?: ModelType;
}, },
): ChatRequest => { ): ChatRequest => {
let sendMessages = messages.map((v) => ({ let sendMessages = messages.map((v) => ({
@ -24,18 +23,14 @@ const makeRequestParam = (
content: v.content, content: v.content,
})); }));
if (options?.filterBot) {
sendMessages = sendMessages.filter((m) => m.role !== "assistant");
}
const modelConfig = { const modelConfig = {
...useAppConfig.getState().modelConfig, ...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig,
}; };
// override model config // override model config
if (options?.model) { if (options?.overrideModel) {
modelConfig.model = options.model; modelConfig.model = options.overrideModel;
} }
return { return {
@ -82,8 +77,7 @@ export async function requestChat(
}, },
) { ) {
const req: ChatRequest = makeRequestParam(messages, { const req: ChatRequest = makeRequestParam(messages, {
filterBot: true, overrideModel: options?.model,
model: options?.model,
}); });
const res = await requestOpenaiClient("v1/chat/completions")(req); const res = await requestOpenaiClient("v1/chat/completions")(req);
@ -149,9 +143,8 @@ export async function requestUsage() {
export async function requestChatStream( export async function requestChatStream(
messages: Message[], messages: Message[],
options?: { options?: {
filterBot?: boolean;
modelConfig?: ModelConfig; modelConfig?: ModelConfig;
model?: ModelType; overrideModel?: ModelType;
onMessage: (message: string, done: boolean) => void; onMessage: (message: string, done: boolean) => void;
onError: (error: Error, statusCode?: number) => void; onError: (error: Error, statusCode?: number) => void;
onController?: (controller: AbortController) => void; onController?: (controller: AbortController) => void;
@ -159,8 +152,7 @@ export async function requestChatStream(
) { ) {
const req = makeRequestParam(messages, { const req = makeRequestParam(messages, {
stream: true, stream: true,
filterBot: options?.filterBot, overrideModel: options?.overrideModel,
model: options?.model,
}); });
console.log("[Request] ", req); console.log("[Request] ", req);

View File

@ -236,6 +236,9 @@ export const useChatStore = create<ChatStore>()(
}, },
async onUserInput(content) { async onUserInput(content) {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const userMessage: Message = createMessage({ const userMessage: Message = createMessage({
role: "user", role: "user",
content, content,
@ -245,7 +248,7 @@ export const useChatStore = create<ChatStore>()(
role: "assistant", role: "assistant",
streaming: true, streaming: true,
id: userMessage.id! + 1, id: userMessage.id! + 1,
model: useAppConfig.getState().modelConfig.model, model: modelConfig.model,
}); });
// get recent messages // get recent messages
@ -279,14 +282,16 @@ export const useChatStore = create<ChatStore>()(
} }
}, },
onError(error, statusCode) { onError(error, statusCode) {
const isAborted = error.message.includes("aborted");
if (statusCode === 401) { if (statusCode === 401) {
botMessage.content = Locale.Error.Unauthorized; botMessage.content = Locale.Error.Unauthorized;
} else if (!error.message.includes("aborted")) { } else if (!isAborted) {
botMessage.content += "\n\n" + Locale.Store.Error; botMessage.content += "\n\n" + Locale.Store.Error;
} }
botMessage.streaming = false; botMessage.streaming = false;
userMessage.isError = true; userMessage.isError = !isAborted;
botMessage.isError = true; botMessage.isError = !isAborted;
set(() => ({})); set(() => ({}));
ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex); ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
}, },
@ -298,8 +303,7 @@ export const useChatStore = create<ChatStore>()(
controller, controller,
); );
}, },
filterBot: !useAppConfig.getState().sendBotMessages, modelConfig: { ...modelConfig },
modelConfig: useAppConfig.getState().modelConfig,
}); });
}, },
@ -318,7 +322,7 @@ export const useChatStore = create<ChatStore>()(
getMessagesWithMemory() { getMessagesWithMemory() {
const session = get().currentSession(); const session = get().currentSession();
const config = useAppConfig.getState(); const modelConfig = session.mask.modelConfig;
const messages = session.messages.filter((msg) => !msg.isError); const messages = session.messages.filter((msg) => !msg.isError);
const n = messages.length; const n = messages.length;
@ -326,7 +330,7 @@ export const useChatStore = create<ChatStore>()(
// long term memory // long term memory
if ( if (
session.mask.modelConfig.sendMemory && modelConfig.sendMemory &&
session.memoryPrompt && session.memoryPrompt &&
session.memoryPrompt.length > 0 session.memoryPrompt.length > 0
) { ) {
@ -337,14 +341,14 @@ export const useChatStore = create<ChatStore>()(
// get short term and unmemoried long term memory // get short term and unmemoried long term memory
const shortTermMemoryMessageIndex = Math.max( const shortTermMemoryMessageIndex = Math.max(
0, 0,
n - config.modelConfig.historyMessageCount, n - modelConfig.historyMessageCount,
); );
const longTermMemoryMessageIndex = session.lastSummarizeIndex; const longTermMemoryMessageIndex = session.lastSummarizeIndex;
const oldestIndex = Math.max( const oldestIndex = Math.max(
shortTermMemoryMessageIndex, shortTermMemoryMessageIndex,
longTermMemoryMessageIndex, longTermMemoryMessageIndex,
); );
const threshold = config.modelConfig.compressMessageLengthThreshold; const threshold = modelConfig.compressMessageLengthThreshold;
// get recent messages as many as possible // get recent messages as many as possible
const reversedRecentMessages = []; const reversedRecentMessages = [];
@ -403,17 +407,17 @@ export const useChatStore = create<ChatStore>()(
}); });
} }
const config = useAppConfig.getState(); const modelConfig = session.mask.modelConfig;
let toBeSummarizedMsgs = session.messages.slice( let toBeSummarizedMsgs = session.messages.slice(
session.lastSummarizeIndex, session.lastSummarizeIndex,
); );
const historyMsgLength = countMessages(toBeSummarizedMsgs); const historyMsgLength = countMessages(toBeSummarizedMsgs);
if (historyMsgLength > config?.modelConfig?.max_tokens ?? 4000) { if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
const n = toBeSummarizedMsgs.length; const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice( toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
Math.max(0, n - config.modelConfig.historyMessageCount), Math.max(0, n - modelConfig.historyMessageCount),
); );
} }
@ -426,12 +430,11 @@ export const useChatStore = create<ChatStore>()(
"[Chat History] ", "[Chat History] ",
toBeSummarizedMsgs, toBeSummarizedMsgs,
historyMsgLength, historyMsgLength,
config.modelConfig.compressMessageLengthThreshold, modelConfig.compressMessageLengthThreshold,
); );
if ( if (
historyMsgLength > historyMsgLength > modelConfig.compressMessageLengthThreshold &&
config.modelConfig.compressMessageLengthThreshold &&
session.mask.modelConfig.sendMemory session.mask.modelConfig.sendMemory
) { ) {
requestChatStream( requestChatStream(
@ -441,8 +444,7 @@ export const useChatStore = create<ChatStore>()(
date: "", date: "",
}), }),
{ {
filterBot: false, overrideModel: "gpt-3.5-turbo",
model: "gpt-3.5-turbo",
onMessage(message, done) { onMessage(message, done) {
session.memoryPrompt = message; session.memoryPrompt = message;
if (done) { if (done) {

View File

@ -17,7 +17,6 @@ export enum Theme {
} }
export const DEFAULT_CONFIG = { export const DEFAULT_CONFIG = {
sendBotMessages: true as boolean,
submitKey: SubmitKey.CtrlEnter as SubmitKey, submitKey: SubmitKey.CtrlEnter as SubmitKey,
avatar: "1f603", avatar: "1f603",
fontSize: 14, fontSize: 14,