diff --git a/kohya_gui/class_flux1.py b/kohya_gui/class_flux1.py new file mode 100644 index 00000000..5d2e50a7 --- /dev/null +++ b/kohya_gui/class_flux1.py @@ -0,0 +1,221 @@ +import gradio as gr +from typing import Tuple +from .common_gui import ( + get_folder_path, + get_any_file_path, + list_files, + list_dirs, + create_refresh_button, + document_symbol, +) + + +class flux1Training: + """ + This class configures and initializes the advanced training settings for a machine learning model, + including options for headless operation, fine-tuning, training type selection, and default directory paths. + + Attributes: + headless (bool): If True, run without the Gradio interface. + finetuning (bool): If True, enables fine-tuning of the model. + training_type (str): Specifies the type of training to perform. + no_token_padding (gr.Checkbox): Checkbox to disable token padding. + gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps. + weighted_captions (gr.Checkbox): Checkbox to enable weighted captions. + """ + + def __init__( + self, + headless: bool = False, + finetuning: bool = False, + training_type: str = "", + config: dict = {}, + flux1_checkbox: gr.Checkbox = False, + ) -> None: + """ + Initializes the AdvancedTraining class with given settings. + + Parameters: + headless (bool): Run in headless mode without GUI. + finetuning (bool): Enable model fine-tuning. + training_type (str): The type of training to be performed. + config (dict): Configuration options for the training process. + """ + self.headless = headless + self.finetuning = finetuning + self.training_type = training_type + self.config = config + self.flux1_checkbox = flux1_checkbox + + # Define the behavior for changing noise offset type. + def noise_offset_type_change( + noise_offset_type: str, + ) -> Tuple[gr.Group, gr.Group]: + """ + Returns a tuple of Gradio Groups with visibility set based on the noise offset type. + + Parameters: + noise_offset_type (str): The selected noise offset type. + + Returns: + Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set. + """ + if noise_offset_type == "Original": + return (gr.Group(visible=True), gr.Group(visible=False)) + else: + return (gr.Group(visible=False), gr.Group(visible=True)) + + with gr.Accordion( + "Flux.1", open=True, elem_id="flux1_tab", visible=False + ) as flux1_accordion: + with gr.Group(): + # gr.Markdown("### Flux.1 Specific Parameters") + # with gr.Row(): + # self.weighting_scheme = gr.Dropdown( + # label="Weighting Scheme", + # choices=["logit_normal", "sigma_sqrt", "mode", "cosmap"], + # value=self.config.get("flux1.weighting_scheme", "logit_normal"), + # interactive=True, + # ) + # self.logit_mean = gr.Number( + # label="Logit Mean", + # value=self.config.get("flux1.logit_mean", 0.0), + # interactive=True, + # ) + # self.logit_std = gr.Number( + # label="Logit Std", + # value=self.config.get("flux1.logit_std", 1.0), + # interactive=True, + # ) + # self.mode_scale = gr.Number( + # label="Mode Scale", + # value=self.config.get("flux1.mode_scale", 1.29), + # interactive=True, + # ) + + with gr.Row(): + self.ae = gr.Textbox( + label="VAE Path", + placeholder="Path to VAE model", + value=self.config.get("flux1.ae", ""), + interactive=True, + ) + self.ae_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.ae_button.click( + get_any_file_path, + outputs=self.ae, + show_progress=False, + ) + + self.clip_l = gr.Textbox( + label="CLIP-L Path", + placeholder="Path to CLIP-L model", + value=self.config.get("flux1.clip_l", ""), + interactive=True, + ) + self.clip_l_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.clip_l_button.click( + get_any_file_path, + outputs=self.clip_l, + show_progress=False, + ) + + # self.clip_g = gr.Textbox( + # label="CLIP-G Path", + # placeholder="Path to CLIP-G model", + # value=self.config.get("flux1.clip_g", ""), + # interactive=True, + # ) + # self.clip_g_button = gr.Button( + # document_symbol, + # elem_id="open_folder_small", + # visible=(not headless), + # interactive=True, + # ) + # self.clip_g_button.click( + # get_any_file_path, + # outputs=self.clip_g, + # show_progress=False, + # ) + + self.t5xxl = gr.Textbox( + label="T5-XXL Path", + placeholder="Path to T5-XXL model", + value=self.config.get("flux1.t5xxl", ""), + interactive=True, + ) + self.t5xxl_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.t5xxl_button.click( + get_any_file_path, + outputs=self.t5xxl, + show_progress=False, + ) + + # with gr.Row(): + # self.save_clip = gr.Checkbox( + # label="Save CLIP models", + # value=self.config.get("flux1.save_clip", False), + # interactive=True, + # ) + # self.save_t5xxl = gr.Checkbox( + # label="Save T5-XXL model", + # value=self.config.get("flux1.save_t5xxl", False), + # interactive=True, + # ) + + with gr.Row(): + # self.t5xxl_device = gr.Textbox( + # label="T5-XXL Device", + # placeholder="Device for T5-XXL (e.g., cuda:0)", + # value=self.config.get("flux1.t5xxl_device", ""), + # interactive=True, + # ) + # self.t5xxl_dtype = gr.Dropdown( + # label="T5-XXL Dtype", + # choices=["float32", "fp16", "bf16"], + # value=self.config.get("flux1.t5xxl_dtype", "bf16"), + # interactive=True, + # ) + # self.flux1_text_encoder_batch_size = gr.Number( + # label="Text Encoder Batch Size", + # value=self.config.get("flux1.text_encoder_batch_size", 1), + # minimum=1, + # maximum=1024, + # step=1, + # interactive=True, + # ) + self.flux1_cache_text_encoder_outputs = gr.Checkbox( + label="Cache Text Encoder Outputs", + value=self.config.get("flux1.cache_text_encoder_outputs", False), + info="Cache text encoder outputs to speed up inference", + interactive=True, + ) + self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox( + label="Cache Text Encoder Outputs to Disk", + value=self.config.get( + "flux1.cache_text_encoder_outputs_to_disk", False + ), + info="Cache text encoder outputs to disk to speed up inference", + interactive=True, + ) + + self.flux1_checkbox.change( + lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox), + inputs=[self.flux1_checkbox], + outputs=[flux1_accordion], + ) diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 1f619d36..49ff236f 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -36,6 +36,7 @@ from .class_huggingface import HuggingFace from .class_metadata import MetaData from .class_gui_config import KohyaSSGUIConfig +from .class_flux1 import flux1Training from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -238,6 +239,11 @@ def save_configuration( loraplus_lr_ratio, loraplus_text_encoder_lr_ratio, loraplus_unet_lr_ratio, + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + clip_l, + t5xxl, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -449,6 +455,11 @@ def open_configuration( loraplus_lr_ratio, loraplus_text_encoder_lr_ratio, loraplus_unet_lr_ratio, + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + clip_l, + t5xxl, training_preset, ): # Get list of function parameters and their values @@ -691,6 +702,11 @@ def train_model( loraplus_lr_ratio, loraplus_text_encoder_lr_ratio, loraplus_unet_lr_ratio, + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + clip_l, + t5xxl, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -1007,7 +1023,35 @@ def train_model( network_module = "lycoris.kohya" network_args = f" preset={LyCORIS_preset} rank_dropout={rank_dropout} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=full train_norm={train_norm}" - if LoRA_type in ["Flux1", "Kohya LoCon", "Standard"]: + if LoRA_type in ["Flux1"]: + kohya_lora_var_list = [ + "down_lr_weight", + "mid_lr_weight", + "up_lr_weight", + "block_lr_zero_threshold", + "block_dims", + "block_alphas", + "conv_block_dims", + "conv_block_alphas", + "rank_dropout", + "module_dropout", + ] + network_module = "networks.lora_flux" + kohya_lora_vars = { + key: value + for key, value in vars().items() + if key in kohya_lora_var_list and value + } + + # Not sure if Flux1 is Standard... or LoCon style... flip a coin... going for LoCon style... + if LoRA_type in ["Flux1"]: + network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f" {key}={value}" + + if LoRA_type in ["Kohya LoCon", "Standard"]: kohya_lora_var_list = [ "down_lr_weight", "mid_lr_weight", @@ -1028,7 +1072,7 @@ def train_model( } # Not sure if Flux1 is Standard... or LoCon style... flip a coin... going for LoCon style... - if LoRA_type in ["Flux1", "Kohya LoCon"]: + if LoRA_type in ["Kohya LoCon"]: network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' for key, value in kohya_lora_vars.items(): @@ -1208,7 +1252,7 @@ def train_model( "noise_offset_random_strength": noise_offset_random_strength if noise_offset_type == "Original" else None, "noise_offset_type": noise_offset_type, "optimizer_type": optimizer, - "optimizer_args": str(optimizer_args).replace('"', "").split(), + "optimizer_args": str(optimizer_args).replace('"', "").split() if optimizer_args != [] else None, "output_dir": output_dir, "output_name": output_name, "persistent_data_loader_workers": int(persistent_data_loader_workers), @@ -1263,6 +1307,14 @@ def train_model( "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, + + # Flux.1 specific parameters + "flux1_cache_text_encoder_outputs": flux1_cache_text_encoder_outputs if flux1_checkbox else None, + "flux1_cache_text_encoder_outputs_to_disk": flux1_cache_text_encoder_outputs_to_disk if flux1_checkbox else None, + "ae": ae if flux1_checkbox else None, + "clip_l": clip_l if flux1_checkbox else None, + "t5xxl": t5xxl if flux1_checkbox else None, + } # Given dictionary `config_toml_data` @@ -1530,6 +1582,9 @@ def list_presets(path): sdxl_params = SDXLParameters( source_model.sdxl_checkbox, config=config ) + + # Add FLUX1 Parameters + flux1_training = flux1Training(headless=headless, config=config, flux1_checkbox=source_model.flux1_checkbox) # LyCORIS Specific parameters with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion: @@ -2392,6 +2447,12 @@ def update_LoRA_settings( loraplus_lr_ratio, loraplus_text_encoder_lr_ratio, loraplus_unet_lr_ratio, + # Flux1 parameters + flux1_training.flux1_cache_text_encoder_outputs, + flux1_training.flux1_cache_text_encoder_outputs_to_disk, + flux1_training.ae, + flux1_training.clip_l, + flux1_training.t5xxl, ] configuration.button_open_config.click(