From daf18edef222c4e3bb152989414a8ec8b9106c6e Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Mon, 12 Aug 2024 07:38:07 +0000 Subject: [PATCH] feat: added image shape option --- datasets/dataset preparations.ipynb | 92 ++++++++++++++++++++++------- flaxdiff/data/__init__.py | 1 + flaxdiff/data/online_loader.py | 14 ++--- setup.py | 2 +- 4 files changed, 79 insertions(+), 30 deletions(-) diff --git a/datasets/dataset preparations.ipynb b/datasets/dataset preparations.ipynb index 3ffcd6a..45a8255 100644 --- a/datasets/dataset preparations.ipynb +++ b/datasets/dataset preparations.ipynb @@ -492,9 +492,27 @@ "import albumentations as A\n", "import queue\n", "\n", + "USER_AGENT = get_datasets_user_agent()\n", + "\n", "data_queue = Queue(16*2000)\n", "error_queue = Queue(16*2000)\n", "\n", + "\n", + "def fetch_single_image(image_url, timeout=None, retries=0):\n", + " for _ in range(retries + 1):\n", + " try:\n", + " request = urllib.request.Request(\n", + " image_url,\n", + " data=None,\n", + " headers={\"user-agent\": USER_AGENT},\n", + " )\n", + " with urllib.request.urlopen(request, timeout=timeout) as req:\n", + " image = PIL.Image.open(io.BytesIO(req.read()))\n", + " break\n", + " except Exception:\n", + " image = None\n", + " return image\n", + "\n", "def map_sample(\n", " url, caption, \n", " image_shape=(256, 256),\n", @@ -540,12 +558,12 @@ " \"error\": str(e)\n", " })\n", " \n", - "def map_batch(batch, num_threads=256, timeout=None, retries=0):\n", + "def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):\n", " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", - " executor.map(map_sample, batch[\"url\"], batch['caption'])\n", + " executor.map(map_sample, batch[\"url\"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)\n", " \n", - "def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):\n", - " map_batch_fn = partial(map_batch, num_threads=num_threads)\n", + "def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):\n", + " map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)\n", " shard_len = len(dataset) // num_workers\n", " print(f\"Local Shard lengths: {shard_len}\")\n", " with multiprocessing.Pool(num_workers) as pool:\n", @@ -558,11 +576,11 @@ " iteration += 1\n", " \n", "class ImageBatchIterator:\n", - " def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):\n", + " def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):\n", " self.dataset = dataset\n", " self.num_workers = num_workers\n", " self.batch_size = batch_size\n", - " loader = partial(parallel_image_loader, num_threads=num_threads)\n", + " loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n", " self.thread = threading.Thread(target=loader, args=(dataset, num_workers))\n", " self.thread.start()\n", " \n", @@ -592,6 +610,14 @@ " \"image\": images,\n", " }\n", " \n", + "def dataMapper(map: Dict[str, Any]):\n", + " def _map(sample) -> Dict[str, Any]:\n", + " return {\n", + " \"url\": sample[map[\"url\"]],\n", + " \"caption\": sample[map[\"caption\"]],\n", + " }\n", + " return _map\n", + "\n", "class OnlineStreamingDataLoader():\n", " def __init__(\n", " self, \n", @@ -658,18 +684,39 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataset length: 591753\n", - "Local Shard lengths: 36984\n" + "Loading dataset from path\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset length: 591753\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Exception in thread Thread-7:\n", + "Traceback (most recent call last):\n", + " File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n", + " self.run()\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 766, in run_closure\n", + " _threading_Thread_run(self)\n", + " File \"/usr/lib/python3.10/threading.py\", line 953, in run\n", + " self._target(*self._args, **self._kwargs)\n", + "TypeError: parallel_image_loader() got multiple values for argument 'num_workers'\n" ] } ], "source": [ - "dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, split=\"train\")" + "dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, default_split=\"train\")" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -678,7 +725,7 @@ "0" ] }, - "execution_count": 15, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -689,18 +736,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "1571" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" + "ename": "NameError", + "evalue": "name 'data_queue' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdata_queue\u001b[49m\u001b[38;5;241m.\u001b[39mqsize()\n", + "\u001b[0;31mNameError\u001b[0m: name 'data_queue' is not defined" + ] } ], "source": [ @@ -729,14 +777,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2000/2000 [00:44<00:00, 44.77it/s]\n" + "100%|██████████| 2000/2000 [00:37<00:00, 53.41it/s]\n" ] } ], diff --git a/flaxdiff/data/__init__.py b/flaxdiff/data/__init__.py index e69de29..d3410cd 100644 --- a/flaxdiff/data/__init__.py +++ b/flaxdiff/data/__init__.py @@ -0,0 +1 @@ +from .online_loader import OnlineStreamingDataLoader \ No newline at end of file diff --git a/flaxdiff/data/online_loader.py b/flaxdiff/data/online_loader.py index b284dbd..29de81b 100644 --- a/flaxdiff/data/online_loader.py +++ b/flaxdiff/data/online_loader.py @@ -88,12 +88,12 @@ def map_sample( "error": str(e) }) -def map_batch(batch, num_threads=256, timeout=None, retries=0): +def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0): with ThreadPoolExecutor(max_workers=num_threads) as executor: - executor.map(map_sample, batch["url"], batch['caption']) + executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries) -def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256): - map_batch_fn = partial(map_batch, num_threads=num_threads) +def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256): + map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape) shard_len = len(dataset) // num_workers print(f"Local Shard lengths: {shard_len}") with multiprocessing.Pool(num_workers) as pool: @@ -106,12 +106,12 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=25 iteration += 1 class ImageBatchIterator: - def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256): + def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256): self.dataset = dataset self.num_workers = num_workers self.batch_size = batch_size - loader = partial(parallel_image_loader, num_threads=num_threads) - self.thread = threading.Thread(target=loader, args=(dataset, num_workers)) + loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers) + self.thread = threading.Thread(target=loader, args=(dataset)) self.thread.start() def __iter__(self): diff --git a/setup.py b/setup.py index 7d45b84..b754b93 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.13', + version='0.1.14', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',