Skip to content

Commit

Permalink
Support serving of Virtual Host requests
Browse files Browse the repository at this point in the history
  • Loading branch information
vagaerg committed Jun 5, 2024
1 parent 5c7f2e5 commit d73cba3
Show file tree
Hide file tree
Showing 20 changed files with 295 additions and 95 deletions.
5 changes: 5 additions & 0 deletions trino-s3-proxy/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@
<artifactId>jakarta.annotation-api</artifactId>
</dependency>

<dependency>
<groupId>jakarta.validation</groupId>
<artifactId>jakarta.validation-api</artifactId>
</dependency>

<dependency>
<groupId>jakarta.ws.rs</groupId>
<artifactId>jakarta.ws.rs-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.trino.s3.proxy.server.remote.VirtualHostStyleRemoteS3Facade;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyConfig;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyResource;
import io.trino.s3.proxy.server.rest.TrinoStsResource;

Expand All @@ -45,7 +46,7 @@ public final void configure(Binder binder)
jaxrsBinder(binder).bind(TrinoStsResource.class);

configBinder(binder).bindConfig(SigningControllerConfig.class);

configBinder(binder).bindConfig(TrinoS3ProxyConfig.class);
binder.bind(SigningController.class).in(Scopes.SINGLETON);
binder.bind(CredentialsController.class).in(Scopes.SINGLETON);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.collections;

import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;

import java.util.List;
import java.util.Locale;
import java.util.function.BiFunction;

public class MultiMapHelper
{
public static <V> MultivaluedMap<String, V> lowercase(MultivaluedMap<String, V> map)
{
return lowercase(map, (ignored, values) -> values);
}

public static <V> MultivaluedMap<String, V> lowercase(MultivaluedMap<String, V> map, BiFunction<String, List<V>, List<V>> valueMapper)
{
MultivaluedMap<String, V> result = new MultivaluedHashMap<>();
map.forEach((name, values) -> result.put(name.toLowerCase(Locale.ROOT), valueMapper.apply(name.toLowerCase(Locale.ROOT), values)));
return result;
}

private MultiMapHelper() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
Expand All @@ -33,11 +32,14 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase;

final class Signer
{
Expand Down Expand Up @@ -73,8 +75,15 @@ static String sign(
Duration maxClockDrift,
Optional<byte[]> entity)
{
requestHeaders = lowercase(requestHeaders);
queryParameters = lowercase(queryParameters);
BiFunction<String, List<String>, List<String>> lowercaseHeaderValues = (key, values) -> {
if (!LOWERCASE_HEADERS.contains(key)) {
return values;
}
return values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList());
};

requestHeaders = lowercase(requestHeaders, lowercaseHeaderValues);
queryParameters = lowercase(queryParameters, lowercaseHeaderValues);

SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
.uri(requestURI)
Expand Down Expand Up @@ -123,17 +132,4 @@ static String sign(
return new WebApplicationException(Response.Status.BAD_REQUEST);
});
}

private static MultivaluedMap<String, String> lowercase(MultivaluedMap<String, String> map)
{
MultivaluedMap<String, String> result = new MultivaluedHashMap<>();
map.forEach((name, values) -> {
name = name.toLowerCase(Locale.ROOT);
if (LOWERCASE_HEADERS.contains(name)) {
values = values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList());
}
result.put(name, values);
});
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.rest;

import com.google.common.base.Splitter;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.UriBuilder;
import org.glassfish.jersey.server.ContainerRequest;

import java.util.List;
import java.util.Locale;
import java.util.Optional;

import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase;
import static java.util.Objects.requireNonNull;

public record ParsedS3Request(
String bucketName,
String keyInBucket,
MultivaluedMap<String, String> requestHeaders,
MultivaluedMap<String, String> requestQueryParameters,
String httpVerb)
{
public ParsedS3Request
{
requireNonNull(bucketName, "bucketName is null");
requireNonNull(keyInBucket, "keyInBucket is null");
requestHeaders = lowercase(requireNonNull(requestHeaders, "requestHeaders is null"));
requireNonNull(requestQueryParameters, "requestQueryParameters is null");
requireNonNull(httpVerb, "httpVerb is null");
}

public static ParsedS3Request fromRequest(String requestPath, MultivaluedMap<String, String> requestHeaders, MultivaluedMap<String, String> requestQueryParameters, String httpVerb, Optional<String> serverHostName)
{
MultivaluedMap<String, String> headers = lowercase(requestHeaders);
return serverHostName
.flatMap(serverHostNameValue -> {
String lowercaseServerHostName = serverHostNameValue.toLowerCase(Locale.ROOT);
return Optional.ofNullable(headers.getFirst("host"))
.map(value -> UriBuilder.fromUri("http://" + value.toLowerCase(Locale.ROOT)).build().getHost())
.filter(value -> value.endsWith(lowercaseServerHostName))
.map(value -> value.substring(0, value.length() - lowercaseServerHostName.length()))
.map(value -> value.endsWith(".") ? value.substring(0, value.length() - 1) : value);
})
.map(bucket -> new ParsedS3Request(bucket, requestPath, headers, requestQueryParameters, httpVerb))
.orElseGet(() -> {
List<String> parts = Splitter.on("/").limit(2).splitToList(requestPath);
if (parts.size() <= 1) {
return new ParsedS3Request(requestPath, "", headers, requestQueryParameters, httpVerb);
}
return new ParsedS3Request(parts.get(0), parts.get(1), headers, requestQueryParameters, httpVerb);
});
}

public static ParsedS3Request fromRequest(String requestPath, ContainerRequest request, Optional<String> serverHostName)
{
return fromRequest(requestPath, request.getHeaders(), request.getUriInfo().getQueryParameters(), request.getMethod(), serverHostName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import org.glassfish.jersey.server.ContainerRequest;
import jakarta.ws.rs.core.UriBuilder;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;
Expand All @@ -38,7 +38,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination;
import static io.trino.s3.proxy.server.credentials.SigningController.formatRequestInstant;
import static java.lang.annotation.ElementType.FIELD;
Expand Down Expand Up @@ -75,14 +74,12 @@ public void shutDown()
}
}

public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest request, AsyncResponse asyncResponse, String bucket)
public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse)
{
String remotePath = rewriteRequestPath(request, bucket);

URI remoteUri = remoteS3Facade.buildEndpoint(request.getUriInfo().getRequestUriBuilder(), remotePath, bucket, signingMetadata.region());
URI remoteUri = remoteS3Facade.buildEndpoint(UriBuilder.newInstance(), request.keyInBucket(), request.bucketName(), signingMetadata.region());

Request.Builder remoteRequestBuilder = new Request.Builder()
.setMethod(request.getMethod())
.setMethod(request.httpVerb())
.setUri(remoteUri)
.setFollowRedirects(true);

Expand All @@ -91,9 +88,10 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque
}

