Skip to content

Commit

Permalink
Add utils::stream::queue & openai chat completions example (#184)
Browse files Browse the repository at this point in the history
* Add `utils::stream::queue` !!!!!!!!!!!!!!!!!

* next: handle annoying buffering of reqwest

* @2024-06-17 00:06+9:00

* handle reqwest buffering with `utils::stream::queue`
  • Loading branch information
kanarus authored Jun 16, 2024
1 parent 3dfd0f5 commit 87ad61d
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"sse",
"form",
"hello",
"openai",
"realworld",
"quick_start",
"static_files",
Expand Down
12 changes: 12 additions & 0 deletions examples/openai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "openai"
version = "0.1.0"
edition = "2021"

[dependencies]
ohkami = { workspace = true }
tokio = { workspace = true }
reqwest = { version = "0.12", features = ["json", "stream"] }

[features]
DEBUG = ["ohkami/DEBUG"]
61 changes: 61 additions & 0 deletions examples/openai/src/bin/reqwest_chat_completion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::env;
use ohkami::utils::{StreamExt, stream};
use openai::models::{ChatCompletions, ChatMessage, Role};


#[tokio::main]
async fn main() {
let mut gpt_response = reqwest::Client::builder()
.build().unwrap()
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(env::var("OPENAI_API_KEY").expect("env var `OPENAI_API_KEY` is required"))
.json(&ChatCompletions {
model: "gpt-4o",
stream: true,
messages: vec![
ChatMessage {
role: Role::user,
content: env::args().nth(1).expect("CLI arg (message) is required"),
}
],
})
.send().await.expect("reqwest failed")
.bytes_stream();

/* Handle reqwest's annoying buffering */
let mut chat_completion_chunk_stream = stream::queue(|mut q| async move {
let mut push_line = |mut line: String| {
#[cfg(debug_assertions)] {
assert!(line.ends_with("\n\n"))
}
line.truncate(line.len() - 2);
q.push(line)
};

let mut remaining = String::new();

while let Some(Ok(raw_chunk)) = gpt_response.next().await {
for line in std::str::from_utf8(&raw_chunk).unwrap()
.split_inclusive("\n\n")
{
if let Some(data) = line.strip_prefix("data: ") {
if data.ends_with("\n\n") {
push_line(data.to_string())
} else {
remaining = data.into()
}
} else {
push_line(std::mem::take(&mut remaining) + line)
}
}
}
});

while let Some(chunk) = chat_completion_chunk_stream.next().await {
println!("\n\n[chunk]\n---------------------------\n{}\n---------------------------\n",
chunk
.replace('\n', r"\n")
.replace('\r', r"\r")
);
}
}
35 changes: 35 additions & 0 deletions examples/openai/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use ohkami::prelude::*;


#[derive(Debug)]
pub enum Error {
Fetch(reqwest::Error),
}

impl IntoResponse for Error {
fn into_response(self) -> Response {
match self {
Self::Fetch(e) => Response::InternalServerError().with_text(e.to_string()),
}
}
}

const _: () = {
impl std::error::Error for Error {}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fetch(e) => e.fmt(f)
}
}
}
};

const _: () = {
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Self::Fetch(e)
}
}
};
29 changes: 29 additions & 0 deletions examples/openai/src/fangs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::env;
use std::sync::OnceLock;
use ohkami::prelude::*;


#[derive(Clone)]
pub struct WithAPIKey {
api_key: &'static str,
}
impl WithAPIKey {
pub fn from_env() -> Option<Self> {
static API_KEY: OnceLock<Option<String>> = OnceLock::new();

Some(Self {
api_key: API_KEY.get_or_init(|| {
match env::args().nth(1).as_deref() {
Some("--api-key") => env::args().nth(2),
_ => env::var("OPENAI_API_KEY").ok()
}
}).as_deref()?
})
}
}
impl FangAction for WithAPIKey {
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
req.memorize(self.api_key);
Ok(())
}
}
1 change: 1 addition & 0 deletions examples/openai/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod models;
93 changes: 93 additions & 0 deletions examples/openai/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
pub mod error;
pub mod fangs;
pub mod models;

