Skip to content

Commit

Permalink
feat: add OpenAI streaming API support
Browse files Browse the repository at this point in the history
- Add support for OpenAI Chat API streaming response
- Use tokio-stream and reqwest-eventsource crates
- Update OpenAI crate with streaming create method
- Handle accumulating and returning stream in OpenAI crate
  • Loading branch information
cloudbridgeuy committed Jun 30, 2023
1 parent 62abc06 commit 535f6ea
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 90 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ serde_yaml = "0.9.19" # YAML data format for Serde
tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.…
reqwest-eventsource = "0.4.0"
futures = "0.3.28"
tokio-stream = "0.1.14"
226 changes: 136 additions & 90 deletions crates/openai/src/chats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ use std::collections::HashMap;
use std::fs::{create_dir_all, File};
use std::io::{BufReader, BufWriter};

use futures::StreamExt;
use crate::client::Client;
use crate::error;
use crate::utils::{directory_exists, file_exists, get_home_directory};
use gpt_tokenizer::Default as DefaultTokenizer;
use log;
use serde::{Deserialize, Serialize};
use serde_either::SingleOrVec;
use serde_json::Value;

use crate::client::Client;
use crate::error;
use crate::utils::{directory_exists, file_exists, get_home_directory};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::{Stream, StreamExt};

#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct ChatsApi {
Expand Down Expand Up @@ -270,8 +273,14 @@ impl ChatsApi {
Ok(self)
}

