diff --git a/evaluate.ipynb b/evaluate.ipynb index c89aa50..6bf9e78 100644 --- a/evaluate.ipynb +++ b/evaluate.ipynb @@ -9,10 +9,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-09-09 14:52:58.205982: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-09-09 14:52:58.282905: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-09-09 14:52:58.304069: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-09-09 14:52:59.339008: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "2024-09-17 11:40:04.350778: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-09-17 11:40:04.427397: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-09-17 11:40:04.448116: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2024-09-17 11:40:05.513235: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", "There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.\n" ] } @@ -98,22 +98,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Fetching 16 files: 100%|██████████| 16/16 [00:00<00:00, 85926.84it/s]\n", - "/home/mrwhite0racle/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", - " warnings.warn(\n", - "Some of the weights of FlaxStableDiffusionSafetyChecker were initialized in bfloat16 precision from the model checkpoint at /home/mrwhite0racle/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/safety_checker:\n", - "[('concept_embeds',), ('concept_embeds_weights',), ('special_care_embeds',), ('special_care_embeds_weights',), ('vision_model', 'vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'post_layernorm', 'bias'), ('vision_model', 'vision_model', 'post_layernorm', 'scale'), ('vision_model', 'vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'vision_model', 'pre_layrnorm', 'scale'), ('visual_projection', 'kernel')]\n", - "You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n", - "Some of the weights of FlaxCLIPTextModel were initialized in bfloat16 precision from the model checkpoint at /home/mrwhite0racle/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/text_encoder:\n", - "[('text_model', 'embeddings', 'position_embedding', 'embedding'), ('text_model', 'embeddings', 'token_embedding', 'embedding'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'final_layer_norm', 'bias'), ('text_model', 'final_layer_norm', 'scale')]\n", - "You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n", - "/home/mrwhite0racle/.local/lib/python3.10/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n", - " warnings.warn(\n", - "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing FlaxCLIPTextModel: {('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'post_layernorm', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'pre_layrnorm', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('logit_scale',), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('visual_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'post_layernorm', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('text_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias')}\n", - "- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "/home/mrwhite0racle/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", - " warnings.warn(\n" + "There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). Please, ensure the directory exists and can be written to.\n" + ] + }, + { + "ename": "OSError", + "evalue": "Can't load config for 'CompVis/stable-diffusion-v1-4'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'CompVis/stable-diffusion-v1-4' is the correct path to a directory containing a model_index.json file", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/diffusers/configuration_utils.py:383\u001b[0m, in \u001b[0;36mConfigMixin.load_config\u001b[0;34m(cls, pretrained_model_name_or_path, return_unused_kwargs, return_commit_hash, **kwargs)\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 382\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 383\u001b[0m config_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 384\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 385\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 386\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 387\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 388\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 389\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 390\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 391\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 392\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 393\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 394\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 395\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 396\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_dir_use_symlinks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_dir_use_symlinks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 397\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m RepositoryNotFoundError:\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/huggingface_hub/file_download.py:1221\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, proxies, etag_timeout, token, local_files_only, headers, endpoint, legacy_cache_layout, resume_download, force_filename, local_dir_use_symlinks)\u001b[0m\n\u001b[1;32m 1220\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1221\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_hf_hub_download_to_cache_dir\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1222\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Destination\u001b[39;49;00m\n\u001b[1;32m 1223\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1224\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# File info\u001b[39;49;00m\n\u001b[1;32m 1225\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1226\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1227\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1228\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1229\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# HTTP info\u001b[39;49;00m\n\u001b[1;32m 1230\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1231\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1232\u001b[0m \u001b[43m \u001b[49m\u001b[43metag_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43metag_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1233\u001b[0m \u001b[43m \u001b[49m\u001b[43mendpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mendpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1234\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Additional options\u001b[39;49;00m\n\u001b[1;32m 1235\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1236\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1237\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/huggingface_hub/file_download.py:1335\u001b[0m, in \u001b[0;36m_hf_hub_download_to_cache_dir\u001b[0;34m(cache_dir, repo_id, filename, repo_type, revision, headers, proxies, etag_timeout, endpoint, local_files_only, force_download)\u001b[0m\n\u001b[1;32m 1333\u001b[0m pointer_path \u001b[38;5;241m=\u001b[39m _get_pointer_path(storage_folder, commit_hash, relative_filename)\n\u001b[0;32m-> 1335\u001b[0m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmakedirs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdirname\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblob_path\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 1336\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(pointer_path), exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m:215\u001b[0m, in \u001b[0;36mmakedirs\u001b[0;34m(name, mode, exist_ok)\u001b[0m\n", + "File \u001b[0;32m:225\u001b[0m, in \u001b[0;36mmakedirs\u001b[0;34m(name, mode, exist_ok)\u001b[0m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/mrwhite0racle/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mflaxdiff\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodels\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msimple_unet\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Unet\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mflaxdiff\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodels\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautoencoder\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiffusers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StableDiffusionVAE\n\u001b[0;32m----> 6\u001b[0m sd_vae \u001b[38;5;241m=\u001b[39m \u001b[43mStableDiffusionVAE\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdefaultTextEncodeModel\u001b[39m():\n\u001b[1;32m 9\u001b[0m modelname \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mopenai/clip-vit-large-patch14\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Desktop/ml-poc-notebooks/diffusion experiments/flaxdiff/models/autoencoder/diffusers.py:19\u001b[0m, in \u001b[0;36mStableDiffusionVAE.__init__\u001b[0;34m(self, modelname, revision, dtype)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdiffusers\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodels\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvae_flax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FlaxEncoder, FlaxDecoder\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdiffusers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FlaxStableDiffusionPipeline\n\u001b[0;32m---> 19\u001b[0m pipeline, params \u001b[38;5;241m=\u001b[39m \u001b[43mFlaxStableDiffusionPipeline\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodelname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 25\u001b[0m vae \u001b[38;5;241m=\u001b[39m pipeline\u001b[38;5;241m.\u001b[39mvae\n\u001b[1;32m 27\u001b[0m enc \u001b[38;5;241m=\u001b[39m FlaxEncoder(\n\u001b[1;32m 28\u001b[0m in_channels\u001b[38;5;241m=\u001b[39mvae\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39min_channels,\n\u001b[1;32m 29\u001b[0m out_channels\u001b[38;5;241m=\u001b[39mvae\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mlatent_channels,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 36\u001b[0m dtype\u001b[38;5;241m=\u001b[39mvae\u001b[38;5;241m.\u001b[39mdtype,\n\u001b[1;32m 37\u001b[0m )\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/diffusers/pipelines/pipeline_flax_utils.py:332\u001b[0m, in \u001b[0;36mFlaxDiffusionPipeline.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;66;03m# 1. Download the checkpoints and configs\u001b[39;00m\n\u001b[1;32m 330\u001b[0m \u001b[38;5;66;03m# use snapshot download here to get it working from from_pretrained\u001b[39;00m\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(pretrained_model_name_or_path):\n\u001b[0;32m--> 332\u001b[0m config_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_config\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 335\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 337\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 338\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 339\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 340\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 341\u001b[0m \u001b[38;5;66;03m# make sure we only download sub-folders and `diffusers` filenames\u001b[39;00m\n\u001b[1;32m 342\u001b[0m folder_names \u001b[38;5;241m=\u001b[39m [k \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m config_dict\u001b[38;5;241m.\u001b[39mkeys() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m k\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m\"\u001b[39m)]\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/site-packages/diffusers/configuration_utils.py:428\u001b[0m, in \u001b[0;36mConfigMixin.load_config\u001b[0;34m(cls, pretrained_model_name_or_path, return_unused_kwargs, return_commit_hash, **kwargs)\u001b[0m\n\u001b[1;32m 420\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 421\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWe couldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt connect to \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mHUGGINGFACE_CO_RESOLVE_ENDPOINT\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m to load this model, couldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find it\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 422\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in the cached files and it looks like \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpretrained_model_name_or_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not the path to a\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://huggingface.co/docs/diffusers/installation#offline-mode\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 426\u001b[0m )\n\u001b[1;32m 427\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m:\n\u001b[0;32m--> 428\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 429\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt load config for \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpretrained_model_name_or_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. If you were trying to load it from \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 430\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://huggingface.co/models\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, make sure you don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt have a local directory with the same name. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 431\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOtherwise, make sure \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpretrained_model_name_or_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m is the correct path to a directory \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 432\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontaining a \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mconfig_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m file\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 433\u001b[0m )\n\u001b[1;32m 435\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 436\u001b[0m \u001b[38;5;66;03m# Load config dict\u001b[39;00m\n\u001b[1;32m 437\u001b[0m config_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_dict_from_json_file(config_file)\n", + "\u001b[0;31mOSError\u001b[0m: Can't load config for 'CompVis/stable-diffusion-v1-4'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'CompVis/stable-diffusion-v1-4' is the correct path to a directory containing a model_index.json file" ] } ], @@ -2256,11 +2266,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import jax\n", + "import jax.experimental\n", + "import jax.experimental.pallas\n", + "import jax.experimental.pallas.ops\n", + "import jax.experimental.pallas.ops.attention\n", + "import jax.experimental.pallas.ops.gpu\n", + "import jax.experimental.pallas.ops.gpu.attention\n", "import jax.numpy as jnp\n", "from flax import linen as nn\n", "from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional\n", @@ -2345,7 +2361,11 @@ " key = self._reshape_tensor_to_head_dim(key)\n", " value = self._reshape_tensor_to_head_dim(value)\n", " \n", - " hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(\n", + " # hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(\n", + " # query, key, value, None\n", + " # )\n", + " \n", + " hidden_states = jax.experimental.pallas.ops.attention.mha_forward_kernel(\n", " query, key, value, None\n", " )\n", " \n", @@ -2621,14 +2641,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "170 μs ± 137 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "780 µs ± 7.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -2648,6 +2668,85 @@ "%timeit apply(params, x, context)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "safe_zip() argument 2 is shorter than argument 1", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m:198\u001b[0m, in \u001b[0;36m_run_module_as_main\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32m:88\u001b[0m, in \u001b[0;36m_run_code\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel_launcher.py:18\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mipykernel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m kernelapp \u001b[38;5;28;01mas\u001b[39;00m app\n\u001b[0;32m---> 18\u001b[0m app\u001b[38;5;241m.\u001b[39mlaunch_new_instance()\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/traitlets/config/application.py:1075\u001b[0m, in \u001b[0;36mlaunch_instance\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1074\u001b[0m app\u001b[38;5;241m.\u001b[39minitialize(argv)\n\u001b[0;32m-> 1075\u001b[0m app\u001b[38;5;241m.\u001b[39mstart()\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/kernelapp.py:739\u001b[0m, in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 739\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mio_loop\u001b[38;5;241m.\u001b[39mstart()\n\u001b[1;32m 740\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/tornado/platform/asyncio.py:205\u001b[0m, in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstart\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39masyncio_loop\u001b[38;5;241m.\u001b[39mrun_forever()\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/asyncio/base_events.py:641\u001b[0m, in \u001b[0;36mrun_forever\u001b[0;34m()\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 641\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_once()\n\u001b[1;32m 642\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stopping:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/nest_asyncio.py:133\u001b[0m, in \u001b[0;36m_run_once\u001b[0;34m()\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 133\u001b[0m handle\u001b[38;5;241m.\u001b[39m_run()\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 135\u001b[0m \u001b[38;5;66;03m# restore the current task\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/flax/lib/python3.12/asyncio/events.py:88\u001b[0m, in \u001b[0;36m_run\u001b[0;34m()\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_context\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_callback, \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_args)\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mSystemExit\u001b[39;00m, \u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m):\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:545\u001b[0m, in \u001b[0;36mdispatch_queue\u001b[0;34m()\u001b[0m\n\u001b[1;32m 544\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 545\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprocess_one()\n\u001b[1;32m 546\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:534\u001b[0m, in \u001b[0;36mprocess_one\u001b[0;34m()\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[0;32m--> 534\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m dispatch(\u001b[38;5;241m*\u001b[39margs)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:437\u001b[0m, in \u001b[0;36mdispatch_shell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inspect\u001b[38;5;241m.\u001b[39misawaitable(result):\n\u001b[0;32m--> 437\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m result\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py:362\u001b[0m, in \u001b[0;36mexecute_request\u001b[0;34m()\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_associate_new_top_level_threads_with(parent_header)\n\u001b[0;32m--> 362\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mexecute_request(stream, ident, parent)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py:778\u001b[0m, in \u001b[0;36mexecute_request\u001b[0;34m()\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inspect\u001b[38;5;241m.\u001b[39misawaitable(reply_content):\n\u001b[0;32m--> 778\u001b[0m reply_content \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m reply_content\n\u001b[1;32m 780\u001b[0m \u001b[38;5;66;03m# Flush output before sending the reply.\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py:449\u001b[0m, in \u001b[0;36mdo_execute\u001b[0;34m()\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m accepts_params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcell_id\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m--> 449\u001b[0m res \u001b[38;5;241m=\u001b[39m shell\u001b[38;5;241m.\u001b[39mrun_cell(\n\u001b[1;32m 450\u001b[0m code,\n\u001b[1;32m 451\u001b[0m store_history\u001b[38;5;241m=\u001b[39mstore_history,\n\u001b[1;32m 452\u001b[0m silent\u001b[38;5;241m=\u001b[39msilent,\n\u001b[1;32m 453\u001b[0m cell_id\u001b[38;5;241m=\u001b[39mcell_id,\n\u001b[1;32m 454\u001b[0m )\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/ipykernel/zmqshell.py:549\u001b[0m, in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_last_traceback \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 549\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mrun_cell(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3075\u001b[0m, in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3074\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3075\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_cell(\n\u001b[1;32m 3076\u001b[0m raw_cell, store_history, silent, shell_futures, cell_id\n\u001b[1;32m 3077\u001b[0m )\n\u001b[1;32m 3078\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3130\u001b[0m, in \u001b[0;36m_run_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3129\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3130\u001b[0m result \u001b[38;5;241m=\u001b[39m runner(coro)\n\u001b[1;32m 3131\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/IPython/core/async_helpers.py:129\u001b[0m, in \u001b[0;36m_pseudo_sync_runner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 129\u001b[0m coro\u001b[38;5;241m.\u001b[39msend(\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3334\u001b[0m, in \u001b[0;36mrun_cell_async\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3331\u001b[0m interactivity \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m silent \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mast_node_interactivity\n\u001b[0;32m-> 3334\u001b[0m has_raised \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_ast_nodes(code_ast\u001b[38;5;241m.\u001b[39mbody, cell_name,\n\u001b[1;32m 3335\u001b[0m interactivity\u001b[38;5;241m=\u001b[39minteractivity, compiler\u001b[38;5;241m=\u001b[39mcompiler, result\u001b[38;5;241m=\u001b[39mresult)\n\u001b[1;32m 3337\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlast_execution_succeeded \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m has_raised\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3517\u001b[0m, in \u001b[0;36mrun_ast_nodes\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3516\u001b[0m asy \u001b[38;5;241m=\u001b[39m compare(code)\n\u001b[0;32m-> 3517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_code(code, result, async_\u001b[38;5;241m=\u001b[39masy):\n\u001b[1;32m 3518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3577\u001b[0m, in \u001b[0;36mrun_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3576\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3577\u001b[0m exec(code_obj, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muser_global_ns, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muser_ns)\n\u001b[1;32m 3578\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 3579\u001b[0m \u001b[38;5;66;03m# Reset our crash handler in place\u001b[39;00m\n", + "Cell \u001b[0;32mIn[6], line 5\u001b[0m\n\u001b[1;32m 4\u001b[0m attention_block \u001b[38;5;241m=\u001b[39m TransformerBlock(heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, dim_head\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m16\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16, use_flash_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, use_projection\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, use_self_and_cross\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, kernel_init\u001b[38;5;241m=\u001b[39mkernel_init(\u001b[38;5;241m1.0\u001b[39m))\n\u001b[0;32m----> 5\u001b[0m params \u001b[38;5;241m=\u001b[39m attention_block\u001b[38;5;241m.\u001b[39minit(jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mPRNGKey(\u001b[38;5;241m0\u001b[39m), x, context)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply\u001b[39m(params, x, context):\n", + "Cell \u001b[0;32mIn[4], line 327\u001b[0m, in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 325\u001b[0m context \u001b[38;5;241m=\u001b[39m projected_x \u001b[38;5;28;01mif\u001b[39;00m context \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m context\n\u001b[0;32m--> 327\u001b[0m projected_x \u001b[38;5;241m=\u001b[39m BasicTransformerBlock(\n\u001b[1;32m 328\u001b[0m query_dim\u001b[38;5;241m=\u001b[39minner_dim,\n\u001b[1;32m 329\u001b[0m heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mheads,\n\u001b[1;32m 330\u001b[0m dim_head\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdim_head,\n\u001b[1;32m 331\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAttention\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 332\u001b[0m precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision,\n\u001b[1;32m 333\u001b[0m use_bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 334\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype,\n\u001b[1;32m 335\u001b[0m use_flash_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_flash_attention,\n\u001b[1;32m 336\u001b[0m use_cross_only\u001b[38;5;241m=\u001b[39m(\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_self_and_cross),\n\u001b[1;32m 337\u001b[0m only_pure_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39monly_pure_attention,\n\u001b[1;32m 338\u001b[0m force_fp32_for_softmax\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforce_fp32_for_softmax,\n\u001b[1;32m 339\u001b[0m kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel_init\n\u001b[1;32m 340\u001b[0m )(projected_x, context)\n\u001b[1;32m 342\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_projection \u001b[38;5;241m==\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m:\n", + "Cell \u001b[0;32mIn[4], line 281\u001b[0m, in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_cross_only:\n\u001b[0;32m--> 281\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m hidden_states \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattention1(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(hidden_states))\n\u001b[1;32m 283\u001b[0m \u001b[38;5;66;03m# cross attention\u001b[39;00m\n", + "Cell \u001b[0;32mIn[4], line 86\u001b[0m, in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 84\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reshape_tensor_to_head_dim(value)\n\u001b[0;32m---> 86\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39mpallas\u001b[38;5;241m.\u001b[39mops\u001b[38;5;241m.\u001b[39mtpu\u001b[38;5;241m.\u001b[39mflash_attention\u001b[38;5;241m.\u001b[39mflash_attention(\n\u001b[1;32m 87\u001b[0m query, key, value, \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 88\u001b[0m )\n\u001b[1;32m 90\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reshape_tensor_from_head_dim(hidden_states)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py:198\u001b[0m, in \u001b[0;36mflash_attention\u001b[0;34m()\u001b[0m\n\u001b[1;32m 195\u001b[0m block_sizes \u001b[38;5;241m=\u001b[39m BlockSizes\u001b[38;5;241m.\u001b[39mget_default(\n\u001b[1;32m 196\u001b[0m batch_size, num_heads, q_seq_len, kv_seq_len, d_model\n\u001b[1;32m 197\u001b[0m )\n\u001b[0;32m--> 198\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _flash_attention(\n\u001b[1;32m 199\u001b[0m q, k, v, ab, segment_ids, \u001b[38;5;28;01mFalse\u001b[39;00m, causal, sm_scale, block_sizes, debug\n\u001b[1;32m 200\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py:216\u001b[0m, in \u001b[0;36m_flash_attention\u001b[0;34m()\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mpartial(jax\u001b[38;5;241m.\u001b[39mcustom_vjp, nondiff_argnums\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m10\u001b[39m))\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_flash_attention\u001b[39m(\n\u001b[1;32m 205\u001b[0m q,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 214\u001b[0m debug,\n\u001b[1;32m 215\u001b[0m ):\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _flash_attention_impl(\n\u001b[1;32m 217\u001b[0m q,\n\u001b[1;32m 218\u001b[0m k,\n\u001b[1;32m 219\u001b[0m v,\n\u001b[1;32m 220\u001b[0m ab,\n\u001b[1;32m 221\u001b[0m segment_ids,\n\u001b[1;32m 222\u001b[0m save_residuals,\n\u001b[1;32m 223\u001b[0m causal,\n\u001b[1;32m 224\u001b[0m sm_scale,\n\u001b[1;32m 225\u001b[0m block_sizes\u001b[38;5;241m.\u001b[39mblock_b,\n\u001b[1;32m 226\u001b[0m block_sizes\u001b[38;5;241m.\u001b[39mblock_q,\n\u001b[1;32m 227\u001b[0m block_sizes\u001b[38;5;241m.\u001b[39mblock_k_major,\n\u001b[1;32m 228\u001b[0m block_sizes\u001b[38;5;241m.\u001b[39mblock_k,\n\u001b[1;32m 229\u001b[0m debug,\n\u001b[1;32m 230\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py:737\u001b[0m, in \u001b[0;36m_flash_attention_impl\u001b[0;34m()\u001b[0m\n\u001b[1;32m 728\u001b[0m in_specs \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 729\u001b[0m pl\u001b[38;5;241m.\u001b[39mBlockSpec((block_b, \u001b[38;5;241m1\u001b[39m, block_q, head_dim), q_index_map),\n\u001b[1;32m 730\u001b[0m pl\u001b[38;5;241m.\u001b[39mBlockSpec((block_b, \u001b[38;5;241m1\u001b[39m, block_k_major, head_dim), kv_index_map),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 734\u001b[0m kv_segment_ids_spec,\n\u001b[1;32m 735\u001b[0m ]\n\u001b[0;32m--> 737\u001b[0m o, \u001b[38;5;241m*\u001b[39maux \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mpallas_call(\n\u001b[1;32m 738\u001b[0m kernel,\n\u001b[1;32m 739\u001b[0m grid_spec\u001b[38;5;241m=\u001b[39mpltpu\u001b[38;5;241m.\u001b[39mPrefetchScalarGridSpec(\n\u001b[1;32m 740\u001b[0m num_scalar_prefetch\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 741\u001b[0m grid\u001b[38;5;241m=\u001b[39mgrid,\n\u001b[1;32m 742\u001b[0m in_specs\u001b[38;5;241m=\u001b[39min_specs,\n\u001b[1;32m 743\u001b[0m out_specs\u001b[38;5;241m=\u001b[39mout_specs,\n\u001b[1;32m 744\u001b[0m scratch_shapes\u001b[38;5;241m=\u001b[39mscratch_shapes,\n\u001b[1;32m 745\u001b[0m ),\n\u001b[1;32m 746\u001b[0m out_shape\u001b[38;5;241m=\u001b[39mout_shape,\n\u001b[1;32m 747\u001b[0m debug\u001b[38;5;241m=\u001b[39mdebug,\n\u001b[1;32m 748\u001b[0m compiler_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m 749\u001b[0m mosaic\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m 750\u001b[0m dimension_semantics\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 751\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparallel\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 752\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparallel\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 753\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparallel\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 754\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124marbitrary\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 755\u001b[0m )\n\u001b[1;32m 756\u001b[0m )\n\u001b[1;32m 757\u001b[0m ),\n\u001b[1;32m 758\u001b[0m )(q, k, v, ab, q_segment_ids, kv_segment_ids)\n\u001b[1;32m 759\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m save_residuals:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py:1106\u001b[0m, in \u001b[0;36mwrapped\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1105\u001b[0m index_args, rest_args \u001b[38;5;241m=\u001b[39m split_list(flat_args, [grid_mapping\u001b[38;5;241m.\u001b[39mnum_index_operands])\n\u001b[0;32m-> 1106\u001b[0m out_flat \u001b[38;5;241m=\u001b[39m pallas_call_p\u001b[38;5;241m.\u001b[39mbind(\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;241m*\u001b[39mdynamic_grid_bounds, \u001b[38;5;241m*\u001b[39mindex_args, \u001b[38;5;241m*\u001b[39mconsts, \u001b[38;5;241m*\u001b[39mrest_args,\n\u001b[1;32m 1108\u001b[0m jaxpr\u001b[38;5;241m=\u001b[39mjaxpr, name\u001b[38;5;241m=\u001b[39mname,\n\u001b[1;32m 1109\u001b[0m debug\u001b[38;5;241m=\u001b[39mdebug,\n\u001b[1;32m 1110\u001b[0m interpret\u001b[38;5;241m=\u001b[39minterpret,\n\u001b[1;32m 1111\u001b[0m grid_mapping\u001b[38;5;241m=\u001b[39mgrid_mapping,\n\u001b[1;32m 1112\u001b[0m input_output_aliases\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mtuple\u001b[39m(input_output_aliases\u001b[38;5;241m.\u001b[39mitems()),\n\u001b[1;32m 1113\u001b[0m compiler_params\u001b[38;5;241m=\u001b[39mcompiler_params)\n\u001b[1;32m 1114\u001b[0m out \u001b[38;5;241m=\u001b[39m tree_util\u001b[38;5;241m.\u001b[39mtree_unflatten(out_tree, out_flat)\n", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m: ValueError: safe_zip() argument 2 is shorter than argument 1\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m init \u001b[38;5;241m=\u001b[39m partial(kernel_init, dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16)\n\u001b[1;32m 4\u001b[0m attention_block \u001b[38;5;241m=\u001b[39m TransformerBlock(heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, dim_head\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m16\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16, use_flash_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, use_projection\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, use_self_and_cross\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, kernel_init\u001b[38;5;241m=\u001b[39mkernel_init(\u001b[38;5;241m1.0\u001b[39m))\n\u001b[0;32m----> 5\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mattention_block\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply\u001b[39m(params, x, context):\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m attention_block\u001b[38;5;241m.\u001b[39mapply(params, x, context)\n", + " \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[4], line 327\u001b[0m, in \u001b[0;36mTransformerBlock.__call__\u001b[0;34m(self, x, context)\u001b[0m\n\u001b[1;32m 323\u001b[0m inner_dim \u001b[38;5;241m=\u001b[39m C\n\u001b[1;32m 325\u001b[0m context \u001b[38;5;241m=\u001b[39m projected_x \u001b[38;5;28;01mif\u001b[39;00m context \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m context\n\u001b[0;32m--> 327\u001b[0m projected_x \u001b[38;5;241m=\u001b[39m \u001b[43mBasicTransformerBlock\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 328\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_dim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minner_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[43mheads\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mheads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim_head\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdim_head\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 331\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mAttention\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 332\u001b[0m \u001b[43m \u001b[49m\u001b[43mprecision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 335\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_flash_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_flash_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cross_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_self_and_cross\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 337\u001b[0m \u001b[43m \u001b[49m\u001b[43monly_pure_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43monly_pure_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 338\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_fp32_for_softmax\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforce_fp32_for_softmax\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 339\u001b[0m \u001b[43m \u001b[49m\u001b[43mkernel_init\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel_init\u001b[49m\n\u001b[1;32m 340\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprojected_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 342\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_projection \u001b[38;5;241m==\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_linear_attention:\n", + " \u001b[0;31m[... skipping hidden 2 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[4], line 281\u001b[0m, in \u001b[0;36mBasicTransformerBlock.__call__\u001b[0;34m(self, hidden_states, context)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;66;03m# self attention\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_cross_only:\n\u001b[0;32m--> 281\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m hidden_states \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention1\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[38;5;66;03m# cross attention\u001b[39;00m\n\u001b[1;32m 284\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m hidden_states \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattention2(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(hidden_states), context)\n", + " \u001b[0;31m[... skipping hidden 2 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[4], line 86\u001b[0m, in \u001b[0;36mEfficientAttention.__call__\u001b[0;34m(self, x, context)\u001b[0m\n\u001b[1;32m 83\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reshape_tensor_to_head_dim(key)\n\u001b[1;32m 84\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reshape_tensor_to_head_dim(value)\n\u001b[0;32m---> 86\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpallas\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mops\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtpu\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflash_attention\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflash_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 87\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\n\u001b[1;32m 88\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reshape_tensor_from_head_dim(hidden_states)\n\u001b[1;32m 93\u001b[0m \u001b[38;5;66;03m# hidden_states = nn.dot_product_attention(\u001b[39;00m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;66;03m# )\u001b[39;00m\n", + " \u001b[0;31m[... skipping hidden 27 frame]\u001b[0m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py:948\u001b[0m, in \u001b[0;36m_pallas_call_lowering\u001b[0;34m(ctx, interpret, *in_nodes, **params)\u001b[0m\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m _unsupported_lowering_error(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 944\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pallas_call_registration\u001b[38;5;241m.\u001b[39mpallas_call_lowering(\n\u001b[1;32m 945\u001b[0m ctx, \u001b[38;5;241m*\u001b[39min_nodes, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams\n\u001b[1;32m 946\u001b[0m )\n\u001b[0;32m--> 948\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmlir\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlower_per_platform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpallas_call\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 949\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcpu\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcpu_lowering\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 950\u001b[0m \u001b[43m \u001b[49m\u001b[43mtpu\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtpu_lowering\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 951\u001b[0m \u001b[43m \u001b[49m\u001b[43mcuda\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgpu_lowering\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 952\u001b[0m \u001b[43m \u001b[49m\u001b[43mrocm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgpu_lowering\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 953\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# default_rule\u001b[39;49;00m\n\u001b[1;32m 954\u001b[0m \u001b[43m \u001b[49m\u001b[43meffects\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mno_effects\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 955\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43min_nodes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 956\u001b[0m \u001b[43m \u001b[49m\u001b[43minterpret\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minterpret\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 957\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py:944\u001b[0m, in \u001b[0;36m_pallas_call_lowering..gpu_lowering\u001b[0;34m(ctx, *in_nodes, **params)\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m:\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m _unsupported_lowering_error(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 944\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpallas_call_registration\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpallas_call_lowering\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 945\u001b[0m \u001b[43m \u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43min_nodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\n\u001b[1;32m 946\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/_src/pallas/triton/pallas_call_registration.py:73\u001b[0m, in \u001b[0;36mpallas_call_lowering\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28mprint\u001b[39m(jaxpr)\n\u001b[1;32m 71\u001b[0m \u001b[38;5;28mprint\u001b[39m(grid_mapping)\n\u001b[0;32m---> 73\u001b[0m lowering_result \u001b[38;5;241m=\u001b[39m \u001b[43mlowering\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlower_jaxpr_to_triton_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlowering_platform\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 76\u001b[0m module_op \u001b[38;5;241m=\u001b[39m lowering_result\u001b[38;5;241m.\u001b[39mmodule\u001b[38;5;241m.\u001b[39moperation\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m debug:\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/_src/pallas/triton/lowering.py:335\u001b[0m, in \u001b[0;36mlower_jaxpr_to_triton_module\u001b[0;34m(jaxpr, grid_mapping, name, platform)\u001b[0m\n\u001b[1;32m 320\u001b[0m start_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(\n\u001b[1;32m 321\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(_eval_index_map, ctx, program_ids),\n\u001b[1;32m 322\u001b[0m grid_mapping\u001b[38;5;241m.\u001b[39mblock_mappings,\n\u001b[1;32m 323\u001b[0m )\n\u001b[1;32m 324\u001b[0m block_infos \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 325\u001b[0m BlockInfo(\n\u001b[1;32m 326\u001b[0m block_mapping\u001b[38;5;241m.\u001b[39marray_shape_dtype,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 333\u001b[0m )\n\u001b[1;32m 334\u001b[0m ]\n\u001b[0;32m--> 335\u001b[0m () \u001b[38;5;241m=\u001b[39m \u001b[43mlower_jaxpr_to_triton_ir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_infos\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mentry\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marguments\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 336\u001b[0m tt_dialect\u001b[38;5;241m.\u001b[39mreturn_([])\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m LoweringResult(module, new_grid)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/jax/_src/pallas/triton/lowering.py:361\u001b[0m, in \u001b[0;36mlower_jaxpr_to_triton_ir\u001b[0;34m(ctx, jaxpr, block_infos, *args)\u001b[0m\n\u001b[1;32m 358\u001b[0m env[var] \u001b[38;5;241m=\u001b[39m val\n\u001b[1;32m 360\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m block_infos \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 361\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m invar, block_info \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mzip\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_infos\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 362\u001b[0m block_info_env[invar] \u001b[38;5;241m=\u001b[39m block_info\n\u001b[1;32m 364\u001b[0m \u001b[38;5;28mmap\u001b[39m(write_env, jaxpr\u001b[38;5;241m.\u001b[39minvars, args)\n", + "\u001b[0;31mValueError\u001b[0m: safe_zip() argument 2 is shorter than argument 1" + ] + } + ], + "source": [ + "x = jnp.ones((16, 16, 16, 64), dtype=jnp.bfloat16)\n", + "context = jnp.ones((16, 16, 16, 64), dtype=jnp.bfloat16)\n", + "init = partial(kernel_init, dtype=jnp.bfloat16)\n", + "attention_block = TransformerBlock(heads=16, dim_head=64//16, dtype=jnp.bfloat16, use_flash_attention=True, use_projection=False, use_self_and_cross=True, kernel_init=kernel_init(1.0))\n", + "params = attention_block.init(jax.random.PRNGKey(0), x, context)\n", + "\n", + "@jax.jit\n", + "def apply(params, x, context):\n", + " return attention_block.apply(params, x, context)\n", + "\n", + "apply(params, x, context).dtype\n", + "\n", + "%timeit apply(params, x, context)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/flaxdiff/models/attention.py b/flaxdiff/models/attention.py index 09584cb..562b829 100644 --- a/flaxdiff/models/attention.py +++ b/flaxdiff/models/attention.py @@ -303,27 +303,30 @@ class TransformerBlock(nn.Module): only_pure_attention:bool = False force_fp32_for_softmax: bool = True kernel_init: Callable = kernel_init(1.0) + norm_inputs: bool = True + explicitly_add_residual: bool = True @nn.compact def __call__(self, x, context=None): inner_dim = self.heads * self.dim_head C = x.shape[-1] - normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x) + if self.norm_inputs: + x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x) if self.use_projection == True: if self.use_linear_attention: projected_x = nn.Dense(features=inner_dim, use_bias=False, precision=self.precision, kernel_init=self.kernel_init, - dtype=self.dtype, name=f'project_in')(normed_x) + dtype=self.dtype, name=f'project_in')(x) else: projected_x = nn.Conv( features=inner_dim, kernel_size=(1, 1), kernel_init=self.kernel_init, strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, precision=self.precision, name=f'project_in_conv', - )(normed_x) + )(x) else: - projected_x = normed_x + projected_x = x inner_dim = C context = projected_x if context is None else context @@ -356,6 +359,9 @@ def __call__(self, x, context=None): strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, precision=self.precision, name=f'project_out_conv', )(projected_x) - - out = x + projected_x + + if self.only_pure_attention or self.explicitly_add_residual: + projected_x = x + projected_x + + out = projected_x return out \ No newline at end of file diff --git a/flaxdiff/models/simple_unet.py b/flaxdiff/models/simple_unet.py index 86ef27d..e9e808e 100644 --- a/flaxdiff/models/simple_unet.py +++ b/flaxdiff/models/simple_unet.py @@ -83,6 +83,8 @@ def __call__(self, x, temb, textcontext): precision=attention_config.get("precision", self.precision), only_pure_attention=attention_config.get("only_pure_attention", True), force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False), + norm_inputs=attention_config.get("norm_inputs", True), + explicitly_add_residual=attention_config.get("explicitly_add_residual", True), kernel_init=self.kernel_init(1.0), name=f"down_{i}_attention_{j}")(x, textcontext) # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in) @@ -125,6 +127,8 @@ def __call__(self, x, temb, textcontext): precision=middle_attention.get("precision", self.precision), only_pure_attention=middle_attention.get("only_pure_attention", True), force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False), + norm_inputs=middle_attention.get("norm_inputs", True), + explicitly_add_residual=middle_attention.get("explicitly_add_residual", True), kernel_init=self.kernel_init(1.0), name=f"middle_attention_{j}")(x, textcontext) x = ResidualBlock( @@ -171,6 +175,8 @@ def __call__(self, x, temb, textcontext): precision=attention_config.get("precision", self.precision), only_pure_attention=attention_config.get("only_pure_attention", True), force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False), + norm_inputs=attention_config.get("norm_inputs", True), + explicitly_add_residual=attention_config.get("explicitly_add_residual", True), kernel_init=self.kernel_init(1.0), name=f"up_{i}_attention_{j}")(x, textcontext) # print("Upscaling ", i, x.shape) diff --git a/flaxdiff/models/simple_vit.py b/flaxdiff/models/simple_vit.py index 001b48f..ad3aebb 100644 --- a/flaxdiff/models/simple_vit.py +++ b/flaxdiff/models/simple_vit.py @@ -69,6 +69,8 @@ class UViT(nn.Module): precision: PrecisionLike = None kernel_init: Callable = partial(kernel_init, scale=1.0) add_residualblock_output: bool = False + norm_inputs: bool = False + explicitly_add_residual: bool = False def setup(self): if self.norm_groups > 0: @@ -110,16 +112,20 @@ def __call__(self, x, temb, textcontext=None): for i in range(self.num_layers // 2): x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, - use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, + use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax, only_pure_attention=False, + norm_inputs=self.norm_inputs, + explicitly_add_residual=self.explicitly_add_residual, kernel_init=self.kernel_init())(x) skips.append(x) # Middle block x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, - use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, + use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax, only_pure_attention=False, + norm_inputs=self.norm_inputs, + explicitly_add_residual=self.explicitly_add_residual, kernel_init=self.kernel_init())(x) # # Out blocks @@ -131,6 +137,8 @@ def __call__(self, x, temb, textcontext=None): dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, only_pure_attention=False, + norm_inputs=self.norm_inputs, + explicitly_add_residual=self.explicitly_add_residual, kernel_init=self.kernel_init())(x) # print(f'Shape of x after transformer blocks: {x.shape}') diff --git a/setup.py b/setup.py index 30c0227..cb46dd0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.35.4', + version='0.1.35.5', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',