Skip to content

Commit

Permalink
Improved markdown streaming implementation in React
Browse files Browse the repository at this point in the history
  • Loading branch information
salmenus committed Apr 24, 2024
1 parent f879dec commit 867198f
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 43 deletions.
20 changes: 12 additions & 8 deletions packages/react/core/src/exports/hooks/useSubmitPromptHandler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {ChatAdapter, ChatAdapterExtras, PromptBoxOptions, StandardChatAdapter} from '@nlux/core';
import {MutableRefObject, useCallback, useEffect, useMemo, useRef, useState} from 'react';
import {MutableRefObject, useCallback, useEffect, useMemo, useRef} from 'react';
import {submitPrompt} from '../../../../../shared/src/services/submitPrompt/submitPromptImpl';
import {ChatSegment} from '../../../../../shared/src/types/chatSegment/chatSegment';
import {ChatSegmentAiMessage} from '../../../../../shared/src/types/chatSegment/chatSegmentAiMessage';
Expand Down Expand Up @@ -37,11 +37,6 @@ export const useSubmitPromptHandler = <AiMsg>(props: SubmitPromptHandlerProps<Ai

const hasValidInput = useMemo(() => promptTyped.length > 0, [promptTyped]);

// The prompt that will be submitted
// We store it in a separate variable because the prompt might change.
// Example: When the user types a new message while the previous message is being streamed.
const [promptSubmitted, setPromptSubmitted] = useState<string>('');

// The prompt typed will be read by the submitPrompt function, but it will not be used as a
// dependency for the submitPrompt function (only the promptToSubmit is a dependency to useCallback).
// Hence, the use of useRef to store the value and access it within the submitPrompt function, without
Expand Down Expand Up @@ -85,6 +80,7 @@ export const useSubmitPromptHandler = <AiMsg>(props: SubmitPromptHandlerProps<Ai

setPromptBoxStatus('submitting');
const promptToSubmit = promptTyped;
const streamedMessageIds: Set<string> = new Set();

const {
segment: chatSegment,
Expand All @@ -95,8 +91,6 @@ export const useSubmitPromptHandler = <AiMsg>(props: SubmitPromptHandlerProps<Ai
adapterExtras,
);

setPromptSubmitted(promptToSubmit);

if (chatSegment.status === 'error') {
warn('Error occurred while submitting prompt');
showException('Error occurred while submitting prompt');
Expand Down Expand Up @@ -146,6 +140,8 @@ export const useSubmitPromptHandler = <AiMsg>(props: SubmitPromptHandlerProps<Ai
if (promptTypedRef.current === promptToSubmit) {
domToReactRef.current.setPrompt('');
}

streamedMessageIds.add(aiStreamedMessage.uid);
});

chatSegmentObservable.on('aiMessageReceived', (aiMessage) => {
Expand Down Expand Up @@ -180,6 +176,14 @@ export const useSubmitPromptHandler = <AiMsg>(props: SubmitPromptHandlerProps<Ai
if (promptTypedRef.current === promptToSubmit) {
setPrompt('');
}

if (streamedMessageIds.size > 0) {
streamedMessageIds.forEach((messageId) => {
conversationRef.current?.completeStream(chatSegmentObservable.segmentId, messageId);
});

streamedMessageIds.clear();
}
});

chatSegmentObservable.on('aiChunkReceived', (chunk: string, messageId: string) => {
Expand Down
11 changes: 7 additions & 4 deletions packages/react/core/src/logic/ChatSegment/ChatSegmentComp.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ export const ChatSegmentComp: <AiMsg>(
useImperativeHandle(ref, () => ({
streamChunk: (messageId: string, chunk: string) => {
const messageCompRef = chatItemsRef.get(messageId);
if (messageCompRef?.current) {
messageCompRef.current.streamChunk(chunk);
}
messageCompRef?.current?.streamChunk(chunk);
},
completeStream: (messageId: string) => {
const messageCompRef = chatItemsRef.get(messageId);
messageCompRef?.current?.completeStream();
chatItemsRef.delete(messageId);
},
}), []);

Expand Down Expand Up @@ -132,7 +135,7 @@ export const ChatSegmentComp: <AiMsg>(
ref={ref}
key={chatItem.uid}
uid={chatItem.uid}
status={'rendered'}
status={'streaming'}
direction={'incoming'}
message={chatItem.content}
name={nameFromMessageAndPersona(chatItem.participantRole, props.personaOptions)}
Expand Down
1 change: 1 addition & 0 deletions packages/react/core/src/logic/ChatSegment/props.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ export type ChatSegmentProps<AiMsg> = {

export type ChatSegmentImperativeProps<AiMsg> = {
streamChunk: (messageId: string, chunk: string) => void;
completeStream: (messageId: string) => void;
};
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ export const ConversationComp: ConversationCompType = function <AiMsg>(
const chatSegment = segmentsController.get(segmentId);
chatSegment?.streamChunk(messageId, chunk);
},
completeStream: (segmentId: string, messageId: string) => {
const chatSegment = segmentsController.get(segmentId);
chatSegment?.completeStream(messageId);
},
}), []);

const ForwardRefChatSegmentComp = useMemo(() => forwardRef(
Expand Down
1 change: 1 addition & 0 deletions packages/react/core/src/logic/Conversation/props.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ export type ConversationCompProps<AiMsg> = {

export type ImperativeConversationCompProps = {
streamChunk: (segmentId: string, messageId: string, chunk: string) => void;
completeStream: (segmentId: string, messageId: string) => void;
};
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {Ref, useImperativeHandle, useRef} from 'react';
import {createMarkdownStreamParser, MarkdownStreamParser} from '@nlux/markdown';
import {Ref, useEffect, useImperativeHandle, useMemo, useRef, useState} from 'react';
import {className as compMessageClassName} from '../../../../../shared/src/ui/Message/create';
import {
directionClassName as compMessageDirectionClassName,
Expand All @@ -12,22 +13,76 @@ export const StreamContainerComp = (
props: StreamContainerProps,
ref: Ref<StreamContainerImperativeProps>,
) => {
const streamContainer = useRef<HTMLDivElement>(null);
const {
status,
markdownOptions,
initialMarkdownMessage,
} = props;

// We use references in this component to avoid re-renders — as streaming happens outside of React
// rendering cycle, we don't want to trigger re-renders on every chunk of data received.
const streamContainerRef = useRef<HTMLDivElement | null>(null);
const streamContainerRefPreviousValue = useRef<HTMLDivElement | null>(null);
const markdownStreamParserRef = useRef<MarkdownStreamParser | null>(null);

const [streamContainer, setStreamContainer] = useState<HTMLDivElement>();
const markdownElement = useMemo(() => {
const element = document.createElement('div');
element.className = 'nlux-msg-md';
return element;
}, []);

const [initialMarkdownMessageParsed, setInitialMarkdownMessageParsed] = useState(false);

useEffect(() => {
if (streamContainerRef.current !== streamContainerRefPreviousValue.current) {
streamContainerRefPreviousValue.current = streamContainerRef.current;
setStreamContainer(streamContainerRef.current || undefined);
}
}); // No dependencies, this effect should run on every render
// The 'if' statement inside the effect plays a similar role to a useEffect dependency array
// to prevent setting the streamContainer state to the same value multiple times.

useEffect(() => {
if (streamContainer) {
streamContainer.append(markdownElement);
} else {
const fragment = document.createDocumentFragment();
fragment.append(markdownElement);
}
});

useEffect(() => {
markdownStreamParserRef.current = createMarkdownStreamParser(markdownElement, {
openLinksInNewWindow: markdownOptions?.openLinksInNewWindow ?? true,
syntaxHighlighter: markdownOptions?.syntaxHighlighter ?? undefined,
});

if (!initialMarkdownMessageParsed && initialMarkdownMessage) {
markdownStreamParserRef.current.next(initialMarkdownMessage);
setInitialMarkdownMessageParsed(true);
}
}, [markdownOptions?.openLinksInNewWindow, markdownOptions?.syntaxHighlighter]);

useEffect(() => {
return () => {
streamContainerRefPreviousValue.current = null;
markdownStreamParserRef.current?.complete();
markdownStreamParserRef.current = null;
setStreamContainer(undefined);
};
}, []);

useImperativeHandle(ref, () => ({
streamChunk: (chunk: string) => {
streamContainer.current?.append(
// TODO - Handle markdown
document.createTextNode(chunk),
);
},
streamChunk: (chunk: string) => markdownStreamParserRef.current?.next(chunk),
completeStream: () => markdownStreamParserRef.current?.complete(),
}), []);

const compDirectionClassName = compMessageDirectionClassName['incoming'];
const compStatusClassName = compMessageStatusClassName[props.status];
const compStatusClassName = compMessageStatusClassName[status];
const className = `${compMessageClassName} ${compStatusClassName} ${compDirectionClassName}`;

return (
<div className={className} ref={streamContainer}/>
<div className={className} ref={streamContainerRef}/>
);
};
7 changes: 7 additions & 0 deletions packages/react/core/src/logic/StreamContainer/props.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import {HighlighterExtension} from '@nlux/core';
import {MessageDirection} from '../../../../../shared/src/ui/Message/props';

export type StreamContainerProps = {
uid: string,
direction: MessageDirection,
status: 'rendered' | 'streaming' | 'error';
initialMarkdownMessage?: string;
markdownOptions?: {
syntaxHighlighter?: HighlighterExtension;
openLinksInNewWindow?: boolean;
}
};

export type StreamContainerImperativeProps = {
streamChunk: (chunk: string) => void;
completeStream: () => void;
};
30 changes: 9 additions & 21 deletions packages/react/core/src/ui/ChatItem/ChatItemComp.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {FC, forwardRef, ReactElement, Ref, useImperativeHandle, useMemo, useRef} from 'react';
import {forwardRef, ReactElement, Ref, useImperativeHandle, useMemo, useRef} from 'react';
import {className as compChatItemClassName} from '../../../../../shared/src/ui/ChatItem/create';
import {
directionClassName as compChatItemDirectionClassName,
Expand All @@ -7,6 +7,7 @@ import {StreamContainerImperativeProps} from '../../logic/StreamContainer/props'
import {StreamContainerComp} from '../../logic/StreamContainer/StreamContainerComp';
import {AvatarComp} from '../Avatar/AvatarComp';
import {MessageComp} from '../Message/MessageComp';
import {createMessageRenderer} from '../Message/MessageRenderer';
import {ChatItemImperativeProps, ChatItemProps} from './props';

export const ChatItemComp: <AiMsg>(
Expand All @@ -22,16 +23,13 @@ export const ChatItemComp: <AiMsg>(
}

return <AvatarComp name={props.name} picture={props.picture}/>;
}, [props.picture, props.name]);
}, [props?.picture, props?.name]);

const streamContainer = useRef<StreamContainerImperativeProps | null>(null);

useImperativeHandle(ref, () => ({
streamChunk: (chunk: string) => {
if (streamContainer?.current) {
streamContainer.current.streamChunk(chunk);
}
},
streamChunk: (chunk: string) => streamContainer?.current?.streamChunk(chunk),
completeStream: () => streamContainer?.current?.completeStream(),
}), []);

const isStreaming = useMemo(
Expand All @@ -44,20 +42,10 @@ export const ChatItemComp: <AiMsg>(
: compChatItemDirectionClassName['incoming'];

const className = `${compChatItemClassName} ${compDirectionClassName}`;
const MessageRenderer: FC<void> = useMemo(() => {
if (props.customRenderer) {
if (props.message === undefined) {
return () => null;
}

return () => props.customRenderer ? props.customRenderer({
message: props.message as AiMsg,
}) : null;
}

// TODO - Markdown support
return () => <>{props.message !== undefined ? props.message : ''}</>;
}, [props.customRenderer, props.message]);
const MessageRenderer = useMemo(() => createMessageRenderer(props), [
props.message,
props.customRenderer,
]);

const ForwardRefStreamContainerComp = useMemo(() => forwardRef(
StreamContainerComp,
Expand Down
1 change: 1 addition & 0 deletions packages/react/core/src/ui/ChatItem/props.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ export type ChatItemProps<AiMsg> = {

export type ChatItemImperativeProps = {
streamChunk: (chunk: string) => void;
completeStream: () => void;
};
35 changes: 35 additions & 0 deletions packages/react/core/src/ui/Message/MessageRenderer.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import {HighlighterExtension} from '@nlux/core';
import {FC} from 'react';
import {ChatItemProps} from '../ChatItem/props';

const MarkdownRenderer = (props: {
initialMarkdownMessage?: string,
markdownOptions?: {
syntaxHighlighter?: HighlighterExtension,
openLinksInNewWindow?: boolean,
},
}) => {
// TODO - Implement markdown parsing
const {initialMarkdownMessage} = props;
return <div className={'markdown-NOT-parsed'}>{initialMarkdownMessage}</div>;
};

export const createMessageRenderer: <AiMsg>(props: ChatItemProps<AiMsg>) => FC<void> = function <AiMsg>(props: ChatItemProps<AiMsg>) {
if (props.customRenderer !== undefined) {
if (props.message === undefined) {
return () => null;
}

return () => props.customRenderer!({
message: props.message as AiMsg,
});
}

if (typeof props.message === 'string') {
const messageToRender: string = props.message;
return () => <MarkdownRenderer initialMarkdownMessage={messageToRender}/>;
}

// No custom renderer and message is not a string!
return () => '';
};

0 comments on commit 867198f

Please sign in to comment.