Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serving of Virtual Host requests #51

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions trino-s3-proxy/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@
<artifactId>jakarta.annotation-api</artifactId>
</dependency>

<dependency>
<groupId>jakarta.validation</groupId>
<artifactId>jakarta.validation-api</artifactId>
</dependency>

<dependency>
<groupId>jakarta.ws.rs</groupId>
<artifactId>jakarta.ws.rs-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.trino.s3.proxy.server.remote.VirtualHostStyleRemoteS3Facade;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyConfig;
import io.trino.s3.proxy.server.rest.TrinoS3ProxyResource;
import io.trino.s3.proxy.server.rest.TrinoStsResource;

Expand All @@ -45,7 +46,7 @@ public final void configure(Binder binder)
jaxrsBinder(binder).bind(TrinoStsResource.class);

configBinder(binder).bindConfig(SigningControllerConfig.class);

configBinder(binder).bindConfig(TrinoS3ProxyConfig.class);
binder.bind(SigningController.class).in(Scopes.SINGLETON);
binder.bind(CredentialsController.class).in(Scopes.SINGLETON);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.collections;

import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;

import java.util.List;
import java.util.Locale;
import java.util.function.BiFunction;

public class MultiMapHelper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

{
public static <V> MultivaluedMap<String, V> lowercase(MultivaluedMap<String, V> map)
{
return lowercase(map, (ignored, values) -> values);
}

public static <V> MultivaluedMap<String, V> lowercase(MultivaluedMap<String, V> map, BiFunction<String, List<V>, List<V>> valueMapper)
{
MultivaluedMap<String, V> result = new MultivaluedHashMap<>();
map.forEach((name, values) -> result.put(name.toLowerCase(Locale.ROOT), valueMapper.apply(name.toLowerCase(Locale.ROOT), values)));
return result;
}

private MultiMapHelper() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
Expand All @@ -33,11 +32,14 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase;

final class Signer
{
Expand Down Expand Up @@ -73,8 +75,15 @@ static String sign(
Duration maxClockDrift,
Optional<byte[]> entity)
{
requestHeaders = lowercase(requestHeaders);
queryParameters = lowercase(queryParameters);
BiFunction<String, List<String>, List<String>> lowercaseHeaderValues = (key, values) -> {
if (!LOWERCASE_HEADERS.contains(key)) {
return values;
}
return values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList());
};

requestHeaders = lowercase(requestHeaders, lowercaseHeaderValues);
queryParameters = lowercase(queryParameters, lowercaseHeaderValues);

SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
.uri(requestURI)
Expand Down Expand Up @@ -123,17 +132,4 @@ static String sign(
return new WebApplicationException(Response.Status.BAD_REQUEST);
});
}

