diff --git a/trino-s3-proxy/pom.xml b/trino-s3-proxy/pom.xml
index fc6f0e0a..2fe5fa30 100644
--- a/trino-s3-proxy/pom.xml
+++ b/trino-s3-proxy/pom.xml
@@ -94,6 +94,11 @@
jakarta.annotation-api
+
+ jakarta.validation
+ jakarta.validation-api
+
+
jakarta.ws.rs
jakarta.ws.rs-api
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java
index 581262e7..1342ad8d 100644
--- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java
@@ -25,6 +25,7 @@
import io.trino.s3.proxy.server.remote.VirtualHostStyleRemoteS3Facade;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient;
+import io.trino.s3.proxy.server.rest.TrinoS3ProxyConfig;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyResource;
import io.trino.s3.proxy.server.rest.TrinoStsResource;
@@ -45,7 +46,7 @@ public final void configure(Binder binder)
jaxrsBinder(binder).bind(TrinoStsResource.class);
configBinder(binder).bindConfig(SigningControllerConfig.class);
-
+ configBinder(binder).bindConfig(TrinoS3ProxyConfig.class);
binder.bind(SigningController.class).in(Scopes.SINGLETON);
binder.bind(CredentialsController.class).in(Scopes.SINGLETON);
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java
new file mode 100644
index 00000000..42d91d01
--- /dev/null
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.s3.proxy.server.collections;
+
+import jakarta.ws.rs.core.MultivaluedHashMap;
+import jakarta.ws.rs.core.MultivaluedMap;
+
+import java.util.List;
+import java.util.Locale;
+import java.util.function.BiFunction;
+
+public class MultiMapHelper
+{
+ public static MultivaluedMap lowercase(MultivaluedMap map)
+ {
+ return lowercase(map, (ignored, values) -> values);
+ }
+
+ public static MultivaluedMap lowercase(MultivaluedMap map, BiFunction, List> valueMapper)
+ {
+ MultivaluedMap result = new MultivaluedHashMap<>();
+ map.forEach((name, values) -> result.put(name.toLowerCase(Locale.ROOT), valueMapper.apply(name.toLowerCase(Locale.ROOT), values)));
+ return result;
+ }
+
+ private MultiMapHelper() {}
+}
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java
index ed0f6ab3..0c9dd596 100644
--- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java
@@ -16,7 +16,6 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import jakarta.ws.rs.WebApplicationException;
-import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
@@ -33,11 +32,14 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
+import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
+import java.util.function.BiFunction;
import static com.google.common.collect.ImmutableList.toImmutableList;
+import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase;
final class Signer
{
@@ -73,8 +75,15 @@ static String sign(
Duration maxClockDrift,
Optional entity)
{
- requestHeaders = lowercase(requestHeaders);
- queryParameters = lowercase(queryParameters);
+ BiFunction, List> lowercaseHeaderValues = (key, values) -> {
+ if (!LOWERCASE_HEADERS.contains(key)) {
+ return values;
+ }
+ return values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList());
+ };
+
+ requestHeaders = lowercase(requestHeaders, lowercaseHeaderValues);
+ queryParameters = lowercase(queryParameters, lowercaseHeaderValues);
SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
.uri(requestURI)
@@ -123,17 +132,4 @@ static String sign(
return new WebApplicationException(Response.Status.BAD_REQUEST);
});
}
-
- private static MultivaluedMap lowercase(MultivaluedMap map)
- {
- MultivaluedMap result = new MultivaluedHashMap<>();
- map.forEach((name, values) -> {
- name = name.toLowerCase(Locale.ROOT);
- if (LOWERCASE_HEADERS.contains(name)) {
- values = values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList());
- }
- result.put(name, values);
- });
- return result;
- }
}
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java
new file mode 100644
index 00000000..4da2a792
--- /dev/null
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.s3.proxy.server.rest;
+
+import com.google.common.base.Splitter;
+import jakarta.ws.rs.core.MultivaluedMap;
+import jakarta.ws.rs.core.UriBuilder;
+import org.glassfish.jersey.server.ContainerRequest;
+
+import java.util.List;
+import java.util.Locale;
+import java.util.Optional;
+
+import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase;
+import static java.util.Objects.requireNonNull;
+
+public record ParsedS3Request(
+ String bucketName,
+ String keyInBucket,
+ MultivaluedMap requestHeaders,
+ MultivaluedMap requestQueryParameters,
+ String httpVerb)
+{
+ public ParsedS3Request
+ {
+ requireNonNull(bucketName, "bucketName is null");
+ requireNonNull(keyInBucket, "keyInBucket is null");
+ requestHeaders = lowercase(requireNonNull(requestHeaders, "requestHeaders is null"));
+ requireNonNull(requestQueryParameters, "requestQueryParameters is null");
+ requireNonNull(httpVerb, "httpVerb is null");
+ }
+
+ public static ParsedS3Request fromRequest(String requestPath, MultivaluedMap requestHeaders, MultivaluedMap requestQueryParameters, String httpVerb, Optional serverHostName)
+ {
+ MultivaluedMap headers = lowercase(requestHeaders);
+ return serverHostName
+ .flatMap(serverHostNameValue -> {
+ String lowercaseServerHostName = serverHostNameValue.toLowerCase(Locale.ROOT);
+ return Optional.ofNullable(headers.getFirst("host"))
+ .map(value -> UriBuilder.fromUri("http://" + value.toLowerCase(Locale.ROOT)).build().getHost())
+ .filter(value -> value.endsWith(lowercaseServerHostName))
+ .map(value -> value.substring(0, value.length() - lowercaseServerHostName.length()))
+ .map(value -> value.endsWith(".") ? value.substring(0, value.length() - 1) : value);
+ })
+ .map(bucket -> new ParsedS3Request(bucket, requestPath, headers, requestQueryParameters, httpVerb))
+ .orElseGet(() -> {
+ List parts = Splitter.on("/").limit(2).splitToList(requestPath);
+ if (parts.size() <= 1) {
+ return new ParsedS3Request(requestPath, "", headers, requestQueryParameters, httpVerb);
+ }
+ return new ParsedS3Request(parts.get(0), parts.get(1), headers, requestQueryParameters, httpVerb);
+ });
+ }
+
+ public static ParsedS3Request fromRequest(String requestPath, ContainerRequest request, Optional serverHostName)
+ {
+ return fromRequest(requestPath, request.getHeaders(), request.getUriInfo().getQueryParameters(), request.getMethod(), serverHostName);
+ }
+}
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java
index 90ee32ea..d5ec6665 100644
--- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java
@@ -27,7 +27,7 @@
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
-import org.glassfish.jersey.server.ContainerRequest;
+import jakarta.ws.rs.core.UriBuilder;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
@@ -38,7 +38,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
-import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination;
import static io.trino.s3.proxy.server.credentials.SigningController.formatRequestInstant;
import static java.lang.annotation.ElementType.FIELD;
@@ -75,14 +74,12 @@ public void shutDown()
}
}
- public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest request, AsyncResponse asyncResponse, String bucket)
+ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse)
{
- String remotePath = rewriteRequestPath(request, bucket);
-
- URI remoteUri = remoteS3Facade.buildEndpoint(request.getUriInfo().getRequestUriBuilder(), remotePath, bucket, signingMetadata.region());
+ URI remoteUri = remoteS3Facade.buildEndpoint(UriBuilder.newInstance(), request.keyInBucket(), request.bucketName(), signingMetadata.region());
Request.Builder remoteRequestBuilder = new Request.Builder()
- .setMethod(request.getMethod())
+ .setMethod(request.httpVerb())
.setUri(remoteUri)
.setFollowRedirects(true);
@@ -91,9 +88,10 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque
}
MultivaluedMap remoteRequestHeaders = new MultivaluedHashMap<>();
- request.getRequestHeaders().forEach((key, value) -> {
- switch (key.toLowerCase()) {
+ request.requestHeaders().forEach((key, value) -> {
+ switch (key) {
case "x-amz-security-token" -> {} // we add this below
+ case "authorization" -> {} // we will create our own authorization header
case "amz-sdk-invocation-id", "amz-sdk-request" -> {} // don't send these
case "x-amz-date" -> remoteRequestHeaders.putSingle("X-Amz-Date", formatRequestInstant(Instant.now())); // use now for the remote request
case "host" -> remoteRequestHeaders.putSingle("Host", buildRemoteHost(remoteUri)); // replace source host with the remote AWS host
@@ -107,14 +105,13 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque
.ifPresent(sessionToken -> remoteRequestHeaders.add("x-amz-security-token", sessionToken));
// set the new signed request auth header
- String encodedPath = firstNonNull(remoteUri.getRawPath(), "");
String signature = signingController.signRequest(
signingMetadata,
Credentials::requiredRemoteCredential,
remoteUri,
remoteRequestHeaders,
- request.getUriInfo().getQueryParameters(),
- request.getMethod(),
+ request.requestQueryParameters(),
+ request.httpVerb(),
Optional.empty());
remoteRequestHeaders.putSingle("Authorization", signature);
@@ -134,23 +131,4 @@ private static String buildRemoteHost(URI remoteUri)
}
return remoteUri.getHost() + ":" + port;
}
-
- private static String rewriteRequestPath(ContainerRequest request, String bucket)
- {
- String path = "/" + request.getPath(false);
- if (!path.startsWith(TrinoS3ProxyRestConstants.S3_PATH)) {
- throw new WebApplicationException(Response.Status.BAD_REQUEST);
- }
-
- path = path.substring(TrinoS3ProxyRestConstants.S3_PATH.length());
- if (path.startsWith("/" + bucket)) {
- path = path.substring(("/" + bucket).length());
- }
-
- if (path.isEmpty() && bucket.isEmpty()) {
- path = "/";
- }
-
- return path;
- }
}
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java
new file mode 100644
index 00000000..639670fc
--- /dev/null
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.s3.proxy.server.rest;
+
+import io.airlift.configuration.Config;
+import io.airlift.configuration.ConfigDescription;
+import jakarta.validation.constraints.NotNull;
+
+import java.util.Optional;
+
+public class TrinoS3ProxyConfig
+{
+ private Optional hostName = Optional.empty();
+
+ @Config("s3proxy.hostname")
+ @ConfigDescription("Hostname to use for REST operations, virtual-host style addressing is only supported if this is set")
+ public TrinoS3ProxyConfig setHostName(String hostName)
+ {
+ this.hostName = Optional.ofNullable(hostName);
+ return this;
+ }
+
+ @NotNull
+ public Optional getHostName()
+ {
+ return hostName;
+ }
+}
diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java
index 9e68b77d..785335f3 100644
--- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java
+++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java
@@ -13,7 +13,6 @@
*/
package io.trino.s3.proxy.server.rest;
-import com.google.common.base.Splitter;
import com.google.inject.Inject;
import io.trino.s3.proxy.server.credentials.SigningController;
import io.trino.s3.proxy.server.credentials.SigningServiceType;
@@ -26,7 +25,6 @@
import jakarta.ws.rs.core.Context;
import org.glassfish.jersey.server.ContainerRequest;
-import java.util.List;
import java.util.Optional;
import static java.util.Objects.requireNonNull;
@@ -36,12 +34,14 @@ public class TrinoS3ProxyResource
{
private final SigningController signingController;
private final TrinoS3ProxyClient proxyClient;
+ private final Optional serverHostName;
@Inject
- public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient)
+ public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient, TrinoS3ProxyConfig trinoS3ProxyConfig)
{
this.signingController = requireNonNull(signingController, "signingController is null");
this.proxyClient = requireNonNull(proxyClient, "proxyClient is null");
+ this.serverHostName = requireNonNull(trinoS3ProxyConfig, "restConfig is null").getHostName();
}
@GET
@@ -54,7 +54,7 @@ public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse as
@Path("{path:.*}")
public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path)
{
- proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path));
+ proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse);
}
@HEAD
@@ -67,12 +67,11 @@ public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse a
@Path("{path:.*}")
public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path)
{
- proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path));
+ proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse);
}
- private String getBucket(String path)
+ private ParsedS3Request parseRequest(String path, ContainerRequest request)
{
- List parts = Splitter.on("/").splitToList(path);
- return parts.isEmpty() ? "" : parts.getFirst();
+ return ParsedS3Request.fromRequest(path, request, serverHostName);
}
}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java
index dd9f3918..506f3a4f 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java
@@ -37,7 +37,7 @@
import static org.assertj.core.api.Assertions.assertThat;
-@TrinoS3ProxyTest(modules = WithConfiguredBuckets.class)
+@TrinoS3ProxyTest(filters = WithConfiguredBuckets.class)
public class TestProxiedAssumedRoleRequests
extends AbstractTestProxiedRequests
{
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java
index e3a0434c..04ef35aa 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java
@@ -21,7 +21,7 @@
import java.util.List;
-@TrinoS3ProxyTest(modules = WithConfiguredBuckets.class)
+@TrinoS3ProxyTest(filters = WithConfiguredBuckets.class)
public class TestProxiedRequests
extends AbstractTestProxiedRequests
{
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java
index 7d952281..571c5f20 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java
@@ -22,7 +22,7 @@
import java.util.List;
-@TrinoS3ProxyTest(modules = {WithConfiguredBuckets.class, WithVirtualHostAddressing.class})
+@TrinoS3ProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostAddressing.class})
public class TestProxiedRequestsToVirtualHostEndpoint
extends AbstractTestProxiedRequests
{
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java
new file mode 100644
index 00000000..e31fa573
--- /dev/null
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.s3.proxy.server;
+
+import com.google.inject.Inject;
+import io.trino.s3.proxy.server.testing.ManagedS3MockContainer.ForS3MockContainer;
+import io.trino.s3.proxy.server.testing.harness.TrinoS3ProxyTest;
+import io.trino.s3.proxy.server.testing.harness.TrinoS3ProxyTestCommonModules.WithConfiguredBuckets;
+import io.trino.s3.proxy.server.testing.harness.TrinoS3ProxyTestCommonModules.WithVirtualHostEnabledProxy;
+import software.amazon.awssdk.services.s3.S3Client;
+
+import java.util.List;
+
+@TrinoS3ProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostEnabledProxy.class, WithVirtualHostEnabledProxy.class})
+public class TestProxiedRequestsWithVirtualHostProxy
+ extends AbstractTestProxiedRequests
+{
+ @Inject
+ public TestProxiedRequestsWithVirtualHostProxy(S3Client s3Client, @ForS3MockContainer S3Client storageClient, @ForS3MockContainer List configuredBuckets)
+ {
+ super(s3Client, storageClient, configuredBuckets);
+ }
+}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java
index 80744e40..54509f6c 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java
@@ -27,7 +27,7 @@
import static io.trino.s3.proxy.server.testing.TestingUtil.clientBuilder;
-@TrinoS3ProxyTest(modules = TrinoS3ProxyTestCommonModules.WithConfiguredBuckets.class)
+@TrinoS3ProxyTest(filters = TrinoS3ProxyTestCommonModules.WithConfiguredBuckets.class)
public class TestRemoteSessionProxiedRequests
extends AbstractTestProxiedRequests
{
@@ -41,7 +41,7 @@ private static S3Client buildInternalClient(Credentials credentials, TestingHttp
{
AwsBasicCredentials awsBasicCredentials = AwsBasicCredentials.create(credentials.emulated().accessKey(), credentials.emulated().secretKey());
- return clientBuilder(httpServer)
+ return clientBuilder(httpServer.getBaseUrl())
.credentialsProvider(() -> awsBasicCredentials)
.build();
}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java
index 4212a220..55f78043 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java
@@ -62,7 +62,7 @@ public void testStsSession()
EmulatedAssumedRole emulatedAssumedRole = credentialsController.assumeEmulatedRole(testingCredentials.emulated(), "us-east-1", ARN, Optional.empty(), Optional.empty(), Optional.empty())
.orElseThrow(() -> new RuntimeException("Failed to assume role"));
- try (S3Client client = clientBuilder(httpServer)
+ try (S3Client client = clientBuilder(httpServer.getBaseUrl())
.credentialsProvider(() -> AwsSessionCredentials.create(emulatedAssumedRole.credential().accessKey(), emulatedAssumedRole.credential().secretKey(), emulatedAssumedRole.session()))
.build()) {
// valid assumed role session - should work
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java
index 302196c2..fdbfb9a5 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java
@@ -13,31 +13,51 @@
*/
package io.trino.s3.proxy.server.testing;
+import com.google.inject.BindingAnnotation;
import com.google.inject.Inject;
import com.google.inject.Provider;
import io.airlift.http.server.testing.TestingHttpServer;
import io.trino.s3.proxy.server.credentials.Credential;
import io.trino.s3.proxy.server.credentials.Credentials;
+import io.trino.s3.proxy.server.rest.TrinoS3ProxyResource;
import io.trino.s3.proxy.server.testing.TestingUtil.ForTesting;
+import jakarta.ws.rs.core.UriBuilder;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.services.s3.S3Client;
+import java.lang.annotation.Retention;
+import java.lang.annotation.Target;
+import java.net.URI;
+import java.util.Optional;
+
import static io.trino.s3.proxy.server.testing.TestingUtil.clientBuilder;
+import static java.lang.annotation.ElementType.FIELD;
+import static java.lang.annotation.ElementType.METHOD;
+import static java.lang.annotation.ElementType.PARAMETER;
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static java.util.Objects.requireNonNull;
public class TestingS3ClientProvider
implements Provider
{
+ private final URI proxyUri;
private final Credential testingCredentials;
- private final TestingHttpServer httpServer;
+ private final boolean forcePathStyle;
+
+ @Retention(RUNTIME)
+ @Target({FIELD, PARAMETER, METHOD})
+ @BindingAnnotation
+ public @interface ForS3ClientProvider {}
@Inject
- public TestingS3ClientProvider(
- TestingHttpServer httpServer,
- @ForTesting Credentials testingCredentials)
+ public TestingS3ClientProvider(TestingHttpServer httpServer, @ForTesting Credentials testingCredentials, @ForS3ClientProvider Optional hostName)
{
- this.httpServer = requireNonNull(httpServer, "httpServer is null");
+ URI localProxyServerUri = httpServer.getBaseUrl();
+ this.proxyUri = requireNonNull(hostName, "hostName is null")
+ .map(serverHostName -> UriBuilder.newInstance().host(serverHostName).port(localProxyServerUri.getPort()).scheme("http").path(TrinoS3ProxyResource.class).build())
+ .orElse(UriBuilder.fromUri(localProxyServerUri).path(TrinoS3ProxyResource.class).build());
this.testingCredentials = requireNonNull(testingCredentials, "testingCredentials is null").emulated();
+ this.forcePathStyle = hostName.isEmpty();
}
@Override
@@ -45,8 +65,9 @@ public S3Client get()
{
AwsBasicCredentials awsBasicCredentials = AwsBasicCredentials.create(testingCredentials.accessKey(), testingCredentials.secretKey());
- return clientBuilder(httpServer)
+ return clientBuilder(proxyUri)
.credentialsProvider(() -> awsBasicCredentials)
+ .forcePathStyle(forcePathStyle)
.build();
}
}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java
index 4c64a4f4..f4a84655 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java
@@ -14,6 +14,7 @@
package io.trino.s3.proxy.server.testing;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Injector;
import com.google.inject.Key;
@@ -31,6 +32,7 @@
import java.io.Closeable;
import java.util.Collection;
import java.util.List;
+import java.util.Map;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static io.trino.s3.proxy.server.testing.TestingUtil.TESTING_CREDENTIALS;
@@ -69,6 +71,7 @@ public static Builder builder()
public static class Builder
{
private final ImmutableSet.Builder modules = ImmutableSet.builder();
+ private final ImmutableMap.Builder properties = ImmutableMap.builder();
private Builder()
{
@@ -85,18 +88,24 @@ public Builder withMockS3Container()
this.modules.add(binder -> {
binder.bind(ManagedS3MockContainer.class).asEagerSingleton();
binder.bind(Credentials.class).annotatedWith(TestingUtil.ForTesting.class).toInstance(TESTING_CREDENTIALS);
- newOptionalBinder(binder, Key.get(new TypeLiteral>(){}, ManagedS3MockContainer.ForS3MockContainer.class)).setDefault().toInstance(ImmutableList.of());
+ newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ManagedS3MockContainer.ForS3MockContainer.class)).setDefault().toInstance(ImmutableList.of());
});
return this;
}
+ public Builder withServerHostName(String serverHostName)
+ {
+ properties.put("s3proxy.hostname", serverHostName);
+ return this;
+ }
+
public TestingTrinoS3ProxyServer buildAndStart()
{
- return start(modules.build());
+ return start(modules.build(), properties.buildKeepingLast());
}
}
- private static TestingTrinoS3ProxyServer start(Collection extraModules)
+ private static TestingTrinoS3ProxyServer start(Collection extraModules, Map properties)
{
ImmutableList.Builder modules = ImmutableList.builder()
.add(new TestingTrinoS3ProxyServerModule())
@@ -109,7 +118,7 @@ private static TestingTrinoS3ProxyServer start(Collection extraModules)
extraModules.forEach(modules::add);
Bootstrap app = new Bootstrap(modules.build());
- Injector injector = app.initialize();
+ Injector injector = app.setOptionalConfigurationProperties(properties).initialize();
return new TestingTrinoS3ProxyServer(injector);
}
}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java
index 9a2e674f..c33e1a5b 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java
@@ -14,7 +14,6 @@
package io.trino.s3.proxy.server.testing;
import com.google.inject.BindingAnnotation;
-import io.airlift.http.server.testing.TestingHttpServer;
import io.trino.s3.proxy.server.credentials.Credential;
import io.trino.s3.proxy.server.credentials.Credentials;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyRestConstants;
@@ -46,9 +45,8 @@ public final class TestingUtil
@BindingAnnotation
public @interface ForTesting {}
- public static S3ClientBuilder clientBuilder(TestingHttpServer httpServer)
+ public static S3ClientBuilder clientBuilder(URI baseUrl)
{
- URI baseUrl = httpServer.getBaseUrl();
URI localProxyServerUri = baseUrl.resolve(TrinoS3ProxyRestConstants.S3_PATH);
return S3Client.builder()
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java
index f727ebda..6ba7dc44 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java
@@ -13,7 +13,6 @@
*/
package io.trino.s3.proxy.server.testing.harness;
-import com.google.inject.Module;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -32,7 +31,5 @@
@TestInstance(PER_CLASS)
public @interface TrinoS3ProxyTest
{
- Class extends Module>[] modules() default {};
-
Class extends BuilderFilter>[] filters() default {};
}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java
index 0698c2b0..c0829fc8 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java
@@ -14,44 +14,61 @@
package io.trino.s3.proxy.server.testing.harness;
import com.google.common.collect.ImmutableList;
-import com.google.inject.Binder;
import com.google.inject.Key;
-import com.google.inject.Module;
import com.google.inject.TypeLiteral;
import com.google.inject.multibindings.OptionalBinder;
import io.trino.s3.proxy.server.remote.RemoteS3Facade;
import io.trino.s3.proxy.server.testing.ContainerS3Facade;
import io.trino.s3.proxy.server.testing.ManagedS3MockContainer.ForS3MockContainer;
+import io.trino.s3.proxy.server.testing.TestingS3ClientProvider.ForS3ClientProvider;
+import io.trino.s3.proxy.server.testing.TestingTrinoS3ProxyServer;
import io.trino.s3.proxy.server.testing.TestingUtil.ForTesting;
import java.util.List;
+import static io.trino.s3.proxy.server.testing.TestingUtil.LOCALHOST_DOMAIN;
+
public final class TrinoS3ProxyTestCommonModules
{
public static final class WithConfiguredBuckets
- implements Module
+ implements BuilderFilter
{
private static final List CONFIGURED_BUCKETS = ImmutableList.of("one", "two", "three");
@Override
- public void configure(Binder binder)
+ public TestingTrinoS3ProxyServer.Builder filter(TestingTrinoS3ProxyServer.Builder builder)
{
- OptionalBinder.newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ForS3MockContainer.class))
- .setBinding()
- .toInstance(CONFIGURED_BUCKETS);
+ return builder.addModule(binder ->
+ OptionalBinder.newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ForS3MockContainer.class))
+ .setBinding()
+ .toInstance(CONFIGURED_BUCKETS));
}
}
public static final class WithVirtualHostAddressing
- implements Module
+ implements BuilderFilter
+ {
+ @Override
+ public TestingTrinoS3ProxyServer.Builder filter(TestingTrinoS3ProxyServer.Builder builder)
+ {
+ return builder.addModule(binder ->
+ OptionalBinder.newOptionalBinder(binder, Key.get(RemoteS3Facade.class, ForTesting.class))
+ .setBinding()
+ .to(ContainerS3Facade.VirtualHostStyleContainerS3Facade.class)
+ .asEagerSingleton());
+ }
+ }
+
+ public static final class WithVirtualHostEnabledProxy
+ implements BuilderFilter
{
@Override
- public void configure(Binder binder)
+ public TestingTrinoS3ProxyServer.Builder filter(TestingTrinoS3ProxyServer.Builder builder)
{
- OptionalBinder.newOptionalBinder(binder, Key.get(RemoteS3Facade.class, ForTesting.class))
- .setBinding()
- .to(ContainerS3Facade.VirtualHostStyleContainerS3Facade.class)
- .asEagerSingleton();
+ return builder
+ .withServerHostName(LOCALHOST_DOMAIN)
+ .addModule(
+ binder -> OptionalBinder.newOptionalBinder(binder, Key.get(String.class, ForS3ClientProvider.class)).setBinding().toInstance(LOCALHOST_DOMAIN));
}
}
diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java
index 52437d69..78950f16 100644
--- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java
+++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java
@@ -22,6 +22,7 @@
import io.trino.s3.proxy.server.testing.ManagedS3MockContainer;
import io.trino.s3.proxy.server.testing.ManagedS3MockContainer.ForS3MockContainer;
import io.trino.s3.proxy.server.testing.TestingS3ClientProvider;
+import io.trino.s3.proxy.server.testing.TestingS3ClientProvider.ForS3ClientProvider;
import io.trino.s3.proxy.server.testing.TestingTrinoS3ProxyServer;
import io.trino.s3.proxy.server.testing.TestingUtil.ForTesting;
import org.junit.jupiter.api.extension.ExtensionContext;
@@ -53,10 +54,6 @@ public Object createTestInstance(TestInstanceFactoryContext factoryContext, Exte
TestingTrinoS3ProxyServer.Builder builder = TestingTrinoS3ProxyServer.builder();
- Stream.of(trinoS3ProxyTest.modules())
- .map(TrinoS3ProxyTestExtension::instantiateModule)
- .forEach(builder::addModule);
-
List filters = Stream.of(trinoS3ProxyTest.filters())
.map(TrinoS3ProxyTestExtension::instantiateBuilderFilter)
.collect(toImmutableList());
@@ -67,6 +64,7 @@ public Object createTestInstance(TestInstanceFactoryContext factoryContext, Exte
TestingTrinoS3ProxyServer trinoS3ProxyServer = builder
.withMockS3Container()
.addModule(binder -> {
+ newOptionalBinder(binder, Key.get(String.class, ForS3ClientProvider.class));
binder.bind(S3Client.class).annotatedWith(ForS3MockContainer.class).toProvider(ManagedS3MockContainer.class);
newOptionalBinder(binder, Key.get(RemoteS3Facade.class, ForTesting.class))
.setDefault()