diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 6b7b74c5..da59f99e 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.3.0" + ".": "0.4.0" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index bd0b5c69..9275574e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## 0.4.0 (2024-11-21) + +Full Changelog: [v0.3.0...v0.4.0](https://github.com/openai/openai-java/compare/v0.3.0...v0.4.0) + +### Features + +* **azure:** Add HttpRequest.Builder extension methods ([#9](https://github.com/openai/openai-java/issues/9)) ([097c7c9](https://github.com/openai/openai-java/commit/097c7c91d23ff3bafdb4c01baea0df9beeadcd74)) + + +### Bug Fixes + +* **azure:** add missing azure changes ([656d3b5](https://github.com/openai/openai-java/commit/656d3b5a6d1c2d68733d5139d3a2982b04009f2a)) + ## 0.3.0 (2024-11-20) Full Changelog: [v0.2.0...v0.3.0](https://github.com/openai/openai-java/compare/v0.2.0...v0.3.0) diff --git a/README.md b/README.md index 9511b424..6e342167 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ -[![Maven Central](https://img.shields.io/maven-central/v/com.openai/openai-java)](https://central.sonatype.com/artifact/com.openai/openai-java/0.3.0) +[![Maven Central](https://img.shields.io/maven-central/v/com.openai/openai-java)](https://central.sonatype.com/artifact/com.openai/openai-java/0.4.0) @@ -25,7 +25,7 @@ The REST API documentation can be foundĀ on [platform.openai.com](https://platfo ```kotlin -implementation("com.openai:openai-java:0.3.0") +implementation("com.openai:openai-java:0.4.0") ``` #### Maven @@ -34,7 +34,7 @@ implementation("com.openai:openai-java:0.3.0") com.openai openai-java - 0.3.0 + 0.4.0 ``` diff --git a/build.gradle.kts b/build.gradle.kts index 282a0796..855ac122 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -4,7 +4,7 @@ plugins { allprojects { group = "com.openai" - version = "0.3.0" // x-release-please-version + version = "0.4.0" // x-release-please-version } nexusPublishing { diff --git a/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClient.kt b/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClient.kt index c431be86..ecb4e82f 100644 --- a/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClient.kt +++ b/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClient.kt @@ -3,11 +3,13 @@ package com.openai.client.okhttp import com.fasterxml.jackson.databind.json.JsonMapper +import com.openai.azure.AzureOpenAIServiceVersion import com.openai.client.OpenAIClient import com.openai.client.OpenAIClientImpl import com.openai.core.ClientOptions import com.openai.core.http.Headers import com.openai.core.http.QueryParams +import com.openai.credential.Credential import java.net.Proxy import java.time.Clock import java.time.Duration @@ -130,6 +132,12 @@ class OpenAIOkHttpClient private constructor() { fun apiKey(apiKey: String) = apply { clientOptions.apiKey(apiKey) } + fun credential(credential: Credential) = apply { clientOptions.credential(credential) } + + fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply { + clientOptions.azureServiceVersion(azureServiceVersion) + } + fun organization(organization: String?) = apply { clientOptions.organization(organization) } fun project(project: String?) = apply { clientOptions.project(project) } diff --git a/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClientAsync.kt b/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClientAsync.kt index 9e8e37c3..cf4711aa 100644 --- a/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClientAsync.kt +++ b/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClientAsync.kt @@ -3,11 +3,13 @@ package com.openai.client.okhttp import com.fasterxml.jackson.databind.json.JsonMapper +import com.openai.azure.AzureOpenAIServiceVersion import com.openai.client.OpenAIClientAsync import com.openai.client.OpenAIClientAsyncImpl import com.openai.core.ClientOptions import com.openai.core.http.Headers import com.openai.core.http.QueryParams +import com.openai.credential.Credential import java.net.Proxy import java.time.Clock import java.time.Duration @@ -130,6 +132,12 @@ class OpenAIOkHttpClientAsync private constructor() { fun apiKey(apiKey: String) = apply { clientOptions.apiKey(apiKey) } + fun credential(credential: Credential) = apply { clientOptions.credential(credential) } + + fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply { + clientOptions.azureServiceVersion(azureServiceVersion) + } + fun organization(organization: String?) = apply { clientOptions.organization(organization) } fun project(project: String?) = apply { clientOptions.project(project) } diff --git a/openai-java-core/src/main/kotlin/com/openai/azure/AzureOpenAIServiceVersion.kt b/openai-java-core/src/main/kotlin/com/openai/azure/AzureOpenAIServiceVersion.kt new file mode 100644 index 00000000..1188d2fb --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/azure/AzureOpenAIServiceVersion.kt @@ -0,0 +1,36 @@ +package com.openai.azure + +import java.util.concurrent.ConcurrentHashMap + +class AzureOpenAIServiceVersion private constructor(@get:JvmName("value") val value: String) { + + companion object { + private val values: ConcurrentHashMap = + ConcurrentHashMap() + + @JvmStatic + fun fromString(version: String): AzureOpenAIServiceVersion = + values.computeIfAbsent(version) { AzureOpenAIServiceVersion(version) } + + @JvmStatic val V2022_12_01 = fromString("2022-12-01") + @JvmStatic val V2023_05_15 = fromString("2023-05-15") + @JvmStatic val V2024_02_01 = fromString("2024-02-01") + @JvmStatic val V2024_06_01 = fromString("2024-06-01") + @JvmStatic val V2023_06_01_PREVIEW = fromString("2023-06-01-preview") + @JvmStatic val V2023_07_01_PREVIEW = fromString("2023-07-01-preview") + @JvmStatic val V2024_02_15_PREVIEW = fromString("2024-02-15-preview") + @JvmStatic val V2024_03_01_PREVIEW = fromString("2024-03-01-preview") + @JvmStatic val V2024_04_01_PREVIEW = fromString("2024-04-01-preview") + @JvmStatic val V2024_05_01_PREVIEW = fromString("2024-05-01-preview") + @JvmStatic val V2024_07_01_PREVIEW = fromString("2024-07-01-preview") + @JvmStatic val V2024_08_01_PREVIEW = fromString("2024-08-01-preview") + @JvmStatic val V2024_09_01_PREVIEW = fromString("2024-09-01-preview") + } + + override fun equals(other: Any?): Boolean = + this === other || (other is AzureOpenAIServiceVersion && value == other.value) + + override fun hashCode(): Int = value.hashCode() + + override fun toString(): String = "AzureOpenAIServiceVersion{value=$value}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/azure/HttpRequestBuilderExtensions.kt b/openai-java-core/src/main/kotlin/com/openai/azure/HttpRequestBuilderExtensions.kt new file mode 100644 index 00000000..c6b7ff44 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/azure/HttpRequestBuilderExtensions.kt @@ -0,0 +1,27 @@ +package com.openai.azure + +import com.openai.core.ClientOptions +import com.openai.core.http.HttpRequest +import com.openai.core.isAzureEndpoint +import com.openai.credential.BearerTokenCredential + +@JvmSynthetic +internal fun HttpRequest.Builder.addPathSegmentsForAzure( + clientOptions: ClientOptions, + deploymentModel: String +): HttpRequest.Builder = apply { + if (isAzureEndpoint(clientOptions.baseUrl)) { + addPathSegments("openai", "deployments", deploymentModel) + } +} + +@JvmSynthetic +internal fun HttpRequest.Builder.replaceBearerTokenForAzure( + clientOptions: ClientOptions +): HttpRequest.Builder = apply { + if ( + isAzureEndpoint(clientOptions.baseUrl) && clientOptions.credential is BearerTokenCredential + ) { + replaceHeaders("Authorization", "Bearer ${clientOptions.credential.token()}") + } +} diff --git a/openai-java-core/src/main/kotlin/com/openai/azure/credential/AzureApiKeyCredential.kt b/openai-java-core/src/main/kotlin/com/openai/azure/credential/AzureApiKeyCredential.kt new file mode 100644 index 00000000..3614e335 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/azure/credential/AzureApiKeyCredential.kt @@ -0,0 +1,26 @@ +package com.openai.azure.credential + +import com.openai.credential.Credential + +/** A credential that provides an Azure API key. */ +class AzureApiKeyCredential private constructor(private var apiKey: String) : Credential { + + init { + validateApiKey(apiKey) + } + + companion object { + @JvmStatic fun create(apiKey: String): Credential = AzureApiKeyCredential(apiKey) + + private fun validateApiKey(apiKey: String) { + require(apiKey.isNotEmpty()) { "Azure API key cannot be empty." } + } + } + + fun apiKey(): String = apiKey + + fun update(apiKey: String) = apply { + validateApiKey(apiKey) + this.apiKey = apiKey + } +} diff --git a/openai-java-core/src/main/kotlin/com/openai/core/ClientOptions.kt b/openai-java-core/src/main/kotlin/com/openai/core/ClientOptions.kt index 20ab4e7f..6eee95f6 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/ClientOptions.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/ClientOptions.kt @@ -3,11 +3,16 @@ package com.openai.core import com.fasterxml.jackson.databind.json.JsonMapper +import com.openai.azure.AzureOpenAIServiceVersion +import com.openai.azure.AzureOpenAIServiceVersion.Companion.V2024_06_01 +import com.openai.azure.credential.AzureApiKeyCredential import com.openai.core.http.Headers import com.openai.core.http.HttpClient import com.openai.core.http.PhantomReachableClosingHttpClient import com.openai.core.http.QueryParams import com.openai.core.http.RetryingHttpClient +import com.openai.credential.BearerTokenCredential +import com.openai.credential.Credential import java.time.Clock import java.util.concurrent.Executor import java.util.concurrent.Executors @@ -26,7 +31,7 @@ private constructor( @get:JvmName("queryParams") val queryParams: QueryParams, @get:JvmName("responseValidation") val responseValidation: Boolean, @get:JvmName("maxRetries") val maxRetries: Int, - @get:JvmName("apiKey") val apiKey: String, + @get:JvmName("credential") val credential: Credential, @get:JvmName("organization") val organization: String?, @get:JvmName("project") val project: String?, ) { @@ -53,7 +58,8 @@ private constructor( private var queryParams: QueryParams.Builder = QueryParams.builder() private var responseValidation: Boolean = false private var maxRetries: Int = 2 - private var apiKey: String? = null + private var credential: Credential? = null + private var azureServiceVersion: AzureOpenAIServiceVersion? = null private var organization: String? = null private var project: String? = null @@ -68,7 +74,7 @@ private constructor( queryParams = clientOptions.queryParams.toBuilder() responseValidation = clientOptions.responseValidation maxRetries = clientOptions.maxRetries - apiKey = clientOptions.apiKey + credential = clientOptions.credential organization = clientOptions.organization project = clientOptions.project } @@ -171,21 +177,56 @@ private constructor( fun maxRetries(maxRetries: Int) = apply { this.maxRetries = maxRetries } - fun apiKey(apiKey: String) = apply { this.apiKey = apiKey } + fun apiKey(apiKey: String) = apply { + this.credential = BearerTokenCredential.create(apiKey) + } + + fun credential(credential: Credential) = apply { this.credential = credential } + + fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply { + this.azureServiceVersion = azureServiceVersion + } fun organization(organization: String?) = apply { this.organization = organization } fun project(project: String?) = apply { this.project = project } fun fromEnv() = apply { - System.getenv("OPENAI_API_KEY")?.let { apiKey(it) } - System.getenv("OPENAI_ORG_ID")?.let { organization(it) } - System.getenv("OPENAI_PROJECT_ID")?.let { project(it) } + val openAIKey = System.getenv("OPENAI_API_KEY") + val openAIOrgId = System.getenv("OPENAI_ORG_ID") + val openAIProjectId = System.getenv("OPENAI_PROJECT_ID") + val azureOpenAIKey = System.getenv("AZURE_OPENAI_KEY") + val azureEndpoint = System.getenv("AZURE_OPENAI_ENDPOINT") + + when { + !openAIKey.isNullOrEmpty() && !azureOpenAIKey.isNullOrEmpty() -> { + throw IllegalArgumentException( + "Both OpenAI and Azure OpenAI API keys, `OPENAI_API_KEY` and `AZURE_OPENAI_KEY`, are set. Please specify only one" + ) + } + !openAIKey.isNullOrEmpty() -> { + credential(BearerTokenCredential.create(openAIKey)) + organization(openAIOrgId) + project(openAIProjectId) + } + !azureOpenAIKey.isNullOrEmpty() -> { + credential(AzureApiKeyCredential.create(azureOpenAIKey)) + baseUrl(azureEndpoint) + } + !azureEndpoint.isNullOrEmpty() -> { + // Both 'openAIKey' and 'azureOpenAIKey' are not set. + // Only 'azureEndpoint' is set here, and user still needs to call method + // '.credential(BearerTokenCredential(Supplier))' + // to get the token through the supplier, which requires Azure Entra ID as a + // dependency. + baseUrl(azureEndpoint) + } + } } fun build(): ClientOptions { checkNotNull(httpClient) { "`httpClient` is required but was not set" } - checkNotNull(apiKey) { "`apiKey` is required but was not set" } + checkNotNull(credential) { "`credential` is required but was not set" } val headers = Headers.builder() val queryParams = QueryParams.builder() @@ -198,11 +239,26 @@ private constructor( headers.put("X-Stainless-Runtime-Version", getJavaVersion()) organization?.let { headers.put("OpenAI-Organization", it) } project?.let { headers.put("OpenAI-Project", it) } - apiKey?.let { - if (!it.isEmpty()) { - headers.put("Authorization", "Bearer $it") + + when (val currentCredential = credential) { + is AzureApiKeyCredential -> { + headers.put("api-key", currentCredential.apiKey()) + } + is BearerTokenCredential -> { + headers.put("Authorization", "Bearer ${currentCredential.token()}") + } + else -> { + throw IllegalArgumentException("Invalid credential type") } } + + if (isAzureEndpoint(baseUrl)) { + // Default Azure OpenAI version is used if Azure user doesn't + // specific a service API version in 'queryParams'. + // We can update the default value every major announcement if needed. + replaceQueryParams("api-version", (azureServiceVersion ?: V2024_06_01).value) + } + headers.replaceAll(this.headers.build()) queryParams.replaceAll(this.queryParams.build()) @@ -237,7 +293,7 @@ private constructor( queryParams.build(), responseValidation, maxRetries, - apiKey!!, + credential!!, organization, project, ) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/Utils.kt b/openai-java-core/src/main/kotlin/com/openai/core/Utils.kt index 49af26ee..6601f8ef 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/Utils.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/Utils.kt @@ -23,4 +23,11 @@ internal fun , V> SortedMap.toImmutable(): SortedMap.openai.azure.com`. + // Or `https://.azure-api.net` for Azure OpenAI Management URL. + return baseUrl.endsWith(".openai.azure.com", true) || baseUrl.endsWith(".azure-api.net", true) +} + internal interface Enum diff --git a/openai-java-core/src/main/kotlin/com/openai/credential/BearerTokenCredential.kt b/openai-java-core/src/main/kotlin/com/openai/credential/BearerTokenCredential.kt new file mode 100644 index 00000000..6da402a9 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/credential/BearerTokenCredential.kt @@ -0,0 +1,36 @@ +package com.openai.credential + +import java.util.function.Supplier + +/** + *

A credential that provides a bearer token.

+ * + *

+ * If you are using the OpenAI API, you need to provide a bearer token for authentication. All API + * requests should include your API key in an Authorization HTTP header as follows: "Authorization: + * Bearer OPENAI_API_KEY"

+ * + *

Two ways to provide the token:

+ *
    + * 1. Provide the token directly, 'BearerTokenCredential.create(String)'. The method + * 'ClientOptions.apiKey(String)' is a wrapper for this. 2. Provide a supplier that + * provides the token, 'BearerTokenCredential.create(Supplier)'. + *
+ * + * @param tokenSupplier a supplier that provides the token. + * @see OpenAI + * Authentication + */ +class BearerTokenCredential private constructor(private val tokenSupplier: Supplier) : + Credential { + + companion object { + @JvmStatic fun create(token: String): Credential = BearerTokenCredential { token } + + @JvmStatic + fun create(tokenSupplier: Supplier): Credential = + BearerTokenCredential(tokenSupplier) + } + + fun token(): String = tokenSupplier.get() +} diff --git a/openai-java-core/src/main/kotlin/com/openai/credential/Credential.kt b/openai-java-core/src/main/kotlin/com/openai/credential/Credential.kt new file mode 100644 index 00000000..f43ab84c --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/credential/Credential.kt @@ -0,0 +1,4 @@ +package com.openai.credential + +/** An interface that represents a credential. */ +interface Credential diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt index d9a13795..53c0e8f5 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.async +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.JsonValue import com.openai.core.RequestOptions @@ -41,10 +43,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().toString()) .addPathSegments("completions") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/EmbeddingServiceAsyncImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/EmbeddingServiceAsyncImpl.kt index 604d07d0..b47a66b9 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/async/EmbeddingServiceAsyncImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/async/EmbeddingServiceAsyncImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.async +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.RequestOptions import com.openai.core.handlers.errorHandler @@ -35,10 +37,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().toString()) .addPathSegments("embeddings") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/ImageServiceAsyncImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/ImageServiceAsyncImpl.kt index 70f5d82e..968cae2f 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/async/ImageServiceAsyncImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/async/ImageServiceAsyncImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.async +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.RequestOptions import com.openai.core.handlers.errorHandler @@ -34,10 +36,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().get().toString()) .addPathSegments("images", "generations") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt index f52b5332..e77f6432 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.async.chat +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.JsonValue import com.openai.core.RequestOptions @@ -47,10 +49,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().toString()) .addPathSegments("chat", "completions") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/CompletionServiceImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/CompletionServiceImpl.kt index 6dea9334..a839a439 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/CompletionServiceImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/CompletionServiceImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.blocking +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.JsonValue import com.openai.core.RequestOptions @@ -38,10 +40,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().toString()) .addPathSegments("completions") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/EmbeddingServiceImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/EmbeddingServiceImpl.kt index a06f7ddd..8ae7f5fa 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/EmbeddingServiceImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/EmbeddingServiceImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.blocking +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.RequestOptions import com.openai.core.handlers.errorHandler @@ -34,10 +36,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().toString()) .addPathSegments("embeddings") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ImageServiceImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ImageServiceImpl.kt index 31c03cfe..8e4dcd1e 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ImageServiceImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ImageServiceImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.blocking +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.RequestOptions import com.openai.core.handlers.errorHandler @@ -33,10 +35,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().get().toString()) .addPathSegments("images", "generations") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt index 4052936f..76a46916 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt @@ -2,6 +2,8 @@ package com.openai.services.blocking.chat +import com.openai.azure.addPathSegmentsForAzure +import com.openai.azure.replaceBearerTokenForAzure import com.openai.core.ClientOptions import com.openai.core.JsonValue import com.openai.core.RequestOptions @@ -44,10 +46,12 @@ constructor( val request = HttpRequest.builder() .method(HttpMethod.POST) + .addPathSegmentsForAzure(clientOptions, params.model().toString()) .addPathSegments("chat", "completions") .putAllQueryParams(clientOptions.queryParams) .replaceAllQueryParams(params.getQueryParams()) .putAllHeaders(clientOptions.headers) + .replaceBearerTokenForAzure(clientOptions) .replaceAllHeaders(params.getHeaders()) .body(json(clientOptions.jsonMapper, params.getBody())) .build() diff --git a/openai-java-core/src/test/kotlin/com/openai/core/http/ClientOptionsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/http/ClientOptionsTest.kt new file mode 100644 index 00000000..d53f5d8f --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/http/ClientOptionsTest.kt @@ -0,0 +1,70 @@ +package com.openai.core.http + +import com.openai.azure.credential.AzureApiKeyCredential +import com.openai.client.okhttp.OkHttpClient +import com.openai.core.ClientOptions +import com.openai.credential.BearerTokenCredential +import java.util.stream.Stream +import kotlin.test.Test +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +internal class ClientOptionsTest { + + companion object { + private const val FAKE_API_KEY = "test-api-key" + + @JvmStatic + private fun createOkHttpClient(baseUrl: String): OkHttpClient { + return OkHttpClient.builder().baseUrl(baseUrl).build() + } + + @JvmStatic + private fun provideBaseUrls(): Stream { + return Stream.of( + "https://api.openai.com/v1", + "https://example.openai.azure.com", + "https://example.azure-api.net" + ) + } + } + + @ParameterizedTest + @MethodSource("provideBaseUrls") + fun clientOptionsWithoutBaseUrl(baseUrl: String) { + // Arrange + val apiKey = FAKE_API_KEY + + // Act + val clientOptions = + ClientOptions.builder() + .httpClient(createOkHttpClient(baseUrl)) + .credential(BearerTokenCredential.create(apiKey)) + .build() + + // Assert + assertThat(clientOptions.baseUrl).isEqualTo(ClientOptions.PRODUCTION_URL) + } + + @ParameterizedTest + @MethodSource("provideBaseUrls") + fun throwExceptionWhenNullCredential(baseUrl: String) { + // Act + val clientOptionsBuilder = + ClientOptions.builder().httpClient(createOkHttpClient(baseUrl)).baseUrl(baseUrl) + + // Assert + assertThatThrownBy { clientOptionsBuilder.build() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("`credential` is required but was not set") + } + + @Test + fun throwExceptionWhenEmptyCredential() { + assertThatThrownBy { AzureApiKeyCredential.create("") } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Azure API key cannot be empty.") + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt b/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt index 5ed44e5f..d779133b 100644 --- a/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt @@ -4,7 +4,6 @@ package com.openai.services.blocking.fineTuning.jobs import com.openai.TestServerExtension import com.openai.client.okhttp.OpenAIOkHttpClient -import com.openai.models.* import com.openai.models.FineTuningJobCheckpointListParams import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith diff --git a/settings.gradle.kts b/settings.gradle.kts index 3c5725fb..58e9de02 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -4,3 +4,4 @@ include("openai-java") include("openai-java-client-okhttp") include("openai-java-core") include("openai-java-example") +include("openai-azure-java-example")