private static MultivaluedMap<String, String> lowercase(MultivaluedMap<String, String> map)
{
MultivaluedMap<String, String> result = new MultivaluedHashMap<>();
map.forEach((name, values) -> {
name = name.toLowerCase(Locale.ROOT);
if (LOWERCASE_HEADERS.contains(name)) {
values = values.stream().map(value -> value.toLowerCase(Locale.ROOT)).collect(toImmutableList());
}
result.put(name, values);
});
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.rest;

import com.google.common.base.Splitter;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.UriBuilder;
import org.glassfish.jersey.server.ContainerRequest;

import java.util.List;
import java.util.Locale;
import java.util.Optional;

import static io.trino.s3.proxy.server.collections.MultiMapHelper.lowercase;
import static java.util.Objects.requireNonNull;

public record ParsedS3Request(
String bucketName,
String keyInBucket,
MultivaluedMap<String, String> requestHeaders,
MultivaluedMap<String, String> requestQueryParameters,
String httpVerb)
{
public ParsedS3Request
{
requireNonNull(bucketName, "bucketName is null");
requireNonNull(keyInBucket, "keyInBucket is null");
requestHeaders = lowercase(requireNonNull(requestHeaders, "requestHeaders is null"));
requireNonNull(requestQueryParameters, "requestQueryParameters is null");
requireNonNull(httpVerb, "httpVerb is null");
}

public static ParsedS3Request fromRequest(String requestPath, MultivaluedMap<String, String> requestHeaders, MultivaluedMap<String, String> requestQueryParameters, String httpVerb, Optional<String> serverHostName)
{
MultivaluedMap<String, String> headers = lowercase(requestHeaders);
return serverHostName
.flatMap(serverHostNameValue -> {
String lowercaseServerHostName = serverHostNameValue.toLowerCase(Locale.ROOT);
return Optional.ofNullable(headers.getFirst("host"))
.map(value -> UriBuilder.fromUri("http://" + value.toLowerCase(Locale.ROOT)).build().getHost())
.filter(value -> value.endsWith(lowercaseServerHostName))
.map(value -> value.substring(0, value.length() - lowercaseServerHostName.length()))
.map(value -> value.endsWith(".") ? value.substring(0, value.length() - 1) : value);
})
.map(bucket -> new ParsedS3Request(bucket, requestPath, headers, requestQueryParameters, httpVerb))
.orElseGet(() -> {
List<String> parts = Splitter.on("/").limit(2).splitToList(requestPath);
if (parts.size() <= 1) {
return new ParsedS3Request(requestPath, "", headers, requestQueryParameters, httpVerb);
}
return new ParsedS3Request(parts.get(0), parts.get(1), headers, requestQueryParameters, httpVerb);
});
}

public static ParsedS3Request fromRequest(String requestPath, ContainerRequest request, Optional<String> serverHostName)
{
return fromRequest(requestPath, request.getHeaders(), request.getUriInfo().getQueryParameters(), request.getMethod(), serverHostName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import org.glassfish.jersey.server.ContainerRequest;
import jakarta.ws.rs.core.UriBuilder;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;
Expand All @@ -38,7 +38,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination;
import static io.trino.s3.proxy.server.credentials.SigningController.formatRequestInstant;
import static java.lang.annotation.ElementType.FIELD;
Expand Down Expand Up @@ -75,14 +74,12 @@ public void shutDown()
}
}

public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest request, AsyncResponse asyncResponse, String bucket)
public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse)
{
String remotePath = rewriteRequestPath(request, bucket);

URI remoteUri = remoteS3Facade.buildEndpoint(request.getUriInfo().getRequestUriBuilder(), remotePath, bucket, signingMetadata.region());
URI remoteUri = remoteS3Facade.buildEndpoint(UriBuilder.newInstance(), request.keyInBucket(), request.bucketName(), signingMetadata.region());

Request.Builder remoteRequestBuilder = new Request.Builder()
.setMethod(request.getMethod())
.setMethod(request.httpVerb())
.setUri(remoteUri)
.setFollowRedirects(true);

Expand All @@ -91,9 +88,10 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque
}

