Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename and deprecate flux option names #1264

Merged
merged 2 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
{pixart_sigma,kolors,sd3,flux,smoldit,sdxl,legacy}
[--model_type {full,lora,deepfloyd-full,deepfloyd-lora,deepfloyd-stage2,deepfloyd-stage2-lora}]
[--flux_lora_target {mmdit,context,context+ffs,all,all+ffs,ai-toolkit,tiny,nano}]
[--flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE]
[--flux_fast_schedule] [--flux_use_uniform_schedule]
[--flux_use_beta_schedule]
[--flux_beta_schedule_alpha FLUX_BETA_SCHEDULE_ALPHA]
[--flux_beta_schedule_beta FLUX_BETA_SCHEDULE_BETA]
[--flux_schedule_shift FLUX_SCHEDULE_SHIFT]
[--flux_schedule_auto_shift]
[--flow_sigmoid_scale flow_sigmoid_scale]
[--flux_fast_schedule] [--flow_use_uniform_schedule]
[--flow_use_beta_schedule]
[--flow_beta_schedule_alpha flow_beta_schedule_alpha]
[--flow_beta_schedule_beta flow_beta_schedule_beta]
[--flow_schedule_shift flow_schedule_shift]
[--flow_schedule_auto_shift]
[--flux_guidance_mode {constant,random-range,mobius}]
[--flux_guidance_value FLUX_GUIDANCE_VALUE]
[--flux_guidance_min FLUX_GUIDANCE_MIN]
Expand Down Expand Up @@ -602,29 +602,29 @@ options:
norms (based on ostris/ai-toolkit). If 'tiny' is
provided, only two layers will be trained. If 'nano'
is provided, only one layers will be trained.
--flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE
--flow_sigmoid_scale flow_sigmoid_scale
Scale factor for sigmoid timestep sampling for flow-
matching models..
--flux_fast_schedule An experimental feature to train Flux.1S using a noise
schedule closer to what it was trained with, which has
improved results in short experiments. Thanks to
@mhirki for the contribution.
--flux_use_uniform_schedule
--flow_use_uniform_schedule
Whether or not to use a uniform schedule with Flux
instead of sigmoid. Using uniform sampling may help
preserve more capabilities from the base model. Some
tasks may not benefit from this.
--flux_use_beta_schedule
--flow_use_beta_schedule
Whether or not to use a beta schedule with Flux
instead of sigmoid. The default values of alpha and
beta approximate a sigmoid.
--flux_beta_schedule_alpha FLUX_BETA_SCHEDULE_ALPHA
--flow_beta_schedule_alpha flow_beta_schedule_alpha
The alpha value of the flux beta schedule. Default is
2.0
--flux_beta_schedule_beta FLUX_BETA_SCHEDULE_BETA
--flow_beta_schedule_beta flow_beta_schedule_beta
The beta value of the flux beta schedule. Default is
2.0
--flux_schedule_shift FLUX_SCHEDULE_SHIFT
--flow_schedule_shift flow_schedule_shift
Shift the noise schedule. This is a value between 0
and ~4.0, where 0 disables the timestep-dependent
shift, and anything greater than 0 will shift the
Expand All @@ -637,7 +637,7 @@ options:
contrast is learnt by the model, and whether fine
details are ignored or accentuated, removing fine
details and making the outputs blurrier.
--flux_schedule_auto_shift
--flow_schedule_auto_shift
Shift the noise schedule depending on image
resolution. The shift value calculation is taken from
the official Flux inference code. Shift value is
Expand Down
6 changes: 3 additions & 3 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@ Flow-matching models such as Flux and SD3 have a property called "shift" that al
By default, no schedule shift is applied to flux, which results in a sigmoid bell-shape to the timestep sampling distribution. This is unlikely to be the ideal approach for Flux, but it results in a greater amount of learning in a shorter period of time than auto-shift.

##### Auto-shift
A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flux_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results.
A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flow_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results.

##### Manual specification
_Thanks to General Awareness from Discord for the following examples_

