diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index f8779651..aaabaa53 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -100,8 +100,11 @@ import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; import javax.annotation.concurrent.GuardedBy; +import javax.net.ssl.SSLContext; import java.io.File; +import java.io.IOException; import java.io.InterruptedIOException; +import java.io.UncheckedIOException; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.management.MemoryUsage; @@ -109,6 +112,7 @@ import java.lang.reflect.Method; import java.net.URI; import java.nio.channels.ClosedByInterruptException; +import java.security.NoSuchAlgorithmException; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; @@ -431,7 +435,7 @@ public abstract class ModelMesh extends ThriftService private PayloadProcessor initPayloadProcessor() { String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null); logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions); - if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) { + if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) { List payloadProcessors = new ArrayList<>(); for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) { try { @@ -442,6 +446,14 @@ private PayloadProcessor initPayloadProcessor() { String method = uri.getFragment(); if ("http".equals(processorName)) { processor = new RemotePayloadProcessor(uri); + } else if ("https".equals(processorName)) { + SSLContext sslContext; + try { + sslContext = SSLContext.getDefault(); + } catch (NoSuchAlgorithmException missingAlgorithmException) { + throw new UncheckedIOException(new IOException(missingAlgorithmException)); + } + processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters()); } else if ("logger".equals(processorName)) { processor = new LoggingPayloadProcessor(); } diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java index 23c2fba1..12a64f1f 100644 --- a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java +++ b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java @@ -23,6 +23,8 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import com.fasterxml.jackson.databind.ObjectMapper; import io.grpc.Metadata; @@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor { private final URI uri; + private final SSLContext sslContext; + private final SSLParameters sslParameters; + private final HttpClient client; public RemotePayloadProcessor(URI uri) { + this(uri, null, null); + } + + public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) { this.uri = uri; - this.client = HttpClient.newHttpClient(); + this.sslContext = sslContext; + this.sslParameters = sslParameters; + if (sslContext != null && sslParameters != null) { + this.client = HttpClient.newBuilder() + .sslContext(sslContext) + .sslParameters(sslParameters) + .build(); + } else { + this.client = HttpClient.newHttpClient(); + } } @Override diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java index ec08ea0a..a8da3c4c 100644 --- a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java @@ -16,7 +16,9 @@ package com.ibm.watson.modelmesh.payload; +import java.io.IOException; import java.net.URI; +import java.security.NoSuchAlgorithmException; import io.grpc.Metadata; import io.grpc.Status; @@ -24,22 +26,45 @@ import io.netty.buffer.Unpooled; import org.junit.jupiter.api.Test; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + import static org.junit.jupiter.api.Assertions.assertFalse; class RemotePayloadProcessorTest { + void testDestinationUnreachable() throws IOException { + URI uri = URI.create("http://this-does-not-exist:123"); + try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) { + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } + } + @Test - void testDestinationUnreachable() { - RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123")); - String id = "123"; - String modelId = "456"; - String method = "predict"; - Status kind = Status.INVALID_ARGUMENT; - Metadata metadata = new Metadata(); - metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); - ByteBuf data = Unpooled.buffer(4); - Payload payload = new Payload(id, modelId, method, metadata, data, kind); - assertFalse(remotePayloadProcessor.process(payload)); + void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException { + URI uri = URI.create("https://this-does-not-exist:123"); + SSLContext sslContext = SSLContext.getDefault(); + SSLParameters sslParameters = sslContext.getDefaultSSLParameters(); + try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) { + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } } } \ No newline at end of file