Skip to content

Commit

Permalink
reverted online loader to previous version
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Sep 11, 2024
1 parent 3b0e749 commit d1b8440
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 64 deletions.
82 changes: 19 additions & 63 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
USER_AGENT = get_datasets_user_agent()

data_queue = Queue(16*2000)
error_queue = Queue(16*2000)


def fetch_single_image(image_url, timeout=None, retries=0):
Expand Down Expand Up @@ -78,12 +77,6 @@ def default_image_processor(
return image, original_height, original_width


def default_feature_extractor(sample):
return {
"url": sample["url"],
"caption": sample["caption"],
}

def map_sample(
url,
caption,
Expand Down Expand Up @@ -127,6 +120,14 @@ def map_sample(
# })
pass


def default_feature_extractor(sample):
return {
"url": sample["url"],
"caption": sample["caption"],
}


def map_batch(
batch, num_threads=256, image_shape=(256, 256),
min_image_shape=(128, 128),
Expand All @@ -140,55 +141,21 @@ def map_batch(
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor
downscale_interpolation=downscale_interpolation
)
features = feature_extractor(batch)
url, caption = features["url"], features["caption"]
with ThreadPoolExecutor(max_workers=num_threads) as executor:
features = feature_extractor(batch)
url, caption = features["url"], features["caption"]
executor.map(map_sample_fn, url, caption)
return None
except Exception as e:
print(f"Error maping batch", e)
traceback.print_exc()
# error_queue.put_nowait({
# "batch": batch,
# "error": str(e)
# })
return e


# def map_batch_repeat_forever(
# batch, num_threads=256, image_shape=(256, 256),
# min_image_shape=(128, 128),
# timeout=15, retries=3, image_processor=default_image_processor,
# upscale_interpolation=cv2.INTER_CUBIC,
# downscale_interpolation=cv2.INTER_AREA,
# feature_extractor=default_feature_extractor,
# ):
# while True: # Repeat forever
# try:
# map_sample_fn = partial(
# map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
# timeout=timeout, retries=retries, image_processor=image_processor,
# upscale_interpolation=upscale_interpolation,
# downscale_interpolation=downscale_interpolation,
# feature_extractor=feature_extractor
# )
# features = feature_extractor(batch)
# url, caption = features["url"], features["caption"]
# with ThreadPoolExecutor(max_workers=num_threads) as executor:
# executor.map(map_sample_fn, url, caption)
# # Shuffle the batch
# batch = batch.shuffle(seed=np.random.randint(0, 1000000))
# except Exception as e:
# print(f"Error maping batch", e)
# traceback.print_exc()
# # error_queue.put_nowait({
# # "batch": batch,
# # "error": str(e)
# # })
# pass
pass


def parallel_image_loader(
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
Expand All @@ -197,35 +164,28 @@ def parallel_image_loader(
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
map_batch_fn=map_batch,

):
map_batch_fn = partial(
map_batch_fn, num_threads=num_threads, image_shape=image_shape,
map_batch, num_threads=num_threads, image_shape=image_shape,
min_image_shape=min_image_shape,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor
)
shard_len = len(dataset) // num_workers
print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
print(f"Local Shard lengths: {shard_len}")
with multiprocessing.Pool(num_workers) as pool:
iteration = 0
while True:
# Repeat forever
shards = [dataset[i*shard_len:(i+1)*shard_len]
for i in range(num_workers)]
# shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
print(f"mapping {len(shards)} shards")
errors = pool.map(map_batch_fn, shards)
for error in errors:
if error is not None:
print(f"Error in mapping batch", error)
pool.map(map_batch_fn, shards)
iteration += 1
print(f"Shuffling dataset with seed {iteration}")
dataset = dataset.shuffle(seed=iteration)
print(f"Dataset shuffled")
# Clear the error queue
# while not error_queue.empty():
# error_queue.get_nowait()
Expand All @@ -240,7 +200,6 @@ def __init__(
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
map_batch_fn=map_batch,
):
self.dataset = dataset
self.num_workers = num_workers
Expand All @@ -255,8 +214,7 @@ def __init__(
image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor,
map_batch_fn=map_batch_fn,
feature_extractor=feature_extractor
)
self.thread = threading.Thread(target=loader, args=(dataset,))
self.thread.start()
Expand Down Expand Up @@ -323,7 +281,6 @@ def __init__(
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
map_batch_fn=map_batch,
):
if isinstance(dataset, str):
dataset_path = dataset
Expand Down Expand Up @@ -351,8 +308,7 @@ def __init__(
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor,
map_batch_fn=map_batch_fn,
feature_extractor=feature_extractor
)
self.batch_size = batch_size

Expand All @@ -364,7 +320,7 @@ def batch_loader():
try:
self.batch_queue.put(collate_fn(batch))
except Exception as e:
print("Error collating batch", e)
print("Error processing batch", e)

self.loader_thread = threading.Thread(target=batch_loader)
self.loader_thread.start()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.35.2',
version='0.1.35.3',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit d1b8440

Please sign in to comment.