fix: #1124 mask model config does not works

This commit is contained in:
Yidadaa 2023-05-01 23:37:02 +08:00
parent b2fc7d476a
commit 9f3188fe45
3 changed files with 26 additions and 33 deletions

View File

@ -14,9 +14,8 @@ const TIME_OUT_MS = 60000;
const makeRequestParam = (
messages: Message[],
options?: {
filterBot?: boolean;
stream?: boolean;
model?: ModelType;
overrideModel?: ModelType;
},
): ChatRequest => {
let sendMessages = messages.map((v) => ({
@ -24,18 +23,14 @@ const makeRequestParam = (
content: v.content,
}));
if (options?.filterBot) {
sendMessages = sendMessages.filter((m) => m.role !== "assistant");
}
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
};
// override model config
if (options?.model) {
modelConfig.model = options.model;
if (options?.overrideModel) {
modelConfig.model = options.overrideModel;
}
return {
@ -82,8 +77,7 @@ export async function requestChat(
},
) {
const req: ChatRequest = makeRequestParam(messages, {
filterBot: true,
model: options?.model,
overrideModel: options?.model,
});
const res = await requestOpenaiClient("v1/chat/completions")(req);
@ -149,9 +143,8 @@ export async function requestUsage() {
export async function requestChatStream(
messages: Message[],
options?: {
filterBot?: boolean;
modelConfig?: ModelConfig;
model?: ModelType;
overrideModel?: ModelType;
onMessage: (message: string, done: boolean) => void;
onError: (error: Error, statusCode?: number) => void;
onController?: (controller: AbortController) => void;
@ -159,8 +152,7 @@ export async function requestChatStream(
) {
const req = makeRequestParam(messages, {
stream: true,
filterBot: options?.filterBot,
model: options?.model,
overrideModel: options?.overrideModel,
});
console.log("[Request] ", req);

View File

@ -236,6 +236,9 @@ export const useChatStore = create<ChatStore>()(
},
async onUserInput(content) {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const userMessage: Message = createMessage({
role: "user",
content,
@ -245,7 +248,7 @@ export const useChatStore = create<ChatStore>()(
role: "assistant",
streaming: true,
id: userMessage.id! + 1,
model: useAppConfig.getState().modelConfig.model,
model: modelConfig.model,
});
// get recent messages
@ -279,14 +282,16 @@ export const useChatStore = create<ChatStore>()(
}
},
onError(error, statusCode) {
const isAborted = error.message.includes("aborted");
if (statusCode === 401) {
botMessage.content = Locale.Error.Unauthorized;
} else if (!error.message.includes("aborted")) {
} else if (!isAborted) {
botMessage.content += "\n\n" + Locale.Store.Error;
}
botMessage.streaming = false;
userMessage.isError = true;
botMessage.isError = true;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
set(() => ({}));
ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
},
@ -298,8 +303,7 @@ export const useChatStore = create<ChatStore>()(
controller,
);
},
filterBot: !useAppConfig.getState().sendBotMessages,
modelConfig: useAppConfig.getState().modelConfig,
modelConfig: { ...modelConfig },
});
},
@ -318,7 +322,7 @@ export const useChatStore = create<ChatStore>()(
getMessagesWithMemory() {
const session = get().currentSession();
const config = useAppConfig.getState();
const modelConfig = session.mask.modelConfig;
const messages = session.messages.filter((msg) => !msg.isError);
const n = messages.length;
@ -326,7 +330,7 @@ export const useChatStore = create<ChatStore>()(
// long term memory
if (
session.mask.modelConfig.sendMemory &&
modelConfig.sendMemory &&
session.memoryPrompt &&
session.memoryPrompt.length > 0
) {
@ -337,14 +341,14 @@ export const useChatStore = create<ChatStore>()(
// get short term and unmemoried long term memory
const shortTermMemoryMessageIndex = Math.max(
0,
n - config.modelConfig.historyMessageCount,
n - modelConfig.historyMessageCount,
);
const longTermMemoryMessageIndex = session.lastSummarizeIndex;
const oldestIndex = Math.max(
shortTermMemoryMessageIndex,
longTermMemoryMessageIndex,
);
const threshold = config.modelConfig.compressMessageLengthThreshold;
const threshold = modelConfig.compressMessageLengthThreshold;
// get recent messages as many as possible
const reversedRecentMessages = [];
@ -403,17 +407,17 @@ export const useChatStore = create<ChatStore>()(
});
}
const config = useAppConfig.getState();
const modelConfig = session.mask.modelConfig;
let toBeSummarizedMsgs = session.messages.slice(
session.lastSummarizeIndex,
);
const historyMsgLength = countMessages(toBeSummarizedMsgs);
if (historyMsgLength > config?.modelConfig?.max_tokens ?? 4000) {
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
Math.max(0, n - config.modelConfig.historyMessageCount),
Math.max(0, n - modelConfig.historyMessageCount),
);
}
@ -426,12 +430,11 @@ export const useChatStore = create<ChatStore>()(
"[Chat History] ",
toBeSummarizedMsgs,
historyMsgLength,
config.modelConfig.compressMessageLengthThreshold,
modelConfig.compressMessageLengthThreshold,
);
if (
historyMsgLength >
config.modelConfig.compressMessageLengthThreshold &&
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
session.mask.modelConfig.sendMemory
) {
requestChatStream(
@ -441,8 +444,7 @@ export const useChatStore = create<ChatStore>()(
date: "",
}),
{
filterBot: false,
model: "gpt-3.5-turbo",
overrideModel: "gpt-3.5-turbo",
onMessage(message, done) {
session.memoryPrompt = message;
if (done) {

View File

@ -17,7 +17,6 @@ export enum Theme {
}
export const DEFAULT_CONFIG = {
sendBotMessages: true as boolean,
submitKey: SubmitKey.CtrlEnter as SubmitKey,
avatar: "1f603",
fontSize: 14,