Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

support openai stream api #183

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api/src/main/java/com/theokanning/openai/IStream.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.theokanning.openai;

public interface IStream {
public void setStream(boolean isSSE);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.theokanning.openai.completion;

import com.theokanning.openai.IStream;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
Expand All @@ -20,7 +21,7 @@
@NoArgsConstructor
@AllArgsConstructor
@Data
public class CompletionRequest {
public class CompletionRequest implements IStream {

/**
* The name of the model to use.
Expand Down Expand Up @@ -137,4 +138,9 @@ public class CompletionRequest {
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
*/
String user;

@Override
public void setStream(boolean isSSE) {
this.stream = isSSE;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.theokanning.openai.IStream;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
Expand All @@ -13,7 +14,7 @@
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatCompletionRequest {
public class ChatCompletionRequest implements IStream {

/**
* ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported.
Expand Down Expand Up @@ -94,4 +95,9 @@ public class ChatCompletionRequest {
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
*/
String user;

@Override
public void setStream(boolean isSSE) {
this.stream = isSSE;
}
}
1 change: 1 addition & 0 deletions example/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ application {

dependencies {
implementation project(":service")
implementation 'com.squareup.okhttp3:okhttp-sse:4.9.3'
}
46 changes: 45 additions & 1 deletion example/src/main/java/example/OpenAiApiExample.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
package example;

import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.image.CreateImageRequest;
import com.theokanning.openai.service.OpenApiStreamService;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.List;

class OpenAiApiExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);
OpenApiStreamService streamService = new OpenApiStreamService(token);

System.out.println("\nCreating completion...");
CompletionRequest completionRequest = CompletionRequest.builder()
Expand All @@ -26,5 +38,37 @@ public static void main(String... args) {

System.out.println("\nImage is located at:");
System.out.println(service.createImage(request).getData().get(0).getUrl());

System.out.println("\nCreating stream completion...");
List<ChatMessage> messages = new ArrayList<>();
messages.add(new ChatMessage("user", "Somebody once told me the world is gonna roll me"));
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model("gpt-3.5-turbo")
.messages(messages)
.maxTokens(1000)
.n(1)
.build();
streamService.streamRequest(completionRequest, "v1/chat/completions", new EventSourceListener() {
@Override
public void onClosed(@NotNull EventSource eventSource) {
super.onClosed(eventSource);
}

@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
super.onEvent(eventSource, id, type, data);
}

@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
super.onFailure(eventSource, t, response);
}

@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
super.onOpen(eventSource, response);
}
});

}
}
}
1 change: 1 addition & 0 deletions service/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies {
api 'com.squareup.retrofit2:retrofit:2.9.0'
implementation 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'
implementation 'com.squareup.okhttp3:okhttp-sse:4.9.3'

testImplementation(platform('org.junit:junit-bom:5.8.2'))
testImplementation('org.junit.jupiter:junit-jupiter')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public Model getModel(String modelId) {
public CompletionResult createCompletion(CompletionRequest request) {
return execute(api.createCompletion(request));
}

public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}
Expand Down Expand Up @@ -265,4 +265,4 @@ public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper)
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.theokanning.openai.service;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.theokanning.openai.IStream;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import okhttp3.*;
import okhttp3.internal.sse.RealEventSource;
import okhttp3.sse.EventSourceListener;

import java.time.Duration;
import java.util.concurrent.TimeUnit;

public class OpenApiStreamService {
private static final String BASE_URL = "https://api.openai.com/";
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
private static final ObjectMapper errorMapper = defaultObjectMapper();

private final OkHttpClient client;

/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
*/
public OpenApiStreamService(final String token) {
this(token, DEFAULT_TIMEOUT);
}

/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenApiStreamService(final String token, final Duration timeout) {
this(defaultClient(token, timeout));
}

public OpenApiStreamService(OkHttpClient client) {
this.client = client;
}


public static ObjectMapper defaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
return mapper;
}

public static OkHttpClient defaultClient(String token, Duration timeout) {
return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
.readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS)
.build();
}

public void streamRequest(IStream streamRequest, String path, EventSourceListener listener) {
try {
streamRequest.setStream(true);
String requestBody = errorMapper.writeValueAsString(streamRequest);
Request request = new Request.Builder()
.url(BASE_URL + path)
.header("Accept-Encoding", "")
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache")
.post(RequestBody.create(MediaType.get("application/json"), requestBody))
.build();
RealEventSource realEventSource = new RealEventSource(request, listener);
realEventSource.connect(client);
} catch (JsonProcessingException e) {
e.getMessage();
}
}


}