Skip to content

Commit

Permalink
Avoid catch (Throwable t) in AmazonBedrockStreamingChatProcessor (e…
Browse files Browse the repository at this point in the history
…lastic#115715)

`CompletableFuture.runAsync` implicitly catches all `Throwable`
instances thrown by the task, which includes `Error` instances that no
reasonable application should catch. Moreover, discarding the return
value from these methods means that any such `Error` will be ignored,
allowing the JVM to carry on running in an invalid state.

This commit replaces these trappy calls with more appropriate exception
handling.
  • Loading branch information
DaveCTurner authored Oct 28, 2024
1 parent 03f2559 commit a1670c1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/115715.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115715
summary: Avoid `catch (Throwable t)` in `AmazonBedrockStreamingChatProcessor`
area: Machine Learning
type: bug
issues: []
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
requires org.slf4j;
requires software.amazon.awssdk.retries.api;
requires org.reactivestreams;
requires org.elasticsearch.logging;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;

import java.util.ArrayDeque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -27,6 +28,8 @@
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> {
private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingChatProcessor.class);

private final AtomicReference<Throwable> error = new AtomicReference<>(null);
private final AtomicLong demand = new AtomicLong(0);
private final AtomicBoolean isDone = new AtomicBoolean(false);
Expand Down Expand Up @@ -75,13 +78,13 @@ public void onNext(ConverseStreamOutput item) {

// this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response
private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) {
CompletableFuture.runAsync(() -> {
runOnUtilityThreadPool(() -> {
var text = event.delta().text();
var result = new ArrayDeque<StreamingChatCompletionResults.Result>(1);
result.offer(new StreamingChatCompletionResults.Result(text));
var results = new StreamingChatCompletionResults.Results(result);
downstream.onNext(results);
}, threadPool.executor(UTILITY_THREAD_POOL_NAME));
});
}

@Override
Expand All @@ -108,6 +111,14 @@ public void onComplete() {
}
}

private void runOnUtilityThreadPool(Runnable runnable) {
try {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(runnable);
} catch (Exception e) {
logger.error(Strings.format("failed to fork [%s] to utility thread pool", runnable), e);
}
}

private class StreamSubscription implements Flow.Subscription {
@Override
public void request(long n) {
Expand Down Expand Up @@ -142,7 +153,7 @@ private void requestOnMlThread(long n) {
if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) {
upstream.request(n);
} else {
CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME));
runOnUtilityThreadPool(() -> upstream.request(n));
}
}

Expand Down

0 comments on commit a1670c1

Please sign in to comment.