Skip to content

Commit

Permalink
fix: add close method to stream iterator for connection release (#630)
Browse files Browse the repository at this point in the history
* fix: add close method to stream iterator for connection release

* fix: add close method to stream iterator for connection release

* fix: add close method to stream iterator for connection release

* fix: add close method to stream iterator for connection release

* fix: add close method to stream iterator for connection release

* fix: add close method to stream iterator for connection release
  • Loading branch information
Azure99 authored Jul 1, 2024
1 parent 4e0a508 commit b88d2a0
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 139 deletions.
13 changes: 6 additions & 7 deletions java/src/main/java/com/baidubce/qianfan/Qianfan.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.baidubce.qianfan;

import com.baidubce.qianfan.core.StreamIterator;
import com.baidubce.qianfan.core.builder.*;
import com.baidubce.qianfan.model.BaseRequest;
import com.baidubce.qianfan.model.BaseResponse;
Expand All @@ -36,8 +37,6 @@
import com.baidubce.qianfan.model.rerank.RerankRequest;
import com.baidubce.qianfan.model.rerank.RerankResponse;

import java.util.Iterator;

public class Qianfan {
private final QianfanClient client;

Expand Down Expand Up @@ -71,7 +70,7 @@ public ChatResponse chatCompletion(ChatRequest request) {
return request(request, ChatResponse.class);
}

public Iterator<ChatResponse> chatCompletionStream(ChatRequest request) {
public StreamIterator<ChatResponse> chatCompletionStream(ChatRequest request) {
request.setStream(true);
return requestStream(request, ChatResponse.class);
}
Expand All @@ -84,7 +83,7 @@ public CompletionResponse completion(CompletionRequest request) {
return request(request, CompletionResponse.class);
}

public Iterator<CompletionResponse> completionStream(CompletionRequest request) {
public StreamIterator<CompletionResponse> completionStream(CompletionRequest request) {
request.setStream(true);
return requestStream(request, CompletionResponse.class);
}
Expand Down Expand Up @@ -113,7 +112,7 @@ public Image2TextResponse image2Text(Image2TextRequest request) {
return request(request, Image2TextResponse.class);
}

public Iterator<Image2TextResponse> image2TextStream(Image2TextRequest request) {
public StreamIterator<Image2TextResponse> image2TextStream(Image2TextRequest request) {
request.setStream(true);
return requestStream(request, Image2TextResponse.class);
}
Expand All @@ -134,7 +133,7 @@ public PluginResponse plugin(PluginRequest request) {
return request(request, PluginResponse.class);
}

public Iterator<PluginResponse> pluginStream(PluginRequest request) {
public StreamIterator<PluginResponse> pluginStream(PluginRequest request) {
request.setStream(true);
return requestStream(request, PluginResponse.class);
}
Expand All @@ -143,7 +142,7 @@ public <T extends BaseResponse<T>, U extends BaseRequest<U>> T request(BaseReque
return client.request(request, responseClass);
}

public <T extends BaseResponse<T>, U extends BaseRequest<U>> Iterator<T> requestStream(BaseRequest<U> request, Class<T> responseClass) {
public <T extends BaseResponse<T>, U extends BaseRequest<U>> StreamIterator<T> requestStream(BaseRequest<U> request, Class<T> responseClass) {
return client.requestStream(request, responseClass);
}
}
46 changes: 2 additions & 44 deletions java/src/main/java/com/baidubce/qianfan/QianfanClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,18 @@
import com.baidubce.qianfan.core.ModelEndpointRetriever;
import com.baidubce.qianfan.core.QianfanConfig;
import com.baidubce.qianfan.core.RateLimiter;
import com.baidubce.qianfan.core.StreamIterator;
import com.baidubce.qianfan.core.auth.Auth;
import com.baidubce.qianfan.core.auth.IAuth;
import com.baidubce.qianfan.model.*;
import com.baidubce.qianfan.model.exception.ApiException;
import com.baidubce.qianfan.model.exception.QianfanException;
import com.baidubce.qianfan.model.exception.RequestException;
import com.baidubce.qianfan.model.plugin.PluginMetaInfo;
import com.baidubce.qianfan.model.plugin.PluginResponse;
import com.baidubce.qianfan.util.Json;
import com.baidubce.qianfan.util.StringUtils;
import com.baidubce.qianfan.util.function.ThrowingFunction;
import com.baidubce.qianfan.util.http.*;

import java.util.Iterator;
import java.util.Map;

class QianfanClient {
private static final String SDK_VERSION = "0.0.8";
private static final String QIANFAN_URL_TEMPLATE = "%s/rpc/2.0/ai_custom/v1/wenxinworkshop%s";
Expand Down Expand Up @@ -83,7 +79,7 @@ public <T extends BaseResponse<T>, U extends BaseRequest<U>> T request(BaseReque
);
}

public <T extends BaseResponse<T>, U extends BaseRequest<U>> Iterator<T> requestStream(BaseRequest<U> request, Class<T> responseClass) {
public <T extends BaseResponse<T>, U extends BaseRequest<U>> StreamIterator<T> requestStream(BaseRequest<U> request, Class<T> responseClass) {
return request(
request,
HttpRequest::executeSSE,
Expand Down Expand Up @@ -159,42 +155,4 @@ private void backoffSleep(int retryCount, double backoffFactor, int maxWaitInter
throw new RequestException("Request failed: retry delay interrupted", e);
}
}

private static class StreamIterator<T extends BaseResponse<T>> implements Iterator<T> {
private final Map<String, String> headers;
private final Iterator<String> sseIterator;
private final Class<T> responseClass;

private PluginMetaInfo metaInfo;

public StreamIterator(Map<String, String> headers, Iterator<String> sseIterator, Class<T> responseClass) {
this.headers = headers;
this.sseIterator = sseIterator;
this.responseClass = responseClass;
}

@Override
public boolean hasNext() {
return sseIterator.hasNext();
}

@Override
@SuppressWarnings("unchecked")
public T next() {
String event = sseIterator.next().replaceFirst("data: ", "");
// Skip sse empty line
sseIterator.next();
T response = Json.deserialize(event, responseClass);

// Set meta info for PluginResponse
if (responseClass.equals(PluginResponse.class)) {
if (metaInfo == null) {
metaInfo = Json.deserialize(event, PluginMetaInfo.class);
}
((PluginResponse) response).setMetaInfo(metaInfo);
}

return (T) response.setHeaders(headers);
}
}
}
89 changes: 89 additions & 0 deletions java/src/main/java/com/baidubce/qianfan/core/StreamIterator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.baidubce.qianfan.core;

import com.baidubce.qianfan.model.BaseResponse;
import com.baidubce.qianfan.model.plugin.PluginMetaInfo;
import com.baidubce.qianfan.model.plugin.PluginResponse;
import com.baidubce.qianfan.util.Json;
import com.baidubce.qianfan.util.http.SSEIterator;

import java.io.Closeable;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;

public class StreamIterator<T extends BaseResponse<T>> implements Iterator<T>, Closeable {
private final Map<String, String> headers;
private final SSEIterator sseIterator;
private final Class<T> responseClass;

private PluginMetaInfo metaInfo;

public StreamIterator(Map<String, String> headers, SSEIterator sseIterator, Class<T> responseClass) {
this.headers = headers;
this.sseIterator = sseIterator;
this.responseClass = responseClass;
}

@Override
public boolean hasNext() {
return sseIterator.hasNext();
}

@Override
@SuppressWarnings("unchecked")
public T next() {
String event = sseIterator.next().replaceFirst("data: ", "");
// Skip sse empty line
sseIterator.next();
T response = Json.deserialize(event, responseClass);

// Set meta info for PluginResponse
if (responseClass.equals(PluginResponse.class)) {
if (metaInfo == null) {
metaInfo = Json.deserialize(event, PluginMetaInfo.class);
}
((PluginResponse) response).setMetaInfo(metaInfo);
}

return (T) response.setHeaders(headers);
}

@Override
public void forEachRemaining(Consumer<? super T> action) {
Objects.requireNonNull(action);
try {
while (hasNext()) {
action.accept(next());
}
} finally {
try {
close();
} catch (Exception e) {
// ignored
}
}
}

@Override
public void close() throws IOException {
sseIterator.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package com.baidubce.qianfan.core.builder;

import com.baidubce.qianfan.Qianfan;
import com.baidubce.qianfan.core.StreamIterator;
import com.baidubce.qianfan.model.chat.*;

import java.util.Iterator;
import java.util.List;

public class ChatBuilder extends BaseBuilder<ChatBuilder> {
Expand Down Expand Up @@ -175,7 +175,7 @@ public ChatResponse execute() {
return super.getQianfan().chatCompletion(build());
}

public Iterator<ChatResponse> executeStream() {
public StreamIterator<ChatResponse> executeStream() {
return super.getQianfan().chatCompletionStream(build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package com.baidubce.qianfan.core.builder;

import com.baidubce.qianfan.Qianfan;
import com.baidubce.qianfan.core.StreamIterator;
import com.baidubce.qianfan.model.completion.CompletionRequest;
import com.baidubce.qianfan.model.completion.CompletionResponse;

import java.util.Iterator;
import java.util.List;

public class CompletionBuilder extends BaseBuilder<CompletionBuilder> {
Expand Down Expand Up @@ -92,7 +92,7 @@ public CompletionResponse execute() {
return super.getQianfan().completion(build());
}

public Iterator<CompletionResponse> executeStream() {
public StreamIterator<CompletionResponse> executeStream() {
return super.getQianfan().completionStream(build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package com.baidubce.qianfan.core.builder;

import com.baidubce.qianfan.Qianfan;
import com.baidubce.qianfan.core.StreamIterator;
import com.baidubce.qianfan.model.image.Image2TextRequest;
import com.baidubce.qianfan.model.image.Image2TextResponse;

import java.util.Iterator;
import java.util.List;

public class Image2TextBuilder extends BaseBuilder<Image2TextBuilder> {
Expand Down Expand Up @@ -140,7 +140,7 @@ public Image2TextResponse execute() {
return super.getQianfan().image2Text(build());
}

public Iterator<Image2TextResponse> executeStream() {
public StreamIterator<Image2TextResponse> executeStream() {
return super.getQianfan().image2TextStream(build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
package com.baidubce.qianfan.core.builder;

import com.baidubce.qianfan.Qianfan;
import com.baidubce.qianfan.core.StreamIterator;
import com.baidubce.qianfan.model.plugin.PluginHistory;
import com.baidubce.qianfan.model.plugin.PluginLLM;
import com.baidubce.qianfan.model.plugin.PluginRequest;
import com.baidubce.qianfan.model.plugin.PluginResponse;

import java.util.Iterator;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -102,7 +102,7 @@ public PluginResponse execute() {
return super.getQianfan().plugin(build());
}

public Iterator<PluginResponse> executeStream() {
public StreamIterator<PluginResponse> executeStream() {
return super.getQianfan().pluginStream(build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@

import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

public class HttpClient {
private static final int MAX_CONNECTIONS = 128;
private static final int MAX_CONNECTIONS = 512;

private static final CloseableHttpClient client;

Expand Down Expand Up @@ -94,7 +93,7 @@ private static <T> HttpResponse<T> execute(ClassicHttpRequest request, HttpRespo
});
}

public static HttpResponse<Iterator<String>> executeSSE(ClassicHttpRequest request) throws IOException {
public static HttpResponse<SSEIterator> executeSSE(ClassicHttpRequest request) throws IOException {
// Use legacy API to avoid auto-closing the response
CloseableHttpResponse resp = client.execute(request);

Expand All @@ -103,12 +102,12 @@ public static HttpResponse<Iterator<String>> executeSSE(ClassicHttpRequest reque
headers.put(header.getName(), header.getValue());
}

Iterator<String> body = null;
SSEIterator body = null;
String stringBody = null;

String contentType = headers.getOrDefault(ContentType.HEADER, "");
if (contentType.startsWith(ContentType.TEXT_EVENT_STREAM)) {
body = SSEWrapper.wrap(resp.getEntity().getContent(), resp);
body = new SSEIterator(resp.getEntity().getContent(), resp);
} else {
try {
// If the response is not an SSE stream, read the whole body as a string
Expand All @@ -118,7 +117,7 @@ public static HttpResponse<Iterator<String>> executeSSE(ClassicHttpRequest reque
}
}

return new HttpResponse<Iterator<String>>()
return new HttpResponse<SSEIterator>()
.setCode(resp.getCode())
.setHeaders(headers)
.setBody(body)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URI;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

Expand Down Expand Up @@ -120,7 +119,7 @@ public HttpResponse<String> executeString() throws IOException {
return HttpClient.executeString(toClassicHttpRequest());
}

public HttpResponse<Iterator<String>> executeSSE() throws IOException {
public HttpResponse<SSEIterator> executeSSE() throws IOException {
return HttpClient.executeSSE(toClassicHttpRequest());
}

Expand Down
Loading

0 comments on commit b88d2a0

Please sign in to comment.