-
Notifications
You must be signed in to change notification settings - Fork 0
/
qianwen_responder.rs
105 lines (94 loc) · 3.49 KB
/
qianwen_responder.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use reqwest::StatusCode;
use crate::data::config::entity::runtime_data::AccountVisitor;
use crate::data::http_api::alibaba::qian_wen_request::{Input, Parameters, QianWenRequest};
use crate::data::http_api::alibaba::qian_wen_response::QianWenResponse;
use crate::http::client::client_sender::channel_manager::{
ChannelBufferManager, ChannelSender, ClientSender,
};
use crate::http::client::specific_responder::{ResponderError, ResponseParser, SpecificResponder};
/// The parser for the QianWen responder
#[derive(Default)]
pub struct QianWenResponderParser;
impl ResponseParser for QianWenResponderParser {
async fn parse_response(
&mut self,
sender: &mut ClientSender,
response: &[u8],
) -> Result<(), ResponderError> {
match (
serde_json::from_slice::<QianWenResponse>(response),
sender.request.is_stream(),
) {
(Err(err), _) => {
return Err(ResponderError::Request(format!(
"Error when parse response from serde: {}, origin text: {}",
err,
String::from_utf8_lossy(response)
)));
}
(Ok(response), false) => {
if let Some(choice) = response.output.choices.first() {
let content = &choice.message.content;
sender.append_buffer(content.as_str());
}
}
(Ok(response), true) => {
if let Some(choice) = response.output.choices.first() {
let content = &choice.message.content;
sender.append_buffer(content.as_str());
sender
.send_text(content, choice.finish_reason == "stop".to_string())
.await
.map_err(|e| ResponderError::Response(e.to_string()))?;
}
}
}
Ok(())
}
}
#[derive(Default)]
pub struct QianWenResponder;
impl SpecificResponder for QianWenResponder {
async fn make_response(
&self,
sender: &mut ClientSender,
accessor: &AccountVisitor,
) -> Result<(), ResponderError> {
let stream = accessor
.client
.post(accessor.endpoint_url.clone())
.header(
"X-DashScope-SSE",
if sender.is_stream() {
"enable"
} else {
"disable"
},
)
.json(&QianWenRequest {
model: sender.request.model.clone(),
input: Input {
messages: sender.request.messages.clone(),
},
parameters: Parameters {
incremental_output: if sender.is_stream() { Some(true) } else { None },
result_format: "message".to_string(),
},
})
.send()
.await
.map_err(|e| ResponderError::Request(format!("Error when send request: {}", e)))?;
if stream.status() != StatusCode::OK {
return Err(ResponderError::Request(format!(
"Error when get response with code: {}, error message: {}",
stream.status(),
stream
.text()
.await
.map_err(|e| ResponderError::Request(e.to_string()))?
)));
}
process_stream!(stream, QianWenResponderParser::default(), sender);
Ok(())
}
}