Skip to content

Commit

Permalink
feat: dataloader map_sample refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 13, 2024
1 parent 95cb0d2 commit 74c0075
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 3,198 deletions.
115 changes: 100 additions & 15 deletions datasets/dataset preparations copy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,51 @@
"# Leonardo Synthetic Dataset"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "023f308ec2164688892e01f46c23a6a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/14381152 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18cfa17b175f4c9fbdce335600c300ce",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/12 shards): 0%| | 0/14381152 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# playground = load_dataset(\"bigdata-pw/playground-liked\", split=\"all\")\n",
"playgroundMap = {\n",
" \"url\": \"url\",\n",
" \"caption\": \"prompt\",\n",
"}\n",
"final_data = mapDataset(playground, (playgroundMap,), batch_size=1000000, workers=None)\n",
"\n",
"final_data.save_to_disk(\"gs://flaxdiff-datasets-regional/datasets/playground-liked\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down Expand Up @@ -744,12 +789,12 @@
"leonardoMap = {\n",
" \"url\": \"image_url\",\n",
" \"caption\": \"caption\",\n",
"}"
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -759,8 +804,8 @@
"\n",
"def leonardoFilter(filterMap):\n",
" def _filter(sample):\n",
" if len(sample['negative_prompt']) != 0:\n",
" return False\n",
" # if len(sample['negative_prompt']) != 0:\n",
" # return False\n",
" for key, value in filterMap.items():\n",
" if sample[key] < value[\"min\"] or sample[key] > value[\"max\"]:\n",
" return False\n",
Expand Down Expand Up @@ -794,20 +839,27 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"249it [06:54, 1.57it/s] "
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d522bd677e184c57855719a8d6013f31",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/958 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"leonardo = load_dataset(\"bigdata-pw/leonardo\", split='train', streaming=True)\n",
"leonardo_100m = leonardo.shuffle().take(300_000_000)\n",
"leonardo_100m = leonardo.shuffle().take(600_000_000)\n",
"\n",
"filtered_leonardo_iterator = leonardo_100m.filter(leonardoFilter(heavyFilterMap))\n",
"filtered_leonardo = []\n",
Expand All @@ -821,9 +873,9 @@
" # return {\"url\": urls, \"caption\": captions}\n",
" return [{\"url\": item['url'], \"caption\": item['prompt']} for item in batch]\n",
"\n",
"loader = DataLoader(filtered_leonardo_iterator, batch_size=1000, num_workers=64, persistent_workers=True, collate_fn=collate_fn)\n",
"loader = DataLoader(filtered_leonardo_iterator, batch_size=100000, num_workers=64, persistent_workers=True, collate_fn=collate_fn)\n",
"\n",
"for batch in tqdm.tqdm(loader, total=100_000//1000):\n",
"for batch in tqdm.tqdm(loader):\n",
" filtered_leonardo.extend(batch)"
]
},
Expand All @@ -835,10 +887,10 @@
{
"data": {
"text/plain": [
"191462"
"605231"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -847,6 +899,39 @@
"len(filtered_leonardo)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"data = Dataset.from_list(filtered_leonardo)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a4070476bf6a4d379fb2f9637a72f1f8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/1 shards): 0%| | 0/605231 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.save_to_disk(\"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-600k\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit 74c0075

Please sign in to comment.