Skip to content

Commit

Permalink
feat: added image shape option
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 12, 2024
1 parent 6244b6a commit daf18ed
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 30 deletions.
92 changes: 70 additions & 22 deletions datasets/dataset preparations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,27 @@
"import albumentations as A\n",
"import queue\n",
"\n",
"USER_AGENT = get_datasets_user_agent()\n",
"\n",
"data_queue = Queue(16*2000)\n",
"error_queue = Queue(16*2000)\n",
"\n",
"\n",
"def fetch_single_image(image_url, timeout=None, retries=0):\n",
" for _ in range(retries + 1):\n",
" try:\n",
" request = urllib.request.Request(\n",
" image_url,\n",
" data=None,\n",
" headers={\"user-agent\": USER_AGENT},\n",
" )\n",
" with urllib.request.urlopen(request, timeout=timeout) as req:\n",
" image = PIL.Image.open(io.BytesIO(req.read()))\n",
" break\n",
" except Exception:\n",
" image = None\n",
" return image\n",
"\n",
"def map_sample(\n",
" url, caption, \n",
" image_shape=(256, 256),\n",
Expand Down Expand Up @@ -540,12 +558,12 @@
" \"error\": str(e)\n",
" })\n",
" \n",
"def map_batch(batch, num_threads=256, timeout=None, retries=0):\n",
"def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):\n",
" with ThreadPoolExecutor(max_workers=num_threads) as executor:\n",
" executor.map(map_sample, batch[\"url\"], batch['caption'])\n",
" executor.map(map_sample, batch[\"url\"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)\n",
" \n",
"def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):\n",
" map_batch_fn = partial(map_batch, num_threads=num_threads)\n",
"def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):\n",
" map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)\n",
" shard_len = len(dataset) // num_workers\n",
" print(f\"Local Shard lengths: {shard_len}\")\n",
" with multiprocessing.Pool(num_workers) as pool:\n",
Expand All @@ -558,11 +576,11 @@
" iteration += 1\n",
" \n",
"class ImageBatchIterator:\n",
" def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):\n",
" def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):\n",
" self.dataset = dataset\n",
" self.num_workers = num_workers\n",
" self.batch_size = batch_size\n",
" loader = partial(parallel_image_loader, num_threads=num_threads)\n",
" loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n",
" self.thread = threading.Thread(target=loader, args=(dataset, num_workers))\n",
" self.thread.start()\n",
" \n",
Expand Down Expand Up @@ -592,6 +610,14 @@
" \"image\": images,\n",
" }\n",
" \n",
"def dataMapper(map: Dict[str, Any]):\n",
" def _map(sample) -> Dict[str, Any]:\n",
" return {\n",
" \"url\": sample[map[\"url\"]],\n",
" \"caption\": sample[map[\"caption\"]],\n",
" }\n",
" return _map\n",
"\n",
"class OnlineStreamingDataLoader():\n",
" def __init__(\n",
" self, \n",
Expand Down Expand Up @@ -658,18 +684,39 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset length: 591753\n",
"Local Shard lengths: 36984\n"
"Loading dataset from path\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset length: 591753\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Exception in thread Thread-7:\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n",
" self.run()\n",
" File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 766, in run_closure\n",
" _threading_Thread_run(self)\n",
" File \"/usr/lib/python3.10/threading.py\", line 953, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"TypeError: parallel_image_loader() got multiple values for argument 'num_workers'\n"
]
}
],
"source": [
"dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, split=\"train\")"
"dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, default_split=\"train\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -678,7 +725,7 @@
"0"
]
},
"execution_count": 15,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -689,18 +736,19 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1571"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
"ename": "NameError",
"evalue": "name 'data_queue' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdata_queue\u001b[49m\u001b[38;5;241m.\u001b[39mqsize()\n",
"\u001b[0;31mNameError\u001b[0m: name 'data_queue' is not defined"
]
}
],
"source": [
Expand Down Expand Up @@ -729,14 +777,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2000/2000 [00:44<00:00, 44.77it/s]\n"
"100%|██████████| 2000/2000 [00:37<00:00, 53.41it/s]\n"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions flaxdiff/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .online_loader import OnlineStreamingDataLoader
14 changes: 7 additions & 7 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def map_sample(
"error": str(e)
})

def map_batch(batch, num_threads=256, timeout=None, retries=0):
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(map_sample, batch["url"], batch['caption'])
executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)

def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
map_batch_fn = partial(map_batch, num_threads=num_threads)
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
shard_len = len(dataset) // num_workers
print(f"Local Shard lengths: {shard_len}")
with multiprocessing.Pool(num_workers) as pool:
Expand All @@ -106,12 +106,12 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=25
iteration += 1

class ImageBatchIterator:
def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
self.dataset = dataset
self.num_workers = num_workers
self.batch_size = batch_size
loader = partial(parallel_image_loader, num_threads=num_threads)
self.thread = threading.Thread(target=loader, args=(dataset, num_workers))
loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
self.thread = threading.Thread(target=loader, args=(dataset))
self.thread.start()

def __iter__(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.13',
version='0.1.14',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit daf18ed

Please sign in to comment.