From a1670c18bfeac4a14057130d328053ecc3c8fe71 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 28 Oct 2024 15:32:00 +0000 Subject: [PATCH] Avoid `catch (Throwable t)` in `AmazonBedrockStreamingChatProcessor` (#115715) `CompletableFuture.runAsync` implicitly catches all `Throwable` instances thrown by the task, which includes `Error` instances that no reasonable application should catch. Moreover, discarding the return value from these methods means that any such `Error` will be ignored, allowing the JVM to carry on running in an invalid state. This commit replaces these trappy calls with more appropriate exception handling. --- docs/changelog/115715.yaml | 5 +++++ .../inference/src/main/java/module-info.java | 1 + .../AmazonBedrockStreamingChatProcessor.java | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 docs/changelog/115715.yaml diff --git a/docs/changelog/115715.yaml b/docs/changelog/115715.yaml new file mode 100644 index 0000000000000..378f2c42e5e50 --- /dev/null +++ b/docs/changelog/115715.yaml @@ -0,0 +1,5 @@ +pr: 115715 +summary: Avoid `catch (Throwable t)` in `AmazonBedrockStreamingChatProcessor` +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 60cb254e0afbe..53974657e4e23 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -33,6 +33,7 @@ requires org.slf4j; requires software.amazon.awssdk.retries.api; requires org.reactivestreams; + requires org.elasticsearch.logging; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java index 439fc5b65efd5..12f394e300e0f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java @@ -14,11 +14,12 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import java.util.ArrayDeque; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -27,6 +28,8 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; class AmazonBedrockStreamingChatProcessor implements Flow.Processor { + private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingChatProcessor.class); + private final AtomicReference error = new AtomicReference<>(null); private final AtomicLong demand = new AtomicLong(0); private final AtomicBoolean isDone = new AtomicBoolean(false); @@ -75,13 +78,13 @@ public void onNext(ConverseStreamOutput item) { // this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) { - CompletableFuture.runAsync(() -> { + runOnUtilityThreadPool(() -> { var text = event.delta().text(); var result = new ArrayDeque(1); result.offer(new StreamingChatCompletionResults.Result(text)); var results = new StreamingChatCompletionResults.Results(result); downstream.onNext(results); - }, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + }); } @Override @@ -108,6 +111,14 @@ public void onComplete() { } } + private void runOnUtilityThreadPool(Runnable runnable) { + try { + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(runnable); + } catch (Exception e) { + logger.error(Strings.format("failed to fork [%s] to utility thread pool", runnable), e); + } + } + private class StreamSubscription implements Flow.Subscription { @Override public void request(long n) { @@ -142,7 +153,7 @@ private void requestOnMlThread(long n) { if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) { upstream.request(n); } else { - CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME)); + runOnUtilityThreadPool(() -> upstream.request(n)); } }