diff --git a/.gitignore b/.gitignore index e01ff25..7751bc1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ good models tensorboard wandb gcs_mount -datacache \ No newline at end of file +datacache +*.deb +gcsfuse.yml \ No newline at end of file diff --git a/Diffusion flax linen on TPUs.ipynb b/Diffusion flax linen on TPUs.ipynb index 2218a23..98026db 100644 --- a/Diffusion flax linen on TPUs.ipynb +++ b/Diffusion flax linen on TPUs.ipynb @@ -66,15 +66,32 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The dotenv extension is already loaded. To reload it, use:\n", - " %reload_ext dotenv\n" + "cannot find .env file\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-08-01 13:05:48.518555: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-08-01 13:05:48.533175: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-08-01 13:05:48.537626: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-08-01 13:05:49.419359: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "/home/mrwhite0racle/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -110,12 +127,14 @@ "gcs_utils._is_gcs_disabled = True\n", "import json\n", "# For CLIP\n", - "from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel" + "from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel\n", + "\n", + "import wandb" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -145,7 +164,7 @@ " TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -163,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -214,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -276,7 +295,7 @@ "def data_source_cc12m(source=\"/home/mrwhite0racle/research/FlaxDiff/datasets/gcs_mount/arrayrecord/cc12m/\"):\n", " cc12m_records_path = source\n", " cc12m_records = [os.path.join(cc12m_records_path, i) for i in os.listdir(cc12m_records_path) if 'array_record' in i]\n", - " ds = pygrain.ArrayRecordDataSource(cc12m_records)\n", + " ds = pygrain.ArrayRecordDataSource(cc12m_records[:-1])\n", " return ds\n", "\n", "def labelizer_oxford_flowers102(path):\n", @@ -292,10 +311,10 @@ "\n", "# Configure the following for your datasets\n", "datasetMap = {\n", - " \"oxford_flowers102\": {\n", - " # \"source\":data_source_tfds(\"oxford_flowers102\"),\n", - " \"labelizer\":labelizer_oxford_flowers102(\"/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt\"),\n", - " },\n", + " # \"oxford_flowers102\": {\n", + " # # \"source\":data_source_tfds(\"oxford_flowers102\"),\n", + " # \"labelizer\":labelizer_oxford_flowers102(\"/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt\"),\n", + " # },\n", " \"cc12m\": {\n", " \"source\":data_source_cc12m(),\n", " \"labelizer\":labelizer_cc12m,\n", @@ -383,9 +402,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 69, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing FlaxCLIPTextModel: {('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('text_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'post_layernorm', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'pre_layrnorm', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'post_layernorm', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('visual_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('logit_scale',), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel')}\n", + "- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], "source": [ "import struct as st\n", "\n", @@ -408,7 +437,11 @@ " unpacked_dict[key] = byte_array\n", " return unpacked_dict\n", "\n", - "def get_dataset_grain(data_name=\"oxford_flowers102\", batch_size=64, image_scale=256, text_encoders=defaultTextEncodeModel(), method=jax.image.ResizeMethod.LANCZOS3):\n", + "def get_dataset_grain(data_name=\"oxford_flowers102\", \n", + " batch_size=64, image_scale=256, \n", + " count=None, num_epochs=None,\n", + " text_encoders=defaultTextEncodeModel(), \n", + " method=jax.image.ResizeMethod.LANCZOS3):\n", " dataset = datasetMap[data_name]\n", " data_source = dataset[\"source\"]\n", " labelizer = dataset[\"labelizer\"]\n", @@ -442,10 +475,10 @@ " } \n", "\n", " sampler = pygrain.IndexSampler(\n", - " num_records=len(data_source),\n", + " num_records=len(data_source) if count is None else count,\n", " shuffle=True,\n", " seed=0,\n", - " num_epochs=None,\n", + " num_epochs=num_epochs,\n", " shard_options=pygrain.NoSharding(),\n", " )\n", "\n", @@ -455,7 +488,7 @@ " data_source=data_source,\n", " sampler=sampler,\n", " operations=transformations,\n", - " worker_count=32,\n", + " worker_count=120,\n", " read_options=pygrain.ReadOptions(64, 50),\n", " worker_buffer_size=20\n", " )\n", @@ -476,11 +509,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ - "data = get_dataset_grain(\"cc12m\", batch_size=64, image_scale=128)" + "data = get_dataset_grain(\"cc12m\", batch_size=256, image_scale=128)" ] }, { @@ -502,21 +535,23 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "data = get_dataset_grain(\"cc12m\", batch_size=64, image_scale=128, count = 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/1000 [00:00.async_save() done, defined at /home/mrwhite0racle/.local/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:415> exception=KeyboardInterrupt()>\n", + "Traceback (most recent call last):\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3577, in run_code\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n", + " File \"/tmp/ipykernel_38123/1732640408.py\", line 3, in \n", + " final_state = trainer.fit(data, 1000, epochs=2)\n", + " File \"/tmp/ipykernel_38123/4134358043.py\", line 429, in fit\n", + " super().fit(data, steps_per_epoch, epochs, {\"batch_size\":batch_size, \"null_labels_seq\":null_labels_full, \"text_embedder\":text_embedder})\n", + " File \"/tmp/ipykernel_38123/4134358043.py\", line 273, in fit\n", + " self.save(epochs)\n", + " File \"/tmp/ipykernel_38123/4134358043.py\", line 147, in save\n", + " self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args}, force=True)\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py\", line 1080, in save\n", + " self._checkpointer.save(save_directory, args=args)\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py\", line 194, in save\n", + " self._handler.save(tmpdir.get(), args=ckpt_args)\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py\", line 424, in save\n", + " asyncio.run(async_save())\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/nest_asyncio.py\", line 30, in run\n", + " return loop.run_until_complete(task)\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/nest_asyncio.py\", line 92, in run_until_complete\n", + " self._run_once()\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/nest_asyncio.py\", line 133, in _run_once\n", + " handle._run()\n", + " File \"/usr/lib/python3.10/asyncio/events.py\", line 80, in _run\n", + " self._context.run(self._callback, *self._args)\n", + " File \"/usr/lib/python3.10/asyncio/tasks.py\", line 315, in __wakeup\n", + " self.__step()\n", + " File \"/usr/lib/python3.10/asyncio/tasks.py\", line 232, in __step\n", + " result = coro.send(None)\n", + " File \"/home/mrwhite0racle/.local/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py\", line 422, in async_save\n", + " f.result()\n", + "KeyboardInterrupt\n" + ] + } + ], "source": [ + "import jax.experimental.pallas.ops.tpu.flash_attention\n", "from flaxdiff.models.simple_unet import l2norm, ConvLayer, TimeEmbedding, TimeProjection, Upsample, Downsample, ResidualBlock, PixelShuffle\n", "from flaxdiff.models.simple_unet import FourierEmbedding\n", "\n", @@ -680,38 +757,71 @@ " def setup(self):\n", " inner_dim = self.dim_head * self.heads\n", " # Weights were exported with old names {to_q, to_k, to_v, to_out}\n", - " self.query = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision, \n", - " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_q\")\n", - " self.key = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision, \n", - " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_k\")\n", - " self.value = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision, \n", - " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_v\")\n", + " dense = functools.partial(\n", + " nn.Dense,\n", + " self.heads * self.dim_head,\n", + " precision=self.precision, \n", + " use_bias=self.use_bias, \n", + " kernel_init=self.kernel_init(), \n", + " dtype=self.dtype\n", + " )\n", + " self.query = dense(name=\"to_q\")\n", + " self.key = dense(name=\"to_k\")\n", + " self.value = dense(name=\"to_v\")\n", + " \n", " self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, \n", " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_out_0\")\n", " # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)\n", + " \n", + " def _reshape_tensor_to_head_dim(self, tensor):\n", + " batch_size, _, seq_len, dim = tensor.shape\n", + " head_size = self.heads\n", + " tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)\n", + " tensor = jnp.transpose(tensor, (0, 2, 1, 3))\n", + " return tensor\n", + " \n", + " def _reshape_tensor_from_head_dim(self, tensor):\n", + " batch_size, _, seq_len, dim = tensor.shape\n", + " head_size = self.heads\n", + " tensor = jnp.transpose(tensor, (0, 2, 1, 3))\n", + " tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size)\n", + " return tensor\n", "\n", " @nn.compact\n", " def __call__(self, x:jax.Array, context=None):\n", + " # print(x.shape)\n", " # x has shape [B, H * W, C]\n", " context = x if context is None else context\n", + " \n", + " B, H, W, C = x.shape\n", + " x = x.reshape((B, 1, H * W, C))\n", + " \n", + " B, _H, _W, _C = context.shape\n", + " context = context.reshape((B, 1, _H * _W, _C))\n", + " \n", " query = self.query(x)\n", " key = self.key(context)\n", " value = self.value(context)\n", " \n", - " # print(query.shape, key.shape, value.shape)\n", + " query = self._reshape_tensor_to_head_dim(query)\n", + " key = self._reshape_tensor_to_head_dim(key)\n", + " value = self._reshape_tensor_to_head_dim(value)\n", " \n", - " # hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.mha_reference(\n", - " # query, key, value, None\n", - " # )\n", - " \n", - " hidden_states = nn.dot_product_attention(\n", - " query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision\n", + " hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(\n", + " query, key, value, None\n", " )\n", - " # hidden_states = self.attnfn(\n", - " # query, key, value, None\n", + " \n", + " hidden_states = self._reshape_tensor_from_head_dim(hidden_states)\n", + " \n", + " \n", + " # hidden_states = nn.dot_product_attention(\n", + " # query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision\n", " # )\n", " \n", " proj = self.proj_attn(hidden_states)\n", + " \n", + " proj = proj.reshape((B, H, W, C))\n", + " \n", " return proj\n", "\n", "\n", @@ -764,6 +874,7 @@ " hidden_states = nn.dot_product_attention(\n", " query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision\n", " )\n", + " \n", " proj = self.proj_attn(hidden_states)\n", " return proj\n", " \n", @@ -849,7 +960,7 @@ " heads: int = 4\n", " dim_head: int = 32\n", " use_linear_attention: bool = True\n", - " dtype: Any = jnp.float32\n", + " dtype: Any = jnp.bfloat16\n", " precision: Any = jax.lax.Precision.HIGH\n", " use_projection: bool = False\n", " use_flash_attention:bool = True\n", @@ -939,134 +1050,65 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "19.9 ms ± 337 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "(16, 1, 256, 64)\n", + "(16, 4, 256, 16)\n" ] } ], "source": [ - "x = jnp.ones((16, 16, 16, 64))\n", - "context = jnp.ones((16, 16, 16, 64))\n", - "attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.float32, use_flash_attention=True, use_projection=True, use_self_and_cross=False)\n", - "params = attention_block.init(jax.random.PRNGKey(0), x, context)\n", - "%timeit attention_block.apply(params, x, context)" + "x = jnp.ones((16, 1, 16*16, 64))\n", + "batch_size, _, seq_len, dim = x.shape\n", + "head_size = 4\n", + "dim_head = dim // head_size\n", + "k = nn.Dense(dim_head * head_size, precision=jax.lax.Precision.HIGHEST, use_bias=True)\n", + "param = k.init(jax.random.PRNGKey(42), x)\n", + "tensor = k.apply(param, x)\n", + "print(tensor.shape)\n", + "tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)\n", + "tensor = jnp.transpose(tensor, (0, 2, 1, 3))\n", + "print(tensor.shape)\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 119, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(1, 16, 768)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Output : (1, 16, 16, 64)\n" - ] - }, - { - "data": { - "text/html": [ - "
                                             TransformerBlock Summary                                             \n",
