-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
utils::stream::queue
& openai chat completions example (#184)
* Add `utils::stream::queue` !!!!!!!!!!!!!!!!! * next: handle annoying buffering of reqwest * @2024-06-17 00:06+9:00 * handle reqwest buffering with `utils::stream::queue`
- Loading branch information
Showing
10 changed files
with
407 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ members = [ | |
"sse", | ||
"form", | ||
"hello", | ||
"openai", | ||
"realworld", | ||
"quick_start", | ||
"static_files", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "openai" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
ohkami = { workspace = true } | ||
tokio = { workspace = true } | ||
reqwest = { version = "0.12", features = ["json", "stream"] } | ||
|
||
[features] | ||
DEBUG = ["ohkami/DEBUG"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
use std::env; | ||
use ohkami::utils::{StreamExt, stream}; | ||
use openai::models::{ChatCompletions, ChatMessage, Role}; | ||
|
||
|
||
#[tokio::main] | ||
async fn main() { | ||
let mut gpt_response = reqwest::Client::builder() | ||
.build().unwrap() | ||
.post("https://api.openai.com/v1/chat/completions") | ||
.bearer_auth(env::var("OPENAI_API_KEY").expect("env var `OPENAI_API_KEY` is required")) | ||
.json(&ChatCompletions { | ||
model: "gpt-4o", | ||
stream: true, | ||
messages: vec![ | ||
ChatMessage { | ||
role: Role::user, | ||
content: env::args().nth(1).expect("CLI arg (message) is required"), | ||
} | ||
], | ||
}) | ||
.send().await.expect("reqwest failed") | ||
.bytes_stream(); | ||
|
||
/* Handle reqwest's annoying buffering */ | ||
let mut chat_completion_chunk_stream = stream::queue(|mut q| async move { | ||
let mut push_line = |mut line: String| { | ||
#[cfg(debug_assertions)] { | ||
assert!(line.ends_with("\n\n")) | ||
} | ||
line.truncate(line.len() - 2); | ||
q.push(line) | ||
}; | ||
|
||
let mut remaining = String::new(); | ||
|
||
while let Some(Ok(raw_chunk)) = gpt_response.next().await { | ||
for line in std::str::from_utf8(&raw_chunk).unwrap() | ||
.split_inclusive("\n\n") | ||
{ | ||
if let Some(data) = line.strip_prefix("data: ") { | ||
if data.ends_with("\n\n") { | ||
push_line(data.to_string()) | ||
} else { | ||
remaining = data.into() | ||
} | ||
} else { | ||
push_line(std::mem::take(&mut remaining) + line) | ||
} | ||
} | ||
} | ||
}); | ||
|
||
while let Some(chunk) = chat_completion_chunk_stream.next().await { | ||
println!("\n\n[chunk]\n---------------------------\n{}\n---------------------------\n", | ||
chunk | ||
.replace('\n', r"\n") | ||
.replace('\r', r"\r") | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
use ohkami::prelude::*; | ||
|
||
|
||
#[derive(Debug)] | ||
pub enum Error { | ||
Fetch(reqwest::Error), | ||
} | ||
|
||
impl IntoResponse for Error { | ||
fn into_response(self) -> Response { | ||
match self { | ||
Self::Fetch(e) => Response::InternalServerError().with_text(e.to_string()), | ||
} | ||
} | ||
} | ||
|
||
const _: () = { | ||
impl std::error::Error for Error {} | ||
|
||
impl std::fmt::Display for Error { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
Self::Fetch(e) => e.fmt(f) | ||
} | ||
} | ||
} | ||
}; | ||
|
||
const _: () = { | ||
impl From<reqwest::Error> for Error { | ||
fn from(e: reqwest::Error) -> Self { | ||
Self::Fetch(e) | ||
} | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use std::env; | ||
use std::sync::OnceLock; | ||
use ohkami::prelude::*; | ||
|
||
|
||
#[derive(Clone)] | ||
pub struct WithAPIKey { | ||
api_key: &'static str, | ||
} | ||
impl WithAPIKey { | ||
pub fn from_env() -> Option<Self> { | ||
static API_KEY: OnceLock<Option<String>> = OnceLock::new(); | ||
|
||
Some(Self { | ||
api_key: API_KEY.get_or_init(|| { | ||
match env::args().nth(1).as_deref() { | ||
Some("--api-key") => env::args().nth(2), | ||
_ => env::var("OPENAI_API_KEY").ok() | ||
} | ||
}).as_deref()? | ||
}) | ||
} | ||
} | ||
impl FangAction for WithAPIKey { | ||
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> { | ||
req.memorize(self.api_key); | ||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub mod models; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
pub mod error; | ||
pub mod fangs; | ||
pub mod models; | ||
|
||
use error::Error; | ||
use models::{UserMessage, ChatMessage, ChatCompletions, Role}; | ||
|
||
use ohkami::prelude::*; | ||
use ohkami::Memory; | ||
use ohkami::typed::DataStream; | ||
use ohkami::utils::{StreamExt, stream}; | ||
|
||
|
||
#[tokio::main] | ||
async fn main() { | ||
Ohkami::with(( | ||
fangs::WithAPIKey::from_env().expect("\ | ||
OpenAI API key is not found. \n\ | ||
\n\ | ||
[USAGE]\n\ | ||
Run `cargo run` with one of \n\ | ||
a. Set an environment variable `OPENAI_API_KEY` to your API key\n\ | ||
b. Pass your API key by command line arguments `-- --api-key <here>`\n\ | ||
"), | ||
), ( | ||
"/chat-once".POST(relay_chat_completion), | ||
)).howl("localhost:5050").await | ||
} | ||
|
||
pub async fn relay_chat_completion( | ||
api_key: Memory<'_, &'static str>, | ||
UserMessage(message): UserMessage, | ||
) -> Result<DataStream<String, Error>, Error> { | ||
let mut gpt_response = reqwest::Client::new() | ||
.post("https://api.openai.com/v1/chat/completions") | ||
.bearer_auth(*api_key) | ||
.json(&ChatCompletions { | ||
model: "gpt-4o", | ||
stream: true, | ||
messages: vec![ | ||
ChatMessage { | ||
role: Role::user, | ||
content: message, | ||
} | ||
], | ||
}) | ||
.send().await? | ||
.bytes_stream(); | ||
|
||
Ok(DataStream::from_stream(stream::queue(|mut q| async move { | ||
let mut push_line = |mut line: String| { | ||
#[cfg(debug_assertions)] { | ||
assert!(line.ends_with("\n\n")) | ||
} | ||
|
||
line.truncate(line.len() - 2); | ||
|
||
#[cfg(debug_assertions)] { | ||
if line != "[DONE]" { | ||
use ohkami::{typed::PayloadType, builtin::payload::JSON}; | ||
|
||
let chunk: models::ChatCompletionChunk = JSON::parse(line.as_bytes()).unwrap(); | ||
print!("{}", chunk.choices[0].delta.content.as_deref().unwrap_or("")); | ||
std::io::Write::flush(&mut std::io::stdout()).unwrap(); | ||
} else { | ||
println!() | ||
} | ||
} | ||
|
||
q.push(Ok(line)); | ||
}; | ||
|
||
let mut remaining = String::new(); | ||
while let Some(Ok(raw_chunk)) = gpt_response.next().await { | ||
for line in std::str::from_utf8(&raw_chunk).unwrap() | ||
.split_inclusive("\n\n") | ||
{ | ||
if let Some(data) = line.strip_prefix("data: ") { | ||
if data.ends_with("\n\n") { | ||
push_line(data.to_string()) | ||
} else { | ||
remaining = data.into() | ||
} | ||
} else { | ||
#[cfg(debug_assertions)] { | ||
assert!(line.ends_with("\n\n")) | ||
} | ||
push_line(std::mem::take(&mut remaining) + line) | ||
} | ||
} | ||
} | ||
}))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use ohkami::typed::Payload; | ||
use ohkami::builtin::payload::{JSON, Text}; | ||
use ohkami::serde::{Deserialize, Serialize}; | ||
|
||
|
||
#[Payload(Text/D)] | ||
pub struct UserMessage( | ||
pub String | ||
); | ||
|
||
#[Payload(JSON/S)] | ||
pub struct ChatCompletions { | ||
pub model: &'static str, | ||
pub messages: Vec<ChatMessage>, | ||
pub stream: bool, | ||
} | ||
#[derive(Serialize)] | ||
pub struct ChatMessage { | ||
pub role: Role, | ||
pub content: String, | ||
} | ||
|
||
#[Payload(JSON/D)] | ||
pub struct ChatCompletionChunk { | ||
pub id: String, | ||
pub choices: [ChatCompletionChoice; 1], | ||
} | ||
#[derive(Deserialize)] | ||
pub struct ChatCompletionChoice { | ||
pub delta: ChatCompletionDelta, | ||
pub finish_reason: Option<ChatCompletionFinishReason>, | ||
} | ||
#[derive(Deserialize)] | ||
pub struct ChatCompletionDelta { | ||
pub role: Option<Role>, | ||
pub content: Option<String>, | ||
} | ||
#[derive(Deserialize)] | ||
#[allow(non_camel_case_types)] | ||
pub enum ChatCompletionFinishReason { | ||
stop, | ||
length, | ||
content_filter, | ||
} | ||
|
||
#[derive(Deserialize, Serialize)] | ||
#[allow(non_camel_case_types)] | ||
pub enum Role { | ||
system, | ||
user, | ||
assistant, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
use ohkami::utils::{stream, StreamExt}; | ||
use tokio::time::sleep; | ||
|
||
|
||
#[tokio::main] | ||
async fn main() { | ||
let mut qs = stream::queue(|mut q| async move { | ||
for i in 1..=5 { | ||
sleep(std::time::Duration::from_secs(1)).await; | ||
q.push(format!("Hi, I'm message#{i}!")) | ||
} | ||
}); | ||
|
||
while let Some(message) = qs.next().await { | ||
println!("{message}") | ||
} | ||
} |
Oops, something went wrong.