From d73cba3442f0d3037e77a8465029bdfac1416797 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Fri, 31 May 2024 14:59:46 +0100 Subject: [PATCH] Support serving of Virtual Host requests --- trino-s3-proxy/pom.xml | 5 ++ .../server/TrinoS3ProxyServerModule.java | 3 +- .../server/collections/MultiMapHelper.java | 38 ++++++++++ .../s3/proxy/server/credentials/Signer.java | 28 ++++---- .../s3/proxy/server/rest/ParsedS3Request.java | 70 +++++++++++++++++++ .../proxy/server/rest/TrinoS3ProxyClient.java | 40 +++-------- .../proxy/server/rest/TrinoS3ProxyConfig.java | 39 +++++++++++ .../server/rest/TrinoS3ProxyResource.java | 15 ++-- .../TestProxiedAssumedRoleRequests.java | 2 +- .../s3/proxy/server/TestProxiedRequests.java | 2 +- ...tProxiedRequestsToVirtualHostEndpoint.java | 2 +- ...stProxiedRequestsWithVirtualHostProxy.java | 34 +++++++++ .../TestRemoteSessionProxiedRequests.java | 4 +- .../server/credentials/TestAssumingRoles.java | 2 +- .../testing/TestingS3ClientProvider.java | 33 +++++++-- .../testing/TestingTrinoS3ProxyServer.java | 17 +++-- .../s3/proxy/server/testing/TestingUtil.java | 4 +- .../testing/harness/TrinoS3ProxyTest.java | 3 - .../TrinoS3ProxyTestCommonModules.java | 43 ++++++++---- .../harness/TrinoS3ProxyTestExtension.java | 6 +- 20 files changed, 295 insertions(+), 95 deletions(-) create mode 100644 trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java create mode 100644 trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java create mode 100644 trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java create mode 100644 trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java diff --git a/trino-s3-proxy/pom.xml b/trino-s3-proxy/pom.xml index fc6f0e0a..2fe5fa30 100644 --- a/trino-s3-proxy/pom.xml +++ b/trino-s3-proxy/pom.xml @@ -94,6 +94,11 @@ jakarta.annotation-api + + jakarta.validation + jakarta.validation-api + + jakarta.ws.rs jakarta.ws.rs-api diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java index 581262e7..1342ad8d 100644 --- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/TrinoS3ProxyServerModule.java @@ -25,6 +25,7 @@ import io.trino.s3.proxy.server.remote.VirtualHostStyleRemoteS3Facade; import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient; import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient; +import io.trino.s3.proxy.server.rest.TrinoS3ProxyConfig; import io.trino.s3.proxy.server.rest.TrinoS3ProxyResource; import io.trino.s3.proxy.server.rest.TrinoStsResource; @@ -45,7 +46,7 @@ public final void configure(Binder binder) jaxrsBinder(binder).bind(TrinoStsResource.class); configBinder(binder).bindConfig(SigningControllerConfig.class); - + configBinder(binder).bindConfig(TrinoS3ProxyConfig.class); binder.bind(SigningController.class).in(Scopes.SINGLETON); binder.bind(CredentialsController.class).in(Scopes.SINGLETON); diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java new file mode 100644 index 00000000..42d91d01 --- /dev/null +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/collections/MultiMapHelper.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.s3.proxy.server.collections; + +import jakarta.ws.rs.core.MultivaluedHashMap; +import jakarta.ws.rs.core.MultivaluedMap; + +import java.util.List; +import java.util.Locale; +import java.util.function.BiFunction; + +public class MultiMapHelper +{ + public static MultivaluedMap lowercase(MultivaluedMap map) + { + return lowercase(map, (ignored, values) -> values); + } + + public static MultivaluedMap lowercase(MultivaluedMap map, BiFunction, List> valueMapper) + { + MultivaluedMap result = new MultivaluedHashMap<>(); + map.forEach((name, values) -> result.put(name.toLowerCase(Locale.ROOT), valueMapper.apply(name.toLowerCase(Locale.ROOT), values))); + return result; + } + + private MultiMapHelper() {} +} diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java index ed0f6ab3..0c9dd596 100644 --- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/credentials/Signer.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableSet; import io.airlift.log.Logger; import jakarta.ws.rs.WebApplicationException; -import jakarta.ws.rs.core.MultivaluedHashMap; import jakarta.ws.rs.core.MultivaluedMap; import jakarta.ws.rs.core.Response; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; @@ -33,11 +32,14 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.List; import java.util.Locale; import java.util.Optional; import java.util.Set; +import java.util.function.BiFunction; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase; final class Signer { @@ -73,8 +75,15 @@ static String sign( Duration maxClockDrift, Optional entity) { - requestHeaders = lowercase(requestHeaders); - queryParameters = lowercase(queryParameters); + BiFunction, List> lowercaseHeaderValues = (key, values) -> { + if (!LOWERCASE_HEADERS.contains(key)) { + return values; + } + return values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList()); + }; + + requestHeaders = lowercase(requestHeaders, lowercaseHeaderValues); + queryParameters = lowercase(queryParameters, lowercaseHeaderValues); SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() .uri(requestURI) @@ -123,17 +132,4 @@ static String sign( return new WebApplicationException(Response.Status.BAD_REQUEST); }); } - - private static MultivaluedMap lowercase(MultivaluedMap map) - { - MultivaluedMap result = new MultivaluedHashMap<>(); - map.forEach((name, values) -> { - name = name.toLowerCase(Locale.ROOT); - if (LOWERCASE_HEADERS.contains(name)) { - values = values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList()); - } - result.put(name, values); - }); - return result; - } } diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java new file mode 100644 index 00000000..4da2a792 --- /dev/null +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/ParsedS3Request.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.s3.proxy.server.rest; + +import com.google.common.base.Splitter; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.core.UriBuilder; +import org.glassfish.jersey.server.ContainerRequest; + +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase; +import static java.util.Objects.requireNonNull; + +public record ParsedS3Request( + String bucketName, + String keyInBucket, + MultivaluedMap requestHeaders, + MultivaluedMap requestQueryParameters, + String httpVerb) +{ + public ParsedS3Request + { + requireNonNull(bucketName, "bucketName is null"); + requireNonNull(keyInBucket, "keyInBucket is null"); + requestHeaders = lowercase(requireNonNull(requestHeaders, "requestHeaders is null")); + requireNonNull(requestQueryParameters, "requestQueryParameters is null"); + requireNonNull(httpVerb, "httpVerb is null"); + } + + public static ParsedS3Request fromRequest(String requestPath, MultivaluedMap requestHeaders, MultivaluedMap requestQueryParameters, String httpVerb, Optional serverHostName) + { + MultivaluedMap headers = lowercase(requestHeaders); + return serverHostName + .flatMap(serverHostNameValue -> { + String lowercaseServerHostName = serverHostNameValue.toLowerCase(Locale.ROOT); + return Optional.ofNullable(headers.getFirst("host")) + .map(value -> UriBuilder.fromUri("http://" + value.toLowerCase(Locale.ROOT)).build().getHost()) + .filter(value -> value.endsWith(lowercaseServerHostName)) + .map(value -> value.substring(0, value.length() - lowercaseServerHostName.length())) + .map(value -> value.endsWith(".") ? value.substring(0, value.length() - 1) : value); + }) + .map(bucket -> new ParsedS3Request(bucket, requestPath, headers, requestQueryParameters, httpVerb)) + .orElseGet(() -> { + List parts = Splitter.on("/").limit(2).splitToList(requestPath); + if (parts.size() <= 1) { + return new ParsedS3Request(requestPath, "", headers, requestQueryParameters, httpVerb); + } + return new ParsedS3Request(parts.get(0), parts.get(1), headers, requestQueryParameters, httpVerb); + }); + } + + public static ParsedS3Request fromRequest(String requestPath, ContainerRequest request, Optional serverHostName) + { + return fromRequest(requestPath, request.getHeaders(), request.getUriInfo().getQueryParameters(), request.getMethod(), serverHostName); + } +} diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java index 90ee32ea..d5ec6665 100644 --- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java @@ -27,7 +27,7 @@ import jakarta.ws.rs.core.MultivaluedHashMap; import jakarta.ws.rs.core.MultivaluedMap; import jakarta.ws.rs.core.Response; -import org.glassfish.jersey.server.ContainerRequest; +import jakarta.ws.rs.core.UriBuilder; import java.lang.annotation.Retention; import java.lang.annotation.Target; @@ -38,7 +38,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination; import static io.trino.s3.proxy.server.credentials.SigningController.formatRequestInstant; import static java.lang.annotation.ElementType.FIELD; @@ -75,14 +74,12 @@ public void shutDown() } } - public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest request, AsyncResponse asyncResponse, String bucket) + public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse) { - String remotePath = rewriteRequestPath(request, bucket); - - URI remoteUri = remoteS3Facade.buildEndpoint(request.getUriInfo().getRequestUriBuilder(), remotePath, bucket, signingMetadata.region()); + URI remoteUri = remoteS3Facade.buildEndpoint(UriBuilder.newInstance(), request.keyInBucket(), request.bucketName(), signingMetadata.region()); Request.Builder remoteRequestBuilder = new Request.Builder() - .setMethod(request.getMethod()) + .setMethod(request.httpVerb()) .setUri(remoteUri) .setFollowRedirects(true); @@ -91,9 +88,10 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque } MultivaluedMap remoteRequestHeaders = new MultivaluedHashMap<>(); - request.getRequestHeaders().forEach((key, value) -> { - switch (key.toLowerCase()) { + request.requestHeaders().forEach((key, value) -> { + switch (key) { case "x-amz-security-token" -> {} // we add this below + case "authorization" -> {} // we will create our own authorization header case "amz-sdk-invocation-id", "amz-sdk-request" -> {} // don't send these case "x-amz-date" -> remoteRequestHeaders.putSingle("X-Amz-Date", formatRequestInstant(Instant.now())); // use now for the remote request case "host" -> remoteRequestHeaders.putSingle("Host", buildRemoteHost(remoteUri)); // replace source host with the remote AWS host @@ -107,14 +105,13 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque .ifPresent(sessionToken -> remoteRequestHeaders.add("x-amz-security-token", sessionToken)); // set the new signed request auth header - String encodedPath = firstNonNull(remoteUri.getRawPath(), ""); String signature = signingController.signRequest( signingMetadata, Credentials::requiredRemoteCredential, remoteUri, remoteRequestHeaders, - request.getUriInfo().getQueryParameters(), - request.getMethod(), + request.requestQueryParameters(), + request.httpVerb(), Optional.empty()); remoteRequestHeaders.putSingle("Authorization", signature); @@ -134,23 +131,4 @@ private static String buildRemoteHost(URI remoteUri) } return remoteUri.getHost() + ":" + port; } - - private static String rewriteRequestPath(ContainerRequest request, String bucket) - { - String path = "/" + request.getPath(false); - if (!path.startsWith(TrinoS3ProxyRestConstants.S3_PATH)) { - throw new WebApplicationException(Response.Status.BAD_REQUEST); - } - - path = path.substring(TrinoS3ProxyRestConstants.S3_PATH.length()); - if (path.startsWith("/" + bucket)) { - path = path.substring(("/" + bucket).length()); - } - - if (path.isEmpty() && bucket.isEmpty()) { - path = "/"; - } - - return path; - } } diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java new file mode 100644 index 00000000..639670fc --- /dev/null +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyConfig.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.s3.proxy.server.rest; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import jakarta.validation.constraints.NotNull; + +import java.util.Optional; + +public class TrinoS3ProxyConfig +{ + private Optional hostName = Optional.empty(); + + @Config("s3proxy.hostname") + @ConfigDescription("Hostname to use for REST operations, virtual-host style addressing is only supported if this is set") + public TrinoS3ProxyConfig setHostName(String hostName) + { + this.hostName = Optional.ofNullable(hostName); + return this; + } + + @NotNull + public Optional getHostName() + { + return hostName; + } +} diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java index 9e68b77d..785335f3 100644 --- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyResource.java @@ -13,7 +13,6 @@ */ package io.trino.s3.proxy.server.rest; -import com.google.common.base.Splitter; import com.google.inject.Inject; import io.trino.s3.proxy.server.credentials.SigningController; import io.trino.s3.proxy.server.credentials.SigningServiceType; @@ -26,7 +25,6 @@ import jakarta.ws.rs.core.Context; import org.glassfish.jersey.server.ContainerRequest; -import java.util.List; import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -36,12 +34,14 @@ public class TrinoS3ProxyResource { private final SigningController signingController; private final TrinoS3ProxyClient proxyClient; + private final Optional serverHostName; @Inject - public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient) + public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient, TrinoS3ProxyConfig trinoS3ProxyConfig) { this.signingController = requireNonNull(signingController, "signingController is null"); this.proxyClient = requireNonNull(proxyClient, "proxyClient is null"); + this.serverHostName = requireNonNull(trinoS3ProxyConfig, "restConfig is null").getHostName(); } @GET @@ -54,7 +54,7 @@ public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse as @Path("{path:.*}") public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path) { - proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path)); + proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse); } @HEAD @@ -67,12 +67,11 @@ public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse a @Path("{path:.*}") public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path) { - proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path)); + proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse); } - private String getBucket(String path) + private ParsedS3Request parseRequest(String path, ContainerRequest request) { - List parts = Splitter.on("/").splitToList(path); - return parts.isEmpty() ? "" : parts.getFirst(); + return ParsedS3Request.fromRequest(path, request, serverHostName); } } diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java index dd9f3918..506f3a4f 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedAssumedRoleRequests.java @@ -37,7 +37,7 @@ import static org.assertj.core.api.Assertions.assertThat; -@TrinoS3ProxyTest(modules = WithConfiguredBuckets.class) +@TrinoS3ProxyTest(filters = WithConfiguredBuckets.class) public class TestProxiedAssumedRoleRequests extends AbstractTestProxiedRequests { diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java index e3a0434c..04ef35aa 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequests.java @@ -21,7 +21,7 @@ import java.util.List; -@TrinoS3ProxyTest(modules = WithConfiguredBuckets.class) +@TrinoS3ProxyTest(filters = WithConfiguredBuckets.class) public class TestProxiedRequests extends AbstractTestProxiedRequests { diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java index 7d952281..571c5f20 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java @@ -22,7 +22,7 @@ import java.util.List; -@TrinoS3ProxyTest(modules = {WithConfiguredBuckets.class, WithVirtualHostAddressing.class}) +@TrinoS3ProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostAddressing.class}) public class TestProxiedRequestsToVirtualHostEndpoint extends AbstractTestProxiedRequests { diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java new file mode 100644 index 00000000..e31fa573 --- /dev/null +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.s3.proxy.server; + +import com.google.inject.Inject; +import io.trino.s3.proxy.server.testing.ManagedS3MockContainer.ForS3MockContainer; +import io.trino.s3.proxy.server.testing.harness.TrinoS3ProxyTest; +import io.trino.s3.proxy.server.testing.harness.TrinoS3ProxyTestCommonModules.WithConfiguredBuckets; +import io.trino.s3.proxy.server.testing.harness.TrinoS3ProxyTestCommonModules.WithVirtualHostEnabledProxy; +import software.amazon.awssdk.services.s3.S3Client; + +import java.util.List; + +@TrinoS3ProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostEnabledProxy.class, WithVirtualHostEnabledProxy.class}) +public class TestProxiedRequestsWithVirtualHostProxy + extends AbstractTestProxiedRequests +{ + @Inject + public TestProxiedRequestsWithVirtualHostProxy(S3Client s3Client, @ForS3MockContainer S3Client storageClient, @ForS3MockContainer List configuredBuckets) + { + super(s3Client, storageClient, configuredBuckets); + } +} diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java index 80744e40..54509f6c 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/TestRemoteSessionProxiedRequests.java @@ -27,7 +27,7 @@ import static io.trino.s3.proxy.server.testing.TestingUtil.clientBuilder; -@TrinoS3ProxyTest(modules = TrinoS3ProxyTestCommonModules.WithConfiguredBuckets.class) +@TrinoS3ProxyTest(filters = TrinoS3ProxyTestCommonModules.WithConfiguredBuckets.class) public class TestRemoteSessionProxiedRequests extends AbstractTestProxiedRequests { @@ -41,7 +41,7 @@ private static S3Client buildInternalClient(Credentials credentials, TestingHttp { AwsBasicCredentials awsBasicCredentials = AwsBasicCredentials.create(credentials.emulated().accessKey(), credentials.emulated().secretKey()); - return clientBuilder(httpServer) + return clientBuilder(httpServer.getBaseUrl()) .credentialsProvider(() -> awsBasicCredentials) .build(); } diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java index 4212a220..55f78043 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/credentials/TestAssumingRoles.java @@ -62,7 +62,7 @@ public void testStsSession() EmulatedAssumedRole emulatedAssumedRole = credentialsController.assumeEmulatedRole(testingCredentials.emulated(), "us-east-1", ARN, Optional.empty(), Optional.empty(), Optional.empty()) .orElseThrow(() -> new RuntimeException("Failed to assume role")); - try (S3Client client = clientBuilder(httpServer) + try (S3Client client = clientBuilder(httpServer.getBaseUrl()) .credentialsProvider(() -> AwsSessionCredentials.create(emulatedAssumedRole.credential().accessKey(), emulatedAssumedRole.credential().secretKey(), emulatedAssumedRole.session())) .build()) { // valid assumed role session - should work diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java index 302196c2..fdbfb9a5 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingS3ClientProvider.java @@ -13,31 +13,51 @@ */ package io.trino.s3.proxy.server.testing; +import com.google.inject.BindingAnnotation; import com.google.inject.Inject; import com.google.inject.Provider; import io.airlift.http.server.testing.TestingHttpServer; import io.trino.s3.proxy.server.credentials.Credential; import io.trino.s3.proxy.server.credentials.Credentials; +import io.trino.s3.proxy.server.rest.TrinoS3ProxyResource; import io.trino.s3.proxy.server.testing.TestingUtil.ForTesting; +import jakarta.ws.rs.core.UriBuilder; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.services.s3.S3Client; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.net.URI; +import java.util.Optional; + import static io.trino.s3.proxy.server.testing.TestingUtil.clientBuilder; +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; import static java.util.Objects.requireNonNull; public class TestingS3ClientProvider implements Provider { + private final URI proxyUri; private final Credential testingCredentials; - private final TestingHttpServer httpServer; + private final boolean forcePathStyle; + + @Retention(RUNTIME) + @Target({FIELD, PARAMETER, METHOD}) + @BindingAnnotation + public @interface ForS3ClientProvider {} @Inject - public TestingS3ClientProvider( - TestingHttpServer httpServer, - @ForTesting Credentials testingCredentials) + public TestingS3ClientProvider(TestingHttpServer httpServer, @ForTesting Credentials testingCredentials, @ForS3ClientProvider Optional hostName) { - this.httpServer = requireNonNull(httpServer, "httpServer is null"); + URI localProxyServerUri = httpServer.getBaseUrl(); + this.proxyUri = requireNonNull(hostName, "hostName is null") + .map(serverHostName -> UriBuilder.newInstance().host(serverHostName).port(localProxyServerUri.getPort()).scheme("http").path(TrinoS3ProxyResource.class).build()) + .orElse(UriBuilder.fromUri(localProxyServerUri).path(TrinoS3ProxyResource.class).build()); this.testingCredentials = requireNonNull(testingCredentials, "testingCredentials is null").emulated(); + this.forcePathStyle = hostName.isEmpty(); } @Override @@ -45,8 +65,9 @@ public S3Client get() { AwsBasicCredentials awsBasicCredentials = AwsBasicCredentials.create(testingCredentials.accessKey(), testingCredentials.secretKey()); - return clientBuilder(httpServer) + return clientBuilder(proxyUri) .credentialsProvider(() -> awsBasicCredentials) + .forcePathStyle(forcePathStyle) .build(); } } diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java index 4c64a4f4..f4a84655 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingTrinoS3ProxyServer.java @@ -14,6 +14,7 @@ package io.trino.s3.proxy.server.testing; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.inject.Injector; import com.google.inject.Key; @@ -31,6 +32,7 @@ import java.io.Closeable; import java.util.Collection; import java.util.List; +import java.util.Map; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.trino.s3.proxy.server.testing.TestingUtil.TESTING_CREDENTIALS; @@ -69,6 +71,7 @@ public static Builder builder() public static class Builder { private final ImmutableSet.Builder modules = ImmutableSet.builder(); + private final ImmutableMap.Builder properties = ImmutableMap.builder(); private Builder() { @@ -85,18 +88,24 @@ public Builder withMockS3Container() this.modules.add(binder -> { binder.bind(ManagedS3MockContainer.class).asEagerSingleton(); binder.bind(Credentials.class).annotatedWith(TestingUtil.ForTesting.class).toInstance(TESTING_CREDENTIALS); - newOptionalBinder(binder, Key.get(new TypeLiteral>(){}, ManagedS3MockContainer.ForS3MockContainer.class)).setDefault().toInstance(ImmutableList.of()); + newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ManagedS3MockContainer.ForS3MockContainer.class)).setDefault().toInstance(ImmutableList.of()); }); return this; } + public Builder withServerHostName(String serverHostName) + { + properties.put("s3proxy.hostname", serverHostName); + return this; + } + public TestingTrinoS3ProxyServer buildAndStart() { - return start(modules.build()); + return start(modules.build(), properties.buildKeepingLast()); } } - private static TestingTrinoS3ProxyServer start(Collection extraModules) + private static TestingTrinoS3ProxyServer start(Collection extraModules, Map properties) { ImmutableList.Builder modules = ImmutableList.builder() .add(new TestingTrinoS3ProxyServerModule()) @@ -109,7 +118,7 @@ private static TestingTrinoS3ProxyServer start(Collection extraModules) extraModules.forEach(modules::add); Bootstrap app = new Bootstrap(modules.build()); - Injector injector = app.initialize(); + Injector injector = app.setOptionalConfigurationProperties(properties).initialize(); return new TestingTrinoS3ProxyServer(injector); } } diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java index 9a2e674f..c33e1a5b 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/TestingUtil.java @@ -14,7 +14,6 @@ package io.trino.s3.proxy.server.testing; import com.google.inject.BindingAnnotation; -import io.airlift.http.server.testing.TestingHttpServer; import io.trino.s3.proxy.server.credentials.Credential; import io.trino.s3.proxy.server.credentials.Credentials; import io.trino.s3.proxy.server.rest.TrinoS3ProxyRestConstants; @@ -46,9 +45,8 @@ public final class TestingUtil @BindingAnnotation public @interface ForTesting {} - public static S3ClientBuilder clientBuilder(TestingHttpServer httpServer) + public static S3ClientBuilder clientBuilder(URI baseUrl) { - URI baseUrl = httpServer.getBaseUrl(); URI localProxyServerUri = baseUrl.resolve(TrinoS3ProxyRestConstants.S3_PATH); return S3Client.builder() diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java index f727ebda..6ba7dc44 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTest.java @@ -13,7 +13,6 @@ */ package io.trino.s3.proxy.server.testing.harness; -import com.google.inject.Module; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,7 +31,5 @@ @TestInstance(PER_CLASS) public @interface TrinoS3ProxyTest { - Class[] modules() default {}; - Class[] filters() default {}; } diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java index 0698c2b0..c0829fc8 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestCommonModules.java @@ -14,44 +14,61 @@ package io.trino.s3.proxy.server.testing.harness; import com.google.common.collect.ImmutableList; -import com.google.inject.Binder; import com.google.inject.Key; -import com.google.inject.Module; import com.google.inject.TypeLiteral; import com.google.inject.multibindings.OptionalBinder; import io.trino.s3.proxy.server.remote.RemoteS3Facade; import io.trino.s3.proxy.server.testing.ContainerS3Facade; import io.trino.s3.proxy.server.testing.ManagedS3MockContainer.ForS3MockContainer; +import io.trino.s3.proxy.server.testing.TestingS3ClientProvider.ForS3ClientProvider; +import io.trino.s3.proxy.server.testing.TestingTrinoS3ProxyServer; import io.trino.s3.proxy.server.testing.TestingUtil.ForTesting; import java.util.List; +import static io.trino.s3.proxy.server.testing.TestingUtil.LOCALHOST_DOMAIN; + public final class TrinoS3ProxyTestCommonModules { public static final class WithConfiguredBuckets - implements Module + implements BuilderFilter { private static final List CONFIGURED_BUCKETS = ImmutableList.of("one", "two", "three"); @Override - public void configure(Binder binder) + public TestingTrinoS3ProxyServer.Builder filter(TestingTrinoS3ProxyServer.Builder builder) { - OptionalBinder.newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ForS3MockContainer.class)) - .setBinding() - .toInstance(CONFIGURED_BUCKETS); + return builder.addModule(binder -> + OptionalBinder.newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ForS3MockContainer.class)) + .setBinding() + .toInstance(CONFIGURED_BUCKETS)); } } public static final class WithVirtualHostAddressing - implements Module + implements BuilderFilter + { + @Override + public TestingTrinoS3ProxyServer.Builder filter(TestingTrinoS3ProxyServer.Builder builder) + { + return builder.addModule(binder -> + OptionalBinder.newOptionalBinder(binder, Key.get(RemoteS3Facade.class, ForTesting.class)) + .setBinding() + .to(ContainerS3Facade.VirtualHostStyleContainerS3Facade.class) + .asEagerSingleton()); + } + } + + public static final class WithVirtualHostEnabledProxy + implements BuilderFilter { @Override - public void configure(Binder binder) + public TestingTrinoS3ProxyServer.Builder filter(TestingTrinoS3ProxyServer.Builder builder) { - OptionalBinder.newOptionalBinder(binder, Key.get(RemoteS3Facade.class, ForTesting.class)) - .setBinding() - .to(ContainerS3Facade.VirtualHostStyleContainerS3Facade.class) - .asEagerSingleton(); + return builder + .withServerHostName(LOCALHOST_DOMAIN) + .addModule( + binder -> OptionalBinder.newOptionalBinder(binder, Key.get(String.class, ForS3ClientProvider.class)).setBinding().toInstance(LOCALHOST_DOMAIN)); } } diff --git a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java index 52437d69..78950f16 100644 --- a/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java +++ b/trino-s3-proxy/src/test/java/io/trino/s3/proxy/server/testing/harness/TrinoS3ProxyTestExtension.java @@ -22,6 +22,7 @@ import io.trino.s3.proxy.server.testing.ManagedS3MockContainer; import io.trino.s3.proxy.server.testing.ManagedS3MockContainer.ForS3MockContainer; import io.trino.s3.proxy.server.testing.TestingS3ClientProvider; +import io.trino.s3.proxy.server.testing.TestingS3ClientProvider.ForS3ClientProvider; import io.trino.s3.proxy.server.testing.TestingTrinoS3ProxyServer; import io.trino.s3.proxy.server.testing.TestingUtil.ForTesting; import org.junit.jupiter.api.extension.ExtensionContext; @@ -53,10 +54,6 @@ public Object createTestInstance(TestInstanceFactoryContext factoryContext, Exte TestingTrinoS3ProxyServer.Builder builder = TestingTrinoS3ProxyServer.builder(); - Stream.of(trinoS3ProxyTest.modules()) - .map(TrinoS3ProxyTestExtension::instantiateModule) - .forEach(builder::addModule); - List filters = Stream.of(trinoS3ProxyTest.filters()) .map(TrinoS3ProxyTestExtension::instantiateBuilderFilter) .collect(toImmutableList()); @@ -67,6 +64,7 @@ public Object createTestInstance(TestInstanceFactoryContext factoryContext, Exte TestingTrinoS3ProxyServer trinoS3ProxyServer = builder .withMockS3Container() .addModule(binder -> { + newOptionalBinder(binder, Key.get(String.class, ForS3ClientProvider.class)); binder.bind(S3Client.class).annotatedWith(ForS3MockContainer.class).toProvider(ManagedS3MockContainer.class); newOptionalBinder(binder, Key.get(RemoteS3Facade.class, ForTesting.class)) .setDefault()