From 33937c22846bc5ae5978b67e7b7da584ac1f3f47 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 21 Aug 2024 07:52:46 -0400 Subject: [PATCH] Fix issue with split_mode and train_blocks --- kohya_gui/class_flux1.py | 34 +++++--- kohya_gui/dreambooth_gui.py | 127 ++++++++++++++++++------------ kohya_gui/finetune_gui.py | 99 ++++++++++++++--------- kohya_gui/lora_gui.py | 120 +++++++++++++++++++--------- sd-scripts | 2 +- test/config/dataset-multires.toml | 40 ++++++++++ 6 files changed, 283 insertions(+), 139 deletions(-) create mode 100644 test/config/dataset-multires.toml diff --git a/kohya_gui/class_flux1.py b/kohya_gui/class_flux1.py index db08fe808..e9ec461df 100644 --- a/kohya_gui/class_flux1.py +++ b/kohya_gui/class_flux1.py @@ -112,19 +112,10 @@ def noise_offset_type_change( value=self.config.get("flux1.timestep_sampling", "sigma"), interactive=True, ) - - self.flux1_cache_text_encoder_outputs = gr.Checkbox( - label="Cache Text Encoder Outputs", - value=self.config.get("flux1.cache_text_encoder_outputs", False), - info="Cache text encoder outputs to speed up inference", - interactive=True, - ) - self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox( - label="Cache Text Encoder Outputs to Disk", - value=self.config.get( - "flux1.cache_text_encoder_outputs_to_disk", False - ), - info="Cache text encoder outputs to disk to speed up inference", + self.apply_t5_attn_mask = gr.Checkbox( + label="Apply T5 Attention Mask", + value=self.config.get("flux1.apply_t5_attn_mask", False), + info="Apply attention mask to T5-XXL encode and FLUX double blocks ", interactive=True, ) with gr.Row(): @@ -158,12 +149,29 @@ def noise_offset_type_change( step=1, interactive=True, ) + + with gr.Row(): + self.flux1_cache_text_encoder_outputs = gr.Checkbox( + label="Cache Text Encoder Outputs", + value=self.config.get("flux1.cache_text_encoder_outputs", False), + info="Cache text encoder outputs to speed up inference", + interactive=True, + ) + self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox( + label="Cache Text Encoder Outputs to Disk", + value=self.config.get( + "flux1.cache_text_encoder_outputs_to_disk", False + ), + info="Cache text encoder outputs to disk to speed up inference", + interactive=True, + ) self.mem_eff_save = gr.Checkbox( label="Memory Efficient Save", value=self.config.get("flux1.mem_eff_save", False), info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.", interactive=True, ) + with gr.Row(visible=True if finetuning else False): self.blockwise_fused_optimizers = gr.Checkbox( label="Blockwise Fused Optimizer", diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 0e8ca86e1..09681cc60 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -17,7 +17,9 @@ SaveConfigFile, scriptdir, update_my_data, - validate_file_path, validate_folder_path, validate_model_path, + validate_file_path, + validate_folder_path, + validate_model_path, validate_args_setting, setup_environment, ) @@ -190,7 +192,6 @@ def save_configuration( metadata_license, metadata_tags, metadata_title, - # SD3 parameters sd3_cache_text_encoder_outputs, sd3_cache_text_encoder_outputs_to_disk, @@ -207,7 +208,6 @@ def save_configuration( sd3_text_encoder_batch_size, weighting_scheme, sd3_checkbox, - # Flux.1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, @@ -226,6 +226,7 @@ def save_configuration( single_blocks_to_swap, double_blocks_to_swap, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -393,7 +394,6 @@ def open_configuration( metadata_license, metadata_tags, metadata_title, - # SD3 parameters sd3_cache_text_encoder_outputs, sd3_cache_text_encoder_outputs_to_disk, @@ -410,7 +410,6 @@ def open_configuration( sd3_text_encoder_batch_size, weighting_scheme, sd3_checkbox, - # Flux.1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, @@ -429,6 +428,7 @@ def open_configuration( single_blocks_to_swap, double_blocks_to_swap, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -591,7 +591,6 @@ def train_model( metadata_license, metadata_tags, metadata_title, - # SD3 parameters sd3_cache_text_encoder_outputs, sd3_cache_text_encoder_outputs_to_disk, @@ -608,7 +607,6 @@ def train_model( sd3_text_encoder_batch_size, weighting_scheme, sd3_checkbox, - # Flux.1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, @@ -627,6 +625,7 @@ def train_model( single_blocks_to_swap, double_blocks_to_swap, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -647,42 +646,46 @@ def train_model( log.info(f"Validating lr scheduler arguments...") if not validate_args_setting(lr_scheduler_args): return - + log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): return TRAIN_BUTTON_VISIBLE # # Validate paths - # - + # + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(log_tracker_config): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + logging_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + output_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(reg_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(resume): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(train_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(vae): return TRAIN_BUTTON_VISIBLE - + # # End of path validation # @@ -821,7 +824,7 @@ def train_model( log.error("accelerate not found") return TRAIN_BUTTON_VISIBLE - run_cmd = [rf'{accelerate_path}', "launch"] + run_cmd = [rf"{accelerate_path}", "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -840,16 +843,22 @@ def train_model( ) if sdxl: - run_cmd.append(rf'{scriptdir}/sd-scripts/sdxl_train.py') + run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py") elif sd3_checkbox: run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py") elif flux1_checkbox: run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py") else: run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py") - - cache_text_encoder_outputs = (sdxl and sdxl_cache_text_encoder_outputs) or (sd3_checkbox and sd3_cache_text_encoder_outputs) or (flux1_checkbox and flux1_cache_text_encoder_outputs) - cache_text_encoder_outputs_to_disk = (sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk) + + cache_text_encoder_outputs = ( + (sdxl and sdxl_cache_text_encoder_outputs) + or (sd3_checkbox and sd3_cache_text_encoder_outputs) + or (flux1_checkbox and flux1_cache_text_encoder_outputs) + ) + cache_text_encoder_outputs_to_disk = ( + sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk + ) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk) no_half_vae = sdxl and sdxl_no_half_vae if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 @@ -862,9 +871,8 @@ def train_model( max_train_steps = int(max_train_steps) if sdxl: - train_text_encoder = ( - (learning_rate_te1 != None and learning_rate_te1 > 0) or - (learning_rate_te2 != None and learning_rate_te2 > 0) + train_text_encoder = (learning_rate_te1 != None and learning_rate_te1 > 0) or ( + learning_rate_te2 != None and learning_rate_te2 > 0 ) # def save_huggingface_to_toml(self, toml_file_path: str): @@ -895,7 +903,9 @@ def train_model( "full_bf16": full_bf16, "full_fp16": full_fp16, "fused_backward_pass": fused_backward_pass, - "fused_optimizer_groups": int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None, + "fused_optimizer_groups": ( + int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None + ), "gradient_accumulation_steps": int(gradient_accumulation_steps), "gradient_checkpointing": gradient_checkpointing, "huber_c": huber_c, @@ -909,9 +919,9 @@ def train_model( "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": int(keep_tokens), "learning_rate": learning_rate, # both for sd1.5 and sdxl - "learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5 - "learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl - "learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl + "learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5 + "learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl + "learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl "logging_dir": logging_dir, "log_config": log_config, "log_tracker_config": log_tracker_config, @@ -921,7 +931,9 @@ def train_model( "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), "lr_scheduler_num_cycles": ( - int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch) + int(lr_scheduler_num_cycles) + if lr_scheduler_num_cycles != "" + else int(epoch) ), "lr_scheduler_power": lr_scheduler_power, "lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None, @@ -930,7 +942,9 @@ def train_model( "max_bucket_reso": max_bucket_reso, "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), - "max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None, + "max_train_epochs": ( + int(max_train_epochs) if int(max_train_epochs) != 0 else None + ), "max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None, "mem_eff_attn": mem_eff_attn, "metadata_author": metadata_author, @@ -1006,7 +1020,6 @@ def train_model( "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, - # SD3 only Parameters # "cache_text_encoder_outputs": see previous assignment above for code # "cache_text_encoder_outputs_to_disk": see previous assignment above for code @@ -1020,9 +1033,10 @@ def train_model( # "t5xxl": see previous assignment above for code "t5xxl_device": t5xxl_device if sd3_checkbox else None, "t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None, - "text_encoder_batch_size": sd3_text_encoder_batch_size if sd3_checkbox else None, + "text_encoder_batch_size": ( + sd3_text_encoder_batch_size if sd3_checkbox else None + ), "weighting_scheme": weighting_scheme if sd3_checkbox else None, - # Flux.1 specific parameters # "cache_text_encoder_outputs": see previous assignment above for code # "cache_text_encoder_outputs_to_disk": see previous assignment above for code @@ -1036,11 +1050,16 @@ def train_model( "train_blocks": train_blocks if flux1_checkbox else None, "t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None, "guidance_scale": guidance_scale if flux1_checkbox else None, - "blockwise_fused_optimizers": blockwise_fused_optimizers if flux1_checkbox else None, - "cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None, + "blockwise_fused_optimizers": ( + blockwise_fused_optimizers if flux1_checkbox else None + ), + "cpu_offload_checkpointing": ( + cpu_offload_checkpointing if flux1_checkbox else None + ), "single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None, "double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None, "mem_eff_save": mem_eff_save if flux1_checkbox else None, + "apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None, } # Given dictionary `config_toml_data` @@ -1058,8 +1077,8 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = fr"{output_dir}/config_dreambooth-{formatted_datetime}.toml" - + tmpfilename = rf"{output_dir}/config_dreambooth-{formatted_datetime}.toml" + # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: toml.dump(config_toml_data, toml_file) @@ -1068,7 +1087,7 @@ def train_model( log.error(f"Failed to write TOML file: {toml_file.name}") run_cmd.append(f"--config_file") - run_cmd.append(rf'{tmpfilename}') + run_cmd.append(rf"{tmpfilename}") # Initialize a dictionary with always-included keyword arguments kwargs_for_training = { @@ -1173,17 +1192,26 @@ def dreambooth_tab( sdxl_checkbox=source_model.sdxl_checkbox, config=config, ) - + # Add SDXL Parameters sdxl_params = SDXLParameters( - source_model.sdxl_checkbox, config=config, trainer="finetune", + source_model.sdxl_checkbox, + config=config, + trainer="finetune", ) - + # Add FLUX1 Parameters - flux1_training = flux1Training(headless=headless, config=config, flux1_checkbox=source_model.flux1_checkbox, finetuning=True) + flux1_training = flux1Training( + headless=headless, + config=config, + flux1_checkbox=source_model.flux1_checkbox, + finetuning=True, + ) # Add SD3 Parameters - sd3_training = sd3Training(headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox) + sd3_training = sd3Training( + headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox + ) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless, config=config) @@ -1341,7 +1369,6 @@ def dreambooth_tab( metadata.metadata_license, metadata.metadata_tags, metadata.metadata_title, - # SD3 Parameters sd3_training.sd3_cache_text_encoder_outputs, sd3_training.sd3_cache_text_encoder_outputs_to_disk, @@ -1358,7 +1385,6 @@ def dreambooth_tab( sd3_training.sd3_text_encoder_batch_size, sd3_training.weighting_scheme, source_model.sd3_checkbox, - # Flux1 parameters flux1_training.flux1_cache_text_encoder_outputs, flux1_training.flux1_cache_text_encoder_outputs_to_disk, @@ -1377,6 +1403,7 @@ def dreambooth_tab( flux1_training.single_blocks_to_swap, flux1_training.double_blocks_to_swap, flux1_training.mem_eff_save, + flux1_training.apply_t5_attn_mask, ] configuration.button_open_config.click( diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index 182196274..0b51186a9 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -18,8 +18,11 @@ SaveConfigFile, scriptdir, update_my_data, - validate_file_path, validate_folder_path, validate_model_path, - validate_args_setting, setup_environment, + validate_file_path, + validate_folder_path, + validate_model_path, + validate_args_setting, + setup_environment, ) from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile @@ -197,7 +200,6 @@ def save_configuration( metadata_license, metadata_tags, metadata_title, - # SD3 parameters sd3_cache_text_encoder_outputs, sd3_cache_text_encoder_outputs_to_disk, @@ -214,7 +216,6 @@ def save_configuration( sd3_text_encoder_batch_size, weighting_scheme, sd3_checkbox, - # Flux.1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, @@ -233,6 +234,7 @@ def save_configuration( single_blocks_to_swap, double_blocks_to_swap, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -406,7 +408,6 @@ def open_configuration( metadata_license, metadata_tags, metadata_title, - # SD3 parameters sd3_cache_text_encoder_outputs, sd3_cache_text_encoder_outputs_to_disk, @@ -423,7 +424,6 @@ def open_configuration( sd3_text_encoder_batch_size, weighting_scheme, sd3_checkbox, - # Flux.1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, @@ -443,6 +443,7 @@ def open_configuration( double_blocks_to_swap, mem_eff_save, training_preset, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -621,7 +622,6 @@ def train_model( metadata_license, metadata_tags, metadata_title, - # SD3 parameters sd3_cache_text_encoder_outputs, sd3_cache_text_encoder_outputs_to_disk, @@ -638,7 +638,6 @@ def train_model( sd3_text_encoder_batch_size, weighting_scheme, sd3_checkbox, - # Flux.1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, @@ -657,6 +656,7 @@ def train_model( single_blocks_to_swap, double_blocks_to_swap, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -689,33 +689,37 @@ def train_model( # # Validate paths - # - + # + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(image_folder): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(log_tracker_config): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + logging_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + output_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(resume): return TRAIN_BUTTON_VISIBLE - + # # End of path validation # - + # if not validate_paths( # dataset_config=dataset_config, # finetune_image_folder=image_folder, @@ -869,7 +873,7 @@ def train_model( log.error("accelerate not found") return TRAIN_BUTTON_VISIBLE - run_cmd = [rf'{accelerate_path}', "launch"] + run_cmd = [rf"{accelerate_path}", "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -901,8 +905,14 @@ def train_model( if use_latent_files == "Yes" else f"{train_dir}/{caption_metadata_filename}" ) - cache_text_encoder_outputs = (sdxl_checkbox and sdxl_cache_text_encoder_outputs) or (sd3_checkbox and sd3_cache_text_encoder_outputs) or (flux1_checkbox and flux1_cache_text_encoder_outputs) - cache_text_encoder_outputs_to_disk = (sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk) + cache_text_encoder_outputs = ( + (sdxl_checkbox and sdxl_cache_text_encoder_outputs) + or (sd3_checkbox and sd3_cache_text_encoder_outputs) + or (flux1_checkbox and flux1_cache_text_encoder_outputs) + ) + cache_text_encoder_outputs_to_disk = ( + sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk + ) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk) no_half_vae = sdxl_checkbox and sdxl_no_half_vae if max_data_loader_n_workers == "" or None: @@ -945,7 +955,9 @@ def train_model( "full_bf16": full_bf16, "full_fp16": full_fp16, "fused_backward_pass": fused_backward_pass, - "fused_optimizer_groups": int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None, + "fused_optimizer_groups": ( + int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None + ), "gradient_accumulation_steps": int(gradient_accumulation_steps), "gradient_checkpointing": gradient_checkpointing, "huber_c": huber_c, @@ -1052,7 +1064,6 @@ def train_model( "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, - # SD3 only Parameters # "cache_text_encoder_outputs": see previous assignment above for code # "cache_text_encoder_outputs_to_disk": see previous assignment above for code @@ -1066,9 +1077,10 @@ def train_model( # "t5xxl": see previous assignment above for code "t5xxl_device": t5xxl_device if sd3_checkbox else None, "t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None, - "text_encoder_batch_size": sd3_text_encoder_batch_size if sd3_checkbox else None, + "text_encoder_batch_size": ( + sd3_text_encoder_batch_size if sd3_checkbox else None + ), "weighting_scheme": weighting_scheme if sd3_checkbox else None, - # Flux.1 specific parameters # "cache_text_encoder_outputs": see previous assignment above for code # "cache_text_encoder_outputs_to_disk": see previous assignment above for code @@ -1082,11 +1094,16 @@ def train_model( "train_blocks": train_blocks if flux1_checkbox else None, "t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None, "guidance_scale": guidance_scale if flux1_checkbox else None, - "blockwise_fused_optimizers": blockwise_fused_optimizers if flux1_checkbox else None, - "cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None, + "blockwise_fused_optimizers": ( + blockwise_fused_optimizers if flux1_checkbox else None + ), + "cpu_offload_checkpointing": ( + cpu_offload_checkpointing if flux1_checkbox else None + ), "single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None, "double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None, "mem_eff_save": mem_eff_save if flux1_checkbox else None, + "apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None, } # Given dictionary `config_toml_data` @@ -1104,7 +1121,7 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = fr"{output_dir}/config_finetune-{formatted_datetime}.toml" + tmpfilename = rf"{output_dir}/config_finetune-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: toml.dump(config_toml_data, toml_file) @@ -1270,7 +1287,9 @@ def list_presets(path): # Add SDXL Parameters sdxl_params = SDXLParameters( - source_model.sdxl_checkbox, config=config, trainer="finetune", + source_model.sdxl_checkbox, + config=config, + trainer="finetune", ) with gr.Row(): @@ -1278,12 +1297,19 @@ def list_presets(path): train_text_encoder = gr.Checkbox( label="Train text encoder", value=True ) - + # Add FLUX1 Parameters - flux1_training = flux1Training(headless=headless, config=config, flux1_checkbox=source_model.flux1_checkbox, finetuning=True) - + flux1_training = flux1Training( + headless=headless, + config=config, + flux1_checkbox=source_model.flux1_checkbox, + finetuning=True, + ) + # Add SD3 Parameters - sd3_training = sd3Training(headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox) + sd3_training = sd3Training( + headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox + ) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): with gr.Row(): @@ -1461,7 +1487,6 @@ def list_presets(path): metadata.metadata_license, metadata.metadata_tags, metadata.metadata_title, - # SD3 Parameters sd3_training.sd3_cache_text_encoder_outputs, sd3_training.sd3_cache_text_encoder_outputs_to_disk, @@ -1478,7 +1503,6 @@ def list_presets(path): sd3_training.sd3_text_encoder_batch_size, sd3_training.weighting_scheme, source_model.sd3_checkbox, - # Flux1 parameters flux1_training.flux1_cache_text_encoder_outputs, flux1_training.flux1_cache_text_encoder_outputs_to_disk, @@ -1497,6 +1521,7 @@ def list_presets(path): flux1_training.single_blocks_to_swap, flux1_training.double_blocks_to_swap, flux1_training.mem_eff_save, + flux1_training.apply_t5_attn_mask, ] configuration.button_open_config.click( diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 01db63da2..ccd530998 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -19,8 +19,12 @@ SaveConfigFile, scriptdir, update_my_data, - validate_file_path, validate_folder_path, validate_model_path, validate_toml_file, - validate_args_setting, setup_environment, + validate_file_path, + validate_folder_path, + validate_model_path, + validate_toml_file, + validate_args_setting, + setup_environment, ) from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile @@ -239,7 +243,7 @@ def save_configuration( loraplus_lr_ratio, loraplus_text_encoder_lr_ratio, loraplus_unet_lr_ratio, - #Flux1 + # Flux1 flux1_cache_text_encoder_outputs, flux1_cache_text_encoder_outputs_to_disk, ae, @@ -253,6 +257,7 @@ def save_configuration( t5xxl_max_token_length, guidance_scale, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -477,6 +482,7 @@ def open_configuration( t5xxl_max_token_length, guidance_scale, mem_eff_save, + apply_t5_attn_mask, training_preset, ): # Get list of function parameters and their values @@ -732,6 +738,7 @@ def train_model( t5xxl_max_token_length, guidance_scale, mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -759,42 +766,46 @@ def train_model( # # Validate paths - # - + # + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(log_tracker_config): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + logging_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES: if not validate_toml_file(LyCORIS_preset): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(network_weights): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + output_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(reg_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(resume): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(train_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(vae): return TRAIN_BUTTON_VISIBLE - + # # End of path validation # @@ -985,7 +996,7 @@ def train_model( log.error("accelerate not found") return TRAIN_BUTTON_VISIBLE - run_cmd = [rf'{accelerate_path}', "launch"] + run_cmd = [rf"{accelerate_path}", "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -1060,6 +1071,7 @@ def train_model( "conv_block_alphas", "rank_dropout", "module_dropout", + "train_blocks", ] network_module = "networks.lora_flux" kohya_lora_vars = { @@ -1067,11 +1079,16 @@ def train_model( for key, value in vars().items() if key in kohya_lora_var_list and value } - + if split_mode: + if train_blocks != "single": + log.warning( + f"train_blocks is currently set to '{train_blocks}'. split_mode is enabled, forcing train_blocks to 'single'." + ) + kohya_lora_vars["train_blocks"] = "single" for key, value in kohya_lora_vars.items(): if value: network_args += f" {key}={value}" - + if LoRA_type in ["Kohya LoCon", "Standard"]: kohya_lora_var_list = [ "down_lr_weight", @@ -1091,7 +1108,7 @@ def train_model( for key, value in vars().items() if key in kohya_lora_var_list and value } - + # Not sure if Flux1 is Standard... or LoCon style... flip a coin... going for LoCon style... if LoRA_type in ["Kohya LoCon"]: network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' @@ -1175,13 +1192,15 @@ def train_model( network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0 # Flag to train unet only if its learning rate is non-zero and text encoder's is zero. network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0 - + if text_encoder_lr_float != 0 or unet_lr_float != 0: do_not_set_learning_rate = True - + config_toml_data = { "adaptive_noise_scale": ( - adaptive_noise_scale if (adaptive_noise_scale != 0 and noise_offset_type == "Original") else None + adaptive_noise_scale + if (adaptive_noise_scale != 0 and noise_offset_type == "Original") + else None ), "async_upload": async_upload, "bucket_no_upscale": bucket_no_upscale, @@ -1189,7 +1208,10 @@ def train_model( "cache_latents": cache_latents, "cache_latents_to_disk": cache_latents_to_disk, "cache_text_encoder_outputs": ( - True if (sdxl and sdxl_cache_text_encoder_outputs) or (flux1_checkbox and flux1_cache_text_encoder_outputs) else None + True + if (sdxl and sdxl_cache_text_encoder_outputs) + or (flux1_checkbox and flux1_cache_text_encoder_outputs) + else None ), "caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs), "caption_dropout_rate": caption_dropout_rate, @@ -1225,7 +1247,9 @@ def train_model( "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, "loraplus_lr_ratio": loraplus_lr_ratio if not 0 else None, - "loraplus_text_encoder_lr_ratio": loraplus_text_encoder_lr_ratio if not 0 else None, + "loraplus_text_encoder_lr_ratio": ( + loraplus_text_encoder_lr_ratio if not 0 else None + ), "loraplus_unet_lr_ratio": loraplus_unet_lr_ratio if not 0 else None, "loss_type": loss_type, "lowvram": lowvram, @@ -1258,9 +1282,13 @@ def train_model( "min_snr_gamma": min_snr_gamma if min_snr_gamma != 0 else None, "min_timestep": min_timestep if min_timestep != 0 else None, "mixed_precision": mixed_precision, - "multires_noise_discount": multires_noise_discount if noise_offset_type == "Multires" else None, + "multires_noise_discount": ( + multires_noise_discount if noise_offset_type == "Multires" else None + ), "multires_noise_iterations": ( - multires_noise_iterations if (multires_noise_iterations != 0 and noise_offset_type == "Multires") else None + multires_noise_iterations + if (multires_noise_iterations != 0 and noise_offset_type == "Multires") + else None ), "network_alpha": network_alpha, "network_args": str(network_args).replace('"', "").split(), @@ -1271,11 +1299,21 @@ def train_model( "network_train_text_encoder_only": network_train_text_encoder_only, "network_weights": network_weights, "no_half_vae": True if sdxl and sdxl_no_half_vae else None, - "noise_offset": noise_offset if (noise_offset != 0 and noise_offset_type == "Original") else None, - "noise_offset_random_strength": noise_offset_random_strength if noise_offset_type == "Original" else None, + "noise_offset": ( + noise_offset + if (noise_offset != 0 and noise_offset_type == "Original") + else None + ), + "noise_offset_random_strength": ( + noise_offset_random_strength if noise_offset_type == "Original" else None + ), "noise_offset_type": noise_offset_type, "optimizer_type": optimizer, - "optimizer_args": str(optimizer_args).replace('"', "").split() if optimizer_args != [] else None, + "optimizer_args": ( + str(optimizer_args).replace('"', "").split() + if optimizer_args != [] + else None + ), "output_dir": output_dir, "output_name": output_name, "persistent_data_loader_workers": int(persistent_data_loader_workers), @@ -1330,10 +1368,11 @@ def train_model( "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, - # Flux.1 specific parameters # "cache_text_encoder_outputs": see previous assignment above for code - "cache_text_encoder_outputs_to_disk": flux1_cache_text_encoder_outputs_to_disk if flux1_checkbox else None, + "cache_text_encoder_outputs_to_disk": ( + flux1_cache_text_encoder_outputs_to_disk if flux1_checkbox else None + ), "ae": ae if flux1_checkbox else None, "clip_l": clip_l if flux1_checkbox else None, "t5xxl": t5xxl if flux1_checkbox else None, @@ -1341,10 +1380,10 @@ def train_model( "model_prediction_type": model_prediction_type if flux1_checkbox else None, "timestep_sampling": timestep_sampling if flux1_checkbox else None, "split_mode": split_mode if flux1_checkbox else None, - "train_blocks": train_blocks if flux1_checkbox else None, "t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None, "guidance_scale": float(guidance_scale) if flux1_checkbox else None, "mem_eff_save": mem_eff_save if flux1_checkbox else None, + "apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None, } # Given dictionary `config_toml_data` @@ -1362,7 +1401,7 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = fr"{output_dir}/config_lora-{formatted_datetime}.toml" + tmpfilename = rf"{output_dir}/config_lora-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: @@ -2194,10 +2233,14 @@ def update_LoRA_settings( results.append(settings["gr_type"](**update_params)) return tuple(results) - + with gr.Group(): # Add FLUX1 Parameters - flux1_training = flux1Training(headless=headless, config=config, flux1_checkbox=source_model.flux1_checkbox) + flux1_training = flux1Training( + headless=headless, + config=config, + flux1_checkbox=source_model.flux1_checkbox, + ) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): # with gr.Accordion('Advanced Configuration', open=False): @@ -2493,6 +2536,7 @@ def update_LoRA_settings( flux1_training.t5xxl_max_token_length, flux1_training.guidance_scale, flux1_training.mem_eff_save, + flux1_training.apply_t5_attn_mask, ] configuration.button_open_config.click( diff --git a/sd-scripts b/sd-scripts index 6ab48b09d..2b07a92c8 160000 --- a/sd-scripts +++ b/sd-scripts @@ -1 +1 @@ -Subproject commit 6ab48b09d8e46973d5e5fa47baeae3a464d06d04 +Subproject commit 2b07a92c8d970a8538a47dd1bcad3122da4e195a diff --git a/test/config/dataset-multires.toml b/test/config/dataset-multires.toml new file mode 100644 index 000000000..9cba749c2 --- /dev/null +++ b/test/config/dataset-multires.toml @@ -0,0 +1,40 @@ +[general] +# define common settings here +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" +min_bucket_reso = 64 +max_bucket_reso = 2048 + +[[datasets]] +# define the first resolution here +batch_size = 1 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "./test/img/10_darius kawasaki person" + num_repeats = 10 + +[[datasets]] +# define the second resolution here +batch_size = 1 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "./test/img/10_darius kawasaki person" + num_repeats = 10 + +[[datasets]] +# define the third resolution here +batch_size = 1 +enable_bucket = true +resolution = [512, 512] + + [[datasets.subsets]] + image_dir = "./test/img/10_darius kawasaki person" + num_repeats = 10 \ No newline at end of file