diff --git a/datasets/dataset preparations.ipynb b/datasets/dataset preparations.ipynb index 9746ea2..de907b2 100644 --- a/datasets/dataset preparations.ipynb +++ b/datasets/dataset preparations.ipynb @@ -509,13 +509,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bfeaa42fbd294774a7448be406e482fe", + "model_id": "c4b6a43425844467a5d03921ef25cee5", "version_major": 2, "version_minor": 0 }, @@ -529,7 +529,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d99905af2df44f708066961258a637ec", + "model_id": "836816f93eda40d3a17398c2ce6795ea", "version_major": 2, "version_minor": 0 }, @@ -547,13 +547,69 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'URL': ['https://images.pexels.com/photos/1464610/pexels-photo-1464610.jpeg?auto=compress&cs=tinysrgb&h=350',\n", + " 'https://dspncdn.com/a1/media/236x/ec/51/59/ec515909d46c49b5f9a1c98db3a50c83.jpg',\n", + " 'http://images.singletracks.com/blog/wp-content/uploads/2014/10/empire_link-enhanced92719.jpg',\n", + " 'https://us.123rf.com/450wm/yupiramos/yupiramos1909/yupiramos190942486/129795457-recently-married-couple-characters-vector-illustration-design.jpg?ver=6',\n", + " 'https://us.123rf.com/450wm/capacitorphoto/capacitorphoto1509/capacitorphoto150900170/45946146-gegrillter-lachs-und-tomaten-zitrone-rosmarin-auf-dem-h%C3%B6lzernen-hintergrund-.jpg?ver=6'],\n", + " 'TEXT': ['Cafe Latte in Round Red Cup and Saucer',\n", + " 'Stunning Adventure Photography by Stevin Tuchiwsky',\n", + " 'Trail: Empire Link, Park City, Utah. Rider: The man himself, Chips Chippendale of Singletrack Magazine. Photo: Jeff.',\n", + " 'recently married couple characters vector illustration design',\n", + " 'Grilled salmon and tomato, lemon, rosemary on the wooden background.'],\n", + " 'WIDTH': [6000, 236, 1200, 450, 450],\n", + " 'HEIGHT': [4000, 295, 799, 450, 300],\n", + " 'similarity': [0.31981587409973145,\n", + " 0.3211732804775238,\n", + " 0.33248665928840637,\n", + " 0.3292200565338135,\n", + " 0.30775392055511475],\n", + " 'hash': [-7039592731149973688,\n", + " -2579203913430252464,\n", + " 4425963208672941281,\n", + " -4722615866461648706,\n", + " 4426107764757616487],\n", + " 'punsafe': [0.00025475025177001953,\n", + " 8.619567211098911e-07,\n", + " 0.002868086099624634,\n", + " 0.0034500062465667725,\n", + " 5.051814514445141e-05],\n", + " 'pwatermark': [0.16689063608646393,\n", + " 0.08742427825927734,\n", + " 0.04702727869153023,\n", + " 0.6831610202789307,\n", + " 0.10683029145002365],\n", + " 'aesthetic': [7.005258560180664,\n", + " 7.266702175140381,\n", + " 7.009383201599121,\n", + " 7.558045387268066,\n", + " 7.084789276123047]}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "laion_aesthetic[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "12b1976efc3d4b77a800e0fd85a7e29f", + "model_id": "f022af3eb7d246ce8b449acf15dd7e7d", "version_major": 2, "version_minor": 0 }, @@ -567,7 +623,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "883961031d5c4c6f9dbb9624e25cca7c", + "model_id": "22feb9d8b0e64c0fb7d05aa065775134", "version_major": 2, "version_minor": 0 }, @@ -583,6 +639,45 @@ "laion_400m = load_dataset(\"laion/laion400m\", split=\"train\")" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'url': ['http://t0.gstatic.com/images?q=tbn:ANd9GcTnX7EwHrzccCd3Ki1mmjgocoPMPB_aGKw4g9PrghYZX1ojZiuS',\n", + " 'https://careers.cfainstitute.org/getasset/4794ad7b-a4a8-4fc7-b135-a92a187b3d86/',\n", + " 'http://img.beckett.com/images/items/custom/marketplace/66045141/migrated.jpg',\n", + " 'https://ae01.alicdn.com/kf/HTB1LWfYsr1YBuNjSszhq6AUsFXaW/high-waist-sleeveless-mini-soft-jeans-dress-frilled-women-ruffles-casual-summer-sundress-short-denim-beach.jpg_3-74x74.jpg',\n", + " 'https://images.wolfgangsvault.com/images/catalog/thumb/JRM09062-UV.jpg'],\n", + " 'NSFW': ['UNLIKELY', 'UNLIKELY', 'UNLIKELY', 'UNLIKELY', 'UNLIKELY'],\n", + " 'similarity': [0.30712828040122986,\n", + " 0.35018008947372437,\n", + " 0.3508765399456024,\n", + " 0.359369695186615,\n", + " 0.3070274293422699],\n", + " 'LICENSE': ['?', '?', '?', '?', '?'],\n", + " 'caption': ['bedroom minimalist home interior storage for kids bedroom design',\n", + " 'InterOcean Capital Group, LLC logo',\n", + " '2001 Absolute Memorabilia #190 Jay Gibbons RPM RC',\n", + " 'high waist sleeveless mini soft jeans dress frilled women ruffles casual summer sundress short denim beach dress cotton',\n", + " '\"Al Hirt / Boston Pops / Arthur Fiedler Vinyl 12\"\" (Used)\"'],\n", + " 'key': [9048, 2517, 6649, 102, 9286],\n", + " 'original_width': [277, 360, 318, 800, 190],\n", + " 'original_height': [182, 180, 231, 800, 190]}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "laion_400m[:5]" + ] + }, { "cell_type": "code", "execution_count": 10, @@ -617,69 +712,77 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:google.auth.compute_engine._metadata:Compute Engine Metadata server unavailable on attempt 1 of 3. Reason: timed out\n", - "WARNING:google.auth.compute_engine._metadata:Compute Engine Metadata server unavailable on attempt 2 of 3. Reason: timed out\n" - ] - }, + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e95c2249860746f8a44a327fe3015215", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Filter (num_proc=32): 0%| | 0/361020613 [00:00 1\u001b[0m \u001b[43mlaion_aesthetic\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_to_disk\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_25M\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/datasets/arrow_dataset.py:1520\u001b[0m, in \u001b[0;36mDataset.save_to_disk\u001b[0;34m(self, dataset_path, fs, max_shard_size, num_shards, num_proc, storage_options)\u001b[0m\n\u001b[1;32m 1517\u001b[0m num_shards \u001b[38;5;241m=\u001b[39m num_shards \u001b[38;5;28;01mif\u001b[39;00m num_shards \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m num_proc\n\u001b[1;32m 1519\u001b[0m fs: fsspec\u001b[38;5;241m.\u001b[39mAbstractFileSystem\n\u001b[0;32m-> 1520\u001b[0m fs, _ \u001b[38;5;241m=\u001b[39m \u001b[43murl_to_fs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_remote_filesystem(fs):\n\u001b[1;32m 1523\u001b[0m parent_cache_files_paths \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 1524\u001b[0m Path(cache_filename[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfilename\u001b[39m\u001b[38;5;124m\"\u001b[39m])\u001b[38;5;241m.\u001b[39mresolve()\u001b[38;5;241m.\u001b[39mparent \u001b[38;5;28;01mfor\u001b[39;00m cache_filename \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_files\n\u001b[1;32m 1525\u001b[0m }\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/fsspec/core.py:408\u001b[0m, in \u001b[0;36murl_to_fs\u001b[0;34m(url, **kwargs)\u001b[0m\n\u001b[1;32m 406\u001b[0m inkwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfo\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m urls\n\u001b[1;32m 407\u001b[0m urlpath, protocol, _ \u001b[38;5;241m=\u001b[39m chain[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 408\u001b[0m fs \u001b[38;5;241m=\u001b[39m \u001b[43mfilesystem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fs, urlpath\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/fsspec/registry.py:303\u001b[0m, in \u001b[0;36mfilesystem\u001b[0;34m(protocol, **storage_options)\u001b[0m\n\u001b[1;32m 296\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 297\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marrow_hdfs\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m protocol has been deprecated and will be \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 298\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mremoved in the future. Specify it as \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhdfs\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 299\u001b[0m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[1;32m 300\u001b[0m )\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m get_filesystem_class(protocol)\n\u001b[0;32m--> 303\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/fsspec/spec.py:81\u001b[0m, in \u001b[0;36m_Cached.__call__\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_cache[token]\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 81\u001b[0m obj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[38;5;66;03m# Setting _fs_token here causes some static linters to complain.\u001b[39;00m\n\u001b[1;32m 83\u001b[0m obj\u001b[38;5;241m.\u001b[39m_fs_token_ \u001b[38;5;241m=\u001b[39m token\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/gcsfs/core.py:319\u001b[0m, in \u001b[0;36mGCSFileSystem.__init__\u001b[0;34m(self, project, access, token, block_size, consistency, cache_timeout, secure_serialize, check_connection, requests_timeout, requester_pays, asynchronous, session_kwargs, loop, timeout, endpoint_url, default_location, version_aware, **kwargs)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_connection:\n\u001b[1;32m 314\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 315\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe `check_connection` argument is deprecated and will be removed in a future release.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 316\u001b[0m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[1;32m 317\u001b[0m )\n\u001b[0;32m--> 319\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcredentials \u001b[38;5;241m=\u001b[39m \u001b[43mGoogleCredentials\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproject\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccess\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/gcsfs/credentials.py:50\u001b[0m, in \u001b[0;36mGoogleCredentials.__init__\u001b[0;34m(self, project, access, token, check_credentials)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlock \u001b[38;5;241m=\u001b[39m threading\u001b[38;5;241m.\u001b[39mLock()\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtoken \u001b[38;5;241m=\u001b[39m token\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_credentials:\n\u001b[1;32m 53\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 54\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe `check_credentials` argument is deprecated and will be removed in a future release.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 55\u001b[0m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[1;32m 56\u001b[0m )\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/gcsfs/credentials.py:232\u001b[0m, in \u001b[0;36mGoogleCredentials.connect\u001b[0;34m(self, method)\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m meth \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgoogle_default\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcache\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcloud\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manon\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 232\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmeth\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 233\u001b[0m logger\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConnected with method \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, meth)\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/gcsfs/credentials.py:249\u001b[0m, in \u001b[0;36mGoogleCredentials.connect\u001b[0;34m(self, method)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAll connection methods have failed!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 249\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattribute__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m_connect_\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmethod \u001b[38;5;241m=\u001b[39m method\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/gcsfs/credentials.py:77\u001b[0m, in \u001b[0;36mGoogleCredentials._connect_google_default\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_connect_google_default\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m---> 77\u001b[0m credentials, project \u001b[38;5;241m=\u001b[39m \u001b[43mgauth\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdefault\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscopes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscope\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m msg \u001b[38;5;241m=\u001b[39m textwrap\u001b[38;5;241m.\u001b[39mdedent(\n\u001b[1;32m 79\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\\\u001b[39;00m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124;03m User-provided project '{}' does not match the google default project '{}'. Either\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 86\u001b[0m )\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproject \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproject \u001b[38;5;241m!=\u001b[39m project:\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/google/auth/_default.py:657\u001b[0m, in \u001b[0;36mdefault\u001b[0;34m(scopes, request, quota_project_id, default_scopes)\u001b[0m\n\u001b[1;32m 645\u001b[0m checkers \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 646\u001b[0m \u001b[38;5;66;03m# Avoid passing scopes here to prevent passing scopes to user credentials.\u001b[39;00m\n\u001b[1;32m 647\u001b[0m \u001b[38;5;66;03m# with_scopes_if_required() below will ensure scopes/default scopes are\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mlambda\u001b[39;00m: _get_gce_credentials(request, quota_project_id\u001b[38;5;241m=\u001b[39mquota_project_id),\n\u001b[1;32m 654\u001b[0m )\n\u001b[1;32m 656\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m checker \u001b[38;5;129;01min\u001b[39;00m checkers:\n\u001b[0;32m--> 657\u001b[0m credentials, project_id \u001b[38;5;241m=\u001b[39m \u001b[43mchecker\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m credentials \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 659\u001b[0m credentials \u001b[38;5;241m=\u001b[39m with_scopes_if_required(\n\u001b[1;32m 660\u001b[0m credentials, scopes, default_scopes\u001b[38;5;241m=\u001b[39mdefault_scopes\n\u001b[1;32m 661\u001b[0m )\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/google/auth/_default.py:653\u001b[0m, in \u001b[0;36mdefault..\u001b[0;34m()\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mgoogle\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mauth\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcredentials\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CredentialsWithQuotaProject\n\u001b[1;32m 641\u001b[0m explicit_project_id \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39menviron\u001b[38;5;241m.\u001b[39mget(\n\u001b[1;32m 642\u001b[0m environment_vars\u001b[38;5;241m.\u001b[39mPROJECT, os\u001b[38;5;241m.\u001b[39menviron\u001b[38;5;241m.\u001b[39mget(environment_vars\u001b[38;5;241m.\u001b[39mLEGACY_PROJECT)\n\u001b[1;32m 643\u001b[0m )\n\u001b[1;32m 645\u001b[0m checkers \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 646\u001b[0m \u001b[38;5;66;03m# Avoid passing scopes here to prevent passing scopes to user credentials.\u001b[39;00m\n\u001b[1;32m 647\u001b[0m \u001b[38;5;66;03m# with_scopes_if_required() below will ensure scopes/default scopes are\u001b[39;00m\n\u001b[1;32m 648\u001b[0m \u001b[38;5;66;03m# safely set on the returned credentials since requires_scopes will\u001b[39;00m\n\u001b[1;32m 649\u001b[0m \u001b[38;5;66;03m# guard against setting scopes on user credentials.\u001b[39;00m\n\u001b[1;32m 650\u001b[0m \u001b[38;5;28;01mlambda\u001b[39;00m: _get_explicit_environ_credentials(quota_project_id\u001b[38;5;241m=\u001b[39mquota_project_id),\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mlambda\u001b[39;00m: _get_gcloud_sdk_credentials(quota_project_id\u001b[38;5;241m=\u001b[39mquota_project_id),\n\u001b[1;32m 652\u001b[0m _get_gae_credentials,\n\u001b[0;32m--> 653\u001b[0m \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[43m_get_gce_credentials\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquota_project_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquota_project_id\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 654\u001b[0m )\n\u001b[1;32m 656\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m checker \u001b[38;5;129;01min\u001b[39;00m checkers:\n\u001b[1;32m 657\u001b[0m credentials, project_id \u001b[38;5;241m=\u001b[39m checker()\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/google/auth/_default.py:326\u001b[0m, in \u001b[0;36m_get_gce_credentials\u001b[0;34m(request, quota_project_id)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m request \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 324\u001b[0m request \u001b[38;5;241m=\u001b[39m google\u001b[38;5;241m.\u001b[39mauth\u001b[38;5;241m.\u001b[39mtransport\u001b[38;5;241m.\u001b[39m_http_client\u001b[38;5;241m.\u001b[39mRequest()\n\u001b[0;32m--> 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43m_metadata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_on_gce\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 327\u001b[0m \u001b[38;5;66;03m# Get the project ID.\u001b[39;00m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 329\u001b[0m project_id \u001b[38;5;241m=\u001b[39m _metadata\u001b[38;5;241m.\u001b[39mget_project_id(request\u001b[38;5;241m=\u001b[39mrequest)\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/google/auth/compute_engine/_metadata.py:78\u001b[0m, in \u001b[0;36mis_on_gce\u001b[0;34m(request)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mis_on_gce\u001b[39m(request):\n\u001b[1;32m 69\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Checks to see if the code runs on Google Compute Engine\u001b[39;00m\n\u001b[1;32m 70\u001b[0m \n\u001b[1;32m 71\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124;03m bool: True if the code runs on Google Compute Engine, False otherwise.\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mping\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnt\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 82\u001b[0m \u001b[38;5;66;03m# TODO: implement GCE residency detection on Windows\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/google/auth/compute_engine/_metadata.py:131\u001b[0m, in \u001b[0;36mping\u001b[0;34m(request, timeout, retry_count)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m attempt \u001b[38;5;129;01min\u001b[39;00m backoff:\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 131\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_METADATA_IP_ROOT\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mGET\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\n\u001b[1;32m 133\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 135\u001b[0m metadata_flavor \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mheaders\u001b[38;5;241m.\u001b[39mget(_METADATA_FLAVOR_HEADER)\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 137\u001b[0m response\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m==\u001b[39m http_client\u001b[38;5;241m.\u001b[39mOK\n\u001b[1;32m 138\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m metadata_flavor \u001b[38;5;241m==\u001b[39m _METADATA_FLAVOR_VALUE\n\u001b[1;32m 139\u001b[0m )\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/google/auth/transport/_http_client.py:104\u001b[0m, in \u001b[0;36mRequest.__call__\u001b[0;34m(self, url, method, body, headers, timeout, **kwargs)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 102\u001b[0m _LOGGER\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaking request: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, method, url)\n\u001b[0;32m--> 104\u001b[0m \u001b[43mconnection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 105\u001b[0m response \u001b[38;5;241m=\u001b[39m connection\u001b[38;5;241m.\u001b[39mgetresponse()\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Response(response)\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/http/client.py:1336\u001b[0m, in \u001b[0;36mHTTPConnection.request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1333\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrequest\u001b[39m(\u001b[38;5;28mself\u001b[39m, method, url, body\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, headers\u001b[38;5;241m=\u001b[39m{}, \u001b[38;5;241m*\u001b[39m,\n\u001b[1;32m 1334\u001b[0m encode_chunked\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 1335\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Send a complete request to the server.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1336\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/http/client.py:1382\u001b[0m, in \u001b[0;36mHTTPConnection._send_request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1378\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(body, \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 1379\u001b[0m \u001b[38;5;66;03m# RFC 2616 Section 3.7.1 says that text default has a\u001b[39;00m\n\u001b[1;32m 1380\u001b[0m \u001b[38;5;66;03m# default charset of iso-8859-1.\u001b[39;00m\n\u001b[1;32m 1381\u001b[0m body \u001b[38;5;241m=\u001b[39m _encode(body, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbody\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m-> 1382\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendheaders\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/http/client.py:1331\u001b[0m, in \u001b[0;36mHTTPConnection.endheaders\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CannotSendHeader()\n\u001b[0;32m-> 1331\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmessage_body\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/http/client.py:1091\u001b[0m, in \u001b[0;36mHTTPConnection._send_output\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1089\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\r\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer)\n\u001b[1;32m 1090\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer[:]\n\u001b[0;32m-> 1091\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1093\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m message_body \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1094\u001b[0m \n\u001b[1;32m 1095\u001b[0m \u001b[38;5;66;03m# create a consistent interface to message_body\u001b[39;00m\n\u001b[1;32m 1096\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(message_body, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mread\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;66;03m# Let file-like take precedence over byte-like. This\u001b[39;00m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# is needed to allow the current position of mmap'ed\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# files to be taken into account.\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/http/client.py:1035\u001b[0m, in \u001b[0;36mHTTPConnection.send\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 1033\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msock \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1034\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_open:\n\u001b[0;32m-> 1035\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1036\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1037\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m NotConnected()\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/http/client.py:1001\u001b[0m, in \u001b[0;36mHTTPConnection.connect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Connect to the host and port specified in __init__.\"\"\"\u001b[39;00m\n\u001b[1;32m 1000\u001b[0m sys\u001b[38;5;241m.\u001b[39maudit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttp.client.connect\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhost, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mport)\n\u001b[0;32m-> 1001\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msock \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create_connection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1002\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhost\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mport\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msource_address\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1003\u001b[0m \u001b[38;5;66;03m# Might fail in OSs that don't implement TCP_NODELAY\u001b[39;00m\n\u001b[1;32m 1004\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", - "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/socket.py:838\u001b[0m, in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, all_errors)\u001b[0m\n\u001b[1;32m 836\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m source_address:\n\u001b[1;32m 837\u001b[0m sock\u001b[38;5;241m.\u001b[39mbind(source_address)\n\u001b[0;32m--> 838\u001b[0m \u001b[43msock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43msa\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 839\u001b[0m \u001b[38;5;66;03m# Break explicitly a reference cycle\u001b[39;00m\n\u001b[1;32m 840\u001b[0m exceptions\u001b[38;5;241m.\u001b[39mclear()\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] + "data": { + "text/plain": [ + "185733069" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "laion_aesthetic.save_to_disk(\"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_25M\")" + "len(laion_400m)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d8906416d261403aa9ef71f3626b3cd8", + "model_id": "e739c7ff96f74c39b32d27c72b61e831", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Loading dataset from disk: 0%| | 0/17 [00:00 2:\n", - " return\n", + " if max(original_height, original_width) / min(original_height, original_width) > 2.4:\n", + " return None, original_height, original_width\n", " # check if the variance is too low\n", - " if np.std(image) < 1e-4:\n", - " return\n", - " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", + " if np.std(image) < 1e-5:\n", + " return None, original_height, original_width\n", + " # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " downscale = max(original_width, original_height) > max(image_shape)\n", " interpolation = downscale_interpolation if downscale else upscale_interpolation\n", - " image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)\n", + "\n", + " image = A.longest_max_size(image, max(\n", + " image_shape), interpolation=interpolation)\n", " image = A.pad(\n", " image,\n", " min_height=image_shape[0],\n", @@ -1218,87 +1307,56 @@ " border_mode=cv2.BORDER_CONSTANT,\n", " value=[255, 255, 255],\n", " )\n", - " data_queue.put({\n", - " \"url\": url,\n", - " \"caption\": caption,\n", - " \"image\": image\n", - " })\n", + " return image, original_height, original_width\n", " except Exception as e:\n", - " error_queue.put({\n", - " \"url\": url,\n", - " \"caption\": caption,\n", - " \"error\": str(e)\n", - " })\n", - " \n", - "def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):\n", - " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", - " executor.map(map_sample, batch[\"url\"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)\n", + " # print(\"Error processing image\", e, image_shape, interpolation)\n", + " # traceback.print_exc()\n", + " return None, 0, 0\n", + "\n", + "def default_feature_extractor(sample):\n", + " url = None\n", + " if \"url\" in sample:\n", + " url = sample[\"url\"]\n", + " elif \"URL\" in sample:\n", + " url = sample[\"URL\"]\n", + " elif \"image_url\" in sample:\n", + " url = sample[\"image_url\"]\n", + " else:\n", + " print(\"No url found in sample, skipping\", sample.keys())\n", " \n", - "def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):\n", - " map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)\n", - " shard_len = len(dataset) // num_workers\n", - " print(f\"Local Shard lengths: {shard_len}\")\n", - " with multiprocessing.Pool(num_workers) as pool:\n", - " iteration = 0\n", - " while True:\n", - " # Repeat forever\n", - " dataset = dataset.shuffle(seed=iteration)\n", - " shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]\n", - " pool.map(map_batch_fn, shards)\n", - " iteration += 1\n", - " \n", - "class ImageBatchIterator:\n", - " def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):\n", - " self.dataset = dataset\n", - " self.num_workers = num_workers\n", - " self.batch_size = batch_size\n", - " loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n", - " self.thread = threading.Thread(target=loader, args=(dataset))\n", - " self.thread.start()\n", + " caption = None\n", + " if \"caption\" in sample:\n", + " caption = sample[\"caption\"]\n", + " elif \"CAPTION\" in sample:\n", + " caption = sample[\"CAPTION\"]\n", + " elif \"txt\" in sample:\n", + " caption = sample[\"txt\"]\n", + " elif \"TEXT\" in sample:\n", + " caption = sample[\"TEXT\"]\n", + " elif \"text\" in sample:\n", + " caption = sample[\"text\"]\n", + " else:\n", + " print(\"No caption found in sample, skipping\", sample.keys())\n", " \n", - " def __iter__(self):\n", - " return self\n", - " \n", - " def __next__(self):\n", - " def fetcher(_):\n", - " return data_queue.get()\n", - " with ThreadPoolExecutor(max_workers=self.batch_size) as executor:\n", - " batch = list(executor.map(fetcher, range(self.batch_size)))\n", - " return batch\n", - " \n", - " def __del__(self):\n", - " self.thread.join()\n", + " print(\"url\", url, \"caption\", caption)\n", " \n", - " def __len__(self):\n", - " return len(self.dataset) // self.batch_size\n", - " \n", - "def default_collate(batch):\n", - " urls = [sample[\"url\"] for sample in batch]\n", - " captions = [sample[\"caption\"] for sample in batch]\n", - " images = np.stack([sample[\"image\"] for sample in batch], axis=0)\n", " return {\n", - " \"url\": urls,\n", - " \"caption\": captions,\n", - " \"image\": images,\n", + " \"url\": url,\n", + " \"caption\": caption,\n", " }\n", " \n", - "def dataMapper(map: Dict[str, Any]):\n", - " def _map(sample) -> Dict[str, Any]:\n", - " return {\n", - " \"url\": sample[map[\"url\"]],\n", - " \"caption\": sample[map[\"caption\"]],\n", - " }\n", - " return _map\n", "\n", "class OnlineStreamingDataLoader():\n", " def __init__(\n", - " self, \n", - " dataset, \n", - " batch_size=64, \n", - " num_workers=16, \n", + " self,\n", + " dataset,\n", + " batch_size=64,\n", + " image_shape=(256, 256),\n", + " min_image_shape=(128, 128),\n", + " num_workers=16,\n", " num_threads=512,\n", " default_split=\"all\",\n", - " pre_map_maker=dataMapper, \n", + " pre_map_maker=dataMapper,\n", " pre_map_def={\n", " \"url\": \"URL\",\n", " \"caption\": \"TEXT\",\n", @@ -1307,53 +1365,145 @@ " global_process_index=0,\n", " prefetch=1000,\n", " collate_fn=default_collate,\n", + " timeout=15,\n", + " retries=3,\n", + " image_processor=default_image_processor,\n", + " upscale_interpolation=cv2.INTER_CUBIC,\n", + " downscale_interpolation=cv2.INTER_AREA,\n", + " feature_extractor=default_feature_extractor,\n", " ):\n", " if isinstance(dataset, str):\n", " dataset_path = dataset\n", " print(\"Loading dataset from path\")\n", - " dataset = load_dataset(dataset_path, split=default_split)\n", + " if \"gs://\" in dataset:\n", + " dataset = load_from_disk(dataset_path)\n", + " else:\n", + " dataset = load_dataset(dataset_path, split=default_split)\n", " elif isinstance(dataset, list):\n", " if isinstance(dataset[0], str):\n", " print(\"Loading multiple datasets from paths\")\n", - " dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]\n", - " else:\n", - " print(\"Concatenating multiple datasets\")\n", - " dataset = concatenate_datasets(dataset)\n", - " dataset = dataset.map(pre_map_maker(pre_map_def))\n", - " self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)\n", + " dataset = [load_from_disk(dataset_path) if \"gs://\" in dataset_path else load_dataset(\n", + " dataset_path, split=default_split) for dataset_path in dataset]\n", + " print(\"Concatenating multiple datasets\")\n", + " dataset = concatenate_datasets(dataset)\n", + " dataset = dataset.shuffle(seed=0)\n", + " # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)\n", + " self.dataset = dataset.shard(\n", + " num_shards=global_process_count, index=global_process_index)\n", " print(f\"Dataset length: {len(dataset)}\")\n", - " self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)\n", - " self.collate_fn = collate_fn\n", - " \n", + " self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,\n", + " min_image_shape=min_image_shape,\n", + " num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,\n", + " timeout=timeout, retries=retries, image_processor=image_processor,\n", + " upscale_interpolation=upscale_interpolation,\n", + " downscale_interpolation=downscale_interpolation,\n", + " feature_extractor=feature_extractor)\n", + " self.batch_size = batch_size\n", + "\n", " # Launch a thread to load batches in the background\n", " self.batch_queue = queue.Queue(prefetch)\n", - " \n", + "\n", " def batch_loader():\n", " for batch in self.iterator:\n", - " self.batch_queue.put(batch)\n", - " \n", + " try:\n", + " print(\"Putting batch in queue\")\n", + " self.batch_queue.put(collate_fn(batch))\n", + " except Exception as e:\n", + " print(\"Error collating batch\", e)\n", + "\n", " self.loader_thread = threading.Thread(target=batch_loader)\n", " self.loader_thread.start()\n", - " \n", + "\n", " def __iter__(self):\n", " return self\n", - " \n", + "\n", " def __next__(self):\n", - " return self.collate_fn(self.batch_queue.get())\n", + " return self.batch_queue.get()\n", " # return self.collate_fn(next(self.iterator))\n", - " \n", + "\n", " def __len__(self):\n", - " return len(self.dataset) // self.batch_size\n", - " " + " return len(self.dataset)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading dataset from path\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0a80f89fa4564d07941570f584498906", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/128 [00:00