MultivaluedMap<String, String> remoteRequestHeaders = new MultivaluedHashMap<>();
request.getRequestHeaders().forEach((key, value) -> {
switch (key.toLowerCase()) {
request.requestHeaders().forEach((key, value) -> {
switch (key) {
case "x-amz-security-token" -> {} // we add this below
case "authorization" -> {} // we will create our own authorization header
case "amz-sdk-invocation-id", "amz-sdk-request" -> {} // don't send these
case "x-amz-date" -> remoteRequestHeaders.putSingle("X-Amz-Date", formatRequestInstant(Instant.now())); // use now for the remote request
case "host" -> remoteRequestHeaders.putSingle("Host", buildRemoteHost(remoteUri)); // replace source host with the remote AWS host
Expand All @@ -107,14 +105,13 @@ public void proxyRequest(SigningMetadata signingMetadata, ContainerRequest reque
.ifPresent(sessionToken -> remoteRequestHeaders.add("x-amz-security-token", sessionToken));

// set the new signed request auth header
String encodedPath = firstNonNull(remoteUri.getRawPath(), "");
String signature = signingController.signRequest(
signingMetadata,
Credentials::requiredRemoteCredential,
remoteUri,
remoteRequestHeaders,
request.getUriInfo().getQueryParameters(),
request.getMethod(),
request.requestQueryParameters(),
request.httpVerb(),
Optional.empty());
remoteRequestHeaders.putSingle("Authorization", signature);

Expand All @@ -134,23 +131,4 @@ private static String buildRemoteHost(URI remoteUri)
}
return remoteUri.getHost() + ":" + port;
}

private static String rewriteRequestPath(ContainerRequest request, String bucket)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move buildRemoteHost() too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you think we could move this to? rewriteRequestsPath got removed because the new S3 Metadata record has been parsed enough that we don't need to compute it again.

buildRemoteHost computes the remote host, which doesn't really belong to S3RequestMetadata I don't think.

Happy to create a helper class though, perhaps RemoteRequestUtil?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure actually. Let's leave it.

{
String path = "/" + request.getPath(false);
if (!path.startsWith(TrinoS3ProxyRestConstants.S3_PATH)) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}

path = path.substring(TrinoS3ProxyRestConstants.S3_PATH.length());
if (path.startsWith("/" + bucket)) {
path = path.substring(("/" + bucket).length());
}

if (path.isEmpty() && bucket.isEmpty()) {
path = "/";
}

return path;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.rest;

import io.airlift.configuration.Config;
import io.airlift.configuration.ConfigDescription;
import jakarta.validation.constraints.NotNull;

import java.util.Optional;

public class TrinoS3ProxyConfig
mosiac1 marked this conversation as resolved.
Show resolved Hide resolved
{
private Optional<String> hostName = Optional.empty();

@Config("s3proxy.hostname")
@ConfigDescription("Hostname to use for REST operations, virtual-host style addressing is only supported if this is set")
public TrinoS3ProxyConfig setHostName(String hostName)
{
this.hostName = Optional.ofNullable(hostName);
return this;
}

@NotNull
public Optional<String> getHostName()
{
return hostName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.s3.proxy.server.rest;

import com.google.common.base.Splitter;
import com.google.inject.Inject;
import io.trino.s3.proxy.server.credentials.SigningController;
import io.trino.s3.proxy.server.credentials.SigningServiceType;
Expand All @@ -26,7 +25,6 @@
import jakarta.ws.rs.core.Context;
import org.glassfish.jersey.server.ContainerRequest;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
Expand All @@ -36,12 +34,14 @@ public class TrinoS3ProxyResource
{
private final SigningController signingController;
private final TrinoS3ProxyClient proxyClient;
private final Optional<String> serverHostName;

@Inject
public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient)
public TrinoS3ProxyResource(SigningController signingController, TrinoS3ProxyClient proxyClient, TrinoS3ProxyConfig trinoS3ProxyConfig)
{
this.signingController = requireNonNull(signingController, "signingController is null");
this.proxyClient = requireNonNull(proxyClient, "proxyClient is null");
this.serverHostName = requireNonNull(trinoS3ProxyConfig, "restConfig is null").getHostName();
}

@GET
Expand All @@ -54,7 +54,7 @@ public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse as
@Path("{path:.*}")
public void s3Get(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path)
{
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path));
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse);
}

@HEAD
Expand All @@ -67,12 +67,11 @@ public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse a
@Path("{path:.*}")
public void s3Head(@Context ContainerRequest request, @Suspended AsyncResponse asyncResponse, @PathParam("path") String path)
{
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), request, asyncResponse, getBucket(path));
proxyClient.proxyRequest(signingController.validateAndParseAuthorization(request, SigningServiceType.S3, Optional.empty()), parseRequest(path, request), asyncResponse);
}

private String getBucket(String path)
private ParsedS3Request parseRequest(String path, ContainerRequest request)
mosiac1 marked this conversation as resolved.
Show resolved Hide resolved
{
List<String> parts = Splitter.on("/").splitToList(path);
return parts.isEmpty() ? "" : parts.getFirst();
return ParsedS3Request.fromRequest(path, request, serverHostName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import static org.assertj.core.api.Assertions.assertThat;

@TrinoS3ProxyTest(modules = WithConfiguredBuckets.class)
@TrinoS3ProxyTest(filters = WithConfiguredBuckets.class)
public class TestProxiedAssumedRoleRequests
extends AbstractTestProxiedRequests
{
Expand Down
Loading
Loading