Skip to content

Commit

Permalink
feat: Make it possible to attach a PayloadProcessor to process model …
Browse files Browse the repository at this point in the history
…predictions (#84)

#### Motivation
This PR seeks to address the model-mesh side of kserve/modelmesh-serving#284.

#### Modifications
It provides a `PayloadProcessor` interface. `PayloadProcessors` are picked by `ModelMesh` instances at startup and predictions (`Payloads`) are processed asynchronously at fixed timing.
A first logger implementation allows to log `Payloads` (at _info_ level).

#### Result
A SPI for post processing model predictions.

---

resolves kserve/modelmesh-serving#284

Signed-off-by: Tommaso Teofili <[email protected]>
  • Loading branch information
tteofili authored Mar 13, 2023
1 parent 426193c commit eb384db
Show file tree
Hide file tree
Showing 20 changed files with 1,237 additions and 33 deletions.
1 change: 1 addition & 0 deletions config/base/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ vars:

#patchesStrategicMerge:
# - patches/etcd.yaml
# - patches/logger.yaml
# - patches/tls.yaml
# - patches/uds.yaml
# - patches/max_msg_size.yaml
Expand Down
29 changes: 29 additions & 0 deletions config/base/patches/logger.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Use this patch to change the max size in bytes allowed
# per proxied gRPC message, for headers and data
#
apiVersion: apps/v1
kind: Deployment
metadata:
name: model-mesh
spec:
template:
spec:
containers:
- name: mm
env:
- name: MM_PAYLOAD_PROCESSORS
value: "logger://*"
44 changes: 43 additions & 1 deletion src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
import com.ibm.watson.modelmesh.TypeConstraintManager.ProhibitedTypeSet;
import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap;
import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap.EvictionListenerWithTime;
import com.ibm.watson.modelmesh.payload.AsyncPayloadProcessor;
import com.ibm.watson.modelmesh.payload.CompositePayloadProcessor;
import com.ibm.watson.modelmesh.payload.LoggingPayloadProcessor;
import com.ibm.watson.modelmesh.payload.MatchingPayloadProcessor;
import com.ibm.watson.modelmesh.payload.PayloadProcessor;
import com.ibm.watson.modelmesh.payload.RemotePayloadProcessor;
import com.ibm.watson.modelmesh.thrift.ApplierException;
import com.ibm.watson.modelmesh.thrift.BaseModelMeshService;
import com.ibm.watson.modelmesh.thrift.InternalException;
Expand Down Expand Up @@ -101,6 +107,7 @@
import java.lang.management.MemoryUsage;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.channels.ClosedByInterruptException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
Expand Down Expand Up @@ -421,6 +428,40 @@ public abstract class ModelMesh extends ThriftService
}
}

private PayloadProcessor initPayloadProcessor() {
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
try {
URI uri = URI.create(processorDefinition);
String processorName = uri.getScheme();
PayloadProcessor processor = null;
String modelId = uri.getQuery();
String method = uri.getFragment();
if ("http".equals(processorName)) {
processor = new RemotePayloadProcessor(uri);
} else if ("logger".equals(processorName)) {
processor = new LoggingPayloadProcessor();
}
if (processor != null) {
MatchingPayloadProcessor p = MatchingPayloadProcessor.from(modelId, method, processor);
payloadProcessors.add(p);
logger.info("Added PayloadProcessor {}", p.getName());
}
} catch (IllegalArgumentException iae) {
logger.error("Unable to parse PayloadProcessor URI definition {}", processorDefinition);
}
}
return new AsyncPayloadProcessor(new CompositePayloadProcessor(payloadProcessors), 1, MINUTES,
Executors.newScheduledThreadPool(getIntParameter(MM_PAYLOAD_PROCESSORS_THREADS, 2)),
getIntParameter(MM_PAYLOAD_PROCESSORS_CAPACITY, 64));
} else {
return null;
}
}

/* ---------------------------------- initialization --------------------------------------------------------- */

@Override
Expand Down Expand Up @@ -854,10 +895,11 @@ protected final TProcessor initialize() throws Exception {
}

