diff --git a/api/src/main/java/com/theokanning/openai/IStream.java b/api/src/main/java/com/theokanning/openai/IStream.java new file mode 100644 index 00000000..73d49af7 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/IStream.java @@ -0,0 +1,5 @@ +package com.theokanning.openai; + +public interface IStream { + public void setStream(boolean isSSE); +} diff --git a/api/src/main/java/com/theokanning/openai/completion/CompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/CompletionRequest.java index f7bca93c..99b9df0b 100644 --- a/api/src/main/java/com/theokanning/openai/completion/CompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/CompletionRequest.java @@ -1,5 +1,6 @@ package com.theokanning.openai.completion; +import com.theokanning.openai.IStream; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -20,7 +21,7 @@ @NoArgsConstructor @AllArgsConstructor @Data -public class CompletionRequest { +public class CompletionRequest implements IStream { /** * The name of the model to use. @@ -137,4 +138,9 @@ public class CompletionRequest { * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. */ String user; + + @Override + public void setStream(boolean isSSE) { + this.stream = isSSE; + } } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index 8ed3ee27..ea0405b0 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -1,6 +1,7 @@ package com.theokanning.openai.completion.chat; import com.fasterxml.jackson.annotation.JsonProperty; +import com.theokanning.openai.IStream; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -13,7 +14,7 @@ @Builder @AllArgsConstructor @NoArgsConstructor -public class ChatCompletionRequest { +public class ChatCompletionRequest implements IStream { /** * ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported. @@ -94,4 +95,9 @@ public class ChatCompletionRequest { * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. */ String user; + + @Override + public void setStream(boolean isSSE) { + this.stream = isSSE; + } } diff --git a/example/build.gradle b/example/build.gradle index d70ac091..24972075 100644 --- a/example/build.gradle +++ b/example/build.gradle @@ -7,4 +7,5 @@ application { dependencies { implementation project(":service") + implementation 'com.squareup.okhttp3:okhttp-sse:4.9.3' } \ No newline at end of file diff --git a/example/src/main/java/example/OpenAiApiExample.java b/example/src/main/java/example/OpenAiApiExample.java index f7cc0bff..614e5271 100644 --- a/example/src/main/java/example/OpenAiApiExample.java +++ b/example/src/main/java/example/OpenAiApiExample.java @@ -1,13 +1,25 @@ package example; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.service.OpenAiService; import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.image.CreateImageRequest; +import com.theokanning.openai.service.OpenApiStreamService; +import okhttp3.Response; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.List; class OpenAiApiExample { public static void main(String... args) { String token = System.getenv("OPENAI_TOKEN"); OpenAiService service = new OpenAiService(token); + OpenApiStreamService streamService = new OpenApiStreamService(token); System.out.println("\nCreating completion..."); CompletionRequest completionRequest = CompletionRequest.builder() @@ -26,5 +38,37 @@ public static void main(String... args) { System.out.println("\nImage is located at:"); System.out.println(service.createImage(request).getData().get(0).getUrl()); + + System.out.println("\nCreating stream completion..."); + List messages = new ArrayList<>(); + messages.add(new ChatMessage("user", "Somebody once told me the world is gonna roll me")); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model("gpt-3.5-turbo") + .messages(messages) + .maxTokens(1000) + .n(1) + .build(); + streamService.streamRequest(completionRequest, "v1/chat/completions", new EventSourceListener() { + @Override + public void onClosed(@NotNull EventSource eventSource) { + super.onClosed(eventSource); + } + + @Override + public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) { + super.onEvent(eventSource, id, type, data); + } + + @Override + public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) { + super.onFailure(eventSource, t, response); + } + + @Override + public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) { + super.onOpen(eventSource, response); + } + }); + } -} +} \ No newline at end of file diff --git a/service/build.gradle b/service/build.gradle index 7766fa18..35a75a41 100644 --- a/service/build.gradle +++ b/service/build.gradle @@ -6,6 +6,7 @@ dependencies { api 'com.squareup.retrofit2:retrofit:2.9.0' implementation 'com.squareup.retrofit2:adapter-rxjava2:2.9.0' implementation 'com.squareup.retrofit2:converter-jackson:2.9.0' + implementation 'com.squareup.okhttp3:okhttp-sse:4.9.3' testImplementation(platform('org.junit:junit-bom:5.8.2')) testImplementation('org.junit.jupiter:junit-jupiter') diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index f6eb61d2..8f7068a1 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -88,7 +88,7 @@ public Model getModel(String modelId) { public CompletionResult createCompletion(CompletionRequest request) { return execute(api.createCompletion(request)); } - + public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { return execute(api.createChatCompletion(request)); } @@ -265,4 +265,4 @@ public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) .build(); } -} +} \ No newline at end of file diff --git a/service/src/main/java/com/theokanning/openai/service/OpenApiStreamService.java b/service/src/main/java/com/theokanning/openai/service/OpenApiStreamService.java new file mode 100644 index 00000000..1712fff1 --- /dev/null +++ b/service/src/main/java/com/theokanning/openai/service/OpenApiStreamService.java @@ -0,0 +1,84 @@ +package com.theokanning.openai.service; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.theokanning.openai.IStream; +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import okhttp3.*; +import okhttp3.internal.sse.RealEventSource; +import okhttp3.sse.EventSourceListener; + +import java.time.Duration; +import java.util.concurrent.TimeUnit; + +public class OpenApiStreamService { + private static final String BASE_URL = "https://api.openai.com/"; + private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10); + private static final ObjectMapper errorMapper = defaultObjectMapper(); + + private final OkHttpClient client; + + /** + * Creates a new OpenAiService that wraps OpenAiApi + * + * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + */ + public OpenApiStreamService(final String token) { + this(token, DEFAULT_TIMEOUT); + } + + /** + * Creates a new OpenAiService that wraps OpenAiApi + * + * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + * @param timeout http read timeout, Duration.ZERO means no timeout + */ + public OpenApiStreamService(final String token, final Duration timeout) { + this(defaultClient(token, timeout)); + } + + public OpenApiStreamService(OkHttpClient client) { + this.client = client; + } + + + public static ObjectMapper defaultObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); + return mapper; + } + + public static OkHttpClient defaultClient(String token, Duration timeout) { + return new OkHttpClient.Builder() + .addInterceptor(new AuthenticationInterceptor(token)) + .connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS)) + .readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS) + .build(); + } + + public void streamRequest(IStream streamRequest, String path, EventSourceListener listener) { + try { + streamRequest.setStream(true); + String requestBody = errorMapper.writeValueAsString(streamRequest); + Request request = new Request.Builder() + .url(BASE_URL + path) + .header("Accept-Encoding", "") + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .post(RequestBody.create(MediaType.get("application/json"), requestBody)) + .build(); + RealEventSource realEventSource = new RealEventSource(request, listener); + realEventSource.connect(client); + } catch (JsonProcessingException e) { + e.getMessage(); + } + } + + +}