MultivaluedMap<String, String> remoteRequestHeaders = new MultivaluedHashMap<>();
request.getRequestHeaders().forEach((key, value) -> {
switch (key.toLowerCase()) {
request.requestHeaders().forEach((key, value) -> {
switch (key) {
case "x-amz-security-token" -> {} // we add this below
case "authorization" -> {} // we will create our own authorization header
case "amz-sdk-invocation-id", "amz-sdk-request" -> {} // don't send these
case "x-amz-date" -> remoteRequestHeaders.putSingle("X-Amz-Date", formatRequestInstant(Instant.now())); // use now for the remote request
case "host" -> remoteRequestHeaders.putSingle("Host", buildRemoteHost(remoteUri)); // replace source host with the remote AWS host
Expand All @@ -107,14 +105,13 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque
.ifPresent(sessionToken -> remoteRequestHeaders.add("x-amz-security-token", sessionToken));

// set the new signed request auth header
String encodedPath = firstNonNull(remoteUri.getRawPath(), "");
String signature = signingController.signRequest(
signingMetadata,
Credentials::requiredRemoteCredential,
remoteUri,
remoteRequestHeaders,
request.getUriInfo().getQueryParameters(),
request.getMethod(),
request.requestQueryParameters(),
request.httpVerb(),
Optional.empty());
remoteRequestHeaders.putSingle("Authorization", signature);

Expand All @@ -134,23 +131,4 @@ private static String buildRemoteHost(URI remoteUri)
}
return remoteUri.getHost() + ":" + port;
}

private static String rewriteRequestPath(ContainerRequest request, String bucket)
{
String path = "/" + request.getPath(false);
if (!path.startsWith(TrinoS3ProxyRestConstants.S3_PATH)) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}

path = path.substring(TrinoS3ProxyRestConstants.S3_PATH.length());
if (path.startsWith("/" + bucket)) {
path = path.substring(("/" + bucket).length());
}

if (path.isEmpty() && bucket.isEmpty()) {
path = "/";
}

return path;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.rest;

import io.airlift.configuration.Config;
import io.airlift.configuration.ConfigDescription;
import jakarta.validation.constraints.NotNull;

import java.util.Optional;

public class TrinoS3ProxyConfig
{
private Optional<String> hostName = Optional.empty();

@Config("s3proxy.hostname")
@ConfigDescription("Hostname to use for REST operations, virtual-host style addressing is only supported if this is set")
public TrinoS3ProxyConfig setHostName(String hostName)
{
this.hostName = Optional.ofNullable(hostName);
return this;
}

@NotNull
public Optional<String> getHostName()
{
return hostName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.s3.proxy.server.rest;

import com.google.common.base.Splitter;
import com.google.inject.Inject;
import io.trino.s3.proxy.server.credentials.SigningController;
import io.trino.s3.proxy.server.credentials.SigningServiceType;
Expand All @@ -26,7 +25,6 @@
import jakarta.ws.rs.core.Context;
import org.glassfish.jersey.server.ContainerRequest;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
Expand All @@ -36,12 +34,14 @@ public class TrinoS3ProxyResource
{
private final SigningController signingController;
private final TrinoS3ProxyClient proxyClient;
private final Optional<String> serverHostName;

@Inject
public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient)
public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient, TrinoS3ProxyConfig trinoS3ProxyConfig)
{
this.signingController = requireNonNull(signingController, "signingController is null");
this.proxyClient = requireNonNull(proxyClient, "proxyClient is null");
this.serverHostName = requireNonNull(trinoS3ProxyConfig, "restConfig is null").getHostName();
}

@GET
Expand All @@ -54,7 +54,7 @@ public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse as
@Path("{path:.*}")
public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path)
{
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path));
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse);
}

@HEAD
Expand All @@ -67,12 +67,11 @@ public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse a
@Path("{path:.*}")
public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path)
{
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path));
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse);
}

private String getBucket(String path)
private ParsedS3Request parseRequest(String path, ContainerRequest request)
{
List<String> parts = Splitter.on("/").splitToList(path);
return parts.isEmpty() ? "" : parts.getFirst();
return ParsedS3Request.fromRequest(path, request, serverHostName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import static org.assertj.core.api.Assertions.assertThat;

@TrinoS3ProxyTest(modules = WithConfiguredBuckets.class)
@TrinoS3ProxyTest(filters = WithConfiguredBuckets.class)
public class TestProxiedAssumedRoleRequests
extends AbstractTestProxiedRequests
{
Expand Down
Loading

0 comments on commit d73cba3

Please sign in to comment.