-       "┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃ path                module              inputs                 outputs              params                 ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│                    │ TransformerBlock   │ - float32[1,16,16,64] │ float32[1,16,16,64] │                        │\n",
-       "│                    │                    │ - None                │                     │                        │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│ RMSNorm_0          │ RMSNorm            │ float32[1,16,16,64]   │ float16[1,16,16,64] │ scale: float32[64]     │\n",
-       "│                    │                    │                       │                     │                        │\n",
-       "│                    │                    │                       │                     │ 64 (256 B)             │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│ Attention          │ EfficientAttention │ - float16[1,16,16,64] │ float16[1,16,16,64] │                        │\n",
-       "│                    │                    │ - float16[1,16,16,64] │                     │                        │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│ Attention/to_q     │ DenseGeneral       │ float16[1,16,16,64]   │ float16[1,16,16,64] │ kernel: float32[64,64] │\n",
-       "│                    │                    │                       │                     │                        │\n",
-       "│                    │                    │                       │                     │ 4,096 (16.4 KB)        │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│ Attention/to_k     │ DenseGeneral       │ float16[1,16,16,64]   │ float16[1,16,16,64] │ kernel: float32[64,64] │\n",
-       "│                    │                    │                       │                     │                        │\n",
-       "│                    │                    │                       │                     │ 4,096 (16.4 KB)        │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│ Attention/to_v     │ DenseGeneral       │ float16[1,16,16,64]   │ float16[1,16,16,64] │ kernel: float32[64,64] │\n",
-       "│                    │                    │                       │                     │                        │\n",
-       "│                    │                    │                       │                     │ 4,096 (16.4 KB)        │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│ Attention/to_out_0 │ DenseGeneral       │ float16[1,16,16,64]   │ float16[1,16,16,64] │ kernel: float32[64,64] │\n",
-       "│                    │                    │                       │                     │                        │\n",
-       "│                    │                    │                       │                     │ 4,096 (16.4 KB)        │\n",
-       "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n",
-       "│                                                                              Total  16,448 (65.8 KB)       │\n",
-       "└────────────────────┴────────────────────┴───────────────────────┴─────────────────────┴────────────────────────┘\n",
-       "                                                                                                                  \n",
-       "                                        Total Parameters: 16,448 (65.8 KB)                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[3m TransformerBlock Summary \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ │ TransformerBlock │ - \u001b[2mfloat32\u001b[0m[1,16,16,64] │ \u001b[2mfloat32\u001b[0m[1,16,16,64] │ │\n", - "│ │ │ - None │ │ │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│ RMSNorm_0 │ RMSNorm │ \u001b[2mfloat32\u001b[0m[1,16,16,64] │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m64 \u001b[0m\u001b[1;2m(256 B)\u001b[0m │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│ Attention │ EfficientAttention │ - \u001b[2mfloat16\u001b[0m[1,16,16,64] │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ │\n", - "│ │ │ - \u001b[2mfloat16\u001b[0m[1,16,16,64] │ │ │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│ Attention/to_q │ DenseGeneral │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│ Attention/to_k │ DenseGeneral │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│ Attention/to_v │ DenseGeneral │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│ Attention/to_out_0 │ DenseGeneral │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ \u001b[2mfloat16\u001b[0m[1,16,16,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├────────────────────┼────────────────────┼───────────────────────┼─────────────────────┼────────────────────────┤\n", - "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", - "└────────────────────┴────────────────────┴───────────────────────┴─────────────────────┴────────────────────────┘\n", - "\u001b[1m \u001b[0m\n", - "\u001b[1m Total Parameters: 16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m\u001b[1m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\n", - "0.796968 1.2287916\n" + "47.3 μs ± 25.1 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], + "source": [ + "x = jnp.ones((16, 64, 64, 128))\n", + "context = jnp.ones((16, 64, 64, 128))\n", + "attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.bfloat16, use_flash_attention=False, use_projection=False, use_self_and_cross=False)\n", + "params = attention_block.init(jax.random.PRNGKey(0), x, context)\n", + "@jax.jit\n", + "def apply(params, x, context):\n", + " return attention_block.apply(params, x, context)\n", + "\n", + "apply(params, x, context)\n", + "\n", + "%timeit -n 1 apply(params, x, context)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "x = jnp.ones((1, 16, 16, 64))\n", "context = jnp.ones((1, 12, 768))\n", @@ -1074,7 +1116,7 @@ "context = jnp.pad(context, ((0, 0), (0, 4), (0, 0)), mode='constant', constant_values=0)\n", "print(context.shape)\n", "context = None#jnp.reshape(context, (1, 1, 16, 768))\n", - "attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.float16, use_flash_attention=True, use_projection=False, use_self_and_cross=False)\n", + "attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.bfloat16, use_flash_attention=True, use_projection=False, use_self_and_cross=False)\n", "params = attention_block.init(jax.random.PRNGKey(0), x, context)\n", "out = attention_block.apply(params, x, context)\n", "print(\"Output :\", out.shape)\n", @@ -1093,7 +1135,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -1272,7 +1314,7 @@ " use_linear_attention=False,\n", " use_projection=middle_attention.get(\"use_projection\", False),\n", " use_self_and_cross=False,\n", - " precision=attention_config.get(\"precision\", self.precision),\n", + " precision=middle_attention.get(\"precision\", self.precision),\n", " name=f\"middle_attention_{j}\")(x)\n", " x = ResidualBlock(\n", " middle_conv_type,\n", @@ -1291,7 +1333,8 @@ " for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))):\n", " # print(\"Upscaling\", i, \"features\", dim_out)\n", " for j in range(self.num_res_blocks):\n", - " x = jnp.concatenate([x, downs.pop()], axis=-1)\n", + " residual = downs.pop()\n", + " x = jnp.concatenate([x, residual], axis=-1)\n", " # print(\"concat==> \", i, \"concat\", x.shape)\n", " # kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1))\n", " kernel_size = (3, 3)\n", @@ -1308,18 +1351,18 @@ " precision=self.precision\n", " )(x, temb)\n", " if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block\n", - " B, H, W, _ = x.shape\n", - " if H > TS:\n", - " padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))\n", - " else:\n", - " padded_context = None\n", + " # B, H, W, _ = x.shape\n", + " # if H > TS:\n", + " # padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))\n", + " # else:\n", + " # padded_context = None\n", " x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), \n", " dim_head=dim_out // attention_config['heads'],\n", " use_flash_attention=attention_config.get(\"flash_attention\", True),\n", " use_projection=attention_config.get(\"use_projection\", False),\n", " use_self_and_cross=attention_config.get(\"use_self_and_cross\", True),\n", " precision=attention_config.get(\"precision\", self.precision),\n", - " name=f\"up_{i}_attention_{j}\")(x, padded_context)\n", + " name=f\"up_{i}_attention_{j}\")(x, residual)\n", " # print(\"Upscaling \", i, x.shape)\n", " if i != len(feature_depths) - 1:\n", " x = Upsample(\n", @@ -1374,29 +1417,51 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 33, "metadata": {}, + "outputs": [], "source": [ - "# Training" + "unet = Unet(emb_features=512, \n", + " feature_depths=[128, 256, 512, 1024],\n", + " attention_configs=[\n", + " None,\n", + " # None,\n", + " # {\"heads\":32, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":False, \"use_self_and_cross\":True}, \n", + " {\"heads\":32, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", + " {\"heads\":32, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", + " {\"heads\":32, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}\n", + " ],\n", + " num_res_blocks=4,\n", + " num_middle_res_blocks=1\n", + ")\n", + "\n", + "inp = jnp.ones((1, 128, 128, 3))\n", + "temb = jnp.ones((1,))\n", + "textcontext = jnp.ones((1, 77, 768))\n", + "\n", + "params = unet.init(jax.random.PRNGKey(0), inp, temb, textcontext)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "BATCH_SIZE = 128\n", - "IMAGE_SIZE = 128\n", - "\n", - "cosine_schedule = CosineNoiseSchedule(1000, beta_end=1)\n", - "karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)\n", - "edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)" + "unet.tabulate(jax.random.key(0), inp, temb, textcontext, console_kwargs={\"width\": 200, \"force_jupyter\":True, })" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -1439,11 +1504,16 @@ " checkpoint_suffix:str=\"\",\n", " loss_fn=optax.l2_loss,\n", " param_transforms:Callable=None,\n", + " wandb_config:Dict[str, Any]=None\n", " ):\n", " self.model = model\n", " self.name = name\n", " self.loss_fn = loss_fn\n", " self.input_shapes = input_shapes\n", + " \n", + " if wandb_config is not None:\n", + " run = wandb.init(**wandb_config)\n", + " self.wandb = run\n", "\n", " checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n", " options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)\n", @@ -1640,6 +1710,8 @@ " pbar.update(100)\n", " current_step = current_epoch*steps_per_epoch + i\n", " summary_writer.scalar('Train Loss', loss, step=current_step)\n", + " if self.wandb is not None:\n", + " self.wandb.log({\"train/loss\": loss})\n", " \n", " print(f\"\\n\\tEpoch done\")\n", " end_time = time.time()\n", @@ -1697,25 +1769,17 @@ " noise_schedule:NoiseScheduler,\n", " rngs:jax.random.PRNGKey,\n", " unconditional_prob:float=0.2,\n", - " train_state:TrainState=None,\n", " name:str=\"Diffusion\",\n", - " load_from_checkpoint:bool=False,\n", - " checkpoint_suffix:str=\"\",\n", - " param_transforms:Callable=None,\n", " model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),\n", - " loss_fn=optax.l2_loss\n", + " **kwargs\n", " ):\n", " super().__init__(\n", " model=model,\n", " input_shapes=input_shapes,\n", " optimizer=optimizer,\n", " rngs=rngs,\n", - " train_state=train_state,\n", " name=name,\n", - " load_from_checkpoint=load_from_checkpoint,\n", - " checkpoint_suffix=checkpoint_suffix,\n", - " loss_fn=loss_fn,\n", - " param_transforms=param_transforms\n", + " **kwargs\n", " )\n", " self.noise_schedule = noise_schedule\n", " self.model_output_transform = model_output_transform\n", @@ -1834,81 +1898,223 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 94, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n" + "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_14:38:34\n" ] }, { - "name": "stdout", + "data": { + "text/html": [ + "Finishing last run (ID:x90ddjrq) before initializing another..." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


train/loss█▄▄▃▂▁▁▁▁▁

Run summary:


train/loss0.07343

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run Diffusion_SDE_VE_TEXT_2024-08-01_14:30:55 at: https://wandb.ai/ashishkumar4/flaxdiff/runs/x90ddjrq
View project at: https://wandb.ai/ashishkumar4/flaxdiff
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240801_143055-x90ddjrq/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require(\"core\")`! See https://wandb.me/wandb-core for more information." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Successfully finished last run (ID:x90ddjrq). Initializing new run:
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_143834-fjuigbav" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run Diffusion_SDE_VE_TEXT_2024-08-01_14:38:34 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/ashishkumar4/flaxdiff" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/ashishkumar4/flaxdiff/runs/fjuigbav" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", "output_type": "stream", "text": [ - "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_08:59:00\n" + "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n" ] } ], "source": [ + "BATCH_SIZE = 64\n", + "IMAGE_SIZE = 128\n", + "\n", + "cosine_schedule = CosineNoiseSchedule(1000, beta_end=1)\n", + "karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)\n", + "edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)\n", + "\n", "experiment_name = \"{name}_{date}\".format(\n", " name=\"Diffusion_SDE_VE_TEXT\", date=datetime.now().strftime(\"%Y-%m-%d_%H:%M:%S\")\n", ")\n", "# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-16_02:16:07'\n", "# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-21_02:12:40'\n", "# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-30_05:48:22'\n", - "experiment_name = 'Diffusion_SDE_VE_TEXT_2024-08-01_08:59:00'\n", + "# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-08-01_08:59:00'\n", "print(\"Experiment_Name:\", experiment_name)\n", - "unet = Unet(emb_features=256, \n", - " feature_depths=[128, 128, 256, 512, 1024],\n", - " attention_configs=[\n", - " None,\n", - " None,\n", - " # {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":True}\n", - " ],\n", - " num_res_blocks=2,\n", - " num_middle_res_blocks=1\n", - ")\n", "\n", "dataset_name = \"cc12m\"\n", "datalen = len(datasetMap[dataset_name]['source'])\n", "batches = datalen // BATCH_SIZE\n", "\n", - "# Suggested configurations\n", - "total_epochs = 100\n", - "steps_per_epoch = 1000\n", - "init_value = 1e-6\n", - "peak_value = 2e-4\n", - "warmup_steps = steps_per_epoch * 5\n", - "decay_steps = total_epochs * steps_per_epoch - warmup_steps\n", - "end_value = 1e-6\n", - "exponent = 1.0\n", - "\n", - "# # Create the learning rate schedule\n", - "# learning_rate_schedule = optax.warmup_cosine_decay_schedule(\n", - "# init_value=init_value,\n", - "# peak_value=peak_value,\n", - "# warmup_steps=warmup_steps,\n", - "# decay_steps=decay_steps,\n", - "# end_value=end_value,\n", - "# exponent=exponent\n", - "# )\n", - "\n", - "# solver = optax.adamw(learning_rate=learning_rate_schedule)\n", - "# solver = optax.radam(2e-4)\n", - "solver = optax.adam(2e-4)\n", - "# solver = optax.adamw(2e-6)\n", + "config = {\n", + " \"model\" : {\n", + " \"emb_features\":256, \n", + " \"feature_depths\":[64, 128, 256, 512],\n", + " \"attention_configs\":[\n", + " None,\n", + " # None,\n", + " # None,\n", + " # None,\n", + " # None,\n", + " # {\"heads\":32, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":False, \"use_self_and_cross\":True}, \n", + " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}, \n", + " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}, \n", + " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False},\n", + " ],\n", + " \"num_res_blocks\":2,\n", + " \"num_middle_res_blocks\":1,\n", + " },\n", + " \n", + " \"dataset\": {\n", + " \"name\" : dataset_name,\n", + " \"length\" : datalen,\n", + " \"batches\": batches\n", + " },\n", + " \"learning_rate\": 2e-4,\n", + " \n", + " \"input_shapes\": {\n", + " \"x\": (IMAGE_SIZE, IMAGE_SIZE, 3),\n", + " \"temb\": (),\n", + " \"textcontext\": (77, 768)\n", + " },\n", + "}\n", "\n", - "# solver = optax.lookahead(solver, sync_period=6, slow_step_size=0.5)\n", - "# params_transform = lambda x: optax.LookaheadParams.init_synced(x)\n", + "unet = Unet(**config['model'])\n", + "\n", + "learning_rate = config['learning_rate']\n", + "solver = optax.adam(learning_rate)\n", + "# solver = optax.adamw(2e-6)\n", "\n", "trainer = DiffusionTrainer(unet, optimizer=solver, \n", - " input_shapes={'x': (128, 128, 3), 'temb': (), 'textcontext': (77, 768)}, \n", + " input_shapes=config['input_shapes'], \n", " noise_schedule=edm_schedule,\n", " rngs=jax.random.PRNGKey(4), \n", " name=experiment_name,\n", @@ -1916,9 +2122,13 @@ " # train_state=trainer.best_state,\n", " # loss_fn=lambda x, y: jnp.abs(x - y),\n", " # param_transforms=params_transform,\n", - " load_from_checkpoint=True,\n", - " )\n", - "#trainer.summary()" + " # load_from_checkpoint=True,\n", + " wandb_config={\n", + " \"project\": \"flaxdiff\",\n", + " \"config\": config,\n", + " \"name\": experiment_name,\n", + " },\n", + " )\n" ] }, { @@ -1932,7 +2142,16 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = get_dataset_grain(config['dataset']['name'], batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1940,22 +2159,112 @@ "output_type": "stream", "text": [ "\n", - "Epoch 1/5\n" + "Epoch 1/1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\t\tEpoch 1: 6%|█▋ | 3600/65076 [10:05<2:12:06, 7.76step/s, loss=0.0962]2024-08-01 09:10:31.725089: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 8449539 nanoseconds and will start immediately.\n", - "\t\tEpoch 1: 24%|██████▉ | 15600/65076 [35:51<1:46:18, 7.76step/s, loss=0.0548]" + "\t\tEpoch 1: 100%|█████████████████████████████████| 1000/1000 [05:20<00:00, 3.12step/s, loss=0.0734]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 1\n" + ] + }, + { + "ename": "IndexError", + "evalue": "Too many indices for array: 1 non-None/Ellipsis indices for dim 0.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[92], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# jax.profiler.start_server(6009)\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m final_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[89], line 431\u001b[0m, in \u001b[0;36mDiffusionTrainer.fit\u001b[0;34m(self, data, steps_per_epoch, epochs)\u001b[0m\n\u001b[1;32m 429\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 430\u001b[0m text_embedder \u001b[38;5;241m=\u001b[39m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbatch_size\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnull_labels_seq\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mnull_labels_full\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtext_embedder\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mtext_embedder\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[89], line 258\u001b[0m, in \u001b[0;36mSimpleTrainer.fit\u001b[0;34m(self, data, steps_per_epoch, epochs, train_step_args)\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_loss \u001b[38;5;241m=\u001b[39m avg_loss\n\u001b[1;32m 257\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_state \u001b[38;5;241m=\u001b[39m state\n\u001b[0;32m--> 258\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# Compute Metrics\u001b[39;00m\n\u001b[1;32m 261\u001b[0m metrics_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n", + "Cell \u001b[0;32mIn[89], line 141\u001b[0m, in \u001b[0;36mSimpleTrainer.save\u001b[0;34m(self, epoch)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave\u001b[39m(\u001b[38;5;28mself\u001b[39m, epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSaving model at epoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 139\u001b[0m ckpt \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 140\u001b[0m \u001b[38;5;66;03m# 'model': self.model,\u001b[39;00m\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstate\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 142\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbest_state\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_best_state(),\n\u001b[1;32m 143\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbest_loss\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_loss\n\u001b[1;32m 144\u001b[0m }\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 146\u001b[0m save_args \u001b[38;5;241m=\u001b[39m orbax_utils\u001b[38;5;241m.\u001b[39msave_args_from_target(ckpt)\n", + "Cell \u001b[0;32mIn[89], line 107\u001b[0m, in \u001b[0;36mSimpleTrainer.get_state\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_state\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 107\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mflax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjax_utils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munreplicate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/flax/jax_utils.py:50\u001b[0m, in \u001b[0;36munreplicate\u001b[0;34m(tree)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21munreplicate\u001b[39m(tree):\n\u001b[1;32m 49\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a single instance of a replicated array.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtree_util\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtree_map\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtree\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/tree_util.py:343\u001b[0m, in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf, *rest)\u001b[0m\n\u001b[1;32m 341\u001b[0m leaves, treedef \u001b[38;5;241m=\u001b[39m tree_flatten(tree, is_leaf)\n\u001b[1;32m 342\u001b[0m all_leaves \u001b[38;5;241m=\u001b[39m [leaves] \u001b[38;5;241m+\u001b[39m [treedef\u001b[38;5;241m.\u001b[39mflatten_up_to(r) \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m rest]\n\u001b[0;32m--> 343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtreedef\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mxs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mxs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mzip\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mall_leaves\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/tree_util.py:343\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 341\u001b[0m leaves, treedef \u001b[38;5;241m=\u001b[39m tree_flatten(tree, is_leaf)\n\u001b[1;32m 342\u001b[0m all_leaves \u001b[38;5;241m=\u001b[39m [leaves] \u001b[38;5;241m+\u001b[39m [treedef\u001b[38;5;241m.\u001b[39mflatten_up_to(r) \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m rest]\n\u001b[0;32m--> 343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m treedef\u001b[38;5;241m.\u001b[39munflatten(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mxs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m xs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39mall_leaves))\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/flax/jax_utils.py:50\u001b[0m, in \u001b[0;36munreplicate..\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21munreplicate\u001b[39m(tree):\n\u001b[1;32m 49\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a single instance of a replicated array.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jax\u001b[38;5;241m.\u001b[39mtree_util\u001b[38;5;241m.\u001b[39mtree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m, tree)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/array.py:355\u001b[0m, in \u001b[0;36mArrayImpl.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 353\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlax_numpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_rewriting_take\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:7856\u001b[0m, in \u001b[0;36m_rewriting_take\u001b[0;34m(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)\u001b[0m\n\u001b[1;32m 7853\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax\u001b[38;5;241m.\u001b[39mdynamic_index_in_dim(arr, idx, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 7855\u001b[0m treedef, static_idx, dynamic_idx \u001b[38;5;241m=\u001b[39m _split_index_for_jit(idx, arr\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m-> 7856\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_gather\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtreedef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstatic_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices_are_sorted\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7857\u001b[0m \u001b[43m \u001b[49m\u001b[43munique_indices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:7865\u001b[0m, in \u001b[0;36m_gather\u001b[0;34m(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)\u001b[0m\n\u001b[1;32m 7862\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_gather\u001b[39m(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,\n\u001b[1;32m 7863\u001b[0m unique_indices, mode, fill_value):\n\u001b[1;32m 7864\u001b[0m idx \u001b[38;5;241m=\u001b[39m _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)\n\u001b[0;32m-> 7865\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[43m_index_to_gather\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# shared with _scatter_update\u001b[39;00m\n\u001b[1;32m 7866\u001b[0m y \u001b[38;5;241m=\u001b[39m arr\n\u001b[1;32m 7868\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fill_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:7973\u001b[0m, in \u001b[0;36m_index_to_gather\u001b[0;34m(x_shape, idx, normalize_indices)\u001b[0m\n\u001b[1;32m 7970\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_index_to_gather\u001b[39m(x_shape: Sequence[\u001b[38;5;28mint\u001b[39m], idx: Sequence[Any],\n\u001b[1;32m 7971\u001b[0m normalize_indices: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _Indexer:\n\u001b[1;32m 7972\u001b[0m \u001b[38;5;66;03m# Remove ellipses and add trailing slice(None)s.\u001b[39;00m\n\u001b[0;32m-> 7973\u001b[0m idx \u001b[38;5;241m=\u001b[39m \u001b[43m_canonicalize_tuple_index\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7975\u001b[0m \u001b[38;5;66;03m# Check for scalar boolean indexing: this requires inserting extra dimensions\u001b[39;00m\n\u001b[1;32m 7976\u001b[0m \u001b[38;5;66;03m# before performing the rest of the logic.\u001b[39;00m\n\u001b[1;32m 7977\u001b[0m scalar_bool_dims: Sequence[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m [n \u001b[38;5;28;01mfor\u001b[39;00m n, i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(idx) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(i, \u001b[38;5;28mbool\u001b[39m)]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:8293\u001b[0m, in \u001b[0;36m_canonicalize_tuple_index\u001b[0;34m(arr_ndim, idx, array_name)\u001b[0m\n\u001b[1;32m 8291\u001b[0m num_dimensions_consumed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(\u001b[38;5;129;01mnot\u001b[39;00m (e \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m e \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28mEllipsis\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, \u001b[38;5;28mbool\u001b[39m)) \u001b[38;5;28;01mfor\u001b[39;00m e \u001b[38;5;129;01min\u001b[39;00m idx)\n\u001b[1;32m 8292\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_dimensions_consumed \u001b[38;5;241m>\u001b[39m arr_ndim:\n\u001b[0;32m-> 8293\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\n\u001b[1;32m 8294\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mToo many indices for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marray_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_dimensions_consumed\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 8295\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-None/Ellipsis indices for dim \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marr_ndim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8296\u001b[0m ellipses \u001b[38;5;241m=\u001b[39m (i \u001b[38;5;28;01mfor\u001b[39;00m i, elt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(idx) \u001b[38;5;28;01mif\u001b[39;00m elt \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28mEllipsis\u001b[39m)\n\u001b[1;32m 8297\u001b[0m ellipsis_index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(ellipses, \u001b[38;5;28;01mNone\u001b[39;00m)\n", + "\u001b[0;31mIndexError\u001b[0m: Too many indices for array: 1 non-None/Ellipsis indices for dim 0." + ] + } + ], + "source": [ + "# jax.profiler.start_server(6009)\n", + "final_state = trainer.fit(data, 1000, epochs=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 1/1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 1: 100%|█████████████████████████████████| 1000/1000 [02:49<00:00, 5.89step/s, loss=0.0929]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 1\n", + "\n", + "\tEpoch 1 completed. Avg Loss: 0.2138698250055313, Time: 169.75s, Best Loss: 0.2138698250055313 \n", + "\n", + "Epoch 2/1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 2: 100%|█████████████████████████████████| 1000/1000 [01:11<00:00, 14.00step/s, loss=0.0773]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 2\n", + "\n", + "\tEpoch 2 completed. Avg Loss: 0.08507582545280457, Time: 71.46s, Best Loss: 0.08507582545280457 \n", + "Saving model at epoch 1\n", + "Error saving checkpoint Checkpoint for step 1 already exists.\n" ] } ], "source": [ - "jax.profiler.start_server(6009)\n", - "data = get_dataset_grain(\"cc12m\", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)\n", - "final_state = trainer.fit(data, batches, epochs=5)" + "# jax.profiler.start_server(6009)\n", + "final_state = trainer.fit(data, 1000, epochs=1)" ] }, { @@ -6753,7 +7062,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/setup_tpu.sh b/setup_tpu.sh index 9a95d45..9b8734c 100755 --- a/setup_tpu.sh +++ b/setup_tpu.sh @@ -4,11 +4,7 @@ pip install jax[tpu] flax[all] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install CPU version of tensorflow -pip install tensorflow[cpu] keras orbax optax clu grain augmax transformers opencv-python pandas - -pip install tensorflow-datasets jupyterlab python-dotenv scikit-learn termcolor wrapt - -pip install "packaging>=22.0" +pip install tensorflow[cpu] keras orbax optax clu grain augmax transformers opencv-python pandas tensorflow-datasets jupyterlab python-dotenv scikit-learn termcolor wrapt wandb wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb sudo dpkg -i knot-resolver-release.deb @@ -66,7 +62,7 @@ sudo apt update sudo apt install gcsfuse # Define the file name -gcsfuse_conf="gcsfuse.yml" +gcsfuse_conf="$HOME/gcsfuse.yml" # Define the contents of the file gcsfuse_conf_content=$(cat <