From 5ccf90e7faf4d33912d1fed988165b2154a33ca8 Mon Sep 17 00:00:00 2001 From: b-fission Date: Mon, 5 Aug 2024 19:27:24 -0500 Subject: [PATCH 1/3] Auto-detect model type for safetensors files Automatically tick the checkboxes for v2 and SDXL on the common training UI and LoRA extract/merge utilities. --- kohya_gui/common_gui.py | 8 ++++++ kohya_gui/extract_lora_gui.py | 8 ++++++ kohya_gui/merge_lora_gui.py | 8 ++++++ kohya_gui/sd_modeltype.py | 47 +++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+) create mode 100755 kohya_gui/sd_modeltype.py diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 0ca334eb6..7763f65f8 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -5,6 +5,7 @@ from easygui import msgbox, ynbox from typing import Optional from .custom_logging import setup_logging +from .sd_modeltype import SDModelType import os import re @@ -1009,6 +1010,13 @@ def set_pretrained_model_name_or_path_input( v_parameterization = gr.Checkbox(visible=True) sdxl = gr.Checkbox(visible=True) + # Auto-detect model type if safetensors file path is given + if pretrained_model_name_or_path.lower().endswith(".safetensors"): + detect = SDModelType(pretrained_model_name_or_path) + v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True) + sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True) + #TODO: v_parameterization + # If a refresh method is provided, use it to update the choices for the Dropdown widget if refresh_method is not None: args = dict( diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index 62b12fd9f..54fd33389 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -12,6 +12,7 @@ ) from .custom_logging import setup_logging +from .sd_modeltype import SDModelType # Set up logging log = setup_logging() @@ -337,6 +338,13 @@ def change_sdxl(sdxl): outputs=[load_tuned_model_to, load_original_model_to], ) + #secondary event on model_tuned for auto-detection of SDXL + model_tuned.change( + lambda sdxl, path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + inputs=[sdxl, model_tuned], + outputs=sdxl + ) + extract_button = gr.Button("Extract LoRA model") extract_button.click( diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index a3337c4cf..92659362c 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -16,6 +16,7 @@ create_refresh_button, setup_environment ) from .custom_logging import setup_logging +from .sd_modeltype import SDModelType # Set up logging log = setup_logging() @@ -145,6 +146,13 @@ def list_save_to(path): show_progress=False, ) + #secondary event on sd_model for auto-detection of SDXL + sd_model.change( + lambda sdxl, path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + inputs=[sdxl_model, sd_model], + outputs=sdxl_model + ) + with gr.Group(), gr.Row(): lora_a_model = gr.Dropdown( label='LoRA model "A" (path to the LoRA A model)', diff --git a/kohya_gui/sd_modeltype.py b/kohya_gui/sd_modeltype.py new file mode 100755 index 000000000..11891bf8e --- /dev/null +++ b/kohya_gui/sd_modeltype.py @@ -0,0 +1,47 @@ +from os.path import isfile +from safetensors import safe_open +import enum + +# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403 + +class ModelType(enum.Enum): + UNKNOWN = 0 + SD1 = 1 + SD2 = 2 + SDXL = 3 + SD3 = 4 + +class SDModelType: + def __init__(self, safetensors_path): + self.model_type = ModelType.UNKNOWN + + if not isfile(safetensors_path): + return + + try: + st = safe_open(filename=safetensors_path, framework="numpy", device="cpu") + def hasKeyPrefix(pfx): + return any(k.startswith(pfx) for k in st.keys()) + + if "model.diffusion_model.x_embedder.proj.weight" in st.keys(): + self.model_type = ModelType.SD3 + elif hasKeyPrefix("conditioner."): + self.model_type = ModelType.SDXL + elif hasKeyPrefix("cond_stage_model.model."): + self.model_type = ModelType.SD2 + elif hasKeyPrefix("model."): + self.model_type = ModelType.SD1 + except: + pass + + def Is_SD1(self): + return self.model_type == ModelType.SD1 + + def Is_SD2(self): + return self.model_type == ModelType.SD2 + + def Is_SDXL(self): + return self.model_type == ModelType.SDXL + + def Is_SD3(self): + return self.model_type == ModelType.SD3 From ed2255685f56674130ba641858c18fbfefd5360f Mon Sep 17 00:00:00 2001 From: b-fission Date: Mon, 5 Aug 2024 19:53:59 -0500 Subject: [PATCH 2/3] autodetect-modeltype: remove unused lambda inputs --- kohya_gui/extract_lora_gui.py | 4 ++-- kohya_gui/merge_lora_gui.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index 54fd33389..ec3c689ba 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -340,8 +340,8 @@ def change_sdxl(sdxl): #secondary event on model_tuned for auto-detection of SDXL model_tuned.change( - lambda sdxl, path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), - inputs=[sdxl, model_tuned], + lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + inputs=model_tuned, outputs=sdxl ) diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index 92659362c..72e632124 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -148,8 +148,8 @@ def list_save_to(path): #secondary event on sd_model for auto-detection of SDXL sd_model.change( - lambda sdxl, path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), - inputs=[sdxl_model, sd_model], + lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + inputs=sd_model, outputs=sdxl_model ) From c0966bcc3b954087b5dc8e314c559b0824a99dd4 Mon Sep 17 00:00:00 2001 From: b-fission Date: Tue, 6 Aug 2024 12:43:22 -0500 Subject: [PATCH 3/3] autodetect-modeltype: also do the v2 checkbox in extract_lora --- kohya_gui/extract_lora_gui.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index ec3c689ba..f1650e7f6 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -338,11 +338,17 @@ def change_sdxl(sdxl): outputs=[load_tuned_model_to, load_original_model_to], ) - #secondary event on model_tuned for auto-detection of SDXL + #secondary event on model_tuned for auto-detection of v2/SDXL + def change_modeltype_model_tuned(path): + detect = SDModelType(path) + v2 = gr.Checkbox(value=detect.Is_SD2()) + sdxl = gr.Checkbox(value=detect.Is_SDXL()) + return v2, sdxl + model_tuned.change( - lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + change_modeltype_model_tuned, inputs=model_tuned, - outputs=sdxl + outputs=[v2, sdxl] ) extract_button = gr.Button("Extract LoRA model")