diff --git a/config/base/kustomization.yaml b/config/base/kustomization.yaml index 592bd4b2..a4647046 100644 --- a/config/base/kustomization.yaml +++ b/config/base/kustomization.yaml @@ -32,6 +32,7 @@ vars: #patchesStrategicMerge: # - patches/etcd.yaml +# - patches/logger.yaml # - patches/tls.yaml # - patches/uds.yaml # - patches/max_msg_size.yaml diff --git a/config/base/patches/logger.yaml b/config/base/patches/logger.yaml new file mode 100644 index 00000000..fd39dec5 --- /dev/null +++ b/config/base/patches/logger.yaml @@ -0,0 +1,29 @@ +# Copyright 2023 IBM Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use this patch to change the max size in bytes allowed +# per proxied gRPC message, for headers and data +# +apiVersion: apps/v1 +kind: Deployment +metadata: + name: model-mesh +spec: + template: + spec: + containers: + - name: mm + env: + - name: MM_PAYLOAD_PROCESSORS + value: "logger://*" diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index f4f964ec..136c9f88 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -61,6 +61,12 @@ import com.ibm.watson.modelmesh.TypeConstraintManager.ProhibitedTypeSet; import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap; import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap.EvictionListenerWithTime; +import com.ibm.watson.modelmesh.payload.AsyncPayloadProcessor; +import com.ibm.watson.modelmesh.payload.CompositePayloadProcessor; +import com.ibm.watson.modelmesh.payload.LoggingPayloadProcessor; +import com.ibm.watson.modelmesh.payload.MatchingPayloadProcessor; +import com.ibm.watson.modelmesh.payload.PayloadProcessor; +import com.ibm.watson.modelmesh.payload.RemotePayloadProcessor; import com.ibm.watson.modelmesh.thrift.ApplierException; import com.ibm.watson.modelmesh.thrift.BaseModelMeshService; import com.ibm.watson.modelmesh.thrift.InternalException; @@ -101,6 +107,7 @@ import java.lang.management.MemoryUsage; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.net.URI; import java.nio.channels.ClosedByInterruptException; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -421,6 +428,40 @@ public abstract class ModelMesh extends ThriftService } } + private PayloadProcessor initPayloadProcessor() { + String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null); + logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions); + if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) { + List payloadProcessors = new ArrayList<>(); + for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) { + try { + URI uri = URI.create(processorDefinition); + String processorName = uri.getScheme(); + PayloadProcessor processor = null; + String modelId = uri.getQuery(); + String method = uri.getFragment(); + if ("http".equals(processorName)) { + processor = new RemotePayloadProcessor(uri); + } else if ("logger".equals(processorName)) { + processor = new LoggingPayloadProcessor(); + } + if (processor != null) { + MatchingPayloadProcessor p = MatchingPayloadProcessor.from(modelId, method, processor); + payloadProcessors.add(p); + logger.info("Added PayloadProcessor {}", p.getName()); + } + } catch (IllegalArgumentException iae) { + logger.error("Unable to parse PayloadProcessor URI definition {}", processorDefinition); + } + } + return new AsyncPayloadProcessor(new CompositePayloadProcessor(payloadProcessors), 1, MINUTES, + Executors.newScheduledThreadPool(getIntParameter(MM_PAYLOAD_PROCESSORS_THREADS, 2)), + getIntParameter(MM_PAYLOAD_PROCESSORS_CAPACITY, 64)); + } else { + return null; + } + } + /* ---------------------------------- initialization --------------------------------------------------------- */ @Override @@ -854,10 +895,11 @@ protected final TProcessor initialize() throws Exception { } LogRequestHeaders logHeaders = LogRequestHeaders.getConfiguredLogRequestHeaders(); + PayloadProcessor payloadProcessor = initPayloadProcessor(); grpcServer = new ModelMeshApi((SidecarModelMesh) this, vModelManager, GRPC_PORT, keyCertFile, privateKeyFile, privateKeyPassphrase, clientAuth, caCertFiles, maxGrpcMessageSize, maxGrpcHeadersSize, - maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders); + maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders, payloadProcessor); } if (grpcServer != null) { diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java index fb722eb0..ff143ac6 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java @@ -45,6 +45,8 @@ import com.ibm.watson.modelmesh.api.UnregisterModelRequest; import com.ibm.watson.modelmesh.api.UnregisterModelResponse; import com.ibm.watson.modelmesh.api.VModelStatusInfo; +import com.ibm.watson.modelmesh.payload.Payload; +import com.ibm.watson.modelmesh.payload.PayloadProcessor; import com.ibm.watson.modelmesh.thrift.ApplierException; import com.ibm.watson.modelmesh.thrift.InvalidInputException; import com.ibm.watson.modelmesh.thrift.InvalidStateException; @@ -156,6 +158,10 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase // null if header logging is not enabled. protected final LogRequestHeaders logHeaders; + private final PayloadProcessor payloadProcessor; + + private final ThreadLocal localIdCounter = ThreadLocal.withInitial(() -> new long[1]); + /** * Create and start the server. * @@ -171,16 +177,18 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase * @param maxConnectionAge in seconds * @param maxConnectionAgeGrace in seconds, custom grace time for graceful connection termination * @param logHeaders + * @param payloadProcessor a processor of payloads * @throws IOException */ public ModelMeshApi(SidecarModelMesh delegate, VModelManager vmm, int port, File keyCert, File privateKey, String privateKeyPassphrase, ClientAuth clientAuth, File[] trustCerts, int maxMessageSize, int maxHeadersSize, long maxConnectionAge, long maxConnectionAgeGrace, - LogRequestHeaders logHeaders) throws IOException { + LogRequestHeaders logHeaders, PayloadProcessor payloadProcessor) throws IOException { this.delegate = delegate; this.vmm = vmm; this.logHeaders = logHeaders; + this.payloadProcessor = payloadProcessor; this.multiParallelism = getMultiParallelism(); @@ -293,6 +301,13 @@ public void shutdown(long timeout, TimeUnit unit) throws InterruptedException { if (!done) { server.shutdownNow(); } + if (payloadProcessor != null) { + try { + payloadProcessor.close(); + } catch (IOException e) { + logger.warn("Error closing PayloadProcessor {}: {}", payloadProcessor, e.getMessage()); + } + } threads.shutdownNow(); shutdownEventLoops(); } @@ -686,49 +701,57 @@ public void onHalfClose() { call.close(INTERNAL.withDescription("Half-closed without a request"), emptyMeta()); return; } - final int reqSize = reqMessage.readableBytes(); + int reqReaderIndex = reqMessage.readerIndex(); + int reqSize = reqMessage.readableBytes(); int respSize = -1; + int respReaderIndex = 0; + io.grpc.Status status = INTERNAL; + String modelId = null; + String requestId = null; + ModelResponse response = null; try (InterruptingListener cancelListener = newInterruptingListener()) { if (logHeaders != null) { logHeaders.addToMDC(headers); // MDC cleared in finally block } - ModelResponse response = null; + if (payloadProcessor != null) { + requestId = Thread.currentThread().getId() + "-" + ++localIdCounter.get()[0]; + } try { - try { - String balancedMetaVal = headers.get(BALANCED_META_KEY); - Iterator midIt = modelIds.iterator(); - // guaranteed at least one - String modelId = validateModelId(midIt.next(), isVModel); - if (!midIt.hasNext()) { - // single model case (most common) - response = callModel(modelId, isVModel, methodName, - balancedMetaVal, headers, reqMessage).retain(); - } else { - // multi-model case (specialized) - boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY)); - List idList = new ArrayList<>(); - idList.add(modelId); - while (midIt.hasNext()) { - idList.add(validateModelId(midIt.next(), isVModel)); - } - response = applyParallelMultiModel(idList, isVModel, methodName, - balancedMetaVal, headers, reqMessage, allRequired); + String balancedMetaVal = headers.get(BALANCED_META_KEY); + Iterator midIt = modelIds.iterator(); + // guaranteed at least one + modelId = validateModelId(midIt.next(), isVModel); + if (!midIt.hasNext()) { + // single model case (most common) + response = callModel(modelId, isVModel, methodName, + balancedMetaVal, headers, reqMessage).retain(); + } else { + // multi-model case (specialized) + boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY)); + List idList = new ArrayList<>(); + idList.add(modelId); + while (midIt.hasNext()) { + idList.add(validateModelId(midIt.next(), isVModel)); } - } finally { - releaseReqMessage(); + response = applyParallelMultiModel(idList, isVModel, methodName, + balancedMetaVal, headers, reqMessage, allRequired); } - - respSize = response.data.readableBytes(); - call.sendHeaders(response.metadata); - call.sendMessage(response.data); - response = null; } finally { - if (response != null) { - response.release(); + if (payloadProcessor != null) { + processPayload(reqMessage.readerIndex(reqReaderIndex), + requestId, modelId, methodName, headers, null, true); + } else { + releaseReqMessage(); } + reqMessage = null; // ownership released or transferred } + respReaderIndex = response.data.readerIndex(); + respSize = response.data.readableBytes(); + call.sendHeaders(response.metadata); + call.sendMessage(response.data); + // response is released via ReleaseAfterResponse.releaseAll() status = OK; } catch (Exception e) { status = toStatus(e); @@ -745,6 +768,15 @@ public void onHalfClose() { evictMethodDescriptor(methodName); } } finally { + if (payloadProcessor != null) { + ByteBuf data = null; + Metadata metadata = null; + if (response != null) { + data = response.data.readerIndex(respReaderIndex); + metadata = response.metadata; + } + processPayload(data, requestId, modelId, methodName, metadata, status, false); + } ReleaseAfterResponse.releaseAll(); clearThreadLocals(); //TODO(maybe) additional trailer info in exception case? @@ -757,6 +789,35 @@ public void onHalfClose() { } } + /** + * Invoke PayloadProcessor on the request/response data + * @param data the binary data + * @param payloadId the id of the request + * @param modelId the id of the model + * @param methodName the name of the invoked method + * @param metadata the method name metadata + * @param status null for requests, non-null for responses + * @param takeOwnership whether the processor should take ownership + */ + private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName, + Metadata metadata, io.grpc.Status status, boolean takeOwnership) { + Payload payload = null; + try { + assert payloadProcessor != null; + if (!takeOwnership) { + data.retain(); + } + payload = new Payload(payloadId, modelId, methodName, metadata, data, status); + if (payloadProcessor.process(payload)) { + data = null; // ownership transferred + } + } catch (Throwable t) { + logger.warn("Error while processing payload: {}", payload, t); + } finally { + ReferenceCountUtil.release(data); + } + } + @Override public void onComplete() { releaseReqMessage(); diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java index 6351f5f6..baa64db2 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java @@ -23,6 +23,10 @@ public final class ModelMeshEnvVars { private ModelMeshEnvVars() {} + public static final String MM_PAYLOAD_PROCESSORS = "MM_PAYLOAD_PROCESSORS"; + public static final String MM_PAYLOAD_PROCESSORS_THREADS = "MM_PAYLOAD_PROCESSORS_THREADS"; + public static final String MM_PAYLOAD_PROCESSORS_CAPACITY = "MM_PAYLOAD_PROCESSORS_CAPACITY"; + // This must not be changed after model-mesh is already deployed to a particular env public static final String KV_STORE_PREFIX = "MM_KVSTORE_PREFIX"; diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/AsyncPayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/AsyncPayloadProcessor.java new file mode 100644 index 00000000..2d5524f3 --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/AsyncPayloadProcessor.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.io.IOException; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An async {@link PayloadProcessor} that queues processing requests and process them asynchronously. + */ +public class AsyncPayloadProcessor implements PayloadProcessor { + + private static final Logger logger = LoggerFactory.getLogger(AsyncPayloadProcessor.class); + + private final PayloadProcessor delegate; + + private final LinkedBlockingDeque payloads; + + private final AtomicInteger dropped; + + private final ScheduledExecutorService executorService; + + public AsyncPayloadProcessor(PayloadProcessor delegate) { + this(delegate, 1, TimeUnit.MINUTES, Executors.newScheduledThreadPool(2), 64); + } + + public AsyncPayloadProcessor(PayloadProcessor delegate, int delay, TimeUnit timeUnit, + ScheduledExecutorService executorService, int capacity) { + this.delegate = delegate; + this.dropped = new AtomicInteger(); + this.payloads = new LinkedBlockingDeque<>(capacity); + this.executorService = executorService; + + this.executorService.execute(() -> { + try { + while (true) { + processPayload(payloads.take()); + } + } catch (InterruptedException ie) { + // Here we assume that we're shutting down + logger.info("Payload queue processing interrupted"); + } + // Process any remaining payloads in the queue + for (Payload p; (p = payloads.poll()) != null;) { + processPayload(p); + } + try { + this.delegate.close(); + } catch (IOException e) { + // ignore + } + logger.info("AsyncPayloadProcessor task exiting"); + }); + + this.executorService.scheduleWithFixedDelay(() -> { + if (dropped.get() > 0) { + int droppedRequest = dropped.getAndSet(0); + logger.warn("{} payloads were dropped because of {} capacity limit in the last {} {}", droppedRequest, + capacity, delay, timeUnit); + } + }, 0, delay, timeUnit); + } + + void processPayload(Payload p) { + boolean released = false; + try { + released = delegate.process(p); + } catch (Throwable t) { + logger.warn("Error while processing payload: {}", p, t); + } finally { + if (!released) { + p.release(); + } + } + } + + @Override + public boolean mayTakeOwnership() { + return true; + } + + @Override + public String getName() { + return delegate.getName() + "-async"; + } + + @Override + public boolean process(Payload payload) { + boolean enqueued = payloads.offer(payload); + if (!enqueued) { + dropped.incrementAndGet(); + } + return enqueued; + } + + @Override + public void close() throws IOException { + this.executorService.shutdownNow(); + } + +} diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/CompositePayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/CompositePayloadProcessor.java new file mode 100644 index 00000000..90c8c09a --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/CompositePayloadProcessor.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A composite {@link PayloadProcessor} that delegates processing to multiple delegate {@link PayloadProcessor}s (sequentially). + */ +public class CompositePayloadProcessor implements PayloadProcessor { + + private static final Logger logger = LoggerFactory.getLogger(CompositePayloadProcessor.class); + + private final List delegates; + + /** + * If any of the delegate processors take ownership of the payload. + * + * @param delegates the delegate processors + */ + public CompositePayloadProcessor(List delegates) { + if (delegates.stream().anyMatch(PayloadProcessor::mayTakeOwnership)) { + throw new IllegalArgumentException( + "CompositePayloadProcessor can only be used with delegate processors that won't take ownership" + ); + } + this.delegates = delegates; + } + + @Override + public String getName() { + return "composite:[" + delegates.stream().map(PayloadProcessor::getName).collect(Collectors.joining()) + "]"; + } + + @Override + public boolean process(Payload payload) { + for (PayloadProcessor processor : delegates) { + boolean consumed = false; + try { + consumed = processor.process(payload); + } catch (Throwable t) { + logger.error("PayloadProcessor {} failed processing payload {}", processor.getName(), payload, t); + } + if (consumed) { + throw new RuntimeException("PayloadProcessor " + processor.getName() + + " unexpectedly took ownership of the payload"); + } + } + return false; + } + + @Override + public void close() throws IOException { + for (PayloadProcessor processor : this.delegates) { + processor.close(); + } + } +} diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/LoggingPayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/LoggingPayloadProcessor.java new file mode 100644 index 00000000..31c05f4f --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/LoggingPayloadProcessor.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link PayloadProcessor} that logs {@link Payload}s to a logger (INFO level). + */ +public class LoggingPayloadProcessor implements PayloadProcessor { + + private final static Logger LOG = LoggerFactory.getLogger(LoggingPayloadProcessor.class); + + @Override + public String getName() { + return "logger"; + } + + @Override + public boolean process(Payload payload) { + LOG.info("Payload: {}", payload); + return false; + } + +} diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java new file mode 100644 index 00000000..45402423 --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.io.IOException; + +/** + * A {@link PayloadProcessor} that processes {@link Payload}s only if they match with given model ID or method name. + */ +public class MatchingPayloadProcessor implements PayloadProcessor { + + private final PayloadProcessor delegate; + + private final String methodName; + + private final String modelId; + + MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId) { + this.delegate = delegate; + this.methodName = methodName; + this.modelId = modelId; + } + + @Override + public String getName() { + return delegate.getName(); + } + + @Override + public boolean process(Payload payload) { + boolean processed = false; + boolean methodMatches = true; + if (this.methodName != null) { + methodMatches = payload.getMethod() != null && this.methodName.equals(payload.getMethod()); + } + if (methodMatches) { + boolean modelIdMatches = true; + if (this.modelId != null) { + modelIdMatches = this.modelId.equals(payload.getModelId()); + } + if (modelIdMatches) { + processed = delegate.process(payload); + } + } + return processed; + } + + public static MatchingPayloadProcessor from(String modelId, String method, PayloadProcessor processor) { + if (modelId != null) { + if (modelId.length() > 0) { + modelId = modelId.replaceFirst("/", ""); + if (modelId.length() == 0 || modelId.equals("*")) { + modelId = null; + } + } else { + modelId = null; + } + } + if (method != null) { + if (method.length() == 0 || method.equals("*")) { + method = null; + } + } + return new MatchingPayloadProcessor(processor, method, modelId); + } + + @Override + public void close() throws IOException { + this.delegate.close(); + } +} diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java b/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java new file mode 100644 index 00000000..9eed4367 --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java @@ -0,0 +1,111 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import io.grpc.Metadata; +import io.grpc.Status; +import io.netty.buffer.ByteBuf; +import io.netty.util.ReferenceCountUtil; + +/** + * A model-mesh payload. + */ +public class Payload { + + public enum Kind { + REQUEST, + RESPONSE + } + + private final String id; + + private final String modelId; + + private final String method; + + private final Metadata metadata; + + private final ByteBuf data; + + // null for requests, non-null for responses + private final Status status; + + public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String method, @Nullable Metadata metadata, + @Nullable ByteBuf data, @Nullable Status status) { + this.id = id; + this.modelId = modelId; + this.method = method; + this.metadata = metadata; + this.data = data; + this.status = status; + } + + @Nonnull + public String getId() { + return id; + } + + @Nonnull + public String getModelId() { + return modelId; + } + + @CheckForNull + public String getMethod() { + return method; + } + + @CheckForNull + public Metadata getMetadata() { + return metadata; + } + + @CheckForNull + public ByteBuf getData() { + return data; + } + + @Nonnull + public Kind getKind() { + return status == null ? Kind.REQUEST : Kind.RESPONSE; + } + + @Nullable + public Status getStatus() { + return status; + } + + public void release() { + ReferenceCountUtil.release(this.data); + } + + @Override + public String toString() { + return "Payload{" + + "id='" + id + '\'' + + ", modelId='" + modelId + '\'' + + ", method='" + method + '\'' + + ", status=" + (status == null ? "request" : String.valueOf(status)) + + ", metadata=" + metadata + + ", data=" + (data != null ? data.readableBytes() + "B" : "") + + '}'; + } +} diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/PayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/PayloadProcessor.java new file mode 100644 index 00000000..614876c7 --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/PayloadProcessor.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.io.Closeable; +import java.io.IOException; + +/** + * A {@link PayloadProcessor} is responsible for processing {@link Payload}s for models served by model-mesh. + * Processing shall not modify/dispose payload data. + */ +public interface PayloadProcessor extends Closeable { + + /** + * Get this processor name. + * + * @return the processor name. + */ + String getName(); + + /** + * Check whether this processor may take ownership (e.g., retaining payload data). + * If this returns {@code false} then {@link #process(Payload)} should never return {@code true}. + */ + default boolean mayTakeOwnership() { + return false; + } + + /** + * Process a payload. + * The indices of any contained byte buffers should not be changed + * + * @param payload the payload to be processed. + * @return {@code true} if the called method took ownership of the payload, {@code false} otherwise. + */ + boolean process(Payload payload); + + @Override + default void close() throws IOException {} +} diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/README.md b/src/main/java/com/ibm/watson/modelmesh/payload/README.md new file mode 100644 index 00000000..1b4e2464 --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/README.md @@ -0,0 +1,34 @@ +Processing model-mesh payloads +============================= + +`Model-mesh` exchange `Payloads` with the models deployed within runtimes. +In `model-mesh` a `Payload` consists of information regarding the id of the model and the _method_ of the model being called, together with some data (actual binary requests or responses) and metadata (e.g., headers). +A `PayloadProcessor` is responsible for processing such `Payloads` for models served by `model-mesh`. + +Reasonable examples of `PayloadProcessors` include loggers of prediction requests, data sinks for data visualization, model quality assessment or monitoring tooling. + +A `PayloadProcessor` can be configured to only look at payloads that are consumed and produced by certain models, or payloads containing certain headers, etc. +This configuration is performed at `ModelMesh` instance level. +Multiple `PayloadProcessors` can be configured per each `ModelMesh` instance. + +Implementations of `PayloadProcessors` can care about only specific portions of the payload (e.g., model inputs, model outputs, metadata, specific headers, etc.). + +A `PayloadProcessor` can see input data like the one in this example: +```text +[mmesh.ExamplePredictor/predict, Metadata(content-type=application/grpc,user-agent=grpc-java-netty/1.51.1,mm-model-id=myModel,another-custom-header=custom-value,grpc-accept-encoding=gzip,grpc-timeout=1999774u), CompositeByteBuf(ridx: 0, widx: 2000004, cap: 2000004, components=147) +``` + +A `PayloadProcessor` can see output data as `ByteBuf` like the one in this example: +```text +java.nio.HeapByteBuffer[pos=0 lim=65 cap=65] +``` + +A `PayloadProcessor` can be configured by means of a whitespace separated `String` of URIs. +In a URI like `logger:///*?pytorch1234#predict`: +* the scheme represents the type of processor, e.g., `logger` +* the query represents the model id to observe, e.g., `pytorch1234` +* the fragment represents the method to observe, e.g., `predict` + +Featured `PayloadProcessors`: +* `logger` : logs requests/responses payloads to `model-mesh` logs (_INFO_ level), e.g., use `logger://*` to log every `Payload` +* `http` : sends requests/responses payloads to a remote service (via _HTTP POST_), e.g., use `http://10.10.10.1:8080/consumer/kserve/v2` to send every `Payload` to the specified HTTP endpoint \ No newline at end of file diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java new file mode 100644 index 00000000..8004e0a7 --- /dev/null +++ b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.base64.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link PayloadProcessor} that sends payloads to a remote service via HTTP POST. + */ +public class RemotePayloadProcessor implements PayloadProcessor { + + private final static Logger logger = LoggerFactory.getLogger(RemotePayloadProcessor.class); + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private final URI uri; + + private final HttpClient client; + + public RemotePayloadProcessor(URI uri) { + this.uri = uri; + this.client = HttpClient.newHttpClient(); + } + + @Override + public boolean process(Payload payload) { + return sendPayload(payload); + } + + private static PayloadContent prepareContentBody(Payload payload) { + String id = payload.getId(); + String modelId = payload.getModelId(); + String kind = payload.getKind().toString().toLowerCase(); + ByteBuf byteBuf = payload.getData(); + String data; + if (byteBuf != null) { + ByteBuf encoded = Base64.encode(byteBuf, byteBuf.readerIndex(), byteBuf.readableBytes(), false); + //TODO custom jackson serialization for this field to avoid round-tripping to string + data = encoded.toString(StandardCharsets.US_ASCII); + } else { + data = ""; + } + String status = payload.getStatus() != null ? payload.getStatus().getCode().toString() : ""; + return new PayloadContent(id, modelId, data, kind, status); + } + + + private boolean sendPayload(Payload payload) { + try { + PayloadContent payloadContent = prepareContentBody(payload); + HttpRequest request = HttpRequest.newBuilder() + .uri(uri) + .headers("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(objectMapper.writeValueAsString(payloadContent))) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() != 200) { + logger.warn("Processing {} with request {} didn't succeed: {}", payload, payloadContent, response); + } + } catch (Throwable e) { + logger.error("An error occurred while sending payload {} to {}: {}", payload, uri, e.getCause()); + } + return false; + } + + @Override + public String getName() { + return "remote"; + } + + private static class PayloadContent { + private final String id; + private final String modelid; + private final String data; + private final String kind; + private final String status; + + private PayloadContent(String id, String modelid, String data, String kind, String status) { + this.id = id; + this.modelid = modelid; + this.data = data; + this.kind = kind; + this.status = status; + } + + public String getId() { + return id; + } + + public String getKind() { + return kind; + } + + public String getModelid() { + return modelid; + } + + public String getData() { + return data; + } + + public String getStatus() { + return status; + } + + @Override + public String toString() { + return "PayloadContent{" + + "id='" + id + '\'' + + ", modelid='" + modelid + '\'' + + ", data='" + data + '\'' + + ", kind='" + kind + '\'' + + ", status='" + status + '\'' + + '}'; + } + } +} diff --git a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshPayloadProcessingTest.java b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshPayloadProcessingTest.java new file mode 100644 index 00000000..a74ef778 --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshPayloadProcessingTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2021 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh; + +import com.ibm.watson.modelmesh.api.GetStatusRequest; +import com.ibm.watson.modelmesh.api.ModelInfo; +import com.ibm.watson.modelmesh.api.ModelMeshGrpc; +import com.ibm.watson.modelmesh.api.ModelMeshGrpc.ModelMeshBlockingStub; +import com.ibm.watson.modelmesh.api.ModelStatusInfo; +import com.ibm.watson.modelmesh.api.ModelStatusInfo.ModelStatus; +import com.ibm.watson.modelmesh.api.RegisterModelRequest; +import com.ibm.watson.modelmesh.api.UnregisterModelRequest; +import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc; +import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc.ExamplePredictorBlockingStub; +import com.ibm.watson.modelmesh.example.api.Predictor.PredictRequest; +import com.ibm.watson.modelmesh.example.api.Predictor.PredictResponse; +import io.grpc.ManagedChannel; +import io.grpc.netty.NettyChannelBuilder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Model-mesh test for payload processing + */ +public class SidecarModelMeshPayloadProcessingTest extends SingleInstanceModelMeshTest { + + @BeforeEach + public void initialize() throws Exception { + System.setProperty(ModelMeshEnvVars.MM_PAYLOAD_PROCESSORS, "logger://*"); + super.initialize(); + } + + @Test + public void testPayloadProcessing() { + ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088) + .usePlaintext().build(); + try { + ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel); + + ExamplePredictorBlockingStub useModels = ExamplePredictorGrpc.newBlockingStub(channel); + + // verify not found status + ModelStatusInfo status = manageModels.getModelStatus(GetStatusRequest.newBuilder() + .setModelId("i don't exist").build()); + + assertEquals(ModelStatus.NOT_FOUND, status.getStatus()); + assertEquals(0, status.getErrorsCount()); + + + // add a model + String modelId = "myModel"; + ModelStatusInfo statusInfo = manageModels.registerModel(RegisterModelRequest.newBuilder() + .setModelId(modelId).setModelInfo(ModelInfo.newBuilder().setType("ExampleType").build()) + .setLoadNow(true).build()); + + System.out.println("registerModel returned: " + statusInfo); + + // call predict on the model + PredictRequest req = PredictRequest.newBuilder().setText("predict me!").build(); + PredictResponse response = forModel(useModels, modelId) + .predict(req); + + System.out.println("predict returned: " + response.getResultsList()); + + assertEquals(1.0, response.getResults(0).getConfidence(), 0); + + assertEquals("classification for predict me! by model myModel", + response.getResults(0).getCategory()); + + // verify getStatus + status = manageModels.getModelStatus(GetStatusRequest.newBuilder() + .setModelId(modelId).build()); + + assertEquals(ModelStatus.LOADED, status.getStatus()); + assertEquals(0, status.getErrorsCount()); + + // delete + manageModels.unregisterModel(UnregisterModelRequest.newBuilder() + .setModelId(modelId).build()); + } finally { + channel.shutdown(); + } + } + +} diff --git a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java index d290eb1e..0531ff2c 100644 --- a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java @@ -100,7 +100,7 @@ public void grpcTest() throws Exception { response.getResults(0).getCategory()); // verify larger payload - int bigChars = 2_000_000; + int bigChars = 2; StringBuilder sb = new StringBuilder(bigChars); for (int i = 0; i < bigChars; i++) { sb.append('a'); diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/AsyncPayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/AsyncPayloadProcessorTest.java new file mode 100644 index 00000000..7cfd07b6 --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/payload/AsyncPayloadProcessorTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +class AsyncPayloadProcessorTest { + + @Test + void testPayloadProcessing() { + DummyPayloadProcessor dummyPayloadProcessor = new DummyPayloadProcessor(); + + ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); + AsyncPayloadProcessor payloadProcessor = new AsyncPayloadProcessor(dummyPayloadProcessor, 1, TimeUnit.NANOSECONDS, scheduler, 100); + + for (int i = 0; i < 10; i++) { + payloadProcessor.process(new Payload("123", "456", null, null, null, null)); + } + try { + assertFalse(scheduler.awaitTermination(1, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + // ignore it + } + for (int i = 0; i < 10; i++) { + payloadProcessor.process(new Payload("123", "456", null, null, null, null)); + } + try { + assertFalse(scheduler.awaitTermination(1, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + // ignore it + } + assertEquals(20, dummyPayloadProcessor.getProcessCount().get()); + } +} \ No newline at end of file diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/CompositePayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/CompositePayloadProcessorTest.java new file mode 100644 index 00000000..cbbee29a --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/payload/CompositePayloadProcessorTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CompositePayloadProcessorTest { + + @Test + void testPayloadProcessing() throws IOException { + List delegates = new ArrayList<>(); + delegates.add(new DummyPayloadProcessor()); + delegates.add(new DummyPayloadProcessor()); + + try (CompositePayloadProcessor payloadProcessor = new CompositePayloadProcessor(delegates)) { + for (int i = 0; i < 10; i++) { + payloadProcessor.process(new Payload("123", "456", null, null, null, null)); + } + } + for (PayloadProcessor p : delegates) { + DummyPayloadProcessor dummyPayloadProcessor = (DummyPayloadProcessor) p; + assertEquals(10, dummyPayloadProcessor.getProcessCount().get()); + } + } +} \ No newline at end of file diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/DummyPayloadProcessor.java b/src/test/java/com/ibm/watson/modelmesh/payload/DummyPayloadProcessor.java new file mode 100644 index 00000000..a47458a1 --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/payload/DummyPayloadProcessor.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; + +class DummyPayloadProcessor implements PayloadProcessor { + + private final AtomicInteger processCount; + + DummyPayloadProcessor() { + this(new AtomicInteger(0)); + } + + DummyPayloadProcessor(AtomicInteger processCount) { + this.processCount = processCount; + } + + @Override + public String getName() { + return "dummy"; + } + + @Override + public boolean process(Payload payload) { + this.processCount.incrementAndGet(); + return false; + } + + public AtomicInteger getProcessCount() { + return processCount; + } + + public void reset() { + this.processCount.set(0); + } + + @Override + public void close() throws IOException { + // do nothing + } +} diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessorTest.java new file mode 100644 index 00000000..6d53210f --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessorTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class MatchingPayloadProcessorTest { + + @Test + void testPayloadProcessingAny() { + AtomicInteger counter = new AtomicInteger(); + PayloadProcessor delegate = new DummyPayloadProcessor(counter); + MatchingPayloadProcessor payloadProcessor = MatchingPayloadProcessor.from(null, null, delegate); + payloadProcessor.process(new Payload("123", "456", null, null, null, null)); + assertEquals(1, counter.get()); + payloadProcessor.process(new Payload("456", "456", null, null, null, null)); + assertEquals(2, counter.get()); + payloadProcessor.process(new Payload("789", "456", null, null, null, null)); + assertEquals(3, counter.get()); + payloadProcessor.process(new Payload("abc", "456", "processRequest", null, null, null)); + assertEquals(4, counter.get()); + } + + @Test + void testPayloadProcessingAnySpecialChars() { + AtomicInteger counter = new AtomicInteger(); + DummyPayloadProcessor delegate = new DummyPayloadProcessor(counter); + List processors = new ArrayList<>(); + processors.add(MatchingPayloadProcessor.from("", "", delegate)); + processors.add(MatchingPayloadProcessor.from("*", "", delegate)); + processors.add(MatchingPayloadProcessor.from("/*", "", delegate)); + processors.add(MatchingPayloadProcessor.from("/*", "", delegate)); + processors.add(MatchingPayloadProcessor.from("/*", "*", delegate)); + processors.add(MatchingPayloadProcessor.from("/", "*", delegate)); + processors.add(MatchingPayloadProcessor.from("", "*", delegate)); + processors.add(MatchingPayloadProcessor.from("", "*", delegate)); + for (PayloadProcessor payloadProcessor : processors) { + payloadProcessor.process(new Payload("123", "456", null, null, null, null)); + assertEquals(1, counter.get()); + payloadProcessor.process(new Payload("456", "456", null, null, null, null)); + assertEquals(2, counter.get()); + payloadProcessor.process(new Payload("789", "456", "processRequest", null, null, null)); + assertEquals(3, counter.get()); + payloadProcessor.process(new Payload("abc", "456", "processRequest", null, null, null)); + assertEquals(4, counter.get()); + delegate.reset(); + } + } + + @Test + void testPayloadProcessingModelFilter() { + AtomicInteger counter = new AtomicInteger(); + DummyPayloadProcessor delegate = new DummyPayloadProcessor(counter); + MatchingPayloadProcessor payloadProcessor = MatchingPayloadProcessor.from("someModelId", null, delegate); + payloadProcessor.process(new Payload("123", "nogo", null, null, null, null)); + assertEquals(0, counter.get()); + payloadProcessor.process(new Payload("456", "someModelId", null, null, null, null)); + assertEquals(1, counter.get()); + payloadProcessor.process(new Payload( "789", "nogo", "processRequest", null, null, null)); + assertEquals(1, counter.get()); + payloadProcessor.process(new Payload( "abc", "someModelId", "processRequest", null, null, null)); + assertEquals(2, counter.get()); + } + + @Test + void testPayloadProcessingMethodFilter() { + AtomicInteger counter = new AtomicInteger(); + DummyPayloadProcessor delegate = new DummyPayloadProcessor(counter); + MatchingPayloadProcessor payloadProcessor = MatchingPayloadProcessor.from(null, "getName", delegate); + payloadProcessor.process(new Payload("123", "456", null, null, null, null)); + assertEquals(0, counter.get()); + payloadProcessor.process(new Payload("456", "456", null, null, null, null)); + assertEquals(0, counter.get()); + payloadProcessor.process(new Payload("789", "456", "getName", null, null, null)); + assertEquals(1, counter.get()); + payloadProcessor.process(new Payload("abc", "456", "getName", null, null, null)); + assertEquals(2, counter.get()); + payloadProcessor.process(new Payload("def", "456", "filteredMethod", null, null, null)); + assertEquals(2, counter.get()); + } +} \ No newline at end of file diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java new file mode 100644 index 00000000..7a75a60c --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh.payload; + +import java.net.URI; + +import io.grpc.Metadata; +import io.grpc.Status; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +class RemotePayloadProcessorTest { + + @Test + void testDestinationUnreachable() { + RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123")); + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } +} \ No newline at end of file