use error::Error;
use models::{UserMessage, ChatMessage, ChatCompletions, Role};

use ohkami::prelude::*;
use ohkami::Memory;
use ohkami::typed::DataStream;
use ohkami::utils::{StreamExt, stream};


#[tokio::main]
async fn main() {
Ohkami::with((
fangs::WithAPIKey::from_env().expect("\
OpenAI API key is not found. \n\
\n\
[USAGE]\n\
Run `cargo run` with one of \n\
a. Set an environment variable `OPENAI_API_KEY` to your API key\n\
b. Pass your API key by command line arguments `-- --api-key <here>`\n\
"),
), (
"/chat-once".POST(relay_chat_completion),
)).howl("localhost:5050").await
}

pub async fn relay_chat_completion(
api_key: Memory<'_, &'static str>,
UserMessage(message): UserMessage,
) -> Result<DataStream<String, Error>, Error> {
let mut gpt_response = reqwest::Client::new()
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(*api_key)
.json(&ChatCompletions {
model: "gpt-4o",
stream: true,
messages: vec![
ChatMessage {
role: Role::user,
content: message,
}
],
})
.send().await?
.bytes_stream();

Ok(DataStream::from_stream(stream::queue(|mut q| async move {
let mut push_line = |mut line: String| {
#[cfg(debug_assertions)] {
assert!(line.ends_with("\n\n"))
}

line.truncate(line.len() - 2);

#[cfg(debug_assertions)] {
if line != "[DONE]" {
use ohkami::{typed::PayloadType, builtin::payload::JSON};

let chunk: models::ChatCompletionChunk = JSON::parse(line.as_bytes()).unwrap();
print!("{}", chunk.choices[0].delta.content.as_deref().unwrap_or(""));
std::io::Write::flush(&mut std::io::stdout()).unwrap();
} else {
println!()
}
}

q.push(Ok(line));
};

let mut remaining = String::new();
while let Some(Ok(raw_chunk)) = gpt_response.next().await {
for line in std::str::from_utf8(&raw_chunk).unwrap()
.split_inclusive("\n\n")
{
if let Some(data) = line.strip_prefix("data: ") {
if data.ends_with("\n\n") {
push_line(data.to_string())
} else {
remaining = data.into()
}
} else {
#[cfg(debug_assertions)] {
assert!(line.ends_with("\n\n"))
}
push_line(std::mem::take(&mut remaining) + line)
}
}
}
})))
}
52 changes: 52 additions & 0 deletions examples/openai/src/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use ohkami::typed::Payload;
use ohkami::builtin::payload::{JSON, Text};
use ohkami::serde::{Deserialize, Serialize};


#[Payload(Text/D)]
pub struct UserMessage(
pub String
);

#[Payload(JSON/S)]
pub struct ChatCompletions {
pub model: &'static str,
pub messages: Vec<ChatMessage>,
pub stream: bool,
}
#[derive(Serialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
}

#[Payload(JSON/D)]
pub struct ChatCompletionChunk {
pub id: String,
pub choices: [ChatCompletionChoice; 1],
}
#[derive(Deserialize)]
pub struct ChatCompletionChoice {
pub delta: ChatCompletionDelta,
pub finish_reason: Option<ChatCompletionFinishReason>,
}
#[derive(Deserialize)]
pub struct ChatCompletionDelta {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize)]
#[allow(non_camel_case_types)]
pub enum ChatCompletionFinishReason {
stop,
length,
content_filter,
}

#[derive(Deserialize, Serialize)]
#[allow(non_camel_case_types)]
pub enum Role {
system,
user,
assistant,
}
17 changes: 17 additions & 0 deletions examples/sse/src/bin/queue_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use ohkami::utils::{stream, StreamExt};
use tokio::time::sleep;


#[tokio::main]
async fn main() {
let mut qs = stream::queue(|mut q| async move {
for i in 1..=5 {
sleep(std::time::Duration::from_secs(1)).await;
q.push(format!("Hi, I'm message#{i}!"))
}
});

while let Some(message) = qs.next().await {
println!("{message}")
}
}
Loading

0 comments on commit 87ad61d

Please sign in to comment.