Skip to content

Commit

Permalink
Added Flux.1 parameters to GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Aug 10, 2024
1 parent 03532bb commit 5c53db4
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 3 deletions.
221 changes: 221 additions & 0 deletions kohya_gui/class_flux1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import gradio as gr
from typing import Tuple
from .common_gui import (
get_folder_path,
get_any_file_path,
list_files,
list_dirs,
create_refresh_button,
document_symbol,
)


class flux1Training:
"""
This class configures and initializes the advanced training settings for a machine learning model,
including options for headless operation, fine-tuning, training type selection, and default directory paths.
Attributes:
headless (bool): If True, run without the Gradio interface.
finetuning (bool): If True, enables fine-tuning of the model.
training_type (str): Specifies the type of training to perform.
no_token_padding (gr.Checkbox): Checkbox to disable token padding.
gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps.
weighted_captions (gr.Checkbox): Checkbox to enable weighted captions.
"""

def __init__(
self,
headless: bool = False,
finetuning: bool = False,
training_type: str = "",
config: dict = {},
flux1_checkbox: gr.Checkbox = False,
) -> None:
"""
Initializes the AdvancedTraining class with given settings.
Parameters:
headless (bool): Run in headless mode without GUI.
finetuning (bool): Enable model fine-tuning.
training_type (str): The type of training to be performed.
config (dict): Configuration options for the training process.
"""
self.headless = headless
self.finetuning = finetuning
self.training_type = training_type
self.config = config
self.flux1_checkbox = flux1_checkbox

# Define the behavior for changing noise offset type.
def noise_offset_type_change(
noise_offset_type: str,
) -> Tuple[gr.Group, gr.Group]:
"""
Returns a tuple of Gradio Groups with visibility set based on the noise offset type.
Parameters:
noise_offset_type (str): The selected noise offset type.
Returns:
Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set.
"""
if noise_offset_type == "Original":
return (gr.Group(visible=True), gr.Group(visible=False))
else:
return (gr.Group(visible=False), gr.Group(visible=True))

with gr.Accordion(
"Flux.1", open=True, elem_id="flux1_tab", visible=False
) as flux1_accordion:
with gr.Group():
# gr.Markdown("### Flux.1 Specific Parameters")
# with gr.Row():
# self.weighting_scheme = gr.Dropdown(
# label="Weighting Scheme",
# choices=["logit_normal", "sigma_sqrt", "mode", "cosmap"],
# value=self.config.get("flux1.weighting_scheme", "logit_normal"),
# interactive=True,
# )
# self.logit_mean = gr.Number(
# label="Logit Mean",
# value=self.config.get("flux1.logit_mean", 0.0),
# interactive=True,
# )
# self.logit_std = gr.Number(
# label="Logit Std",
# value=self.config.get("flux1.logit_std", 1.0),
# interactive=True,
# )
# self.mode_scale = gr.Number(
# label="Mode Scale",
# value=self.config.get("flux1.mode_scale", 1.29),
# interactive=True,
# )

with gr.Row():
self.ae = gr.Textbox(
label="VAE Path",
placeholder="Path to VAE model",
value=self.config.get("flux1.ae", ""),
interactive=True,
)
self.ae_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.ae_button.click(
get_any_file_path,
outputs=self.ae,
show_progress=False,
)

self.clip_l = gr.Textbox(
label="CLIP-L Path",
placeholder="Path to CLIP-L model",
value=self.config.get("flux1.clip_l", ""),
interactive=True,
)
self.clip_l_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.clip_l_button.click(
get_any_file_path,
outputs=self.clip_l,
show_progress=False,
)

# self.clip_g = gr.Textbox(
# label="CLIP-G Path",
# placeholder="Path to CLIP-G model",
# value=self.config.get("flux1.clip_g", ""),
# interactive=True,
# )
# self.clip_g_button = gr.Button(
# document_symbol,
# elem_id="open_folder_small",
# visible=(not headless),
# interactive=True,
# )
# self.clip_g_button.click(
# get_any_file_path,
# outputs=self.clip_g,
# show_progress=False,
# )

self.t5xxl = gr.Textbox(
label="T5-XXL Path",
placeholder="Path to T5-XXL model",
value=self.config.get("flux1.t5xxl", ""),
interactive=True,
)
self.t5xxl_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.t5xxl_button.click(
get_any_file_path,
outputs=self.t5xxl,
show_progress=False,
)

# with gr.Row():
# self.save_clip = gr.Checkbox(
# label="Save CLIP models",
# value=self.config.get("flux1.save_clip", False),
# interactive=True,
# )
# self.save_t5xxl = gr.Checkbox(
# label="Save T5-XXL model",
# value=self.config.get("flux1.save_t5xxl", False),
# interactive=True,
# )

