From bd90caa99d1501bbbd75cc722e185e9266973d9b Mon Sep 17 00:00:00 2001 From: Yidadaa Date: Sun, 14 May 2023 23:00:17 +0800 Subject: [PATCH] refactor: llm client api --- app/client/api.ts | 109 +++++++++++++++++++++++++++++ app/client/controller.ts | 37 ++++++++++ app/client/platforms/openai.ts | 124 +++++++++++++++++++++++++++++++++ app/constant.ts | 2 + app/requests.ts | 22 ------ app/store/chat.ts | 1 + package.json | 1 + yarn.lock | 5 ++ 8 files changed, 279 insertions(+), 22 deletions(-) create mode 100644 app/client/api.ts create mode 100644 app/client/controller.ts create mode 100644 app/client/platforms/openai.ts diff --git a/app/client/api.ts b/app/client/api.ts new file mode 100644 index 00000000..103e95e5 --- /dev/null +++ b/app/client/api.ts @@ -0,0 +1,109 @@ +import { fetchEventSource } from "@microsoft/fetch-event-source"; +import { ACCESS_CODE_PREFIX } from "../constant"; +import { ModelType, useAccessStore } from "../store"; +import { ChatGPTApi } from "./platforms/openai"; + +export enum MessageRole { + System = "system", + User = "user", + Assistant = "assistant", +} + +export const Models = ["gpt-3.5-turbo", "gpt-4"] as const; +export type ChatModel = ModelType; + +export interface Message { + role: MessageRole; + content: string; +} + +export interface LLMConfig { + temperature?: number; + topP?: number; + stream?: boolean; + presencePenalty?: number; + frequencyPenalty?: number; +} + +export interface ChatOptions { + messages: Message[]; + model: ChatModel; + config: LLMConfig; + + onUpdate: (message: string, chunk: string) => void; + onFinish: (message: string) => void; + onError: (err: Error) => void; + onUnAuth: () => void; +} + +export interface LLMUsage { + used: number; + total: number; +} + +export abstract class LLMApi { + abstract chat(options: ChatOptions): Promise; + abstract usage(): Promise; +} + +export class ClientApi { + public llm: LLMApi; + + constructor() { + this.llm = new ChatGPTApi(); + } + + headers() { + const accessStore = useAccessStore.getState(); + let headers: Record = {}; + + const makeBearer = (token: string) => `Bearer ${token.trim()}`; + const validString = (x: string) => x && x.length > 0; + + // use user's api key first + if (validString(accessStore.token)) { + headers.Authorization = makeBearer(accessStore.token); + } else if ( + accessStore.enabledAccessControl() && + validString(accessStore.accessCode) + ) { + headers.Authorization = makeBearer( + ACCESS_CODE_PREFIX + accessStore.accessCode, + ); + } + + return headers; + } + + config() {} + + prompts() {} + + masks() {} +} + +export const api = new ClientApi(); + +export function getHeaders() { + const accessStore = useAccessStore.getState(); + let headers: Record = { + "Content-Type": "application/json", + }; + + const makeBearer = (token: string) => `Bearer ${token.trim()}`; + const validString = (x: string) => x && x.length > 0; + + // use user's api key first + if (validString(accessStore.token)) { + headers.Authorization = makeBearer(accessStore.token); + } else if ( + accessStore.enabledAccessControl() && + validString(accessStore.accessCode) + ) { + headers.Authorization = makeBearer( + ACCESS_CODE_PREFIX + accessStore.accessCode, + ); + } + + return headers; +} diff --git a/app/client/controller.ts b/app/client/controller.ts new file mode 100644 index 00000000..86cb99e7 --- /dev/null +++ b/app/client/controller.ts @@ -0,0 +1,37 @@ +// To store message streaming controller +export const ChatControllerPool = { + controllers: {} as Record, + + addController( + sessionIndex: number, + messageId: number, + controller: AbortController, + ) { + const key = this.key(sessionIndex, messageId); + this.controllers[key] = controller; + return key; + }, + + stop(sessionIndex: number, messageId: number) { + const key = this.key(sessionIndex, messageId); + const controller = this.controllers[key]; + controller?.abort(); + }, + + stopAll() { + Object.values(this.controllers).forEach((v) => v.abort()); + }, + + hasPending() { + return Object.values(this.controllers).length > 0; + }, + + remove(sessionIndex: number, messageId: number) { + const key = this.key(sessionIndex, messageId); + delete this.controllers[key]; + }, + + key(sessionIndex: number, messageIndex: number) { + return `${sessionIndex},${messageIndex}`; + }, +}; diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts new file mode 100644 index 00000000..7d4d94da --- /dev/null +++ b/app/client/platforms/openai.ts @@ -0,0 +1,124 @@ +import { REQUEST_TIMEOUT_MS } from "@/app/constant"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { + EventStreamContentType, + fetchEventSource, +} from "@microsoft/fetch-event-source"; +import { ChatOptions, LLMApi, LLMUsage } from "../api"; + +export class ChatGPTApi implements LLMApi { + public ChatPath = "v1/chat/completions"; + + path(path: string): string { + const openaiUrl = useAccessStore.getState().openaiUrl; + if (openaiUrl.endsWith("/")) openaiUrl.slice(0, openaiUrl.length - 1); + return [openaiUrl, path].join("/"); + } + + extractMessage(res: any) { + return res.choices?.at(0)?.message?.content ?? ""; + } + + async chat(options: ChatOptions) { + const messages = options.messages.map((v) => ({ + role: v.role, + content: v.content, + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.model, + }, + }; + + const requestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + }; + + console.log("[Request] openai payload: ", requestPayload); + + const shouldStream = !!options.config.stream; + const controller = new AbortController(); + + try { + const chatPath = this.path(this.ChatPath); + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + }; + + // make a fetch request + const reqestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + if (shouldStream) { + let responseText = ""; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + if ( + res.ok && + res.headers.get("Content-Type") === EventStreamContentType + ) { + return; + } + + if (res.status === 401) { + // TODO: Unauthorized 401 + responseText += "\n\n"; + } else if (res.status !== 200) { + console.error("[Request] response", res); + throw new Error("[Request] server error"); + } + }, + onmessage: (ev) => { + if (ev.data === "[DONE]") { + return options.onFinish(responseText); + } + try { + const resJson = JSON.parse(ev.data); + const message = this.extractMessage(resJson); + responseText += message; + options.onUpdate(responseText, message); + } catch (e) { + console.error("[Request] stream error", e); + options.onError(e as Error); + } + }, + onclose() { + options.onError(new Error("stream closed unexpected")); + }, + onerror(err) { + options.onError(err); + }, + }); + } else { + const res = await fetch(chatPath, chatPayload); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + + clearTimeout(reqestTimeoutId); + } catch (e) { + console.log("[Request] failed to make a chat reqeust", e); + options.onError(e as Error); + } + } + async usage() { + return { + used: 0, + total: 0, + } as LLMUsage; + } +} diff --git a/app/constant.ts b/app/constant.ts index d0f9fc74..577c0af6 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -40,3 +40,5 @@ export const NARROW_SIDEBAR_WIDTH = 100; export const ACCESS_CODE_PREFIX = "ak-"; export const LAST_INPUT_KEY = "last-input"; + +export const REQUEST_TIMEOUT_MS = 60000; diff --git a/app/requests.ts b/app/requests.ts index d9750a5b..df81b4f9 100644 --- a/app/requests.ts +++ b/app/requests.ts @@ -43,28 +43,6 @@ const makeRequestParam = ( }; }; -export function getHeaders() { - const accessStore = useAccessStore.getState(); - let headers: Record = {}; - - const makeBearer = (token: string) => `Bearer ${token.trim()}`; - const validString = (x: string) => x && x.length > 0; - - // use user's api key first - if (validString(accessStore.token)) { - headers.Authorization = makeBearer(accessStore.token); - } else if ( - accessStore.enabledAccessControl() && - validString(accessStore.accessCode) - ) { - headers.Authorization = makeBearer( - ACCESS_CODE_PREFIX + accessStore.accessCode, - ); - } - - return headers; -} - export function requestOpenaiClient(path: string) { const openaiUrl = useAccessStore.getState().openaiUrl; return (body: any, method = "POST") => diff --git a/app/store/chat.ts b/app/store/chat.ts index cb11087d..17cf7707 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -14,6 +14,7 @@ import { showToast } from "../components/ui-lib"; import { ModelType } from "./config"; import { createEmptyMask, Mask } from "./mask"; import { StoreKey } from "../constant"; +import { api } from "../client/api"; export type Message = ChatCompletionResponseMessage & { date: string; diff --git a/package.json b/package.json index 2f194174..6b13b9b6 100644 --- a/package.json +++ b/package.json @@ -14,6 +14,7 @@ }, "dependencies": { "@hello-pangea/dnd": "^16.2.0", + "@microsoft/fetch-event-source": "^2.0.1", "@svgr/webpack": "^6.5.1", "@vercel/analytics": "^0.1.11", "emoji-picker-react": "^4.4.7", diff --git a/yarn.lock b/yarn.lock index 22610c6a..a6695acb 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1111,6 +1111,11 @@ dependencies: "@types/react" ">=16.0.0" +"@microsoft/fetch-event-source@^2.0.1": + version "2.0.1" + resolved "https://registry.npmmirror.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d" + integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA== + "@next/env@13.3.1-canary.8": version "13.3.1-canary.8" resolved "https://registry.yarnpkg.com/@next/env/-/env-13.3.1-canary.8.tgz#9f5cf57999e4f4b59ef6407924803a247cc4e451"