Skip to content

Commit

Permalink
Use SSL context/params in RPP for HTTPS
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Vieira <[email protected]>
  • Loading branch information
ruivieira committed Jul 10, 2024
1 parent af8c300 commit 88b4596
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
14 changes: 13 additions & 1 deletion src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,19 @@
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;

import javax.annotation.concurrent.GuardedBy;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.UncheckedIOException;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.channels.ClosedByInterruptException;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
Expand Down Expand Up @@ -431,7 +435,7 @@ public abstract class ModelMesh extends ThriftService
private PayloadProcessor initPayloadProcessor() {
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) {
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
try {
Expand All @@ -442,6 +446,14 @@ private PayloadProcessor initPayloadProcessor() {
String method = uri.getFragment();
if ("http".equals(processorName)) {
processor = new RemotePayloadProcessor(uri);
} else if ("https".equals(processorName)) {
SSLContext sslContext;
try {
sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException missingAlgorithmException) {
throw new UncheckedIOException(new IOException(missingAlgorithmException));
}
processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters());
} else if ("logger".equals(processorName)) {
processor = new LoggingPayloadProcessor();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.grpc.Metadata;
Expand All @@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor {

private final URI uri;

private final SSLContext sslContext;
private final SSLParameters sslParameters;

private final HttpClient client;

public RemotePayloadProcessor(URI uri) {
this(uri, null, null);
}

public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) {
this.uri = uri;
this.client = HttpClient.newHttpClient();
this.sslContext = sslContext;
this.sslParameters = sslParameters;
if (sslContext != null && sslParameters != null) {
this.client = HttpClient.newBuilder()
.sslContext(sslContext)
.sslParameters(sslParameters)
.build();
} else {
this.client = HttpClient.newHttpClient();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,55 @@

package com.ibm.watson.modelmesh.payload;

import java.io.IOException;
import java.net.URI;
import java.security.NoSuchAlgorithmException;

import io.grpc.Metadata;
import io.grpc.Status;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import static org.junit.jupiter.api.Assertions.assertFalse;

class RemotePayloadProcessorTest {

void testDestinationUnreachable() throws IOException {
URI uri = URI.create("http://this-does-not-exist:123");
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}

@Test
void testDestinationUnreachable() {
RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123"));
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException {
URI uri = URI.create("https://this-does-not-exist:123");
SSLContext sslContext = SSLContext.getDefault();
SSLParameters sslParameters = sslContext.getDefaultSSLParameters();
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}
}

0 comments on commit 88b4596

Please sign in to comment.