diff --git a/flaxdiff/data/online_loader.py b/flaxdiff/data/online_loader.py index 23826d5..a1ea175 100644 --- a/flaxdiff/data/online_loader.py +++ b/flaxdiff/data/online_loader.py @@ -25,7 +25,6 @@ USER_AGENT = get_datasets_user_agent() data_queue = Queue(16*2000) -error_queue = Queue() def fetch_single_image(image_url, timeout=None, retries=0): @@ -60,6 +59,7 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC): def map_sample( url, caption, image_shape=(256, 256), + min_image_shape=(128, 128), timeout=15, retries=3, upscale_interpolation=cv2.INTER_CUBIC, @@ -75,10 +75,10 @@ def map_sample( image = np.array(image) original_height, original_width = image.shape[:2] # check if the image is too small - if min(original_height, original_width) < min(image_shape): + if min(original_height, original_width) < min(min_image_shape): return # check if wrong aspect ratio - if max(original_height, original_width) / min(original_height, original_width) > 2: + if max(original_height, original_width) / min(original_height, original_width) > 2.4: return # check if the variance is too low if np.std(image) < 1e-4: @@ -98,40 +98,45 @@ def map_sample( "original_width": original_width, }) except Exception as e: - error_queue.put_nowait({ - "url": url, - "caption": caption, - "error": str(e) - }) + # error_queue.put_nowait({ + # "url": url, + # "caption": caption, + # "error": str(e) + # }) + pass def map_batch( batch, num_threads=256, image_shape=(256, 256), + min_image_shape=(128, 128), timeout=15, retries=3, image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, ): try: - map_sample_fn = partial(map_sample, image_shape=image_shape, + map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation) with ThreadPoolExecutor(max_workers=num_threads) as executor: executor.map(map_sample_fn, batch["url"], batch['caption']) except Exception as e: - error_queue.put({ - "batch": batch, - "error": str(e) - }) + # error_queue.put_nowait({ + # "batch": batch, + # "error": str(e) + # }) + pass def parallel_image_loader( dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), + min_image_shape=(128, 128), num_threads=256, timeout=15, retries=3, image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, ): - map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape, + map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape, + min_image_shape=min_image_shape, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation) @@ -149,13 +154,14 @@ def parallel_image_loader( print(f"Shuffling dataset with seed {iteration}") dataset = dataset.shuffle(seed=iteration) # Clear the error queue - while not error_queue.empty(): - error_queue.get_nowait() + # while not error_queue.empty(): + # error_queue.get_nowait() class ImageBatchIterator: def __init__( self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), + min_image_shape=(128, 128), num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, @@ -165,7 +171,9 @@ def __init__( self.num_workers = num_workers self.batch_size = batch_size loader = partial(parallel_image_loader, num_threads=num_threads, - image_shape=image_shape, num_workers=num_workers, + image_shape=image_shape, + min_image_shape=min_image_shape, + num_workers=num_workers, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation) @@ -215,6 +223,7 @@ def __init__( dataset, batch_size=64, image_shape=(256, 256), + min_image_shape=(128, 128), num_workers=16, num_threads=512, default_split="all", @@ -253,8 +262,9 @@ def __init__( num_shards=global_process_count, index=global_process_index) print(f"Dataset length: {len(dataset)}") self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, + min_image_shape=min_image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads, - timeout=timeout, retries=retries, image_processor=image_processor, + timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation) self.batch_size = batch_size diff --git a/setup.py b/setup.py index f8106b8..8de9f32 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.19', + version='0.1.20', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',