This commit is contained in:
Yidadaa 2023-06-24 00:18:27 +08:00
parent be597a551d
commit 1722f75dcb
4 changed files with 31 additions and 9 deletions

View File

@ -9,7 +9,7 @@ export const BUILTIN_MASK_ID = 100000;
export const BUILTIN_MASK_STORE = { export const BUILTIN_MASK_STORE = {
buildinId: BUILTIN_MASK_ID, buildinId: BUILTIN_MASK_ID,
masks: {} as Record<number, Mask>, masks: {} as Record<number, BuiltinMask>,
get(id?: number) { get(id?: number) {
if (!id) return undefined; if (!id) return undefined;
return this.masks[id] as Mask | undefined; return this.masks[id] as Mask | undefined;
@ -21,6 +21,6 @@ export const BUILTIN_MASK_STORE = {
}, },
}; };
export const BUILTIN_MASKS: Mask[] = [...CN_MASKS, ...EN_MASKS].map((m) => export const BUILTIN_MASKS: BuiltinMask[] = [...CN_MASKS, ...EN_MASKS].map(
BUILTIN_MASK_STORE.add(m), (m) => BUILTIN_MASK_STORE.add(m),
); );

View File

@ -1,5 +1,7 @@
import { ModelConfig } from "../store";
import { type Mask } from "../store/mask"; import { type Mask } from "../store/mask";
export type BuiltinMask = Omit<Mask, "id"> & { export type BuiltinMask = Omit<Mask, "id" | "modelConfig"> & {
builtin: true; builtin: Boolean;
modelConfig: Partial<ModelConfig>;
}; };

View File

@ -5,7 +5,7 @@ import { trimTopic } from "../utils";
import Locale, { getLang } from "../locales"; import Locale, { getLang } from "../locales";
import { showToast } from "../components/ui-lib"; import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType } from "./config"; import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask"; import { createEmptyMask, Mask } from "./mask";
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant"; import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
import { api, RequestMessage } from "../client/api"; import { api, RequestMessage } from "../client/api";
@ -181,7 +181,16 @@ export const useChatStore = create<ChatStore>()(
session.id = get().globalId; session.id = get().globalId;
if (mask) { if (mask) {
session.mask = { ...mask }; const config = useAppConfig.getState();
const globalModelConfig = config.modelConfig;
session.mask = {
...mask,
modelConfig: {
...globalModelConfig,
...mask.modelConfig,
},
};
session.topic = mask.name; session.topic = mask.name;
} }

View File

@ -3,7 +3,7 @@ import { persist } from "zustand/middleware";
import { BUILTIN_MASKS } from "../masks"; import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales"; import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, ChatMessage } from "./chat"; import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, ModelType, useAppConfig } from "./config"; import { ModelConfig, useAppConfig } from "./config";
import { StoreKey } from "../constant"; import { StoreKey } from "../constant";
export type Mask = { export type Mask = {
@ -89,7 +89,18 @@ export const useMaskStore = create<MaskStore>()(
const userMasks = Object.values(get().masks).sort( const userMasks = Object.values(get().masks).sort(
(a, b) => b.id - a.id, (a, b) => b.id - a.id,
); );
return userMasks.concat(BUILTIN_MASKS); const config = useAppConfig.getState();
const buildinMasks = BUILTIN_MASKS.map(
(m) =>
({
...m,
modelConfig: {
...config.modelConfig,
...m.modelConfig,
},
} as Mask),
);
return userMasks.concat(buildinMasks);
}, },
search(text) { search(text) {
return Object.values(get().masks); return Object.values(get().masks);