LogRequestHeaders logHeaders = LogRequestHeaders.getConfiguredLogRequestHeaders();
PayloadProcessor payloadProcessor = initPayloadProcessor();

grpcServer = new ModelMeshApi((SidecarModelMesh) this, vModelManager, GRPC_PORT, keyCertFile, privateKeyFile,
privateKeyPassphrase, clientAuth, caCertFiles, maxGrpcMessageSize, maxGrpcHeadersSize,
maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders);
maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders, payloadProcessor);
}

if (grpcServer != null) {
Expand Down
123 changes: 92 additions & 31 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import com.ibm.watson.modelmesh.api.UnregisterModelRequest;
import com.ibm.watson.modelmesh.api.UnregisterModelResponse;
import com.ibm.watson.modelmesh.api.VModelStatusInfo;
import com.ibm.watson.modelmesh.payload.Payload;
import com.ibm.watson.modelmesh.payload.PayloadProcessor;
import com.ibm.watson.modelmesh.thrift.ApplierException;
import com.ibm.watson.modelmesh.thrift.InvalidInputException;
import com.ibm.watson.modelmesh.thrift.InvalidStateException;
Expand Down Expand Up @@ -156,6 +158,10 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
// null if header logging is not enabled.
protected final LogRequestHeaders logHeaders;

private final PayloadProcessor payloadProcessor;

private final ThreadLocal<long[]> localIdCounter = ThreadLocal.withInitial(() -> new long[1]);

/**
* Create <b>and start</b> the server.
*
Expand All @@ -171,16 +177,18 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
* @param maxConnectionAge in seconds
* @param maxConnectionAgeGrace in seconds, custom grace time for graceful connection termination
* @param logHeaders
* @param payloadProcessor a processor of payloads
* @throws IOException
*/
public ModelMeshApi(SidecarModelMesh delegate, VModelManager vmm, int port, File keyCert, File privateKey,
String privateKeyPassphrase, ClientAuth clientAuth, File[] trustCerts,
int maxMessageSize, int maxHeadersSize, long maxConnectionAge, long maxConnectionAgeGrace,
LogRequestHeaders logHeaders) throws IOException {
LogRequestHeaders logHeaders, PayloadProcessor payloadProcessor) throws IOException {

this.delegate = delegate;
this.vmm = vmm;
this.logHeaders = logHeaders;
this.payloadProcessor = payloadProcessor;

this.multiParallelism = getMultiParallelism();

Expand Down Expand Up @@ -293,6 +301,13 @@ public void shutdown(long timeout, TimeUnit unit) throws InterruptedException {
if (!done) {
server.shutdownNow();
}
if (payloadProcessor != null) {
try {
payloadProcessor.close();
} catch (IOException e) {
logger.warn("Error closing PayloadProcessor {}: {}", payloadProcessor, e.getMessage());
}
}
threads.shutdownNow();
shutdownEventLoops();
}
Expand Down Expand Up @@ -686,49 +701,57 @@ public void onHalfClose() {
call.close(INTERNAL.withDescription("Half-closed without a request"), emptyMeta());
return;
}
final int reqSize = reqMessage.readableBytes();
int reqReaderIndex = reqMessage.readerIndex();
int reqSize = reqMessage.readableBytes();
int respSize = -1;
int respReaderIndex = 0;

io.grpc.Status status = INTERNAL;
String modelId = null;
String requestId = null;
ModelResponse response = null;
try (InterruptingListener cancelListener = newInterruptingListener()) {
if (logHeaders != null) {
logHeaders.addToMDC(headers); // MDC cleared in finally block
}
ModelResponse response = null;
if (payloadProcessor != null) {
requestId = Thread.currentThread().getId() + "-" + ++localIdCounter.get()[0];
}
try {
try {
String balancedMetaVal = headers.get(BALANCED_META_KEY);
Iterator<String> midIt = modelIds.iterator();
// guaranteed at least one
String modelId = validateModelId(midIt.next(), isVModel);
if (!midIt.hasNext()) {
// single model case (most common)
response = callModel(modelId, isVModel, methodName,
balancedMetaVal, headers, reqMessage).retain();
} else {
// multi-model case (specialized)
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
List<String> idList = new ArrayList<>();
idList.add(modelId);
while (midIt.hasNext()) {
idList.add(validateModelId(midIt.next(), isVModel));
}
response = applyParallelMultiModel(idList, isVModel, methodName,
balancedMetaVal, headers, reqMessage, allRequired);
String balancedMetaVal = headers.get(BALANCED_META_KEY);
Iterator<String> midIt = modelIds.iterator();
// guaranteed at least one
modelId = validateModelId(midIt.next(), isVModel);
if (!midIt.hasNext()) {
// single model case (most common)
response = callModel(modelId, isVModel, methodName,
balancedMetaVal, headers, reqMessage).retain();
} else {
// multi-model case (specialized)
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
List<String> idList = new ArrayList<>();
idList.add(modelId);
while (midIt.hasNext()) {
idList.add(validateModelId(midIt.next(), isVModel));
}
} finally {
releaseReqMessage();
response = applyParallelMultiModel(idList, isVModel, methodName,
balancedMetaVal, headers, reqMessage, allRequired);
}

respSize = response.data.readableBytes();
call.sendHeaders(response.metadata);
call.sendMessage(response.data);
response = null;
} finally {
if (response != null) {
response.release();
if (payloadProcessor != null) {
processPayload(reqMessage.readerIndex(reqReaderIndex),
requestId, modelId, methodName, headers, null, true);
} else {
releaseReqMessage();
}
reqMessage = null; // ownership released or transferred
}

respReaderIndex = response.data.readerIndex();
respSize = response.data.readableBytes();
call.sendHeaders(response.metadata);
call.sendMessage(response.data);
// response is released via ReleaseAfterResponse.releaseAll()
status = OK;
} catch (Exception e) {
status = toStatus(e);
Expand All @@ -745,6 +768,15 @@ public void onHalfClose() {
evictMethodDescriptor(methodName);
}
} finally {
if (payloadProcessor != null) {
ByteBuf data = null;
Metadata metadata = null;
if (response != null) {
data = response.data.readerIndex(respReaderIndex);
metadata = response.metadata;
}
processPayload(data, requestId, modelId, methodName, metadata, status, false);
}
ReleaseAfterResponse.releaseAll();
clearThreadLocals();
//TODO(maybe) additional trailer info in exception case?
Expand All @@ -757,6 +789,35 @@ public void onHalfClose() {
}
}

/**
* Invoke PayloadProcessor on the request/response data
* @param data the binary data
* @param payloadId the id of the request
* @param modelId the id of the model
* @param methodName the name of the invoked method
* @param metadata the method name metadata
* @param status null for requests, non-null for responses
* @param takeOwnership whether the processor should take ownership
*/
private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName,
Metadata metadata, io.grpc.Status status, boolean takeOwnership) {
Payload payload = null;
try {
assert payloadProcessor != null;
if (!takeOwnership) {
data.retain();
}
payload = new Payload(payloadId, modelId, methodName, metadata, data, status);
if (payloadProcessor.process(payload)) {
data = null; // ownership transferred
}
} catch (Throwable t) {
logger.warn("Error while processing payload: {}", payload, t);
} finally {
ReferenceCountUtil.release(data);
}
}

@Override
public void onComplete() {
releaseReqMessage();
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public final class ModelMeshEnvVars {

private ModelMeshEnvVars() {}

public static final String MM_PAYLOAD_PROCESSORS = "MM_PAYLOAD_PROCESSORS";
public static final String MM_PAYLOAD_PROCESSORS_THREADS = "MM_PAYLOAD_PROCESSORS_THREADS";
public static final String MM_PAYLOAD_PROCESSORS_CAPACITY = "MM_PAYLOAD_PROCESSORS_CAPACITY";

// This must not be changed after model-mesh is already deployed to a particular env
public static final String KV_STORE_PREFIX = "MM_KVSTORE_PREFIX";

Expand Down
Loading

0 comments on commit eb384db

Please sign in to comment.