Skip to content

Commit

Permalink
Add flux_fused_backward_pass to dreambooth and finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Aug 21, 2024
1 parent c50ecbb commit 5668ff0
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
8 changes: 7 additions & 1 deletion kohya_gui/class_flux1.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def noise_offset_type_change(
self.blockwise_fused_optimizers = gr.Checkbox(
label="Blockwise Fused Optimizer",
value=self.config.get("flux1.blockwise_fused_optimizers", False),
info="Enable blockwise optimizers for fused backward pass and optimizer step",
info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.",
interactive=True,
)
self.cpu_offload_checkpointing = gr.Checkbox(
Expand All @@ -203,6 +203,12 @@ def noise_offset_type_change(
step=1,
interactive=True,
)
self.flux_fused_backward_pass = gr.Checkbox(
label="Fused Backward Pass",
value=self.config.get("flux1.fused_backward_pass", False),
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)

self.flux1_checkbox.change(
lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox),
Expand Down
7 changes: 6 additions & 1 deletion kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def save_configuration(
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
single_blocks_to_swap,
double_blocks_to_swap,
Expand Down Expand Up @@ -424,6 +425,7 @@ def open_configuration(
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
single_blocks_to_swap,
double_blocks_to_swap,
Expand Down Expand Up @@ -621,6 +623,7 @@ def train_model(
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
single_blocks_to_swap,
double_blocks_to_swap,
Expand Down Expand Up @@ -902,7 +905,7 @@ def train_model(
"fp8_base": fp8_base,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"fused_backward_pass": fused_backward_pass,
"fused_backward_pass": fused_backward_pass if not flux1_checkbox else flux_fused_backward_pass,
"fused_optimizer_groups": (
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
),
Expand Down Expand Up @@ -1053,6 +1056,7 @@ def train_model(
"blockwise_fused_optimizers": (
blockwise_fused_optimizers if flux1_checkbox else None
),
# "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code
"cpu_offload_checkpointing": (
cpu_offload_checkpointing if flux1_checkbox else None
),
Expand Down Expand Up @@ -1399,6 +1403,7 @@ def dreambooth_tab(
flux1_training.t5xxl_max_token_length,
flux1_training.guidance_scale,
flux1_training.blockwise_fused_optimizers,
flux1_training.flux_fused_backward_pass,
flux1_training.cpu_offload_checkpointing,
flux1_training.single_blocks_to_swap,
flux1_training.double_blocks_to_swap,
Expand Down
7 changes: 6 additions & 1 deletion kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def save_configuration(
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
single_blocks_to_swap,
double_blocks_to_swap,
Expand Down Expand Up @@ -438,6 +439,7 @@ def open_configuration(
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
single_blocks_to_swap,
double_blocks_to_swap,
Expand Down Expand Up @@ -652,6 +654,7 @@ def train_model(
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
single_blocks_to_swap,
double_blocks_to_swap,
Expand Down Expand Up @@ -954,7 +957,7 @@ def train_model(
"fp8_base": fp8_base,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"fused_backward_pass": fused_backward_pass,
"fused_backward_pass": fused_backward_pass if not flux1_checkbox else flux_fused_backward_pass,
"fused_optimizer_groups": (
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
),
Expand Down Expand Up @@ -1097,6 +1100,7 @@ def train_model(
"blockwise_fused_optimizers": (
blockwise_fused_optimizers if flux1_checkbox else None
),
# "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code
"cpu_offload_checkpointing": (
cpu_offload_checkpointing if flux1_checkbox else None
),
Expand Down Expand Up @@ -1517,6 +1521,7 @@ def list_presets(path):
flux1_training.t5xxl_max_token_length,
flux1_training.guidance_scale,
flux1_training.blockwise_fused_optimizers,
flux1_training.flux_fused_backward_pass,
flux1_training.cpu_offload_checkpointing,
flux1_training.single_blocks_to_swap,
flux1_training.double_blocks_to_swap,
Expand Down
2 changes: 1 addition & 1 deletion sd-scripts

0 comments on commit 5668ff0

Please sign in to comment.