Skip to content

Commit

Permalink
feat: handle streaming in CLI
Browse files Browse the repository at this point in the history
- Update CLI to properly handle OpenAI streaming response
- Add Spinner util to show progress
- Update chats create command to stream response
- Handle successful, error and printing from stream in Spinner
  • Loading branch information
cloudbridgeuy committed Jul 1, 2023
1 parent 535f6ea commit 70b26c9
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/b/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ async-trait = "0.1.68" # Type erasure for async trait methods
tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.…
indicatif = "0.17.3" # A progress bar and cli reporting library for Rust
anyhow = "1.0.71" # Flexible concrete Error type built on std::error::Error
tokio-stream = "0.1.14"

48 changes: 46 additions & 2 deletions crates/b/src/chats.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::error::Error;
use tokio_stream::StreamExt;

use async_trait::async_trait;
use serde_either::SingleOrVec;
Expand All @@ -8,7 +9,7 @@ use serde_json::from_str;
use openai::chats::{Chat, ChatMessage, ChatsApi};
use openai::error::OpenAi as OpenAiError;

use crate::utils::read_from_stdin;
use crate::utils::{read_from_stdin, Spinner};
use crate::{ChatsCommands, Cli, CommandError, CommandHandle, CommandResult};

pub struct ChatsCreateCommand {
Expand Down Expand Up @@ -165,6 +166,49 @@ impl CommandHandle<Chat> for ChatsCreateCommand {
type CallError = OpenAiError;

async fn call(&self) -> Result<Chat, OpenAiError> {
self.api.create().await
let mut spinner = Spinner::new(false);

log::debug!("Stream is: {:?}", self.api.stream);

if Some(true) == self.api.stream {
log::debug!("Creating stream");

let chunks = match self.api.create_stream().await {
Ok(chunks) => chunks,
Err(e) => {
log::error!("Error creating stream: {}", e);
return Err(OpenAiError::StreamError);
}
};

tokio::pin!(chunks);

while let Some(chunk) = chunks.next().await {
if chunk.is_err() {
log::error!("Error reading stream");
spinner.err("Error reading stream");
return Err(OpenAiError::StreamError);
}

// spinner.ok();

let chunk = chunk.unwrap();

if let Some(choice) = chunk.choices.get(0) {
if let Some(delta) = &choice.delta {
if let Some(content) = &delta.content {
// print!("{}", content);
spinner.print(content);
}
}
}
}

spinner.ok();
Ok(openai::chats::Chat::default())
} else {
log::debug!("Creating chat");
self.api.create().await
}
}
}
2 changes: 1 addition & 1 deletion crates/b/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn main() -> Result<(), CommandError> {
}
};

let spinner = Spinner::new(cli.silent || cli.stream);
let mut spinner = Spinner::new(cli.silent || cli.stream);

let result = match command.call().await {
Ok(result) => {
Expand Down
42 changes: 34 additions & 8 deletions crates/b/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,21 @@ use std::time::Duration;

use indicatif::{ProgressBar, ProgressStyle};

/// Spinner state
enum SpinnerState {
/// Spinner is running
Running,
/// Spinner is stopped
Stopped,
/// Spinner is silent
Silent,
/// Spinner is errored
Errored,
}

pub struct Spinner {
progress_bar: ProgressBar,
silent: bool,
state: SpinnerState,
}

impl Spinner {
Expand All @@ -15,7 +27,7 @@ impl Spinner {
ProgressBar::hidden()
} else {
let progress_bar = ProgressBar::new_spinner();
progress_bar.enable_steady_tick(Duration::from_millis(120));
progress_bar.enable_steady_tick(Duration::from_millis(100));
progress_bar.set_style(
ProgressStyle::with_template("{spinner:.magenta} {msg}")
.unwrap()
Expand All @@ -24,22 +36,36 @@ impl Spinner {
progress_bar
};
Self {
silent,
state: if silent {
SpinnerState::Silent
} else {
SpinnerState::Running
},
progress_bar,
}
}

pub fn print(&mut self, msg: &str) {
if let SpinnerState::Running = self.state {
self.progress_bar.suspend(|| {
print!("{}", msg);
});
}
}

/// Stops the spinner successfully
pub fn ok(&self) {
if !self.silent {
self.progress_bar.finish_and_clear();
pub fn ok(&mut self) {
if let SpinnerState::Running = self.state {
self.state = SpinnerState::Stopped;
self.progress_bar.finish_and_clear()
}
}

/// Stops the spinner with an error
pub fn err(&self, msg: &str) {
if !self.silent {
pub fn err(&mut self, msg: &str) {
if let SpinnerState::Running = self.state {
self.progress_bar.abandon_with_message(msg.to_string());
self.state = SpinnerState::Errored;
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions crates/openai/src/chats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,6 @@ impl ChatsApi {
pub async fn create_stream(
&self,
) -> Result<impl Stream<Item = Result<Chunk, error::OpenAi>>, error::OpenAi> {
if Some(true) == self.stream {
return Err(error::OpenAi::InvalidStream);
}

let mut api = &mut (*self).clone();

let min_available_tokens = api.min_available_tokens.unwrap_or(750);
Expand Down

0 comments on commit 70b26c9

Please sign in to comment.