with gr.Row():
# self.t5xxl_device = gr.Textbox(
# label="T5-XXL Device",
# placeholder="Device for T5-XXL (e.g., cuda:0)",
# value=self.config.get("flux1.t5xxl_device", ""),
# interactive=True,
# )
# self.t5xxl_dtype = gr.Dropdown(
# label="T5-XXL Dtype",
# choices=["float32", "fp16", "bf16"],
# value=self.config.get("flux1.t5xxl_dtype", "bf16"),
# interactive=True,
# )
# self.flux1_text_encoder_batch_size = gr.Number(
# label="Text Encoder Batch Size",
# value=self.config.get("flux1.text_encoder_batch_size", 1),
# minimum=1,
# maximum=1024,
# step=1,
# interactive=True,
# )
self.flux1_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get("flux1.cache_text_encoder_outputs", False),
info="Cache text encoder outputs to speed up inference",
interactive=True,
)
self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox(
label="Cache Text Encoder Outputs to Disk",
value=self.config.get(
"flux1.cache_text_encoder_outputs_to_disk", False
),
info="Cache text encoder outputs to disk to speed up inference",
interactive=True,
)

self.flux1_checkbox.change(
lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox),
inputs=[self.flux1_checkbox],
outputs=[flux1_accordion],
)
67 changes: 64 additions & 3 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .class_huggingface import HuggingFace
from .class_metadata import MetaData
from .class_gui_config import KohyaSSGUIConfig
from .class_flux1 import flux1Training

from .dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
Expand Down Expand Up @@ -238,6 +239,11 @@ def save_configuration(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
clip_l,
t5xxl,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -449,6 +455,11 @@ def open_configuration(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
clip_l,
t5xxl,
training_preset,
):
# Get list of function parameters and their values
Expand Down Expand Up @@ -691,6 +702,11 @@ def train_model(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
clip_l,
t5xxl,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -1007,7 +1023,35 @@ def train_model(
network_module = "lycoris.kohya"
network_args = f" preset={LyCORIS_preset} rank_dropout={rank_dropout} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=full train_norm={train_norm}"

if LoRA_type in ["Flux1", "Kohya LoCon", "Standard"]:
if LoRA_type in ["Flux1"]:
kohya_lora_var_list = [
"down_lr_weight",
"mid_lr_weight",
"up_lr_weight",
"block_lr_zero_threshold",
"block_dims",
"block_alphas",
"conv_block_dims",
"conv_block_alphas",
"rank_dropout",
"module_dropout",
]
network_module = "networks.lora_flux"
kohya_lora_vars = {
key: value
for key, value in vars().items()
if key in kohya_lora_var_list and value
}

# Not sure if Flux1 is Standard... or LoCon style... flip a coin... going for LoCon style...
if LoRA_type in ["Flux1"]:
network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"'

for key, value in kohya_lora_vars.items():
if value:
network_args += f" {key}={value}"

if LoRA_type in ["Kohya LoCon", "Standard"]:
kohya_lora_var_list = [
"down_lr_weight",
"mid_lr_weight",
Expand All @@ -1028,7 +1072,7 @@ def train_model(
}

# Not sure if Flux1 is Standard... or LoCon style... flip a coin... going for LoCon style...
if LoRA_type in ["Flux1", "Kohya LoCon"]:
if LoRA_type in ["Kohya LoCon"]:
network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"'

for key, value in kohya_lora_vars.items():
Expand Down Expand Up @@ -1208,7 +1252,7 @@ def train_model(
"noise_offset_random_strength": noise_offset_random_strength if noise_offset_type == "Original" else None,
"noise_offset_type": noise_offset_type,
"optimizer_type": optimizer,
"optimizer_args": str(optimizer_args).replace('"', "").split(),
"optimizer_args": str(optimizer_args).replace('"', "").split() if optimizer_args != [] else None,
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": int(persistent_data_loader_workers),
Expand Down Expand Up @@ -1263,6 +1307,14 @@ def train_model(
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
"weighted_captions": weighted_captions,
"xformers": True if xformers == "xformers" else None,

# Flux.1 specific parameters
"flux1_cache_text_encoder_outputs": flux1_cache_text_encoder_outputs if flux1_checkbox else None,
"flux1_cache_text_encoder_outputs_to_disk": flux1_cache_text_encoder_outputs_to_disk if flux1_checkbox else None,
"ae": ae if flux1_checkbox else None,
"clip_l": clip_l if flux1_checkbox else None,
"t5xxl": t5xxl if flux1_checkbox else None,

}

# Given dictionary `config_toml_data`
Expand Down Expand Up @@ -1530,6 +1582,9 @@ def list_presets(path):
sdxl_params = SDXLParameters(
source_model.sdxl_checkbox, config=config
)

# Add FLUX1 Parameters
flux1_training = flux1Training(headless=headless, config=config, flux1_checkbox=source_model.flux1_checkbox)

# LyCORIS Specific parameters
with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion:
Expand Down Expand Up @@ -2392,6 +2447,12 @@ def update_LoRA_settings(
loraplus_lr_ratio,
loraplus_text_encoder_lr_ratio,
loraplus_unet_lr_ratio,
# Flux1 parameters
flux1_training.flux1_cache_text_encoder_outputs,
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
flux1_training.ae,
flux1_training.clip_l,
flux1_training.t5xxl,
]

configuration.button_open_config.click(
Expand Down

0 comments on commit 5c53db4

Please sign in to comment.