From 15fbb800e8a74771f631cf511f7a11b61ba7bf3e Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Sat, 8 Feb 2025 10:45:27 -0700 Subject: [PATCH] Shift all batching/concurrency decisions into the NimClient --- .../image/image_handlers.py | 68 +++---- .../extraction_workflows/pdf/pdfium_helper.py | 92 +++++----- src/nv_ingest/stages/nim/chart_extraction.py | 140 +++++---------- src/nv_ingest/stages/nim/table_extraction.py | 169 +++++------------- src/nv_ingest/util/nim/helpers.py | 83 +++++---- 5 files changed, 224 insertions(+), 328 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py index 599052ee..9accecf3 100644 --- a/src/nv_ingest/extraction_workflows/image/image_handlers.py +++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py @@ -27,7 +27,6 @@ import numpy as np from PIL import Image -from math import log from wand.image import Image as WandImage import nv_ingest.util.nim.yolox as yolox_utils @@ -186,7 +185,7 @@ def extract_tables_and_charts_from_images( ---------- images : List[np.ndarray] List of images in NumPy array format. - config : PDFiumConfigSchema + config : ImageConfigSchema Configuration object containing YOLOX endpoints, auth token, etc. trace_info : Optional[List], optional Optional tracing data for debugging/performance profiling. @@ -194,8 +193,8 @@ def extract_tables_and_charts_from_images( Returns ------- List[Tuple[int, object]] - A list of (image_index, CroppedImageWithContent) - representing extracted table/chart data from each image. + A list of (image_index, CroppedImageWithContent) representing extracted + table/chart data from each image. """ tables_and_charts = [] yolox_client = None @@ -209,41 +208,31 @@ def extract_tables_and_charts_from_images( config.yolox_infer_protocol, ) - max_batch_size = YOLOX_MAX_BATCH_SIZE - batches = [] - i = 0 - while i < len(images): - batch_size = min(2 ** int(log(len(images) - i, 2)), max_batch_size) - batches.append(images[i : i + batch_size]) # noqa: E203 - i += batch_size - - img_index = 0 - for batch in batches: - data = {"images": batch} - - # NimClient inference - inference_results = yolox_client.infer( - data, - model_name="yolox", - max_batch_size=YOLOX_MAX_BATCH_SIZE, - num_classes=YOLOX_NUM_CLASSES, - conf_thresh=YOLOX_CONF_THRESHOLD, - iou_thresh=YOLOX_IOU_THRESHOLD, - min_score=YOLOX_MIN_SCORE, - final_thresh=YOLOX_FINAL_SCORE, - trace_info=trace_info, # traceable_func arg - stage_name="pdf_content_extractor", # traceable_func arg - ) + # Prepare the payload with all images. + data = {"images": images} + + # Perform inference in a single call. The NimClient handles batching internally. + inference_results = yolox_client.infer( + data, + model_name="yolox", + max_batch_size=YOLOX_MAX_BATCH_SIZE, + num_classes=YOLOX_NUM_CLASSES, + conf_thresh=YOLOX_CONF_THRESHOLD, + iou_thresh=YOLOX_IOU_THRESHOLD, + min_score=YOLOX_MIN_SCORE, + final_thresh=YOLOX_FINAL_SCORE, + trace_info=trace_info, + stage_name="pdf_content_extractor", + ) - # 5) Extract table/chart info from each image's annotations - for annotation_dict, original_image in zip(inference_results, batch): - extract_table_and_chart_images( - annotation_dict, - original_image, - img_index, - tables_and_charts, - ) - img_index += 1 + # Process each result along with its corresponding image. + for i, (annotation_dict, original_image) in enumerate(zip(inference_results, images)): + extract_table_and_chart_images( + annotation_dict, + original_image, + i, + tables_and_charts, + ) except TimeoutError: logger.error("Timeout error during table/chart extraction.") @@ -252,14 +241,13 @@ def extract_tables_and_charts_from_images( except Exception as e: logger.error(f"Unhandled error during table/chart extraction: {str(e)}") traceback.print_exc() - raise e + raise finally: if yolox_client: yolox_client.close() logger.debug(f"Extracted {len(tables_and_charts)} tables and charts from image.") - return tables_and_charts diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index a18c033c..2ba35ec1 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -19,7 +19,6 @@ import concurrent.futures import logging import traceback -from math import log from typing import List from typing import Optional from typing import Tuple @@ -61,55 +60,61 @@ def extract_tables_and_charts_using_image_ensemble( pages: List[Tuple[int, np.ndarray]], config: PDFiumConfigSchema, trace_info: Optional[List] = None, -) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]] +) -> List[Tuple[int, object]]: + """ + Given a list of (page_index, image) tuples, this function calls the YOLOX-based + inference service to extract table and chart annotations from all pages. + + The NimClient is now responsible for handling batching and concurrency internally. + For each page, the output is processed and the result is added to tables_and_charts. + + Returns + ------- + List[Tuple[int, object]] + For each page, returns (page_index, joined_content) where joined_content + is the result of combining annotations from the inference. + """ tables_and_charts = [] + yolox_client = None try: model_interface = yolox_utils.YoloxPageElementsModelInterface() yolox_client = create_inference_client( - config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol + config.yolox_endpoints, + model_interface, + config.auth_token, + config.yolox_infer_protocol, ) - batches = [] - i = 0 - max_batch_size = YOLOX_MAX_BATCH_SIZE - while i < len(pages): - batch_size = min(2 ** int(log(len(pages) - i, 2)), max_batch_size) - batches.append(pages[i : i + batch_size]) # noqa: E203 - i += batch_size - - page_index = 0 - for batch in batches: - image_page_indices = [page[0] for page in batch] - original_images = [page[1] for page in batch] - - # Prepare data - data = {"images": original_images} - - # Perform inference using NimClient - inference_results = yolox_client.infer( - data, - model_name="yolox", - max_batch_size=YOLOX_MAX_BATCH_SIZE, - num_classes=YOLOX_NUM_CLASSES, - conf_thresh=YOLOX_CONF_THRESHOLD, - iou_thresh=YOLOX_IOU_THRESHOLD, - min_score=YOLOX_MIN_SCORE, - final_thresh=YOLOX_FINAL_SCORE, - trace_info=trace_info, # traceable_func arg - stage_name="pdf_content_extractor", # traceable_func arg - ) + # Collect all page indices and images in order. + image_page_indices = [page[0] for page in pages] + original_images = [page[1] for page in pages] + + # Prepare the data payload with all images. + data = {"images": original_images} + + # Perform inference using the NimClient. + inference_results = yolox_client.infer( + data, + model_name="yolox", + max_batch_size=YOLOX_MAX_BATCH_SIZE, + num_classes=YOLOX_NUM_CLASSES, + conf_thresh=YOLOX_CONF_THRESHOLD, + iou_thresh=YOLOX_IOU_THRESHOLD, + min_score=YOLOX_MIN_SCORE, + final_thresh=YOLOX_FINAL_SCORE, + trace_info=trace_info, + stage_name="pdf_content_extractor", + ) - # Process results - for annotation_dict, page_index, original_image in zip( - inference_results, image_page_indices, original_images - ): - extract_table_and_chart_images( - annotation_dict, - original_image, - page_index, - tables_and_charts, - ) + # Process results: iterate over each image's inference output. + for annotation_dict, page_index, original_image in zip(inference_results, image_page_indices, original_images): + extract_table_and_chart_images( + annotation_dict, + original_image, + page_index, + tables_and_charts, + ) except TimeoutError: logger.error("Timeout error during table/chart extraction.") @@ -118,14 +123,13 @@ def extract_tables_and_charts_using_image_ensemble( except Exception as e: logger.error(f"Unhandled error during table/chart extraction: {str(e)}") traceback.print_exc() - raise e + raise finally: if yolox_client: yolox_client.close() logger.debug(f"Extracted {len(tables_and_charts)} tables and charts.") - return tables_and_charts diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index 5f487149..748eb412 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -11,7 +11,6 @@ import pandas as pd from morpheus.config import Config -from concurrent.futures import ThreadPoolExecutor from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage @@ -29,103 +28,62 @@ def _update_metadata( cached_client: NimClient, deplot_client: NimClient, trace_info: Dict, - batch_size: int = 1, - worker_pool_size: int = 1, + batch_size: int = 1, # No longer used + worker_pool_size: int = 1, # No longer used ) -> List[Tuple[str, Dict]]: """ - Given a list of base64-encoded chart images, this function: - - Splits them into batches of size `batch_size`. - - Calls Cached with *all images* in each batch in a single request if protocol != 'grpc'. - If protocol == 'grpc', calls Cached individually for each image in the batch. - - Calls Deplot individually (one request per image) in parallel. - - Joins the results for each image into a final combined inference result. + Given a list of base64-encoded chart images, this function calls both the Cached and Deplot + inference services to extract chart data for all images. The NimClient implementations are + responsible for handling batching and concurrency internally. - Returns - ------- - List[Tuple[str, Dict]] - For each base64-encoded image, returns (original_image_str, joined_chart_content_dict). + For each base64-encoded image, returns: + (original_image_str, joined_chart_content_dict) """ - logger.debug(f"Running chart extraction: batch_size={batch_size}, worker_pool_size={worker_pool_size}") + logger.debug("Running chart extraction using updated concurrency handling.") - def chunk_list(lst, chunk_size): - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] + # Prepare data payloads for both clients. We assume that both clients now support receiving + # a list of images via the "base64_images" key. + data_cached = {"base64_images": base64_images} + data_deplot = {"base64_images": base64_images} - results = [] + try: + cached_results = cached_client.infer( + data=data_cached, + model_name="cached", + stage_name="chart_data_extraction", + max_batch_size=len(base64_images), + trace_info=trace_info, + ) + except Exception as e: + logger.error(f"Error calling cached_client.infer: {e}", exc_info=True) + raise - with ThreadPoolExecutor(max_workers=worker_pool_size) as executor: - for batch in chunk_list(base64_images, batch_size): - # 1) Cached calls - if cached_client.protocol == "grpc": - # Submit each image in the batch separately - cached_futures = [] - for image_str in batch: - data = {"base64_images": [image_str]} - fut = executor.submit( - cached_client.infer, - data=data, - model_name="cached", - stage_name="chart_data_extraction", - max_batch_size=1, - trace_info=trace_info, - ) - cached_futures.append(fut) - else: - # Single request for the entire batch - data = {"base64_images": batch} - future_cached = executor.submit( - cached_client.infer, - data=data, - model_name="cached", - stage_name="chart_data_extraction", - max_batch_size=batch_size, - trace_info=trace_info, - ) - - # 2) Multiple calls to Deplot, one per image in the batch - deplot_futures = [] - for image_str in batch: - # Deplot only supports single-image calls - deplot_data = {"base64_image": image_str} - fut = executor.submit( - deplot_client.infer, - data=deplot_data, - model_name="deplot", - stage_name="chart_data_extraction", - max_batch_size=1, - trace_info=trace_info, - ) - deplot_futures.append(fut) - - try: - # 3) Retrieve results from Cached - if cached_client.protocol == "grpc": - # Each future should return a single-element list - # We take the 0th item to align with single-image results - cached_results = [] - for fut in cached_futures: - res = fut.result() - if isinstance(res, list) and len(res) == 1: - cached_results.append(res[0]) - else: - # Fallback in case the service returns something unexpected - logger.warning(f"Unexpected CACHED result format: {res}") - cached_results.append(res) - else: - # Single call returning a list of the same length as 'batch' - cached_results = future_cached.result() - - # Retrieve results from Deplot (each call returns a single inference result) - deplot_results = [f.result() for f in deplot_futures] - - # 4) Zip them together, one by one - for img_str, cached_res, deplot_res in zip(batch, cached_results, deplot_results): - chart_content = join_cached_and_deplot_output(cached_res, deplot_res) - results.append((img_str, chart_content)) - - except Exception as e: - logger.error(f"Error processing batch: {batch}, error: {e}", exc_info=True) - raise + try: + deplot_results = deplot_client.infer( + data=data_deplot, + model_name="deplot", + stage_name="chart_data_extraction", + max_batch_size=len(base64_images), + trace_info=trace_info, + ) + except Exception as e: + logger.error(f"Error calling deplot_client.infer: {e}", exc_info=True) + raise + + # Ensure both clients returned lists of results matching the number of input images. + if not (isinstance(cached_results, list) and isinstance(deplot_results, list)): + raise ValueError("Expected list results from both cached_client and deplot_client infer calls.") + + if len(cached_results) != len(base64_images): + raise ValueError(f"Expected {len(base64_images)} cached results, got {len(cached_results)}") + if len(deplot_results) != len(base64_images): + raise ValueError(f"Expected {len(base64_images)} deplot results, got {len(deplot_results)}") + + # Join the corresponding results from both services for each image. + results = [] + for img_str, cached_res, deplot_res in zip(base64_images, cached_results, deplot_results): + joined_chart_content = join_cached_and_deplot_output(cached_res, deplot_res) + results.append((img_str, joined_chart_content)) return results diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index dcc06bb1..e337b23f 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -5,7 +5,6 @@ import functools import logging from typing import Any, Dict, List, Optional, Tuple -from concurrent.futures import ThreadPoolExecutor import pandas as pd @@ -26,137 +25,67 @@ def _update_metadata( base64_images: List[str], paddle_client: NimClient, - batch_size: int = 1, - worker_pool_size: int = 1, + batch_size: int = 1, # No longer used + worker_pool_size: int = 1, # No longer used trace_info: Dict = None, ) -> List[Tuple[str, Tuple[Any, Any]]]: """ - Given a list of base64-encoded images, this function processes them either individually - (if paddle_client.protocol == 'grpc') or in batches (if paddle_client.protocol == 'http'), - then calls the PaddleOCR model to extract data. + Given a list of base64-encoded images, this function filters out images that do not meet the minimum + size requirements and then calls the PaddleOCR model via paddle_client.infer to extract table data. For each base64-encoded image, the result is: (base64_image, (table_content, table_content_format)) - Images that do not meet the minimum size are skipped (("", "")). + Images that do not meet the minimum size are skipped (resulting in ("", "") for that image). + The paddle_client is expected to handle any necessary batching and concurrency. """ - logger.debug( - f"Running table extraction: batch_size={batch_size}, " - f"worker_pool_size={worker_pool_size}, protocol={paddle_client.protocol}" - ) + logger.debug(f"Running table extraction using protocol {paddle_client.protocol}") - # We'll build the final results in the same order as base64_images. - # results[i] = (base64_images[i], (table_content, table_content_format)). + # Initialize the results list in the same order as base64_images. results: List[Optional[Tuple[str, Tuple[Any, Any]]]] = [None] * len(base64_images) - # Pre-decode dimensions once (optional, but efficient if we want to skip small images). - decoded_shapes = [] - for img in base64_images: + valid_images: List[str] = [] + valid_indices: List[int] = [] + + # Pre-decode image dimensions and filter valid images. + for i, img in enumerate(base64_images): array = base64_to_numpy(img) - decoded_shapes.append(array.shape) # e.g. (height, width, channels) - - # ------------------------------------------------ - # GRPC path: submit one request per valid image. - # ------------------------------------------------ - if paddle_client.protocol == "grpc": - with ThreadPoolExecutor(max_workers=worker_pool_size) as executor: - future_to_index = {} - - # Submit individual requests - for i, b64_image in enumerate(base64_images): - height, width = decoded_shapes[i][0], decoded_shapes[i][1] - if width < PADDLE_MIN_WIDTH or height < PADDLE_MIN_HEIGHT: - # Too small, skip inference - results[i] = (b64_image, ("", "")) - continue - - # Enqueue a single-image inference - data = {"base64_images": [b64_image]} # single item - future = executor.submit( - paddle_client.infer, - data=data, - model_name="paddle", - stage_name="table_data_extraction", - max_batch_size=1, - trace_info=trace_info, - ) - future_to_index[future] = i - - # Gather results - for future, i in future_to_index.items(): - b64_image = base64_images[i] - try: - paddle_result = future.result() - # We expect exactly one result for one image - if not isinstance(paddle_result, list) or len(paddle_result) != 1: - raise ValueError(f"Expected 1 result list, got: {paddle_result}") - table_content, table_format = paddle_result[0] - results[i] = (b64_image, (table_content, table_format)) - except Exception as e: - logger.error(f"Error processing image {i}. Error: {e}", exc_info=True) - results[i] = (b64_image, ("", "")) - raise - - # ------------------------------------------------ - # HTTP path: submit requests in batches. - # ------------------------------------------------ - else: - with ThreadPoolExecutor(max_workers=worker_pool_size) as executor: - # Process images in chunks - for start_idx in range(0, len(base64_images), batch_size): - chunk_indices = range(start_idx, min(start_idx + batch_size, len(base64_images))) - valid_indices = [] - valid_images = [] - - # Check dimensions & collect valid images - for i in chunk_indices: - height, width = decoded_shapes[i][0], decoded_shapes[i][1] - if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: - valid_indices.append(i) - valid_images.append(base64_images[i]) - else: - # Too small, skip inference - results[i] = (base64_images[i], ("", "")) - - if not valid_images: - # All images in this chunk were too small - continue - - # Submit a single batch inference - data = {"base64_images": valid_images} - future = executor.submit( - paddle_client.infer, - data=data, - model_name="paddle", - stage_name="table_data_extraction", - max_batch_size=batch_size, - trace_info=trace_info, - ) - - try: - # This should be a list of (table_content, table_content_format) - # in the same order as valid_images - paddle_result = future.result() - - if not isinstance(paddle_result, list): - raise ValueError(f"Expected a list of tuples, got {type(paddle_result)}") - - if len(paddle_result) != len(valid_images): - raise ValueError(f"Expected {len(valid_images)} results, got {len(paddle_result)}") - - # Match each result back to its original index - for idx_in_batch, (tc, tf) in enumerate(paddle_result): - i = valid_indices[idx_in_batch] - results[i] = (base64_images[i], (tc, tf)) - - except Exception as e: - logger.error(f"Error processing batch {valid_images}. Error: {e}", exc_info=True) - # If inference fails, we can fill them with empty or re-raise - for vi in valid_indices: - results[vi] = (base64_images[vi], ("", "")) - raise - - # 'results' now has an entry for every image in base64_images + height, width = array.shape[0], array.shape[1] + if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: + valid_images.append(img) + valid_indices.append(i) + else: + # Image is too small; mark as skipped. + results[i] = (img, ("", "")) + + if valid_images: + data = {"base64_images": valid_images} + try: + # Call infer once for all valid images. The NimClient will handle batching internally. + paddle_result = paddle_client.infer( + data=data, + model_name="paddle", + stage_name="table_data_extraction", + max_batch_size=len(valid_images), + trace_info=trace_info, + ) + + if not isinstance(paddle_result, list): + raise ValueError(f"Expected a list of tuples, got {type(paddle_result)}") + if len(paddle_result) != len(valid_images): + raise ValueError(f"Expected {len(valid_images)} results, got {len(paddle_result)}") + + # Assign each result back to its original position. + for idx, result in enumerate(paddle_result): + original_index = valid_indices[idx] + results[original_index] = (base64_images[original_index], result) + + except Exception as e: + logger.error(f"Error processing images {valid_images}. Error: {e}", exc_info=True) + for i in valid_indices: + results[i] = (base64_images[i], ("", "")) + raise + return results diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 1bdbc6b7..b7df0c74 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -6,6 +6,8 @@ import re import threading import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial from typing import Any from typing import Optional from typing import Tuple @@ -188,6 +190,39 @@ def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int: return self._max_batch_sizes[model_name] + def _process_batch(self, batch_input, *, prepared_data, model_name, **kwargs): + """ + Process a single batch input for inference. + + Parameters + ---------- + batch_input : Any + The batch input data to process. + prepared_data : Any + The prepared data used for inference. + model_name : str + The model name to use for inference. + kwargs : dict + Additional parameters for inference. + + Returns + ------- + Any + The parsed output from the inference request. + """ + if self.protocol == "grpc": + logger.debug("Performing gRPC inference for a batch...") + response = self._grpc_infer(batch_input, model_name) + logger.debug("gRPC inference received response for a batch") + elif self.protocol == "http": + logger.debug("Performing HTTP inference for a batch...") + response = self._http_infer(batch_input) + logger.debug("HTTP inference received response for a batch") + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + return self.model_interface.parse_output(response, protocol=self.protocol, data=prepared_data, **kwargs) + def try_set_max_batch_size(self, model_name, model_version: str = ""): """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety.""" self._fetch_max_batch_size(model_name, model_version) @@ -204,7 +239,8 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: model_name : str The name of the model to use for inference. kwargs : dict - Additional parameters for inference. + Additional parameters for inference. Optionally supports "max_pool_workers" to set + the number of worker threads in the thread pool. Returns ------- @@ -216,54 +252,38 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: ValueError If an invalid protocol is specified. """ - try: - # 1. Retrieve or default to the model's maximum batch size + # 1. Retrieve or default to the model's maximum batch size. batch_size = self._fetch_max_batch_size(model_name) max_requested_batch_size = kwargs.get("max_batch_size", batch_size) force_requested_batch_size = kwargs.get("force_max_batch_size", False) - # 1a. In some cases we can't use the absolute max batch size (or don't want to) so we allow override - # 1b. In some cases we can't reliably retrieve the max batch size so we default to 1 and allow forced - # override if not force_requested_batch_size: max_batch_size = min(batch_size, max_requested_batch_size) else: max_batch_size = max_requested_batch_size - # 2. Prepare data for inference + # 2. Prepare data for inference. prepared_data = self.model_interface.prepare_data_for_inference(data) - # 3. Format the input based on protocol + # 3. Format the input based on protocol. # NOTE: This now returns a list of batches. formatted_batches = self.model_interface.format_input( prepared_data, protocol=self.protocol, max_batch_size=max_batch_size ) - # Container for all parsed outputs - all_parsed_outputs = [] - - # 4. Loop over each batch - for batch_input in formatted_batches: - if self.protocol == "grpc": - logger.debug("Performing gRPC inference for a batch...") - response = self._grpc_infer(batch_input, model_name) - logger.debug("gRPC inference received response for a batch") - elif self.protocol == "http": - logger.debug("Performing HTTP inference for a batch...") - response = self._http_infer(batch_input) - logger.debug("HTTP inference received response for a batch") - else: - raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + # Check for a custom maximum pool worker count, and remove it from kwargs. + max_pool_workers = kwargs.pop("max_pool_workers", len(formatted_batches)) - # Parse the output of this batch - parsed_output = self.model_interface.parse_output( - response, protocol=self.protocol, data=prepared_data, **kwargs - ) - # Accumulate parsed outputs - all_parsed_outputs.append(parsed_output) + # 4. Process each batch concurrently using a thread pool. + process_batch_partial = partial( + self._process_batch, prepared_data=prepared_data, model_name=model_name, **kwargs + ) + + with ThreadPoolExecutor(max_workers=max_pool_workers) as executor: + all_parsed_outputs = list(executor.map(process_batch_partial, formatted_batches)) - # 5. Process the parsed outputs for each batch + # 5. Process the parsed outputs for each batch. all_results = [] for parsed_output in all_parsed_outputs: batch_results = self.model_interface.process_inference_results( @@ -272,8 +292,6 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: protocol=self.protocol, **kwargs, ) - # Extend or append based on how `batch_results` is structured - # (assuming it's a list of result items): if isinstance(batch_results, list): all_results.extend(batch_results) else: @@ -284,7 +302,6 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: logger.error(error_str) raise RuntimeError(error_str) - # 6. Return final accumulated results return all_results def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray: