Skip to content

Commit

Permalink
Shift all batching/concurrency decisions into the NimClient
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 committed Feb 8, 2025
1 parent ef99c57 commit 15fbb80
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 328 deletions.
68 changes: 28 additions & 40 deletions src/nv_ingest/extraction_workflows/image/image_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import numpy as np
from PIL import Image
from math import log
from wand.image import Image as WandImage

import nv_ingest.util.nim.yolox as yolox_utils
Expand Down Expand Up @@ -186,16 +185,16 @@ def extract_tables_and_charts_from_images(
----------
images : List[np.ndarray]
List of images in NumPy array format.
config : PDFiumConfigSchema
config : ImageConfigSchema
Configuration object containing YOLOX endpoints, auth token, etc.
trace_info : Optional[List], optional
Optional tracing data for debugging/performance profiling.
Returns
-------
List[Tuple[int, object]]
A list of (image_index, CroppedImageWithContent)
representing extracted table/chart data from each image.
A list of (image_index, CroppedImageWithContent) representing extracted
table/chart data from each image.
"""
tables_and_charts = []
yolox_client = None
Expand All @@ -209,41 +208,31 @@ def extract_tables_and_charts_from_images(
config.yolox_infer_protocol,
)

max_batch_size = YOLOX_MAX_BATCH_SIZE
batches = []
i = 0
while i < len(images):
batch_size = min(2 ** int(log(len(images) - i, 2)), max_batch_size)
batches.append(images[i : i + batch_size]) # noqa: E203
i += batch_size

img_index = 0
for batch in batches:
data = {"images": batch}

# NimClient inference
inference_results = yolox_client.infer(
data,
model_name="yolox",
max_batch_size=YOLOX_MAX_BATCH_SIZE,
num_classes=YOLOX_NUM_CLASSES,
conf_thresh=YOLOX_CONF_THRESHOLD,
iou_thresh=YOLOX_IOU_THRESHOLD,
min_score=YOLOX_MIN_SCORE,
final_thresh=YOLOX_FINAL_SCORE,
trace_info=trace_info, # traceable_func arg
stage_name="pdf_content_extractor", # traceable_func arg
)
# Prepare the payload with all images.
data = {"images": images}

# Perform inference in a single call. The NimClient handles batching internally.
inference_results = yolox_client.infer(
data,
model_name="yolox",
max_batch_size=YOLOX_MAX_BATCH_SIZE,
num_classes=YOLOX_NUM_CLASSES,
conf_thresh=YOLOX_CONF_THRESHOLD,
iou_thresh=YOLOX_IOU_THRESHOLD,
min_score=YOLOX_MIN_SCORE,
final_thresh=YOLOX_FINAL_SCORE,
trace_info=trace_info,
stage_name="pdf_content_extractor",
)

# 5) Extract table/chart info from each image's annotations
for annotation_dict, original_image in zip(inference_results, batch):
extract_table_and_chart_images(
annotation_dict,
original_image,
img_index,
tables_and_charts,
)
img_index += 1
# Process each result along with its corresponding image.
for i, (annotation_dict, original_image) in enumerate(zip(inference_results, images)):
extract_table_and_chart_images(
annotation_dict,
original_image,
i,
tables_and_charts,
)

except TimeoutError:
logger.error("Timeout error during table/chart extraction.")
Expand All @@ -252,14 +241,13 @@ def extract_tables_and_charts_from_images(
except Exception as e:
logger.error(f"Unhandled error during table/chart extraction: {str(e)}")
traceback.print_exc()
raise e
raise

finally:
if yolox_client:
yolox_client.close()

logger.debug(f"Extracted {len(tables_and_charts)} tables and charts from image.")

return tables_and_charts


Expand Down
92 changes: 48 additions & 44 deletions src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import concurrent.futures
import logging
import traceback
from math import log
from typing import List
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -61,55 +60,61 @@ def extract_tables_and_charts_using_image_ensemble(
pages: List[Tuple[int, np.ndarray]],
config: PDFiumConfigSchema,
trace_info: Optional[List] = None,
) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]]
) -> List[Tuple[int, object]]:
"""
Given a list of (page_index, image) tuples, this function calls the YOLOX-based
inference service to extract table and chart annotations from all pages.
The NimClient is now responsible for handling batching and concurrency internally.
For each page, the output is processed and the result is added to tables_and_charts.
Returns
-------
List[Tuple[int, object]]
For each page, returns (page_index, joined_content) where joined_content
is the result of combining annotations from the inference.
"""
tables_and_charts = []
yolox_client = None

try:
model_interface = yolox_utils.YoloxPageElementsModelInterface()
yolox_client = create_inference_client(
config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol
config.yolox_endpoints,
model_interface,
config.auth_token,
config.yolox_infer_protocol,
)

batches = []
i = 0
max_batch_size = YOLOX_MAX_BATCH_SIZE
while i < len(pages):
batch_size = min(2 ** int(log(len(pages) - i, 2)), max_batch_size)
batches.append(pages[i : i + batch_size]) # noqa: E203
i += batch_size

page_index = 0
for batch in batches:
image_page_indices = [page[0] for page in batch]
original_images = [page[1] for page in batch]

# Prepare data
data = {"images": original_images}

# Perform inference using NimClient
inference_results = yolox_client.infer(
data,
model_name="yolox",
max_batch_size=YOLOX_MAX_BATCH_SIZE,
num_classes=YOLOX_NUM_CLASSES,
conf_thresh=YOLOX_CONF_THRESHOLD,
iou_thresh=YOLOX_IOU_THRESHOLD,
min_score=YOLOX_MIN_SCORE,
final_thresh=YOLOX_FINAL_SCORE,
trace_info=trace_info, # traceable_func arg
stage_name="pdf_content_extractor", # traceable_func arg
)
# Collect all page indices and images in order.
image_page_indices = [page[0] for page in pages]
original_images = [page[1] for page in pages]

# Prepare the data payload with all images.
data = {"images": original_images}

# Perform inference using the NimClient.
inference_results = yolox_client.infer(
data,
model_name="yolox",
max_batch_size=YOLOX_MAX_BATCH_SIZE,
num_classes=YOLOX_NUM_CLASSES,
conf_thresh=YOLOX_CONF_THRESHOLD,
iou_thresh=YOLOX_IOU_THRESHOLD,
min_score=YOLOX_MIN_SCORE,
final_thresh=YOLOX_FINAL_SCORE,
trace_info=trace_info,
stage_name="pdf_content_extractor",
)

# Process results
for annotation_dict, page_index, original_image in zip(
inference_results, image_page_indices, original_images
):
extract_table_and_chart_images(
annotation_dict,
original_image,
page_index,
tables_and_charts,
)
# Process results: iterate over each image's inference output.
for annotation_dict, page_index, original_image in zip(inference_results, image_page_indices, original_images):
extract_table_and_chart_images(
annotation_dict,
original_image,
page_index,
tables_and_charts,
)

except TimeoutError:
logger.error("Timeout error during table/chart extraction.")
Expand All @@ -118,14 +123,13 @@ def extract_tables_and_charts_using_image_ensemble(
except Exception as e:
logger.error(f"Unhandled error during table/chart extraction: {str(e)}")
traceback.print_exc()
raise e
raise

finally:
if yolox_client:
yolox_client.close()

logger.debug(f"Extracted {len(tables_and_charts)} tables and charts.")

return tables_and_charts


Expand Down
140 changes: 49 additions & 91 deletions src/nv_ingest/stages/nim/chart_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import pandas as pd
from morpheus.config import Config
from concurrent.futures import ThreadPoolExecutor

from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
Expand All @@ -29,103 +28,62 @@ def _update_metadata(
cached_client: NimClient,
deplot_client: NimClient,
trace_info: Dict,
batch_size: int = 1,
worker_pool_size: int = 1,
batch_size: int = 1, # No longer used
worker_pool_size: int = 1, # No longer used
) -> List[Tuple[str, Dict]]:
"""
Given a list of base64-encoded chart images, this function:
- Splits them into batches of size `batch_size`.
- Calls Cached with *all images* in each batch in a single request if protocol != 'grpc'.
If protocol == 'grpc', calls Cached individually for each image in the batch.
- Calls Deplot individually (one request per image) in parallel.
- Joins the results for each image into a final combined inference result.
Given a list of base64-encoded chart images, this function calls both the Cached and Deplot
inference services to extract chart data for all images. The NimClient implementations are
responsible for handling batching and concurrency internally.
Returns
-------
List[Tuple[str, Dict]]
For each base64-encoded image, returns (original_image_str, joined_chart_content_dict).
For each base64-encoded image, returns:
(original_image_str, joined_chart_content_dict)
"""
logger.debug(f"Running chart extraction: batch_size={batch_size}, worker_pool_size={worker_pool_size}")
logger.debug("Running chart extraction using updated concurrency handling.")

def chunk_list(lst, chunk_size):
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
# Prepare data payloads for both clients. We assume that both clients now support receiving
# a list of images via the "base64_images" key.
data_cached = {"base64_images": base64_images}
data_deplot = {"base64_images": base64_images}

results = []
try:
cached_results = cached_client.infer(
data=data_cached,
model_name="cached",
stage_name="chart_data_extraction",
max_batch_size=len(base64_images),
trace_info=trace_info,
)
except Exception as e:
logger.error(f"Error calling cached_client.infer: {e}", exc_info=True)
raise

with ThreadPoolExecutor(max_workers=worker_pool_size) as executor:
for batch in chunk_list(base64_images, batch_size):
# 1) Cached calls
if cached_client.protocol == "grpc":
# Submit each image in the batch separately
cached_futures = []
for image_str in batch:
data = {"base64_images": [image_str]}
fut = executor.submit(
cached_client.infer,
data=data,
model_name="cached",
stage_name="chart_data_extraction",
max_batch_size=1,
trace_info=trace_info,
)
cached_futures.append(fut)
else:
# Single request for the entire batch
data = {"base64_images": batch}
future_cached = executor.submit(
cached_client.infer,
data=data,
model_name="cached",
stage_name="chart_data_extraction",
max_batch_size=batch_size,
trace_info=trace_info,
)

# 2) Multiple calls to Deplot, one per image in the batch
deplot_futures = []
for image_str in batch:
# Deplot only supports single-image calls
deplot_data = {"base64_image": image_str}
fut = executor.submit(
deplot_client.infer,
data=deplot_data,
model_name="deplot",
stage_name="chart_data_extraction",
max_batch_size=1,
trace_info=trace_info,
)
deplot_futures.append(fut)

try:
# 3) Retrieve results from Cached
if cached_client.protocol == "grpc":
# Each future should return a single-element list
# We take the 0th item to align with single-image results
cached_results = []
for fut in cached_futures:
res = fut.result()
if isinstance(res, list) and len(res) == 1:
cached_results.append(res[0])
else:
# Fallback in case the service returns something unexpected
logger.warning(f"Unexpected CACHED result format: {res}")
cached_results.append(res)
else:
# Single call returning a list of the same length as 'batch'
cached_results = future_cached.result()

# Retrieve results from Deplot (each call returns a single inference result)
deplot_results = [f.result() for f in deplot_futures]

# 4) Zip them together, one by one
for img_str, cached_res, deplot_res in zip(batch, cached_results, deplot_results):
chart_content = join_cached_and_deplot_output(cached_res, deplot_res)
results.append((img_str, chart_content))

except Exception as e:
logger.error(f"Error processing batch: {batch}, error: {e}", exc_info=True)
raise
try:
deplot_results = deplot_client.infer(
data=data_deplot,
model_name="deplot",
stage_name="chart_data_extraction",
max_batch_size=len(base64_images),
trace_info=trace_info,
)
except Exception as e:
logger.error(f"Error calling deplot_client.infer: {e}", exc_info=True)
raise

# Ensure both clients returned lists of results matching the number of input images.
if not (isinstance(cached_results, list) and isinstance(deplot_results, list)):
raise ValueError("Expected list results from both cached_client and deplot_client infer calls.")

if len(cached_results) != len(base64_images):
raise ValueError(f"Expected {len(base64_images)} cached results, got {len(cached_results)}")
if len(deplot_results) != len(base64_images):
raise ValueError(f"Expected {len(base64_images)} deplot results, got {len(deplot_results)}")

# Join the corresponding results from both services for each image.
results = []
for img_str, cached_res, deplot_res in zip(base64_images, cached_results, deplot_results):
joined_chart_content = join_cached_and_deplot_output(cached_res, deplot_res)
results.append((img_str, joined_chart_content))

return results

Expand Down
Loading

0 comments on commit 15fbb80

Please sign in to comment.