When using a `--flux_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected:
When using a `--flow_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected:
![image](https://github.com/user-attachments/assets/991ca0ad-e25a-4b13-a3d6-b4f2de1fe982)

When using a `--flux_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted:
When using a `--flow_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted:
![image](https://github.com/user-attachments/assets/857a1f8a-07ab-4b75-8e6a-eecff616a28d)


Expand Down
6 changes: 3 additions & 3 deletions documentation/quickstart/SANA.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ If you wish to enable evaluations to score the model's performance, see [this do
Flow-matching models such as Sana, Sana, and SD3 have a property called "shift" that allows us to shift the trained portion of the timestep schedule using a simple decimal value.

##### Auto-shift
A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flux_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results.
A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flow_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results.

##### Manual specification
_Thanks to General Awareness from Discord for the following examples_

When using a `--flux_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected:
When using a `--flow_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected:
![image](https://github.com/user-attachments/assets/991ca0ad-e25a-4b13-a3d6-b4f2de1fe982)

When using a `--flux_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted:
When using a `--flow_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted:
![image](https://github.com/user-attachments/assets/857a1f8a-07ab-4b75-8e6a-eecff616a28d)

#### Dataset considerations
Expand Down
24 changes: 12 additions & 12 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ In your `/home/user/simpletuner/config` directory, create a multidatabackend.jso
"crop": true,
"crop_aspect": "square",
"crop_style": "center",
"resolution": 1.0,
"resolution": 1024,
"minimum_image_size": 0,
"maximum_image_size": 1.0,
"target_downsample_size": 1.0,
"resolution_type": "area",
"cache_dir_vae": "cache/vae/sd3/pseudo-camera-10k",
"maximum_image_size": 1024,
"target_downsample_size": 1024,
"resolution_type": "pixel_area",
"cache_dir_vae": "/home/user/simpletuner/output/cache/vae/sd3/pseudo-camera-10k",
"instance_data_dir": "/home/user/simpletuner/datasets/pseudo-camera-10k",
"disabled": false,
"skip_file_discovery": "",
Expand Down Expand Up @@ -277,8 +277,8 @@ The following values are recommended for `config.json`:
"--validation_guidance_skip_layers_stop": 0.2,
"--validation_guidance_skip_scale": 2.8,
"--validation_guidance": 4.0,
"--flux_use_uniform_schedule": true,
"--flux_schedule_auto_shift": true
"--flow_use_uniform_schedule": true,
"--flow_schedule_auto_shift": true
}
```

Expand Down Expand Up @@ -309,17 +309,17 @@ Some changes were made to SimpleTuner's SD3.5 support:
- No longer zeroing T5 padding space by default (`--t5_padding`)
- Offering a switch (`--sd3_clip_uncond_behaviour` and `--sd3_t5_uncond_behaviour`) to use empty encoded blank captions for unconditional predictions (`empty_string`, **default**) or zeros (`zero`), not a recommended setting to tweak.
- SD3.5 training loss function was updated to match that found in the upstream StabilityAI/SD3.5 repository
- Updated default `--flux_schedule_shift` value to 3 to match the static 1024px value for SD3
- StabilityAI followed-up with documentation to use `--flux_schedule_shift=1` with `--flux_use_uniform_schedule`
- Community members have reported that `--flux_schedule_auto_shift` works better when using mult-aspect or multi-resolution training
- Updated the hard-coded tokeniser sequence length limit to **256** with the option to revert it to **77** tokens to save disk space or compute at the cost of output quality degradation
- Updated default `--flow_schedule_shift` value to 3 to match the static 1024px value for SD3
- StabilityAI followed-up with documentation to use `--flow_schedule_shift=1` with `--flow_use_uniform_schedule`
- Community members have reported that `--flow_schedule_auto_shift` works better when using mult-aspect or multi-resolution training
- Updated the hard-coded tokeniser sequence length limit to **154** with the option to revert it to **77** tokens to save disk space or compute at the cost of output quality degradation


#### Stable configuration values

These options have been known to keep SD3.5 in-tact for as long as possible:
- optimizer=adamw_bf16
- flux_schedule_shift=1
- flow_schedule_shift=1
- learning_rate=1e-4
- batch_size=4 * 3 GPUs
- max_grad_norm=0.1
Expand Down
92 changes: 78 additions & 14 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_argument_parser():
],
default="all",
help=(
"Flux has single and joint attention blocks."
"This option only applies to Standard LoRA, not Lycoris. Flux has single and joint attention blocks."
" By default, all attention layers are trained, but not the feed-forward layers"
" If 'mmdit' is provided, the text input layers will not be trained."
" If 'context' is provided, then ONLY the text attention layers are trained"
Expand All @@ -147,8 +147,14 @@ def get_argument_parser():
parser.add_argument(
"--flow_matching_sigmoid_scale",
type=float,
default=None,
help="Deprecated option. Replaced with --flow_sigmoid_scale.",
)
parser.add_argument(
"--flow_sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for sigmoid timestep sampling for flow-matching models..",
help="Scale factor for sigmoid timestep sampling for flow-matching models.",
)
parser.add_argument(
"--flux_fast_schedule",
Expand All @@ -160,16 +166,30 @@ def get_argument_parser():
)
parser.add_argument(
"--flux_use_uniform_schedule",
default=None,
action="store_true",
help="Deprecated option. Replaced with --flow_use_uniform_schedule.",
)
parser.add_argument(
"--flow_use_uniform_schedule",
default=False,
action="store_true",
help=(
"Whether or not to use a uniform schedule with Flux instead of sigmoid."
"Whether or not to use a uniform schedule with flow-matching models instead of sigmoid."
" Using uniform sampling may help preserve more capabilities from the base model."
" Some tasks may not benefit from this."
),
)
parser.add_argument(
"--flux_use_beta_schedule",
action="store_true",
default=None,
help="Deprecated option. Replaced with --flow_use_beta_schedule.",
)
parser.add_argument(
"--flow_use_beta_schedule",
action="store_true",
default=False,
help=(
"Whether or not to use a beta schedule with Flux instead of sigmoid. The default values of alpha"
" and beta approximate a sigmoid."
Expand All @@ -178,31 +198,54 @@ def get_argument_parser():
parser.add_argument(
"--flux_beta_schedule_alpha",
type=float,
default=None,
help=("Deprecated option. Replaced with --flux_beta_schedule_alpha."),
)
parser.add_argument(
"--flow_beta_schedule_alpha",
type=float,
default=2.0,
help=("The alpha value of the flux beta schedule. Default is 2.0"),
)
parser.add_argument(
"--flux_beta_schedule_beta",
type=float,
default=None,
help=("Deprecated option. Replaced with --flow_beta_schedule_beta."),
)
parser.add_argument(
"--flow_beta_schedule_beta",
type=float,
default=2.0,
help=("The beta value of the flux beta schedule. Default is 2.0"),
)
parser.add_argument(
"--flux_schedule_shift",
type=float,
default=None,
help=("Deprecated option. Replaced with --flow_schedule_shift."),
)
parser.add_argument(
"--flow_schedule_shift",
type=float,
default=3,
help=(
"Shift the noise schedule. This is a value between 0 and ~4.0, where 0 disables the timestep-dependent shift,"
" and anything greater than 0 will shift the timestep sampling accordingly. The SD3 model was trained with"
" a shift value of 3. The value for Flux is unknown. Higher values result in less noisy timesteps sampled,"
" which results in a lower mean loss value, but not necessarily better results. Early reports indicate"
" that modification of this value can change how the contrast is learnt by the model, and whether fine"
" details are ignored or accentuated, removing fine details and making the outputs blurrier."
" and anything greater than 0 will shift the timestep sampling accordingly. Sana and SD3 were trained with"
" a shift value of 3. This value can change how contrast/brightness are learnt by the model, and whether fine"
" details are ignored or accentuated. A higher value will focus more on large compositional features,"
" and a lower value will focus on the high frequency fine details."
),
)
parser.add_argument(
"--flux_schedule_auto_shift",
action="store_true",
default=None,
help="Deprecated option. Replaced with --flow_schedule_auto_shift.",
)
parser.add_argument(
"--flow_schedule_auto_shift",
action="store_true",
default=False,
help=(
"Shift the noise schedule depending on image resolution. The shift value calculation is taken from the official"
Expand All @@ -214,17 +257,12 @@ def get_argument_parser():
parser.add_argument(
"--flux_guidance_mode",
type=str,
choices=["constant", "random-range", "mobius"],
choices=["constant", "random-range"],
default="constant",
help=(
"Flux has a 'guidance' value used during training time that reflects the CFG range of your training samples."
" The default mode 'constant' will use a single value for every sample."
" The mode 'random-range' will randomly select a value from the range of the CFG for each sample."
" The mode 'mobius' will use a value that is a function of the remaining steps in the epoch, constructively"
" deconstructing the constructed deconstructions to then Mobius them back into the constructed reconstructions,"
" possibly resulting in the exploration of what is known as the Mobius space, a new continuous"
" realm of possibility brought about by destroying the model so that you can make it whole once more."
" Or so according to DataVoid, anyway. This is just a Flux-specific implementation of Mobius."
" Set the range using --flux_guidance_min and --flux_guidance_max."
),
)
Expand Down Expand Up @@ -2595,4 +2633,30 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
"For non-CUDA systems, only Diffusers attention mechanism is officially supported."
)

deprecated_options = {
"flux_beta_schedule_alpha": "flow_beta_schedule_alpha",
"flux_beta_schedule_beta": "flow_beta_schedule_beta",
"flux_use_beta_schedule": "flow_use_beta_schedule",
"flux_use_uniform_schedule": "flow_use_uniform_schedule",
"flux_schedule_shift": "flow_schedule_shift",
"flux_schedule_auto_shift": "flow_schedule_auto_shift",
"flow_matching_sigmoid_scale": "flow_sigmoid_scale",
}

for deprecated_option, replacement_option in deprecated_options.items():
if (
getattr(args, replacement_option) is not None
and getattr(args, deprecated_option) is not None
and type(getattr(args, deprecated_option)) is not object
):
warning_log(
f"The option --{deprecated_option} has been replaced with --{replacement_option}."
)
setattr(args, replacement_option, getattr(args, deprecated_option))
elif getattr(args, deprecated_option) is not None:
error_log(
f"The option {deprecated_option} has been deprecated without a replacement option. Please remove it from your configuration."
)
sys.exit(1)

return args
8 changes: 4 additions & 4 deletions helpers/models/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
)


def apply_flux_schedule_shift(args, noise_scheduler, sigmas, noise):
def apply_flow_schedule_shift(args, noise_scheduler, sigmas, noise):
# Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper
shift = None
if args.flux_schedule_shift is not None and args.flux_schedule_shift > 0:
if args.flow_schedule_shift is not None and args.flow_schedule_shift > 0:
# Static shift value for every resolution
shift = args.flux_schedule_shift
elif args.flux_schedule_auto_shift:
shift = args.flow_schedule_shift
elif args.flow_schedule_auto_shift:
# Resolution-dependent shift value calculation used by official Flux inference implementation
image_seq_len = (noise.shape[-1] * noise.shape[-2]) // 4
mu = calculate_shift_flux(
Expand Down
Loading
Loading