Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify saving #7

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions src/batch_processors/batchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import logging
import os
import pickle
from typing import Callable, Generic, Iterable, List, Optional, Tuple, TypeVar
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
)

from tqdm import tqdm

Expand Down Expand Up @@ -51,8 +61,7 @@ def __init__(
self.process_func = process_func
self.batch_size = batch_size
self.pickle_file = pickle_file
self.processed_items: List[T] = []
self.results: List[R] = []
self.processed_items: Dict[str, R] = {}
self.recover_from_checkpoint = recover_from_checkpoint
self.use_tqdm = use_tqdm

Expand Down Expand Up @@ -97,15 +106,16 @@ def process_batch(self, batch: List[T], batch_number: int, total_jobs: int):
total_jobs (int): The total number of jobs to process.
"""
if self.use_tqdm:
batch_results = [
self.process_item(i, item)
batch_results = {
str(item): self.process_item(i, item)
for i, item in enumerate(tqdm(batch, desc=f"Batch {batch_number}"))
]
}
else:
batch_results = [self.process_item(i, item) for i, item in enumerate(batch)]
batch_results = {
str(item): self.process_item(i, item) for i, item in enumerate(batch)
}

self.processed_items.extend(batch)
self.results.extend(batch_results)
self.processed_items.update(batch_results)

if self.pickle_file:
self.save_progress()
Expand All @@ -121,8 +131,7 @@ def load_progress(self):
if self.pickle_file and os.path.exists(self.pickle_file):
with open(self.pickle_file, "rb") as f:
data = pickle.load(f)
self.processed_items = data["processed_items"]
self.results = data["results"]
self.processed_items = data
self.logger.info(
f"Recovered {len(self.processed_items)} items from checkpoint"
)
Expand All @@ -137,13 +146,11 @@ def save_progress(self):
"""
with open(self.pickle_file, "wb") as f:
pickle.dump(
{"processed_items": self.processed_items, "results": self.results},
self.processed_items,
f,
)

def process_items_in_batches(
self, input_items: Iterable[T]
) -> Tuple[List[T], List[R]]:
def process_items_in_batches(self, input_items: Iterable[T]) -> Dict[str, R]:
"""
Process all input items in batches.

Expand All @@ -154,14 +161,17 @@ def process_items_in_batches(
Tuple[List[T], List[R]]: A tuple containing the list of processed items and their results.
"""
input_items = list(input_items) # Convert iterable to list
if self.recover_from_checkpoint:
recovered_items = set(self.processed_items.keys())
input_items = list(set(input_items) - recovered_items)

total_jobs = len(input_items)
start_index = len(self.processed_items) if self.recover_from_checkpoint else 0

for i in range(start_index, total_jobs, self.batch_size):
for i in range(0, total_jobs, self.batch_size):
batch = input_items[i : i + self.batch_size]
self.process_batch(batch, i // self.batch_size + 1, total_jobs)

return self.processed_items, self.results
return self.processed_items


class AsyncBatchProcessor(BatchProcessor[T, R]):
Expand All @@ -173,7 +183,7 @@ class AsyncBatchProcessor(BatchProcessor[T, R]):

def __init__(
self,
process_func: Callable[[T], R],
process_func: Callable[[T], Awaitable[R]],
batch_size: int = 100,
pickle_file: Optional[str] = None,
logfile: Optional[str] = None,
Expand All @@ -185,7 +195,7 @@ def __init__(
Initialize the AsyncBatchProcessor.

Args:
process_func (Callable[[T], R]): The function to process each item.
process_func (Callable[[T], Awaitable[R]]): The async function to process each item.
batch_size (int, optional): The number of items to process in each batch. Defaults to 100.
pickle_file (Optional[str], optional): The file to use for saving/loading progress. Defaults to None.
logfile (Optional[str], optional): The file to use for logging. Defaults to None.
Expand All @@ -205,7 +215,7 @@ def __init__(
asyncio.Semaphore(max_concurrent) if max_concurrent is not None else None
)

async def process_item(self, job_number: int, item: T) -> R:
async def process_item(self, job_number: int, item: T) -> Tuple[T, R]:
"""
Process a single item asynchronously.

Expand All @@ -214,13 +224,13 @@ async def process_item(self, job_number: int, item: T) -> R:
item (T): The item to process.

Returns:
R: The result of processing the item.
Tuple[T, R]: A tuple containing the input item and the result of processing it.
"""

async def _process():
result = await self.process_func(item)
self.logger.info(f"Processed job {job_number}: {item}")
return result
return item, result

if self.semaphore:
async with self.semaphore:
Expand All @@ -245,18 +255,17 @@ async def process_batch(self, batch: List[T], batch_number: int, total_jobs: int
for i, item in enumerate(batch)
]

batch_results = []
batch_results = {}
for task in asyncio.as_completed(tasks):
result = await task
batch_results.append(result)
item, result = await task
batch_results[str(item)] = result
if self.use_tqdm:
pbar.update(1)

if self.use_tqdm:
pbar.close()

self.processed_items.extend(batch)
self.results.extend(batch_results)
self.processed_items.update(batch_results)

if self.pickle_file:
self.save_progress()
Expand All @@ -265,24 +274,27 @@ async def process_batch(self, batch: List[T], batch_number: int, total_jobs: int
f"Batch {batch_number} completed. Total processed: {len(self.processed_items)}/{total_jobs}"
)

async def process_items_in_batches(
self, input_items: Iterable[T]
) -> Tuple[List[T], List[R]]:
async def process_items_in_batches(self, input_items: Iterable[T]) -> Dict[str, R]:
"""
Process all input items in batches asynchronously.

Args:
input_items (Iterable[T]): The items to process.

Returns:
Tuple[List[T], List[R]]: A tuple containing the list of processed items and their results.
Dict[str, R]: A dictionary containing the processed items and their results.
"""
input_items = list(input_items) # Convert iterable to list
if self.recover_from_checkpoint:
recovered_items = set(self.processed_items.keys())
input_items = [
item for item in input_items if str(item) not in recovered_items
]

total_jobs = len(input_items)
start_index = len(self.processed_items) if self.recover_from_checkpoint else 0

for i in range(start_index, total_jobs, self.batch_size):
for i in range(0, total_jobs, self.batch_size):
batch = input_items[i : i + self.batch_size]
await self.process_batch(batch, i // self.batch_size + 1, total_jobs)

return self.processed_items, self.results
return self.processed_items
28 changes: 16 additions & 12 deletions tests/test_batchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def test_batch_processor_process_items(input_data):
and returns the expected results.
"""
processor = BatchProcessor(sync_process_func, batch_size=10)
processed_items, results = processor.process_items_in_batches(input_data)
processed_items = processor.process_items_in_batches(input_data)

assert len(processed_items) == len(input_data)
assert len(results) == len(input_data)
assert all(result.startswith("Processed:") for result in results)
assert set(processed_items.keys()) == set(map(str, input_data))
assert all(keys in values for keys, values in processed_items.items())
assert all(result.startswith("Processed:") for result in processed_items.values())


def test_batch_processor_checkpoint(input_data, temp_pickle_file):
Expand All @@ -77,11 +78,12 @@ def test_batch_processor_checkpoint(input_data, temp_pickle_file):
pickle_file=temp_pickle_file,
recover_from_checkpoint=True,
)
processed_items, results = processor2.process_items_in_batches(input_data)
processed_items = processor2.process_items_in_batches(input_data)

assert len(processed_items) == len(input_data)
assert len(results) == len(input_data)
assert all(result.startswith("Processed:") for result in results)
assert set(processed_items.keys()) == set(map(str, input_data))
assert all(keys in values for keys, values in processed_items.items())
assert all(result.startswith("Processed:") for result in processed_items.values())


@pytest.mark.asyncio
Expand All @@ -93,11 +95,12 @@ async def test_async_batch_processor_process_items(input_data):
asynchronously and returns the expected results.
"""
processor = AsyncBatchProcessor(async_process_func, batch_size=10)
processed_items, results = await processor.process_items_in_batches(input_data)
processed_items = await processor.process_items_in_batches(input_data)

assert len(processed_items) == len(input_data)
assert len(results) == len(input_data)
assert all(result.startswith("Processed:") for result in results)
assert set(processed_items.keys()) == set(map(str, input_data))
assert all(keys in values for keys, values in processed_items.items())
assert all(result.startswith("Processed:") for result in processed_items.values())


@pytest.mark.asyncio
Expand Down Expand Up @@ -144,11 +147,12 @@ async def test_async_batch_processor_checkpoint(input_data, temp_pickle_file):
pickle_file=temp_pickle_file,
recover_from_checkpoint=True,
)
processed_items, results = await processor2.process_items_in_batches(input_data)
processed_items = await processor2.process_items_in_batches(input_data)

assert len(processed_items) == len(input_data)
assert len(results) == len(input_data)
assert all(result.startswith("Processed:") for result in results)
assert set(processed_items.keys()) == set(map(str, input_data))
assert all(keys in values for keys, values in processed_items.items())
assert all(result.startswith("Processed:") for result in processed_items.values())


def test_tqdm_usage(capsys, input_data):
Expand Down