diff --git a/evaluate.ipynb b/evaluate.ipynb index d771d70..03aa578 100644 --- a/evaluate.ipynb +++ b/evaluate.ipynb @@ -9,10 +9,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-08-09 14:08:15.831690: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-08-09 14:08:15.856254: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-08-09 14:08:15.863580: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-08-09 14:08:16.965220: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + "2024-08-10 11:00:36.427705: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-08-10 11:00:36.451551: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-08-10 11:00:36.458826: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2024-08-10 11:00:37.418510: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -97,11 +97,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ - "CONFIG_BIG = {\n", + "OLD_CONFIG_BIG = {\n", " # \"dtype\": \"bfloat16\",\n", " # \"precision\": \"high\",\n", " # \"activation\": \"swish\",\n", @@ -139,10 +139,11 @@ " \"use_self_and_cross\": False\n", " }\n", " ],\n", - " \"num_middle_res_blocks\": 1\n", + " \"num_middle_res_blocks\": 1,\n", + " \"named_norms\": True\n", "}\n", "\n", - "CONFIG_MEDIUM = {\n", + "OLD_CONFIG_MEDIUM = {\n", " # \"dtype\": \"bfloat16\",\n", " # \"precision\": \"high\",\n", " # \"activation\": \"swish\",\n", @@ -180,22 +181,123 @@ " \"use_self_and_cross\": False\n", " }\n", " ],\n", + " \"num_middle_res_blocks\": 1,\n", + " \"named_norms\": True\n", + "}\n", + "\n", + "OLD_CONFIG_SMALL = OLD_CONFIG_MEDIUM.copy()\n", + "OLD_CONFIG_SMALL['feature_depths'] = [64, 128, 256, 512]" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG_BIG = {\n", + " # \"dtype\": \"bfloat16\",\n", + " # \"precision\": \"high\",\n", + " # \"activation\": \"swish\",\n", + " \"emb_features\": 512,\n", + " \"feature_depths\": [\n", + " # 64,\n", + " 128,\n", + " 256,\n", + " 512,\n", + " 512\n", + " ],\n", + " \"num_res_blocks\": 4,\n", + " \"output_channels\": 3,\n", + " \"attention_configs\": [\n", + " None,\n", + " {\n", + " # \"dtype\": jnp.bfloat16,\n", + " \"heads\": 16,\n", + " \"use_projection\": False,\n", + " \"flash_attention\": False,\n", + " \"use_self_and_cross\": True,\n", + " \"only_pure_attention\": False\n", + " },\n", + " {\n", + " # \"dtype\": jnp.bfloat16,\n", + " \"heads\": 16,\n", + " \"use_projection\": False,\n", + " \"flash_attention\": False,\n", + " \"use_self_and_cross\": True,\n", + " \"only_pure_attention\": False\n", + " },\n", + " {\n", + " # \"dtype\": jnp.bfloat16,\n", + " \"heads\": 16,\n", + " \"use_projection\": False,\n", + " \"flash_attention\": False,\n", + " \"use_self_and_cross\": True,\n", + " \"only_pure_attention\": False\n", + " }\n", + " ],\n", + " \"num_middle_res_blocks\": 1,\n", + " \"named_norms\": False,\n", + " \"norm_groups\": 0,\n", + "}\n", + "\n", + "CONFIG_MEDIUM = {\n", + " # \"dtype\": \"bfloat16\",\n", + " # \"precision\": \"high\",\n", + " # \"activation\": \"swish\",\n", + " \"emb_features\": 512,\n", + " \"feature_depths\": [\n", + " # 64,\n", + " 128,\n", + " 256,\n", + " 512,\n", + " 512\n", + " ],\n", + " \"num_res_blocks\": 3,\n", + " \"output_channels\": 3,\n", + " \"attention_configs\": [\n", + " None,\n", + " {\n", + " # \"dtype\": jnp.bfloat16,\n", + " \"heads\": 8,\n", + " \"use_projection\": False,\n", + " \"flash_attention\": False,\n", + " \"use_self_and_cross\": False\n", + " },\n", + " {\n", + " # \"dtype\": jnp.bfloat16,\n", + " \"heads\": 8,\n", + " \"use_projection\": False,\n", + " \"flash_attention\": False,\n", + " \"use_self_and_cross\": False\n", + " },\n", + " {\n", + " # \"dtype\": jnp.bfloat16,\n", + " \"heads\": 8,\n", + " \"use_projection\": False,\n", + " \"flash_attention\": False,\n", + " \"use_self_and_cross\": False\n", + " }\n", + " ],\n", + " \"named_norms\": False,\n", + " \"norm_groups\": 0,\n", " \"num_middle_res_blocks\": 1\n", "}\n", "\n", "CONFIG_SMALL = CONFIG_MEDIUM.copy()\n", - "CONFIG_SMALL['feature_depths'] = [64, 128, 256, 512]" + "CONFIG_SMALL['feature_depths'] = [64, 128, 256, 512]\n", + "CONFIG_SMALL['emb_features'] = 256" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c0483b95e02241e2a11930ce269879cc", + "model_id": "70070037f0d84b57b99cef0d0709cb72", "version_major": 2, "version_minor": 0 }, @@ -210,8 +312,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/mrwhite0racle/.local/lib/python3.10/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n", - " warnings.warn(\n", "Some of the weights of FlaxStableDiffusionSafetyChecker were initialized in bfloat16 precision from the model checkpoint at /home/mrwhite0racle/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/safety_checker:\n", "[('concept_embeds',), ('concept_embeds_weights',), ('special_care_embeds',), ('special_care_embeds_weights',), ('vision_model', 'vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'post_layernorm', 'bias'), ('vision_model', 'vision_model', 'post_layernorm', 'scale'), ('vision_model', 'vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'vision_model', 'pre_layrnorm', 'scale'), ('visual_projection', 'kernel')]\n", "You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n", @@ -219,6 +319,14 @@ "[('text_model', 'embeddings', 'position_embedding', 'embedding'), ('text_model', 'embeddings', 'token_embedding', 'embedding'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'final_layer_norm', 'bias'), ('text_model', 'final_layer_norm', 'scale')]\n", "You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mrwhite0racle/.local/lib/python3.10/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -228,13 +336,14 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ + "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n", "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n" ] }, @@ -242,29 +351,32 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loading model from checkpoint 234181\n", - "Loaded model from checkpoint at epoch 234181 0.046310272\n", + "Loading model from checkpoint at step 264702\n", + "Loaded model from checkpoint at epoch 0 step 264702 1000000000.0\n", "Generating states for DiffusionTrainer\n" ] } ], "source": [ - "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-128/batch-256-v4-32_flaxdiff-0-1-8__3\", CONFIG_MEDIUM)\n", - "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-512/batch-128-v4-16_flaxdiff-0-1-8__ldm_1\", CONFIG_SMALL)\n", - "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_1\", CONFIG_BIG)\n", + "checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-128/batch-256-v4-32_flaxdiff-0-1-8__3\", OLD_CONFIG_MEDIUM)# --> OG\n", + "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-512/batch-128-v4-16_flaxdiff-0-1-8__ldm_1\", CONFIG_SMALL) --> OG\n", + "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_1\", CONFIG_BIG) --> OG\n", "\n", - "checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-128/batch-256-v4-32_flaxdiff-0-1-8__new-combined_1\", CONFIG_MEDIUM)\n", - "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale__new-combined_1\", CONFIG_BIG)\n", + "# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-128/batch-256-v4-32_lighter-arch_combined_30m_1\", CONFIG_MEDIUM) # --> Good\n", + "# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_new_arch_combined_30\", CONFIG_BIG) # --> Good\n", + "checkpoint_id, model_config = (\"dataset-combined_30m/image_size-128/batch-128-v4-16_flaxdiff-0-1-9_light_combined_30m_1\", CONFIG_SMALL) # --> Good\n", "\n", - "# checkpoint_id, model_config = (\"dataset-combined_aesthetic/image_size-512/batch-128-v4-16_flaxdiff-0-1-8_new-combined_ldm_1\", CONFIG_SMALL)\n", + "checkpoint_base_path = \"gs://flaxdiff-datasets-regional/checkpoints/\"\n", "\n", "model_config = model_config.copy()\n", - "IMAGE_SIZE=128\n", + "IMAGE_SIZE=512\n", "DIFFUSION_INPUT_SIZE = IMAGE_SIZE\n", "\n", "autoencoder = None\n", "\n", "if \"ldm\" in checkpoint_id:\n", + " IMAGE_SIZE=512\n", + " DIFFUSION_INPUT_SIZE = IMAGE_SIZE\n", " DIFFUSION_INPUT_SIZE = IMAGE_SIZE // 8\n", " model_config['output_channels'] = 4\n", " autoencoder = sd_vae\n", @@ -272,7 +384,6 @@ "unet = Unet(activation=jax.nn.swish, dtype=jnp.bfloat16, precision=jax.lax.Precision.HIGH, **model_config)\n", "solver = optax.adam(2e-4)\n", "\n", - "\n", "edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)\n", "karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)\n", "\n", @@ -288,39 +399,61 @@ " },\n", " # train_state=trainer.best_state,\n", " # loss_fn=lambda x, y: jnp.abs(x - y),\n", - " # param_transforms=params_transform,\n", - " load_from_checkpoint=True,\n", - " checkpoint_id=checkpoint_id,#,\n", - " checkpoint_base_path=\"gs://flaxdiff-datasets-regional/checkpoints/\",\n", - " # checkpoint_epoch=1,\n", + " load_from_checkpoint=os.path.join(checkpoint_base_path, checkpoint_id),\n", + " checkpoint_base_path=checkpoint_base_path,\n", + " # checkpoint_step=221728,\n", " autoencoder=autoencoder\n", " )" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'model_config' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel_config\u001b[49m\n", - "\u001b[0;31mNameError\u001b[0m: name 'model_config' is not defined" - ] + "data": { + "text/plain": [ + "dict_keys(['conv1', 'conv2', 'norm1', 'norm2', 'temb_projection'])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "model_config" + "trainer.get_state().params['params'][\"down_0_residual_0\"].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.summary()" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -686,9 +819,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing FlaxCLIPTextModel: {('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'post_layernorm', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'pre_layrnorm', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'post_layernorm', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('logit_scale',), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('visual_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('text_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale')}\n", + "- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], "source": [ "def defaultTextEncodeModel():\n", " modelname = \"openai/clip-vit-large-patch14\"\n", @@ -718,42 +861,43 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 50, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using classifier-free guidance\n" + ] + } + ], "source": [ "null_labels, null_labels_full = encodePrompts([\"\"], textEncoderModel, textTokenizer)\n", "\n", "\n", - "sampler = EulerAncestralSampler(trainer.model, trainer.get_best_state().ema_params, karas_ve_schedule, autoencoder=trainer.autoencoder, \n", + "sampler = EulerAncestralSampler(trainer.model, trainer.get_state().ema_params, karas_ve_schedule, autoencoder=trainer.autoencoder, \n", " image_size=IMAGE_SIZE,\n", " model_output_transform=trainer.model_output_transform, guidance_scale=4, null_labels_seq=null_labels_full)" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/200 [00:00" + "
" ] }, "metadata": {}, @@ -763,22 +907,17 @@ "source": [ "\n", "prompts = [\n", - " # \"a beautiful landscape\",\n", - " # \"a beautiful landscape\",\n", - " # \"a beautiful landscape\",\n", + " \"a beautiful landscape\",\n", + " \"a beautiful landscape\",\n", + " \"a beautiful landscape\",\n", " # \"a beautiful river with a mountain\",\n", " # \"a beautiful river with a mountain\",\n", " # \"a beautiful river with a mountain\",\n", " # \"a beautiful river with a mountain\",\n", " # \"a beautiful river with a mountain\",\n", - " \"A fairy\",\n", - " \"The moon in the night sky\",\n", - " \"A fairy\",\n", - " \"The moon in the night sky\",\n", - " \"A fairy\",\n", - " \"The moon in the night sky\",\n", - " \"A fairy\",\n", - " \"The moon in the night sky\",\n", + " # \"A fairy\",\n", + " # \"The moon in the night sky\",\n", + " # \"A microscope\"\n", " ]\n", "pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer)\n", "samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,))\n", @@ -1818,7 +1957,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -1839,7 +1978,8 @@ " dtype: Optional[Dtype] = None\n", " precision: PrecisionLike = None\n", " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " kernel_init: Callable = kernel_init(1.0)\n", + " force_fp32_for_softmax: bool = True\n", "\n", " def setup(self):\n", " inner_dim = self.dim_head * self.heads\n", @@ -1849,7 +1989,7 @@ " self.heads * self.dim_head,\n", " precision=self.precision, \n", " use_bias=self.use_bias, \n", - " kernel_init=self.kernel_init(), \n", + " kernel_init=self.kernel_init, \n", " dtype=self.dtype\n", " )\n", " self.query = dense(name=\"to_q\")\n", @@ -1857,7 +1997,7 @@ " self.value = dense(name=\"to_v\")\n", " \n", " self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, \n", - " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_out_0\")\n", + " kernel_init=self.kernel_init, dtype=self.dtype, name=\"to_out_0\")\n", " # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)\n", " \n", " def _reshape_tensor_to_head_dim(self, tensor):\n", @@ -1930,7 +2070,8 @@ " dtype: Optional[Dtype] = None\n", " precision: PrecisionLike = None\n", " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " kernel_init: Callable = kernel_init(1.0)\n", + " force_fp32_for_softmax: bool = True\n", "\n", " def setup(self):\n", " inner_dim = self.dim_head * self.heads\n", @@ -1940,7 +2081,7 @@ " axis=-1, \n", " precision=self.precision, \n", " use_bias=self.use_bias, \n", - " kernel_init=self.kernel_init(), \n", + " kernel_init=self.kernel_init, \n", " dtype=self.dtype\n", " )\n", " self.query = dense(name=\"to_q\")\n", @@ -1954,7 +2095,7 @@ " use_bias=self.use_bias, \n", " dtype=self.dtype, \n", " name=\"to_out_0\",\n", - " kernel_init=self.kernel_init()\n", + " kernel_init=self.kernel_init\n", " # kernel_init=jax.nn.initializers.xavier_uniform()\n", " )\n", "\n", @@ -1974,10 +2115,9 @@ " \n", " hidden_states = nn.dot_product_attention(\n", " query, key, value, dtype=self.dtype, broadcast_dropout=False, \n", - " dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,\n", + " dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,\n", " deterministic=True\n", " )\n", - " \n", " proj = self.proj_attn(hidden_states)\n", " proj = proj.reshape(orig_x_shape)\n", " return proj\n", @@ -2051,10 +2191,11 @@ " dtype: Optional[Dtype] = None\n", " precision: PrecisionLike = None\n", " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " kernel_init: Callable = kernel_init(1.0)\n", " use_flash_attention:bool = False\n", " use_cross_only:bool = False\n", " only_pure_attention:bool = False\n", + " force_fp32_for_softmax: bool = True\n", " \n", " def setup(self):\n", " if self.use_flash_attention:\n", @@ -2070,7 +2211,8 @@ " precision=self.precision,\n", " use_bias=self.use_bias,\n", " dtype=self.dtype,\n", - " kernel_init=self.kernel_init\n", + " kernel_init=self.kernel_init,\n", + " force_fp32_for_softmax=self.force_fp32_for_softmax\n", " )\n", " self.attention2 = attenBlock(\n", " query_dim=self.query_dim,\n", @@ -2080,7 +2222,8 @@ " precision=self.precision,\n", " use_bias=self.use_bias,\n", " dtype=self.dtype,\n", - " kernel_init=self.kernel_init\n", + " kernel_init=self.kernel_init,\n", + " force_fp32_for_softmax=self.force_fp32_for_softmax\n", " )\n", " \n", " self.ff = FlaxFeedForward(dim=self.query_dim)\n", @@ -2114,24 +2257,24 @@ " use_flash_attention:bool = False\n", " use_self_and_cross:bool = True\n", " only_pure_attention:bool = False\n", + " force_fp32_for_softmax: bool = True\n", + " kernel_init: Callable = kernel_init(1.0)\n", "\n", " @nn.compact\n", - " def __call__(self, x:jnp.array, context:jnp.array=None):\n", + " def __call__(self, x, context=None):\n", " inner_dim = self.heads * self.dim_head\n", - " \n", " C = x.shape[-1]\n", - " \n", " normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)\n", " if self.use_projection == True:\n", " if self.use_linear_attention:\n", " projected_x = nn.Dense(features=inner_dim, \n", " use_bias=False, precision=self.precision, \n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " dtype=self.dtype, name=f'project_in')(normed_x)\n", " else:\n", " projected_x = nn.Conv(\n", " features=inner_dim, kernel_size=(1, 1),\n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,\n", " precision=self.precision, name=f'project_in_conv',\n", " )(normed_x)\n", @@ -2151,52 +2294,54 @@ " dtype=self.dtype,\n", " use_flash_attention=self.use_flash_attention,\n", " use_cross_only=(not self.use_self_and_cross),\n", - " only_pure_attention=self.only_pure_attention\n", + " only_pure_attention=self.only_pure_attention,\n", + " force_fp32_for_softmax=self.force_fp32_for_softmax,\n", + " kernel_init=self.kernel_init\n", " )(projected_x, context)\n", " \n", " if self.use_projection == True:\n", " if self.use_linear_attention:\n", " projected_x = nn.Dense(features=C, precision=self.precision, \n", " dtype=self.dtype, use_bias=False, \n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " name=f'project_out')(projected_x)\n", " else:\n", " projected_x = nn.Conv(\n", " features=C, kernel_size=(1, 1),\n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,\n", " precision=self.precision, name=f'project_out_conv',\n", " )(projected_x)\n", - " \n", + "\n", " out = x + projected_x\n", - " # out = out.reshape(orig_x_shape)\n", " return out" ] }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "176 μs ± 431 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "172 μs ± 2.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ - "x = jnp.ones((16, 16, 16, 64))\n", - "context = jnp.ones((16, 16, 16, 64))\n", - "attention_block = TransformerBlock(heads=16, dim_head=64//16, dtype=jnp.bfloat16, use_flash_attention=False, use_projection=False, use_self_and_cross=True)\n", + "x = jnp.ones((16, 16, 16, 64), dtype=jnp.bfloat16)\n", + "context = jnp.ones((16, 16, 16, 64), dtype=jnp.bfloat16)\n", + "init = partial(kernel_init, dtype=jnp.bfloat16)\n", + "attention_block = TransformerBlock(heads=16, dim_head=64//16, dtype=jnp.bfloat16, use_flash_attention=False, use_projection=False, use_self_and_cross=True, kernel_init=kernel_init(1.0))\n", "params = attention_block.init(jax.random.PRNGKey(0), x, context)\n", "\n", "@jax.jit\n", "def apply(params, x, context):\n", " return attention_block.apply(params, x, context)\n", "\n", - "apply(params, x, context)\n", + "apply(params, x, context).dtype\n", "\n", "%timeit apply(params, x, context)" ] diff --git a/flaxdiff/models/attention.py b/flaxdiff/models/attention.py index 6b29f45..09584cb 100644 --- a/flaxdiff/models/attention.py +++ b/flaxdiff/models/attention.py @@ -22,7 +22,8 @@ class EfficientAttention(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None use_bias: bool = True - kernel_init: Callable = lambda : kernel_init(1.0) + kernel_init: Callable = kernel_init(1.0) + force_fp32_for_softmax: bool = True def setup(self): inner_dim = self.dim_head * self.heads @@ -32,7 +33,7 @@ def setup(self): self.heads * self.dim_head, precision=self.precision, use_bias=self.use_bias, - kernel_init=self.kernel_init(), + kernel_init=self.kernel_init, dtype=self.dtype ) self.query = dense(name="to_q") @@ -40,7 +41,7 @@ def setup(self): self.value = dense(name="to_v") self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, - kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0") + kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0") # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16) def _reshape_tensor_to_head_dim(self, tensor): @@ -113,7 +114,8 @@ class NormalAttention(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None use_bias: bool = True - kernel_init: Callable = lambda : kernel_init(1.0) + kernel_init: Callable = kernel_init(1.0) + force_fp32_for_softmax: bool = True def setup(self): inner_dim = self.dim_head * self.heads @@ -123,7 +125,7 @@ def setup(self): axis=-1, precision=self.precision, use_bias=self.use_bias, - kernel_init=self.kernel_init(), + kernel_init=self.kernel_init, dtype=self.dtype ) self.query = dense(name="to_q") @@ -137,7 +139,7 @@ def setup(self): use_bias=self.use_bias, dtype=self.dtype, name="to_out_0", - kernel_init=self.kernel_init() + kernel_init=self.kernel_init # kernel_init=jax.nn.initializers.xavier_uniform() ) @@ -157,7 +159,7 @@ def __call__(self, x, context=None): hidden_states = nn.dot_product_attention( query, key, value, dtype=self.dtype, broadcast_dropout=False, - dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True, + dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax, deterministic=True ) proj = self.proj_attn(hidden_states) @@ -233,10 +235,11 @@ class BasicTransformerBlock(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None use_bias: bool = True - kernel_init: Callable = lambda : kernel_init(1.0) + kernel_init: Callable = kernel_init(1.0) use_flash_attention:bool = False use_cross_only:bool = False only_pure_attention:bool = False + force_fp32_for_softmax: bool = True def setup(self): if self.use_flash_attention: @@ -252,7 +255,8 @@ def setup(self): precision=self.precision, use_bias=self.use_bias, dtype=self.dtype, - kernel_init=self.kernel_init + kernel_init=self.kernel_init, + force_fp32_for_softmax=self.force_fp32_for_softmax ) self.attention2 = attenBlock( query_dim=self.query_dim, @@ -262,7 +266,8 @@ def setup(self): precision=self.precision, use_bias=self.use_bias, dtype=self.dtype, - kernel_init=self.kernel_init + kernel_init=self.kernel_init, + force_fp32_for_softmax=self.force_fp32_for_softmax ) self.ff = FlaxFeedForward(dim=self.query_dim) @@ -296,6 +301,8 @@ class TransformerBlock(nn.Module): use_flash_attention:bool = False use_self_and_cross:bool = True only_pure_attention:bool = False + force_fp32_for_softmax: bool = True + kernel_init: Callable = kernel_init(1.0) @nn.compact def __call__(self, x, context=None): @@ -306,12 +313,12 @@ def __call__(self, x, context=None): if self.use_linear_attention: projected_x = nn.Dense(features=inner_dim, use_bias=False, precision=self.precision, - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init, dtype=self.dtype, name=f'project_in')(normed_x) else: projected_x = nn.Conv( features=inner_dim, kernel_size=(1, 1), - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init, strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, precision=self.precision, name=f'project_in_conv', )(normed_x) @@ -331,19 +338,21 @@ def __call__(self, x, context=None): dtype=self.dtype, use_flash_attention=self.use_flash_attention, use_cross_only=(not self.use_self_and_cross), - only_pure_attention=self.only_pure_attention + only_pure_attention=self.only_pure_attention, + force_fp32_for_softmax=self.force_fp32_for_softmax, + kernel_init=self.kernel_init )(projected_x, context) if self.use_projection == True: if self.use_linear_attention: projected_x = nn.Dense(features=C, precision=self.precision, dtype=self.dtype, use_bias=False, - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init, name=f'project_out')(projected_x) else: projected_x = nn.Conv( features=C, kernel_size=(1, 1), - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init, strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, precision=self.precision, name=f'project_out_conv', )(projected_x) diff --git a/flaxdiff/models/common.py b/flaxdiff/models/common.py index 98489c2..e619668 100644 --- a/flaxdiff/models/common.py +++ b/flaxdiff/models/common.py @@ -267,15 +267,17 @@ class ResidualBlock(nn.Module): kernel_init:Callable=kernel_init(1.0) dtype: Optional[Dtype] = None precision: PrecisionLike = None + named_norms:bool=False def setup(self): if self.norm_groups > 0: norm = partial(nn.GroupNorm, self.norm_groups) + self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm() + self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm() else: norm = partial(nn.RMSNorm, 1e-5) - - self.norm1 = norm() - self.norm2 = norm() + self.norm1 = norm() + self.norm2 = norm() @nn.compact def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None): diff --git a/flaxdiff/models/simple_unet.py b/flaxdiff/models/simple_unet.py index 160c03d..86ef27d 100644 --- a/flaxdiff/models/simple_unet.py +++ b/flaxdiff/models/simple_unet.py @@ -19,15 +19,16 @@ class Unet(nn.Module): norm_groups:int=8 dtype: Optional[Dtype] = None precision: PrecisionLike = None + named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms + kernel_init: Callable = partial(kernel_init, dtype=jnp.float32) def setup(self): if self.norm_groups > 0: norm = partial(nn.GroupNorm, self.norm_groups) + self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm() else: norm = partial(nn.RMSNorm, 1e-5) - - # self.last_up_norm = norm() - self.conv_out_norm = norm() + self.conv_out_norm = norm() @nn.compact def __call__(self, x, temb, textcontext): @@ -49,7 +50,7 @@ def __call__(self, x, temb, textcontext): features=self.feature_depths[0], kernel_size=(3, 3), strides=(1, 1), - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), dtype=self.dtype, precision=self.precision )(x) @@ -64,13 +65,14 @@ def __call__(self, x, temb, textcontext): down_conv_type, name=f"down_{i}_residual_{j}", features=dim_in, - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), kernel_size=(3, 3), strides=(1, 1), activation=self.activation, norm_groups=self.norm_groups, dtype=self.dtype, - precision=self.precision + precision=self.precision, + named_norms=self.named_norms )(x, temb) if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), @@ -80,6 +82,8 @@ def __call__(self, x, temb, textcontext): use_self_and_cross=attention_config.get("use_self_and_cross", True), precision=attention_config.get("precision", self.precision), only_pure_attention=attention_config.get("only_pure_attention", True), + force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False), + kernel_init=self.kernel_init(1.0), name=f"down_{i}_attention_{j}")(x, textcontext) # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in) downs.append(x) @@ -102,13 +106,14 @@ def __call__(self, x, temb, textcontext): middle_conv_type, name=f"middle_res1_{j}", features=middle_dim_out, - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), kernel_size=(3, 3), strides=(1, 1), activation=self.activation, norm_groups=self.norm_groups, dtype=self.dtype, - precision=self.precision + precision=self.precision, + named_norms=self.named_norms )(x, temb) if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32), @@ -119,18 +124,21 @@ def __call__(self, x, temb, textcontext): use_self_and_cross=False, precision=middle_attention.get("precision", self.precision), only_pure_attention=middle_attention.get("only_pure_attention", True), + force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False), + kernel_init=self.kernel_init(1.0), name=f"middle_attention_{j}")(x, textcontext) x = ResidualBlock( middle_conv_type, name=f"middle_res2_{j}", features=middle_dim_out, - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), kernel_size=(3, 3), strides=(1, 1), activation=self.activation, norm_groups=self.norm_groups, dtype=self.dtype, - precision=self.precision + precision=self.precision, + named_norms=self.named_norms )(x, temb) # Upscaling Blocks @@ -145,13 +153,14 @@ def __call__(self, x, temb, textcontext): up_conv_type,# if j == 0 else "separable", name=f"up_{i}_residual_{j}", features=dim_out, - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), kernel_size=kernel_size, strides=(1, 1), activation=self.activation, norm_groups=self.norm_groups, dtype=self.dtype, - precision=self.precision + precision=self.precision, + named_norms=self.named_norms )(x, temb) if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), @@ -161,6 +170,8 @@ def __call__(self, x, temb, textcontext): use_self_and_cross=attention_config.get("use_self_and_cross", True), precision=attention_config.get("precision", self.precision), only_pure_attention=attention_config.get("only_pure_attention", True), + force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False), + kernel_init=self.kernel_init(1.0), name=f"up_{i}_attention_{j}")(x, textcontext) # print("Upscaling ", i, x.shape) if i != len(feature_depths) - 1: @@ -179,7 +190,7 @@ def __call__(self, x, temb, textcontext): features=self.feature_depths[0], kernel_size=(3, 3), strides=(1, 1), - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), dtype=self.dtype, precision=self.precision )(x) @@ -190,13 +201,14 @@ def __call__(self, x, temb, textcontext): conv_type, name="final_residual", features=self.feature_depths[0], - kernel_init=kernel_init(1.0), + kernel_init=self.kernel_init(1.0), kernel_size=(3,3), strides=(1, 1), activation=self.activation, norm_groups=self.norm_groups, dtype=self.dtype, - precision=self.precision + precision=self.precision, + named_norms=self.named_norms )(x, temb) x = self.conv_out_norm(x) @@ -208,7 +220,7 @@ def __call__(self, x, temb, textcontext): kernel_size=(3, 3), strides=(1, 1), # activation=jax.nn.mish - kernel_init=kernel_init(0.0), + kernel_init=self.kernel_init(0.0), dtype=self.dtype, precision=self.precision )(x) diff --git a/flaxdiff/trainer/diffusion_trainer.py b/flaxdiff/trainer/diffusion_trainer.py index 19bd2fa..f8b824c 100644 --- a/flaxdiff/trainer/diffusion_trainer.py +++ b/flaxdiff/trainer/diffusion_trainer.py @@ -16,6 +16,7 @@ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics from flaxdiff.models.autoencoder.autoencoder import AutoEncoder +from flax.training.dynamic_scale import DynamicScale class TrainState(SimpleTrainState): rngs: jax.random.PRNGKey @@ -83,7 +84,8 @@ def generate_states( new_state = existing_state if param_transforms is not None: - params = param_transforms(params) + new_state['params'] = param_transforms(new_state['params']) + new_state['ema_params'] = param_transforms(new_state['ema_params']) state = TrainState.create( apply_fn=model.apply, @@ -92,7 +94,7 @@ def generate_states( tx=optimizer, rngs=rngs, metrics=Metrics.empty(), - dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None + dynamic_scale = DynamicScale() if use_dynamic_scale else None ) if existing_best_state is not None: diff --git a/flaxdiff/trainer/simple_trainer.py b/flaxdiff/trainer/simple_trainer.py index a136249..594e06c 100644 --- a/flaxdiff/trainer/simple_trainer.py +++ b/flaxdiff/trainer/simple_trainer.py @@ -22,7 +22,7 @@ from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array from termcolor import colored from typing import Dict, Callable, Sequence, Any, Union, Tuple - +from flax.training.dynamic_scale import DynamicScale from flaxdiff.utils import RandomMarkovState PROCESS_COLOR_MAP = { @@ -68,7 +68,7 @@ class Metrics(metrics.Collection): # Define the TrainState class SimpleTrainState(train_state.TrainState): metrics: Metrics - dynamic_scale: flax.training.dynamic_scale.DynamicScale + dynamic_scale: DynamicScale class SimpleTrainer: state: SimpleTrainState @@ -177,13 +177,16 @@ def generate_states( params = model.init(subkey, **input_vars) else: params = existing_state['params'] + + if param_transforms is not None: + params = param_transforms(params) state = SimpleTrainState.create( apply_fn=model.apply, params=params, tx=optimizer, metrics=Metrics.empty(), - dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None + dynamic_scale = DynamicScale() if use_dynamic_scale else None ) if existing_best_state is not None: best_state = state.replace( diff --git a/setup.py b/setup.py index 0bb2cfe..c6941d0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.10', + version='0.1.12', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',