Skip to content

Commit

Permalink
Validate signatures for AWS chunked input streams
Browse files Browse the repository at this point in the history
Follow the AWS chunked protocol to validate chunks that include
an AWS signature extension. Alters `Signer` so that it saves
values during main request validation need to validate the chunk
signatures.

Closes #55
  • Loading branch information
Randgalt committed Jun 18, 2024
1 parent a7b8809 commit 447f220
Show file tree
Hide file tree
Showing 14 changed files with 700 additions and 1,036 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.UriBuilder;
import org.apache.commons.httpclient.ChunkedInputStream;
import org.glassfish.jersey.server.ContainerRequest;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
Expand Down Expand Up @@ -85,14 +83,12 @@ private static RequestContent buildRequestContent(ContainerRequest request)
default -> ContentType.STANDARD;
};

Supplier<InputStream> inputStreamSupplier = () -> buildInputStream(request.getEntityStream(), contentType);

Supplier<Optional<byte[]>> bytesSupplier;
if (contentType == ContentType.STANDARD) {
// memoize the entity bytes so it can be called multiple times
bytesSupplier = Suppliers.memoize(() -> {
try {
return Optional.of(toByteArray(inputStreamSupplier.get()));
return Optional.of(toByteArray(request.getEntityStream()));
}
catch (IOException e) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
Expand Down Expand Up @@ -129,30 +125,11 @@ public Optional<InputStream> inputStream()
{
return standardBytes()
.map(bytes -> (InputStream) new ByteArrayInputStream(bytes))
.or(() -> Optional.of(inputStreamSupplier.get()));
.or(() -> Optional.of(request.getEntityStream()));
}
};
}

private static InputStream buildInputStream(InputStream entityStream, ContentType contentType)
{
return (contentType == ContentType.AWS_CHUNKED) ? awsChunkedStream(entityStream) : entityStream;
}

private static InputStream awsChunkedStream(InputStream inputStream)
{
// TODO do we need to add a Jersey MessageBodyWriter that handles aws-chunked?

// TODO move this into a Jersey MessageBodyReader
try {
// AWS's custom chunked encoding doesn't get handled by Jersey. Do it manually.
return new ChunkedInputStream(inputStream);
}
catch (IOException e) {
throw new UncheckedIOException(e);
}
}

private static String encoding(MultivaluedMap<String, String> requestHeaders)
{
return firstNonNull(requestHeaders.getFirst("content-encoding"), firstNonNull(requestHeaders.getFirst("transfer-encoding"), ""));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques
.ifPresent(sessionToken -> remoteRequestHeaders.add("x-amz-security-token", sessionToken));

request.requestContent().contentLength().ifPresent(length -> remoteRequestHeaders.putSingle("content-length", Integer.toString(length)));
request.requestContent().inputStream().ifPresent(inputStream -> {
remoteRequestBuilder.setBodyGenerator(new StreamingBodyGenerator(inputStream));
remoteRequestHeaders.putSingle("x-amz-content-sha256", "UNSIGNED-PAYLOAD");
});

signingController.inputStreamForContent(request.requestContent(), signingMetadata, Credentials::emulated)
.ifPresent(inputStream -> {
remoteRequestBuilder.setBodyGenerator(new StreamingBodyGenerator(inputStream));
remoteRequestHeaders.putSingle("x-amz-content-sha256", "UNSIGNED-PAYLOAD");
});

// set the new signed request auth header
String signature = signingController.signRequest(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.s3.proxy.server.signing;

import com.google.common.base.Splitter;
import org.apache.commons.httpclient.util.EncodingUtil;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import static org.apache.commons.httpclient.HttpParser.parseHeaders;

// based/copied on Apache Commons ChunkedInputStream
class AwsChunkedInputStream
extends InputStream
{
private final InputStream delegate;
private final Optional<ChunkSigningSession> chunkSigningSession;

private int chunkSize;
private int position;
private boolean latent = true;
private boolean eof;
private boolean closed;

AwsChunkedInputStream(InputStream delegate, Optional<ChunkSigningSession> chunkSigningSession)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.chunkSigningSession = requireNonNull(chunkSigningSession, "chunkSigningSession is null");
}

public int read()
throws IOException
{
checkState(!closed, "Stream is closed");

if (eof) {
return -1;
}
if (position >= chunkSize) {
nextChunk();
if (eof) {
return -1;
}
}
position++;
int i = delegate.read();
if (i >= 0) {
chunkSigningSession.ifPresent(session -> session.write((byte) (i & 0xff)));
}
return i;
}

public int read(byte[] b, int off, int len)
throws IOException
{
checkState(!closed, "Stream is closed");

if (eof) {
return -1;
}
if (position >= chunkSize) {
nextChunk();
if (eof) {
return -1;
}
}

len = Math.min(len, chunkSize - position);
int count = delegate.read(b, off, len);
position += count;

chunkSigningSession.ifPresent(session -> session.write(b, off, count));

return count;
}

private void readCRLF()
throws IOException
{
int cr = delegate.read();
int lf = delegate.read();
if ((cr != '\r') || (lf != '\n')) {
throw new IOException("CRLF expected at end of chunk: " + cr + "/" + lf);
}
}

private void nextChunk()
throws IOException
{
if (!latent) {
readCRLF();
}

ChunkMetadata metadata = chunkMetadata(delegate);
chunkSigningSession.ifPresent(session -> {
String chunkSignature = metadata.chunkSignature().orElseThrow(() -> new UncheckedIOException(new IOException("Chunk is missing a signature: " + metadata.rawDataString)));
session.startChunk(chunkSignature);
});

chunkSize = metadata.chunkSize;
latent = false;
position = 0;
if (chunkSize == 0) {
chunkSigningSession.ifPresent(ChunkSigningSession::complete);
eof = true;
parseHeaders(delegate, "UTF-8");
}
}

private record ChunkMetadata(String rawDataString, int chunkSize, Optional<String> chunkSignature)
{
private ChunkMetadata
{
requireNonNull(rawDataString, "rawDataString is null");
requireNonNull(chunkSignature, "chunkSignature is null");
}
}

private static ChunkMetadata chunkMetadata(InputStream in)
throws IOException
{
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
// States: 0=normal, 1=\r was scanned, 2=inside quoted string, -1=end
int state = 0;
while (state != -1) {
int b = in.read();
if (b == -1) {
throw new IOException("chunked stream ended unexpectedly");
}
switch (state) {
case 0:
switch (b) {
case '\r':
state = 1;
break;
case '\"':
state = 2;
/* fall through */
default:
outputStream.write(b);
}
break;

case 1:
if (b == '\n') {
state = -1;
}
else {
// this was not CRLF
throw new IOException("Protocol violation: Unexpected single newline character in chunk size");
}
break;

case 2:
switch (b) {
case '\\':
b = in.read();
outputStream.write(b);
break;
case '\"':
state = 0;
/* fall through */
default:
outputStream.write(b);
}
break;
default:
throw new RuntimeException("assertion failed");
}
}

String dataString = EncodingUtil.getAsciiString(outputStream.toByteArray());

String chunkSizeString;
Optional<String> chunkSignature;

int separatorIndex = dataString.indexOf(';');
if (separatorIndex > 0) {
chunkSizeString = dataString.substring(0, separatorIndex).trim();

if ((separatorIndex + 1) < dataString.length()) {
String remainder = dataString.substring(separatorIndex + 1).trim();
chunkSignature = Splitter.on(';').trimResults().withKeyValueSeparator('=').split(remainder)
.entrySet()
.stream()
.filter(entry -> entry.getKey().equalsIgnoreCase("chunk-signature"))
.map(Map.Entry::getValue)
.findFirst();
}
else {
chunkSignature = Optional.empty();
}
}
else {
chunkSizeString = dataString.trim();
chunkSignature = Optional.empty();
}

int chunkSize;
try {
chunkSize = Integer.parseInt(chunkSizeString, 16);
}
catch (NumberFormatException e) {
throw new IOException("Bad chunk size: " + chunkSizeString);
}

return new ChunkMetadata(dataString, chunkSize, chunkSignature);
}

public void close()
throws IOException
{
if (!closed) {
try {
if (!eof) {
exhaustInputStream(this);
}
}
finally {
eof = true;
closed = true;
}
}
}

@SuppressWarnings("StatementWithEmptyBody")
private static void exhaustInputStream(InputStream inStream)
throws IOException
{
// read and discard the remainder of the message
byte[] buffer = new byte[8192];
while (inStream.read(buffer) >= 0) {
// NOP
}
}
}
Loading

0 comments on commit 447f220

Please sign in to comment.