From d1b844082739b3da49cc85b562743a7cd0e52085 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Wed, 11 Sep 2024 11:52:03 -0400 Subject: [PATCH] reverted online loader to previous version --- flaxdiff/data/online_loader.py | 82 ++++++++-------------------------- setup.py | 2 +- 2 files changed, 20 insertions(+), 64 deletions(-) diff --git a/flaxdiff/data/online_loader.py b/flaxdiff/data/online_loader.py index 7ffc052..63ff0ab 100644 --- a/flaxdiff/data/online_loader.py +++ b/flaxdiff/data/online_loader.py @@ -26,7 +26,6 @@ USER_AGENT = get_datasets_user_agent() data_queue = Queue(16*2000) -error_queue = Queue(16*2000) def fetch_single_image(image_url, timeout=None, retries=0): @@ -78,12 +77,6 @@ def default_image_processor( return image, original_height, original_width -def default_feature_extractor(sample): - return { - "url": sample["url"], - "caption": sample["caption"], - } - def map_sample( url, caption, @@ -127,6 +120,14 @@ def map_sample( # }) pass + +def default_feature_extractor(sample): + return { + "url": sample["url"], + "caption": sample["caption"], + } + + def map_batch( batch, num_threads=256, image_shape=(256, 256), min_image_shape=(128, 128), @@ -140,14 +141,12 @@ def map_batch( map_sample, image_shape=image_shape, min_image_shape=min_image_shape, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, - downscale_interpolation=downscale_interpolation, - feature_extractor=feature_extractor + downscale_interpolation=downscale_interpolation ) - features = feature_extractor(batch) - url, caption = features["url"], features["caption"] with ThreadPoolExecutor(max_workers=num_threads) as executor: + features = feature_extractor(batch) + url, caption = features["url"], features["caption"] executor.map(map_sample_fn, url, caption) - return None except Exception as e: print(f"Error maping batch", e) traceback.print_exc() @@ -155,40 +154,8 @@ def map_batch( # "batch": batch, # "error": str(e) # }) - return e - - -# def map_batch_repeat_forever( -# batch, num_threads=256, image_shape=(256, 256), -# min_image_shape=(128, 128), -# timeout=15, retries=3, image_processor=default_image_processor, -# upscale_interpolation=cv2.INTER_CUBIC, -# downscale_interpolation=cv2.INTER_AREA, -# feature_extractor=default_feature_extractor, -# ): -# while True: # Repeat forever -# try: -# map_sample_fn = partial( -# map_sample, image_shape=image_shape, min_image_shape=min_image_shape, -# timeout=timeout, retries=retries, image_processor=image_processor, -# upscale_interpolation=upscale_interpolation, -# downscale_interpolation=downscale_interpolation, -# feature_extractor=feature_extractor -# ) -# features = feature_extractor(batch) -# url, caption = features["url"], features["caption"] -# with ThreadPoolExecutor(max_workers=num_threads) as executor: -# executor.map(map_sample_fn, url, caption) -# # Shuffle the batch -# batch = batch.shuffle(seed=np.random.randint(0, 1000000)) -# except Exception as e: -# print(f"Error maping batch", e) -# traceback.print_exc() -# # error_queue.put_nowait({ -# # "batch": batch, -# # "error": str(e) -# # }) -# pass + pass + def parallel_image_loader( dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), @@ -197,11 +164,9 @@ def parallel_image_loader( upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, feature_extractor=default_feature_extractor, - map_batch_fn=map_batch, - ): map_batch_fn = partial( - map_batch_fn, num_threads=num_threads, image_shape=image_shape, + map_batch, num_threads=num_threads, image_shape=image_shape, min_image_shape=min_image_shape, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, @@ -209,23 +174,18 @@ def parallel_image_loader( feature_extractor=feature_extractor ) shard_len = len(dataset) // num_workers - print(f"Local Shard lengths: {shard_len}, workers: {num_workers}") + print(f"Local Shard lengths: {shard_len}") with multiprocessing.Pool(num_workers) as pool: iteration = 0 while True: # Repeat forever shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)] - # shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)] print(f"mapping {len(shards)} shards") - errors = pool.map(map_batch_fn, shards) - for error in errors: - if error is not None: - print(f"Error in mapping batch", error) + pool.map(map_batch_fn, shards) iteration += 1 print(f"Shuffling dataset with seed {iteration}") dataset = dataset.shuffle(seed=iteration) - print(f"Dataset shuffled") # Clear the error queue # while not error_queue.empty(): # error_queue.get_nowait() @@ -240,7 +200,6 @@ def __init__( upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, feature_extractor=default_feature_extractor, - map_batch_fn=map_batch, ): self.dataset = dataset self.num_workers = num_workers @@ -255,8 +214,7 @@ def __init__( image_processor=image_processor, upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation, - feature_extractor=feature_extractor, - map_batch_fn=map_batch_fn, + feature_extractor=feature_extractor ) self.thread = threading.Thread(target=loader, args=(dataset,)) self.thread.start() @@ -323,7 +281,6 @@ def __init__( upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, feature_extractor=default_feature_extractor, - map_batch_fn=map_batch, ): if isinstance(dataset, str): dataset_path = dataset @@ -351,8 +308,7 @@ def __init__( timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation, - feature_extractor=feature_extractor, - map_batch_fn=map_batch_fn, + feature_extractor=feature_extractor ) self.batch_size = batch_size @@ -364,7 +320,7 @@ def batch_loader(): try: self.batch_queue.put(collate_fn(batch)) except Exception as e: - print("Error collating batch", e) + print("Error processing batch", e) self.loader_thread = threading.Thread(target=batch_loader) self.loader_thread.start() diff --git a/setup.py b/setup.py index c99d2cc..28c047e 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.35.2', + version='0.1.35.3', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',