refactor: llm client api

This commit is contained in:
Yidadaa 2023-05-14 23:00:17 +08:00
parent 6da3aab046
commit bd90caa99d
8 changed files with 279 additions and 22 deletions

109
app/client/api.ts Normal file
View File

@ -0,0 +1,109 @@
import { fetchEventSource } from "@microsoft/fetch-event-source";
import { ACCESS_CODE_PREFIX } from "../constant";
import { ModelType, useAccessStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";
export enum MessageRole {
System = "system",
User = "user",
Assistant = "assistant",
}
export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
export type ChatModel = ModelType;
export interface Message {
role: MessageRole;
content: string;
}
export interface LLMConfig {
temperature?: number;
topP?: number;
stream?: boolean;
presencePenalty?: number;
frequencyPenalty?: number;
}
export interface ChatOptions {
messages: Message[];
model: ChatModel;
config: LLMConfig;
onUpdate: (message: string, chunk: string) => void;
onFinish: (message: string) => void;
onError: (err: Error) => void;
onUnAuth: () => void;
}
export interface LLMUsage {
used: number;
total: number;
}
export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>;
}
export class ClientApi {
public llm: LLMApi;
constructor() {
this.llm = new ChatGPTApi();
}
headers() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {};
const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;
// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
return headers;
}
config() {}
prompts() {}
masks() {}
}
export const api = new ClientApi();
export function getHeaders() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {
"Content-Type": "application/json",
};
const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;
// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
return headers;
}

37
app/client/controller.ts Normal file
View File

@ -0,0 +1,37 @@
// To store message streaming controller
export const ChatControllerPool = {
controllers: {} as Record<string, AbortController>,
addController(
sessionIndex: number,
messageId: number,
controller: AbortController,
) {
const key = this.key(sessionIndex, messageId);
this.controllers[key] = controller;
return key;
},
stop(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
const controller = this.controllers[key];
controller?.abort();
},
stopAll() {
Object.values(this.controllers).forEach((v) => v.abort());
},
hasPending() {
return Object.values(this.controllers).length > 0;
},
remove(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
delete this.controllers[key];
},
key(sessionIndex: number, messageIndex: number) {
return `${sessionIndex},${messageIndex}`;
},
};

View File

@ -0,0 +1,124 @@
import { REQUEST_TIMEOUT_MS } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
EventStreamContentType,
fetchEventSource,
} from "@microsoft/fetch-event-source";
import { ChatOptions, LLMApi, LLMUsage } from "../api";
export class ChatGPTApi implements LLMApi {
public ChatPath = "v1/chat/completions";
path(path: string): string {
const openaiUrl = useAccessStore.getState().openaiUrl;
if (openaiUrl.endsWith("/")) openaiUrl.slice(0, openaiUrl.length - 1);
return [openaiUrl, path].join("/");
}
extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
}
async chat(options: ChatOptions) {
const messages = options.messages.map((v) => ({
role: v.role,
content: v.content,
}));
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.model,
},
};
const requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
};
console.log("[Request] openai payload: ", requestPayload);
const shouldStream = !!options.config.stream;
const controller = new AbortController();
try {
const chatPath = this.path(this.ChatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
signal: controller.signal,
};
// make a fetch request
const reqestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
if (shouldStream) {
let responseText = "";
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
if (
res.ok &&
res.headers.get("Content-Type") === EventStreamContentType
) {
return;
}
if (res.status === 401) {
// TODO: Unauthorized 401
responseText += "\n\n";
} else if (res.status !== 200) {
console.error("[Request] response", res);
throw new Error("[Request] server error");
}
},
onmessage: (ev) => {
if (ev.data === "[DONE]") {
return options.onFinish(responseText);
}
try {
const resJson = JSON.parse(ev.data);
const message = this.extractMessage(resJson);
responseText += message;
options.onUpdate(responseText, message);
} catch (e) {
console.error("[Request] stream error", e);
options.onError(e as Error);
}
},
onclose() {
options.onError(new Error("stream closed unexpected"));
},
onerror(err) {
options.onError(err);
},
});
} else {
const res = await fetch(chatPath, chatPayload);
const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
}
clearTimeout(reqestTimeoutId);
} catch (e) {
console.log("[Request] failed to make a chat reqeust", e);
options.onError(e as Error);
}
}
async usage() {
return {
used: 0,
total: 0,
} as LLMUsage;
}
}

View File

@ -40,3 +40,5 @@ export const NARROW_SIDEBAR_WIDTH = 100;
export const ACCESS_CODE_PREFIX = "ak-"; export const ACCESS_CODE_PREFIX = "ak-";
export const LAST_INPUT_KEY = "last-input"; export const LAST_INPUT_KEY = "last-input";
export const REQUEST_TIMEOUT_MS = 60000;

View File

@ -43,28 +43,6 @@ const makeRequestParam = (
}; };
}; };
export function getHeaders() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {};
const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;
// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
return headers;
}
export function requestOpenaiClient(path: string) { export function requestOpenaiClient(path: string) {
const openaiUrl = useAccessStore.getState().openaiUrl; const openaiUrl = useAccessStore.getState().openaiUrl;
return (body: any, method = "POST") => return (body: any, method = "POST") =>

View File

@ -14,6 +14,7 @@ import { showToast } from "../components/ui-lib";
import { ModelType } from "./config"; import { ModelType } from "./config";
import { createEmptyMask, Mask } from "./mask"; import { createEmptyMask, Mask } from "./mask";
import { StoreKey } from "../constant"; import { StoreKey } from "../constant";
import { api } from "../client/api";
export type Message = ChatCompletionResponseMessage & { export type Message = ChatCompletionResponseMessage & {
date: string; date: string;

View File

@ -14,6 +14,7 @@
}, },
"dependencies": { "dependencies": {
"@hello-pangea/dnd": "^16.2.0", "@hello-pangea/dnd": "^16.2.0",
"@microsoft/fetch-event-source": "^2.0.1",
"@svgr/webpack": "^6.5.1", "@svgr/webpack": "^6.5.1",
"@vercel/analytics": "^0.1.11", "@vercel/analytics": "^0.1.11",
"emoji-picker-react": "^4.4.7", "emoji-picker-react": "^4.4.7",

View File

@ -1111,6 +1111,11 @@
dependencies: dependencies:
"@types/react" ">=16.0.0" "@types/react" ">=16.0.0"
"@microsoft/fetch-event-source@^2.0.1":
version "2.0.1"
resolved "https://registry.npmmirror.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d"
integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==
"@next/env@13.3.1-canary.8": "@next/env@13.3.1-canary.8":
version "13.3.1-canary.8" version "13.3.1-canary.8"
resolved "https://registry.yarnpkg.com/@next/env/-/env-13.3.1-canary.8.tgz#9f5cf57999e4f4b59ef6407924803a247cc4e451" resolved "https://registry.yarnpkg.com/@next/env/-/env-13.3.1-canary.8.tgz#9f5cf57999e4f4b59ef6407924803a247cc4e451"