2023-09-27 22:47:06 +08:00
|
|
|
import { useState } from 'react';
|
|
|
|
|
import { useAsync } from 'react-use';
|
|
|
|
|
import { Subscription } from 'rxjs';
|
|
|
|
|
|
2023-10-02 22:04:12 +08:00
|
|
|
import { logError } from '@grafana/runtime';
|
|
|
|
|
import { useAppNotification } from 'app/core/copy/appNotification';
|
|
|
|
|
|
2023-09-28 22:42:56 +08:00
|
|
|
import { openai } from './llms';
|
2023-09-27 22:47:06 +08:00
|
|
|
import { isLLMPluginEnabled, OPEN_AI_MODEL } from './utils';
|
|
|
|
|
|
|
|
|
|
// Declared instead of imported from utils to make this hook modular
|
|
|
|
|
// Ideally we will want to move the hook itself to a different scope later.
|
2023-09-28 22:42:56 +08:00
|
|
|
type Message = openai.Message;
|
2023-09-27 22:47:06 +08:00
|
|
|
|
2023-10-05 21:25:35 +08:00
|
|
|
export enum StreamStatus {
|
|
|
|
|
IDLE = 'idle',
|
|
|
|
|
GENERATING = 'generating',
|
|
|
|
|
COMPLETED = 'completed',
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-27 22:47:06 +08:00
|
|
|
// TODO: Add tests
|
|
|
|
|
export function useOpenAIStream(
|
|
|
|
|
model = OPEN_AI_MODEL,
|
|
|
|
|
temperature = 1
|
|
|
|
|
): {
|
|
|
|
|
setMessages: React.Dispatch<React.SetStateAction<Message[]>>;
|
|
|
|
|
reply: string;
|
2023-10-05 21:25:35 +08:00
|
|
|
streamStatus: StreamStatus;
|
2023-09-27 22:47:06 +08:00
|
|
|
error: Error | undefined;
|
|
|
|
|
value:
|
|
|
|
|
| {
|
2023-10-02 22:04:12 +08:00
|
|
|
enabled: boolean | undefined;
|
2023-09-27 22:47:06 +08:00
|
|
|
stream?: undefined;
|
|
|
|
|
}
|
|
|
|
|
| {
|
2023-10-02 22:04:12 +08:00
|
|
|
enabled: boolean | undefined;
|
2023-09-27 22:47:06 +08:00
|
|
|
stream: Subscription;
|
|
|
|
|
}
|
|
|
|
|
| undefined;
|
|
|
|
|
} {
|
|
|
|
|
// The messages array to send to the LLM, updated when the button is clicked.
|
|
|
|
|
const [messages, setMessages] = useState<Message[]>([]);
|
|
|
|
|
// The latest reply from the LLM.
|
|
|
|
|
const [reply, setReply] = useState('');
|
2023-10-05 21:25:35 +08:00
|
|
|
const [streamStatus, setStreamStatus] = useState<StreamStatus>(StreamStatus.IDLE);
|
2023-10-02 22:04:12 +08:00
|
|
|
const [error, setError] = useState<Error>();
|
|
|
|
|
const { error: notifyError } = useAppNotification();
|
2023-09-27 22:47:06 +08:00
|
|
|
|
2023-10-04 05:36:45 +08:00
|
|
|
const { error: enabledError, value: enabled } = useAsync(
|
|
|
|
|
async () => await isLLMPluginEnabled(),
|
|
|
|
|
[isLLMPluginEnabled]
|
|
|
|
|
);
|
|
|
|
|
|
2023-10-02 22:04:12 +08:00
|
|
|
const { error: asyncError, value } = useAsync(async () => {
|
2023-10-04 05:36:45 +08:00
|
|
|
if (!enabled || !messages.length) {
|
2023-09-27 22:47:06 +08:00
|
|
|
return { enabled };
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-05 21:25:35 +08:00
|
|
|
setStreamStatus(StreamStatus.GENERATING);
|
2023-10-02 22:04:12 +08:00
|
|
|
setError(undefined);
|
2023-09-27 22:47:06 +08:00
|
|
|
// Stream the completions. Each element is the next stream chunk.
|
2023-09-28 22:42:56 +08:00
|
|
|
const stream = openai
|
2023-09-27 22:47:06 +08:00
|
|
|
.streamChatCompletions({
|
|
|
|
|
model,
|
|
|
|
|
temperature,
|
|
|
|
|
messages,
|
|
|
|
|
})
|
|
|
|
|
.pipe(
|
|
|
|
|
// Accumulate the stream content into a stream of strings, where each
|
|
|
|
|
// element contains the accumulated message so far.
|
2023-09-28 22:42:56 +08:00
|
|
|
openai.accumulateContent()
|
2023-09-27 22:47:06 +08:00
|
|
|
// The stream is just a regular Observable, so we can use standard rxjs
|
|
|
|
|
// functionality to update state, e.g. recording when the stream
|
|
|
|
|
// has completed.
|
|
|
|
|
// The operator decision tree on the rxjs website is a useful resource:
|
2023-10-02 22:04:12 +08:00
|
|
|
// https://rxjs.dev/operator-decision-tree.)
|
2023-09-27 22:47:06 +08:00
|
|
|
);
|
|
|
|
|
// Subscribe to the stream and update the state for each returned value.
|
|
|
|
|
return {
|
|
|
|
|
enabled,
|
|
|
|
|
stream: stream.subscribe({
|
|
|
|
|
next: setReply,
|
2023-10-02 22:04:12 +08:00
|
|
|
error: (e: Error) => {
|
2023-10-05 21:25:35 +08:00
|
|
|
setStreamStatus(StreamStatus.IDLE);
|
2023-10-02 22:04:12 +08:00
|
|
|
setMessages([]);
|
|
|
|
|
setError(e);
|
|
|
|
|
notifyError('OpenAI Error', `${e.message}`);
|
|
|
|
|
logError(e, { messages: JSON.stringify(messages), model, temperature: String(temperature) });
|
2023-09-28 22:42:56 +08:00
|
|
|
},
|
2023-09-27 22:47:06 +08:00
|
|
|
complete: () => {
|
2023-10-05 21:25:35 +08:00
|
|
|
setStreamStatus(StreamStatus.COMPLETED);
|
|
|
|
|
setTimeout(() => {
|
|
|
|
|
setStreamStatus(StreamStatus.IDLE);
|
|
|
|
|
});
|
2023-09-27 22:47:06 +08:00
|
|
|
setMessages([]);
|
2023-10-02 22:04:12 +08:00
|
|
|
setError(undefined);
|
2023-09-27 22:47:06 +08:00
|
|
|
},
|
|
|
|
|
}),
|
|
|
|
|
};
|
2023-10-04 05:36:45 +08:00
|
|
|
}, [messages, enabled]);
|
2023-09-27 22:47:06 +08:00
|
|
|
|
2023-10-04 05:36:45 +08:00
|
|
|
if (asyncError || enabledError) {
|
|
|
|
|
setError(asyncError || enabledError);
|
2023-09-27 22:47:06 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
setMessages,
|
|
|
|
|
reply,
|
2023-10-05 21:25:35 +08:00
|
|
|
streamStatus,
|
2023-09-27 22:47:06 +08:00
|
|
|
error,
|
|
|
|
|
value,
|
|
|
|
|
};
|
|
|
|
|
}
|