/// Creates a completion for the chat message
pub async fn create(&self) -> Result<Chat, error::OpenAi> {
/// Creates a completion for the chat message in stream format.
pub async fn create_stream(
&self,
) -> Result<impl Stream<Item = Result<Chunk, error::OpenAi>>, error::OpenAi> {
if Some(true) == self.stream {
return Err(error::OpenAi::InvalidStream);
}

let mut api = &mut (*self).clone();

let min_available_tokens = api.min_available_tokens.unwrap_or(750);
Expand Down Expand Up @@ -307,129 +316,166 @@ impl ChatsApi {

log::debug!("Request: {}", request);

if let Some(true) = &self.stream {
log::debug!("Streaming completion");
let mut event_source = match self.client.post_stream("/chat/completions", request).await
{
Ok(response) => response,
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
};
log::debug!("Streaming completion");
let mut event_source = match self.client.post_stream("/chat/completions", request).await {
Ok(response) => response,
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
};

let (tx, rx) = mpsc::channel(100);
let acc = Arc::new(Mutex::new(String::new()));
let acc_clone = Arc::clone(&acc);

let mut acc: String = String::new();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
Err(_) => {
if tx.send(Err(error::OpenAi::StreamError)).await.is_err() {
return;
}
}
Ok(event) => match event {
reqwest_eventsource::Event::Open { .. } => {}
reqwest_eventsource::Event::Message(message) => {
log::debug!("Message: {:?}", message);

if message.data == "[DONE]" {
break;
return;
}

let response: Chunk = match serde_json::from_str(&message.data) {
Err(e) => {
return Err(error::OpenAi::SerializationError {
body: e.to_string(),
})
let chunk: Chunk = match serde_json::from_str(&message.data) {
Err(_) => {
if tx.send(Err(error::OpenAi::StreamError)).await.is_err() {
return;
}
return;
}
Ok(output) => output,
};

log::debug!("Response: {:?}", response);
log::debug!("Response: {:?}", chunk);

if let Some(choice) = response.choices.get(0) {
if let Some(choice) = &chunk.choices.get(0) {
if let Some(delta) = &choice.delta {
if let Some(content) = &delta.content {
print!("{}", content);
acc.push_str(content);
let mut accumulator = acc.lock().await;
accumulator.push_str(&content.clone());
}
}
}

if tx.send(Ok(chunk)).await.is_err() {
return;
}
}
},
}
}
});

log::debug!("Checking for session, {:?}", session);
if let Some(session) = session {
let session_file = get_sessions_file(&session)?;
api.session = Some(session);
api.min_available_tokens = Some(min_available_tokens);
api.max_supported_tokens = Some(max_supported_tokens);
api.messages = messages;

let data = acc_clone.lock().await;
let data_string = &*data;

api.messages.push(ChatMessage {
content: Some(data_string.to_string()),
role: "assistant".to_string(),
..Default::default()
});
serialize_sessions_file(&session_file, api)?;
}

log::debug!("Returning acc: {}", acc);

log::debug!("Checking for session, {:?}", session);
if let Some(session) = session {
let session_file = get_sessions_file(&session)?;
api.session = Some(session);
api.min_available_tokens = Some(min_available_tokens);
api.max_supported_tokens = Some(max_supported_tokens);
api.messages = messages;
api.messages.push(ChatMessage {
content: Some(acc.clone()),
role: "assistant".to_string(),
..Default::default()
Ok(ReceiverStream::from(rx))
}

/// Creates a completion for the chat message
pub async fn create(&self) -> Result<Chat, error::OpenAi> {
let mut api = &mut (*self).clone();

let min_available_tokens = api.min_available_tokens.unwrap_or(750);
let max_supported_tokens = api.max_supported_tokens.unwrap_or(4096);
let session = api.session.clone();
let messages = api.messages.clone();

api.session = None;
api.min_available_tokens = None;
api.max_supported_tokens = None;
api.messages = trim_messages(
api.messages.clone(),
max_supported_tokens - min_available_tokens,
)?
.iter()
.map(|m| ChatMessage {
role: m.role.clone(),
content: m.content.clone(),
..Default::default()
})
.collect();

log::debug!("Trimmed messages to {:?}", api.messages);

let request = match serde_json::to_string(api) {
Ok(request) => request,
Err(err) => {
return Err(error::OpenAi::SerializationError {
body: err.to_string(),
});
serialize_sessions_file(&session_file, api)?;
}
};

Ok(Chat {
choices: vec![ChatChoice {
message: ChatMessage {
content: Some(acc.clone()),
role: "assistant".to_string(),
..Default::default()
},
..Default::default()
}],
..Default::default()
})
} else {
let body = match self.client.post("/chat/completions", request).await {
Ok(response) => match response.text().await {
Ok(text) => text,
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
},
log::debug!("Request: {}", request);

let body = match self.client.post("/chat/completions", request).await {
Ok(response) => match response.text().await {
Ok(text) => text,
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
};
},
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
};

log::debug!("Response: {}", body);
log::debug!("Response: {}", body);

let body: Chat = match serde_json::from_str(&body) {
Ok(body) => body,
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
};

log::debug!("Checking for session, {:?}", session);
if let Some(session) = session {
let session_file = get_sessions_file(&session)?;
api.session = Some(session);
api.min_available_tokens = Some(min_available_tokens);
api.max_supported_tokens = Some(max_supported_tokens);
api.messages = messages;
api.messages
.push(body.choices.first().unwrap().message.clone());
serialize_sessions_file(&session_file, api)?;
let body: Chat = match serde_json::from_str(&body) {
Ok(body) => body,
Err(e) => {
return Err(error::OpenAi::RequestError {
body: e.to_string(),
})
}
};

Ok(body)
log::debug!("Checking for session, {:?}", session);
if let Some(session) = session {
let session_file = get_sessions_file(&session)?;
api.session = Some(session);
api.min_available_tokens = Some(min_available_tokens);
api.max_supported_tokens = Some(max_supported_tokens);
api.messages = messages;
api.messages
.push(body.choices.first().unwrap().message.clone());
serialize_sessions_file(&session_file, api)?;
}

Ok(body)
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/openai/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ custom_error! {pub OpenAi
NoSession = "no session",
RequestError{body: String} = "request error: {body}",
SerializationError{body: String} = "serialization error: {body}",
StreamError = "stream error",
TrimError = "could not find a message to trim",
UknownError = "unknown error",
}

0 comments on commit 535f6ea

Please sign in to comment.