Skip to content

Commit

Permalink
feat: min image shape
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 13, 2024
1 parent 24f7702 commit 0fd4dbb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
46 changes: 28 additions & 18 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
USER_AGENT = get_datasets_user_agent()

data_queue = Queue(16*2000)
error_queue = Queue()


def fetch_single_image(image_url, timeout=None, retries=0):
Expand Down Expand Up @@ -60,6 +59,7 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
def map_sample(
url, caption,
image_shape=(256, 256),
min_image_shape=(128, 128),
timeout=15,
retries=3,
upscale_interpolation=cv2.INTER_CUBIC,
Expand All @@ -75,10 +75,10 @@ def map_sample(
image = np.array(image)
original_height, original_width = image.shape[:2]
# check if the image is too small
if min(original_height, original_width) < min(image_shape):
if min(original_height, original_width) < min(min_image_shape):
return
# check if wrong aspect ratio
if max(original_height, original_width) / min(original_height, original_width) > 2:
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
return
# check if the variance is too low
if np.std(image) < 1e-4:
Expand All @@ -98,40 +98,45 @@ def map_sample(
"original_width": original_width,
})
except Exception as e:
error_queue.put_nowait({
"url": url,
"caption": caption,
"error": str(e)
})
# error_queue.put_nowait({
# "url": url,
# "caption": caption,
# "error": str(e)
# })
pass


def map_batch(
batch, num_threads=256, image_shape=(256, 256),
min_image_shape=(128, 128),
timeout=15, retries=3, image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
):
try:
map_sample_fn = partial(map_sample, image_shape=image_shape,
map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(map_sample_fn, batch["url"], batch['caption'])
except Exception as e:
error_queue.put({
"batch": batch,
"error": str(e)
})
# error_queue.put_nowait({
# "batch": batch,
# "error": str(e)
# })
pass


def parallel_image_loader(
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
min_image_shape=(128, 128),
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
):
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
min_image_shape=min_image_shape,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
Expand All @@ -149,13 +154,14 @@ def parallel_image_loader(
print(f"Shuffling dataset with seed {iteration}")
dataset = dataset.shuffle(seed=iteration)
# Clear the error queue
while not error_queue.empty():
error_queue.get_nowait()
# while not error_queue.empty():
# error_queue.get_nowait()


class ImageBatchIterator:
def __init__(
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
min_image_shape=(128, 128),
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
Expand All @@ -165,7 +171,9 @@ def __init__(
self.num_workers = num_workers
self.batch_size = batch_size
loader = partial(parallel_image_loader, num_threads=num_threads,
image_shape=image_shape, num_workers=num_workers,
image_shape=image_shape,
min_image_shape=min_image_shape,
num_workers=num_workers,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
Expand Down Expand Up @@ -215,6 +223,7 @@ def __init__(
dataset,
batch_size=64,
image_shape=(256, 256),
min_image_shape=(128, 128),
num_workers=16,
num_threads=512,
default_split="all",
Expand Down Expand Up @@ -253,8 +262,9 @@ def __init__(
num_shards=global_process_count, index=global_process_index)
print(f"Dataset length: {len(dataset)}")
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
min_image_shape=min_image_shape,
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
timeout=timeout, retries=retries, image_processor=image_processor,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
self.batch_size = batch_size
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.19',
version='0.1.20',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 0fd4dbb

Please sign in to comment.