From 237f2e07ad0c6891ab02598842f02f935e61ca74 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 4 Jan 2025 10:04:08 -0600 Subject: [PATCH 1/2] sana: add LoRA (PEFT) loading support to modeling code, fix positional embeds for larger models --- helpers/models/sana/transformer.py | 32 ++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/helpers/models/sana/transformer.py b/helpers/models/sana/transformer.py index 2243a9b9..a23d06d6 100644 --- a/helpers/models/sana/transformer.py +++ b/helpers/models/sana/transformer.py @@ -18,7 +18,13 @@ from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import is_torch_version, logging +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) from diffusers.models.attention_processor import ( Attention, AttentionProcessor, @@ -256,6 +262,7 @@ def __init__( patch_size: int = 1, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, ) -> None: super().__init__() @@ -270,7 +277,7 @@ def __init__( in_channels=in_channels, embed_dim=inner_dim, interpolation_scale=None, - pos_embed_type=None, + pos_embed_type="sincos" if interpolation_scale is not None else None, ) # 2. Additional condition embeddings @@ -401,6 +408,23 @@ def forward( attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -499,6 +523,10 @@ def create_block_forward(block): batch_size, -1, post_patch_height * p, post_patch_width * p ) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) From 04c23399079b3e46a1aeb0f62133516343fb1d27 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 4 Jan 2025 10:05:15 -0600 Subject: [PATCH 2/2] deprecate flux-specific flow matching names for general flow-matching options --- OPTIONS.md | 28 +++--- documentation/quickstart/FLUX.md | 6 +- documentation/quickstart/SANA.md | 6 +- documentation/quickstart/SD3.md | 24 ++--- helpers/configuration/cmd_args.py | 92 ++++++++++++++++--- helpers/models/flux/__init__.py | 8 +- helpers/publishing/metadata.py | 32 +++---- .../training/default_settings/safety_check.py | 8 +- helpers/training/trainer.py | 22 ++--- tests/test_model_card.py | 16 ++-- tests/test_trainer.py | 22 ++++- 11 files changed, 172 insertions(+), 92 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index 79c77603..7a8bca46 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -381,13 +381,13 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] {pixart_sigma,kolors,sd3,flux,smoldit,sdxl,legacy} [--model_type {full,lora,deepfloyd-full,deepfloyd-lora,deepfloyd-stage2,deepfloyd-stage2-lora}] [--flux_lora_target {mmdit,context,context+ffs,all,all+ffs,ai-toolkit,tiny,nano}] - [--flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE] - [--flux_fast_schedule] [--flux_use_uniform_schedule] - [--flux_use_beta_schedule] - [--flux_beta_schedule_alpha FLUX_BETA_SCHEDULE_ALPHA] - [--flux_beta_schedule_beta FLUX_BETA_SCHEDULE_BETA] - [--flux_schedule_shift FLUX_SCHEDULE_SHIFT] - [--flux_schedule_auto_shift] + [--flow_sigmoid_scale flow_sigmoid_scale] + [--flux_fast_schedule] [--flow_use_uniform_schedule] + [--flow_use_beta_schedule] + [--flow_beta_schedule_alpha flow_beta_schedule_alpha] + [--flow_beta_schedule_beta flow_beta_schedule_beta] + [--flow_schedule_shift flow_schedule_shift] + [--flow_schedule_auto_shift] [--flux_guidance_mode {constant,random-range,mobius}] [--flux_guidance_value FLUX_GUIDANCE_VALUE] [--flux_guidance_min FLUX_GUIDANCE_MIN] @@ -602,29 +602,29 @@ options: norms (based on ostris/ai-toolkit). If 'tiny' is provided, only two layers will be trained. If 'nano' is provided, only one layers will be trained. - --flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE + --flow_sigmoid_scale flow_sigmoid_scale Scale factor for sigmoid timestep sampling for flow- matching models.. --flux_fast_schedule An experimental feature to train Flux.1S using a noise schedule closer to what it was trained with, which has improved results in short experiments. Thanks to @mhirki for the contribution. - --flux_use_uniform_schedule + --flow_use_uniform_schedule Whether or not to use a uniform schedule with Flux instead of sigmoid. Using uniform sampling may help preserve more capabilities from the base model. Some tasks may not benefit from this. - --flux_use_beta_schedule + --flow_use_beta_schedule Whether or not to use a beta schedule with Flux instead of sigmoid. The default values of alpha and beta approximate a sigmoid. - --flux_beta_schedule_alpha FLUX_BETA_SCHEDULE_ALPHA + --flow_beta_schedule_alpha flow_beta_schedule_alpha The alpha value of the flux beta schedule. Default is 2.0 - --flux_beta_schedule_beta FLUX_BETA_SCHEDULE_BETA + --flow_beta_schedule_beta flow_beta_schedule_beta The beta value of the flux beta schedule. Default is 2.0 - --flux_schedule_shift FLUX_SCHEDULE_SHIFT + --flow_schedule_shift flow_schedule_shift Shift the noise schedule. This is a value between 0 and ~4.0, where 0 disables the timestep-dependent shift, and anything greater than 0 will shift the @@ -637,7 +637,7 @@ options: contrast is learnt by the model, and whether fine details are ignored or accentuated, removing fine details and making the outputs blurrier. - --flux_schedule_auto_shift + --flow_schedule_auto_shift Shift the noise schedule depending on image resolution. The shift value calculation is taken from the official Flux inference code. Shift value is diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 97555310..d29e0c65 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -206,15 +206,15 @@ Flow-matching models such as Flux and SD3 have a property called "shift" that al By default, no schedule shift is applied to flux, which results in a sigmoid bell-shape to the timestep sampling distribution. This is unlikely to be the ideal approach for Flux, but it results in a greater amount of learning in a shorter period of time than auto-shift. ##### Auto-shift -A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flux_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results. +A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flow_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results. ##### Manual specification _Thanks to General Awareness from Discord for the following examples_ -When using a `--flux_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected: +When using a `--flow_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected: ![image](https://github.com/user-attachments/assets/991ca0ad-e25a-4b13-a3d6-b4f2de1fe982) -When using a `--flux_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted: +When using a `--flow_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted: ![image](https://github.com/user-attachments/assets/857a1f8a-07ab-4b75-8e6a-eecff616a28d) diff --git a/documentation/quickstart/SANA.md b/documentation/quickstart/SANA.md index 2dd355b9..5134531a 100644 --- a/documentation/quickstart/SANA.md +++ b/documentation/quickstart/SANA.md @@ -181,15 +181,15 @@ If you wish to enable evaluations to score the model's performance, see [this do Flow-matching models such as Sana, Sana, and SD3 have a property called "shift" that allows us to shift the trained portion of the timestep schedule using a simple decimal value. ##### Auto-shift -A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flux_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results. +A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, `--flow_schedule_auto_shift` which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results. ##### Manual specification _Thanks to General Awareness from Discord for the following examples_ -When using a `--flux_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected: +When using a `--flow_schedule_shift` value of 0.1 (a very low value), only the finer details of the image are affected: ![image](https://github.com/user-attachments/assets/991ca0ad-e25a-4b13-a3d6-b4f2de1fe982) -When using a `--flux_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted: +When using a `--flow_schedule_shift` value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted: ![image](https://github.com/user-attachments/assets/857a1f8a-07ab-4b75-8e6a-eecff616a28d) #### Dataset considerations diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 832eaa9f..0a5c8bab 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -195,12 +195,12 @@ In your `/home/user/simpletuner/config` directory, create a multidatabackend.jso "crop": true, "crop_aspect": "square", "crop_style": "center", - "resolution": 1.0, + "resolution": 1024, "minimum_image_size": 0, - "maximum_image_size": 1.0, - "target_downsample_size": 1.0, - "resolution_type": "area", - "cache_dir_vae": "cache/vae/sd3/pseudo-camera-10k", + "maximum_image_size": 1024, + "target_downsample_size": 1024, + "resolution_type": "pixel_area", + "cache_dir_vae": "/home/user/simpletuner/output/cache/vae/sd3/pseudo-camera-10k", "instance_data_dir": "/home/user/simpletuner/datasets/pseudo-camera-10k", "disabled": false, "skip_file_discovery": "", @@ -277,8 +277,8 @@ The following values are recommended for `config.json`: "--validation_guidance_skip_layers_stop": 0.2, "--validation_guidance_skip_scale": 2.8, "--validation_guidance": 4.0, - "--flux_use_uniform_schedule": true, - "--flux_schedule_auto_shift": true + "--flow_use_uniform_schedule": true, + "--flow_schedule_auto_shift": true } ``` @@ -309,17 +309,17 @@ Some changes were made to SimpleTuner's SD3.5 support: - No longer zeroing T5 padding space by default (`--t5_padding`) - Offering a switch (`--sd3_clip_uncond_behaviour` and `--sd3_t5_uncond_behaviour`) to use empty encoded blank captions for unconditional predictions (`empty_string`, **default**) or zeros (`zero`), not a recommended setting to tweak. - SD3.5 training loss function was updated to match that found in the upstream StabilityAI/SD3.5 repository -- Updated default `--flux_schedule_shift` value to 3 to match the static 1024px value for SD3 - - StabilityAI followed-up with documentation to use `--flux_schedule_shift=1` with `--flux_use_uniform_schedule` - - Community members have reported that `--flux_schedule_auto_shift` works better when using mult-aspect or multi-resolution training -- Updated the hard-coded tokeniser sequence length limit to **256** with the option to revert it to **77** tokens to save disk space or compute at the cost of output quality degradation +- Updated default `--flow_schedule_shift` value to 3 to match the static 1024px value for SD3 + - StabilityAI followed-up with documentation to use `--flow_schedule_shift=1` with `--flow_use_uniform_schedule` + - Community members have reported that `--flow_schedule_auto_shift` works better when using mult-aspect or multi-resolution training +- Updated the hard-coded tokeniser sequence length limit to **154** with the option to revert it to **77** tokens to save disk space or compute at the cost of output quality degradation #### Stable configuration values These options have been known to keep SD3.5 in-tact for as long as possible: - optimizer=adamw_bf16 -- flux_schedule_shift=1 +- flow_schedule_shift=1 - learning_rate=1e-4 - batch_size=4 * 3 GPUs - max_grad_norm=0.1 diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 5601df3f..1d1db93a 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -132,7 +132,7 @@ def get_argument_parser(): ], default="all", help=( - "Flux has single and joint attention blocks." + "This option only applies to Standard LoRA, not Lycoris. Flux has single and joint attention blocks." " By default, all attention layers are trained, but not the feed-forward layers" " If 'mmdit' is provided, the text input layers will not be trained." " If 'context' is provided, then ONLY the text attention layers are trained" @@ -147,8 +147,14 @@ def get_argument_parser(): parser.add_argument( "--flow_matching_sigmoid_scale", type=float, + default=None, + help="Deprecated option. Replaced with --flow_sigmoid_scale.", + ) + parser.add_argument( + "--flow_sigmoid_scale", + type=float, default=1.0, - help="Scale factor for sigmoid timestep sampling for flow-matching models..", + help="Scale factor for sigmoid timestep sampling for flow-matching models.", ) parser.add_argument( "--flux_fast_schedule", @@ -160,9 +166,16 @@ def get_argument_parser(): ) parser.add_argument( "--flux_use_uniform_schedule", + default=None, + action="store_true", + help="Deprecated option. Replaced with --flow_use_uniform_schedule.", + ) + parser.add_argument( + "--flow_use_uniform_schedule", + default=False, action="store_true", help=( - "Whether or not to use a uniform schedule with Flux instead of sigmoid." + "Whether or not to use a uniform schedule with flow-matching models instead of sigmoid." " Using uniform sampling may help preserve more capabilities from the base model." " Some tasks may not benefit from this." ), @@ -170,6 +183,13 @@ def get_argument_parser(): parser.add_argument( "--flux_use_beta_schedule", action="store_true", + default=None, + help="Deprecated option. Replaced with --flow_use_beta_schedule.", + ) + parser.add_argument( + "--flow_use_beta_schedule", + action="store_true", + default=False, help=( "Whether or not to use a beta schedule with Flux instead of sigmoid. The default values of alpha" " and beta approximate a sigmoid." @@ -178,31 +198,54 @@ def get_argument_parser(): parser.add_argument( "--flux_beta_schedule_alpha", type=float, + default=None, + help=("Deprecated option. Replaced with --flux_beta_schedule_alpha."), + ) + parser.add_argument( + "--flow_beta_schedule_alpha", + type=float, default=2.0, help=("The alpha value of the flux beta schedule. Default is 2.0"), ) parser.add_argument( "--flux_beta_schedule_beta", type=float, + default=None, + help=("Deprecated option. Replaced with --flow_beta_schedule_beta."), + ) + parser.add_argument( + "--flow_beta_schedule_beta", + type=float, default=2.0, help=("The beta value of the flux beta schedule. Default is 2.0"), ) parser.add_argument( "--flux_schedule_shift", type=float, + default=None, + help=("Deprecated option. Replaced with --flow_schedule_shift."), + ) + parser.add_argument( + "--flow_schedule_shift", + type=float, default=3, help=( "Shift the noise schedule. This is a value between 0 and ~4.0, where 0 disables the timestep-dependent shift," - " and anything greater than 0 will shift the timestep sampling accordingly. The SD3 model was trained with" - " a shift value of 3. The value for Flux is unknown. Higher values result in less noisy timesteps sampled," - " which results in a lower mean loss value, but not necessarily better results. Early reports indicate" - " that modification of this value can change how the contrast is learnt by the model, and whether fine" - " details are ignored or accentuated, removing fine details and making the outputs blurrier." + " and anything greater than 0 will shift the timestep sampling accordingly. Sana and SD3 were trained with" + " a shift value of 3. This value can change how contrast/brightness are learnt by the model, and whether fine" + " details are ignored or accentuated. A higher value will focus more on large compositional features," + " and a lower value will focus on the high frequency fine details." ), ) parser.add_argument( "--flux_schedule_auto_shift", action="store_true", + default=None, + help="Deprecated option. Replaced with --flow_schedule_auto_shift.", + ) + parser.add_argument( + "--flow_schedule_auto_shift", + action="store_true", default=False, help=( "Shift the noise schedule depending on image resolution. The shift value calculation is taken from the official" @@ -214,17 +257,12 @@ def get_argument_parser(): parser.add_argument( "--flux_guidance_mode", type=str, - choices=["constant", "random-range", "mobius"], + choices=["constant", "random-range"], default="constant", help=( "Flux has a 'guidance' value used during training time that reflects the CFG range of your training samples." " The default mode 'constant' will use a single value for every sample." " The mode 'random-range' will randomly select a value from the range of the CFG for each sample." - " The mode 'mobius' will use a value that is a function of the remaining steps in the epoch, constructively" - " deconstructing the constructed deconstructions to then Mobius them back into the constructed reconstructions," - " possibly resulting in the exploration of what is known as the Mobius space, a new continuous" - " realm of possibility brought about by destroying the model so that you can make it whole once more." - " Or so according to DataVoid, anyway. This is just a Flux-specific implementation of Mobius." " Set the range using --flux_guidance_min and --flux_guidance_max." ), ) @@ -2595,4 +2633,30 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False): "For non-CUDA systems, only Diffusers attention mechanism is officially supported." ) + deprecated_options = { + "flux_beta_schedule_alpha": "flow_beta_schedule_alpha", + "flux_beta_schedule_beta": "flow_beta_schedule_beta", + "flux_use_beta_schedule": "flow_use_beta_schedule", + "flux_use_uniform_schedule": "flow_use_uniform_schedule", + "flux_schedule_shift": "flow_schedule_shift", + "flux_schedule_auto_shift": "flow_schedule_auto_shift", + "flow_matching_sigmoid_scale": "flow_sigmoid_scale", + } + + for deprecated_option, replacement_option in deprecated_options.items(): + if ( + getattr(args, replacement_option) is not None + and getattr(args, deprecated_option) is not None + and type(getattr(args, deprecated_option)) is not object + ): + warning_log( + f"The option --{deprecated_option} has been replaced with --{replacement_option}." + ) + setattr(args, replacement_option, getattr(args, deprecated_option)) + elif getattr(args, deprecated_option) is not None: + error_log( + f"The option {deprecated_option} has been deprecated without a replacement option. Please remove it from your configuration." + ) + sys.exit(1) + return args diff --git a/helpers/models/flux/__init__.py b/helpers/models/flux/__init__.py index 2b4c16e9..fb20b88b 100644 --- a/helpers/models/flux/__init__.py +++ b/helpers/models/flux/__init__.py @@ -8,13 +8,13 @@ ) -def apply_flux_schedule_shift(args, noise_scheduler, sigmas, noise): +def apply_flow_schedule_shift(args, noise_scheduler, sigmas, noise): # Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper shift = None - if args.flux_schedule_shift is not None and args.flux_schedule_shift > 0: + if args.flow_schedule_shift is not None and args.flow_schedule_shift > 0: # Static shift value for every resolution - shift = args.flux_schedule_shift - elif args.flux_schedule_auto_shift: + shift = args.flow_schedule_shift + elif args.flow_schedule_auto_shift: # Resolution-dependent shift value calculation used by official Flux inference implementation image_seq_len = (noise.shape[-1] * noise.shape[-2]) // 4 mu = calculate_shift_flux( diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index ea7c00b6..f33b2710 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -325,10 +325,10 @@ def flux_schedule_info(args): output_args = [] if args.flux_fast_schedule: output_args.append("flux_fast_schedule") - if args.flux_schedule_auto_shift: - output_args.append("flux_schedule_auto_shift") - if args.flux_schedule_shift is not None: - output_args.append(f"shift={args.flux_schedule_shift}") + if args.flow_schedule_auto_shift: + output_args.append("flow_schedule_auto_shift") + if args.flow_schedule_shift is not None: + output_args.append(f"shift={args.flow_schedule_shift}") output_args.append(f"flux_guidance_mode={args.flux_guidance_mode}") if args.flux_guidance_value: output_args.append(f"flux_guidance_value={args.flux_guidance_value}") @@ -337,9 +337,9 @@ def flux_schedule_info(args): if args.flux_guidance_mode == "random-range": output_args.append(f"flux_guidance_max={args.flux_guidance_max}") output_args.append(f"flux_guidance_min={args.flux_guidance_min}") - if args.flux_use_beta_schedule: - output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}") - output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") + if args.flow_use_beta_schedule: + output_args.append(f"flow_beta_schedule_alpha={args.flow_beta_schedule_alpha}") + output_args.append(f"flow_beta_schedule_beta={args.flow_beta_schedule_beta}") if args.flux_attention_masked_training: output_args.append("flux_attention_masked_training") if args.t5_padding != "unmodified": @@ -364,15 +364,15 @@ def sd3_schedule_info(args): if args.model_family.lower() != "sd3": return "" output_args = [] - if args.flux_schedule_auto_shift: - output_args.append("flux_schedule_auto_shift") - if args.flux_schedule_shift is not None: - output_args.append(f"shift={args.flux_schedule_shift}") - if args.flux_use_beta_schedule: - output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}") - output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") - if args.flux_use_uniform_schedule: - output_args.append(f"flux_use_uniform_schedule") + if args.flow_schedule_auto_shift: + output_args.append("flow_schedule_auto_shift") + if args.flow_schedule_shift is not None: + output_args.append(f"shift={args.flow_schedule_shift}") + if args.flow_use_beta_schedule: + output_args.append(f"flow_beta_schedule_alpha={args.flow_beta_schedule_alpha}") + output_args.append(f"flow_beta_schedule_beta={args.flow_beta_schedule_beta}") + if args.flow_use_uniform_schedule: + output_args.append(f"flow_use_uniform_schedule") # if args.model_type == "lora" and args.lora_type == "standard": # output_args.append(f"flux_lora_target={args.flux_lora_target}") output_str = ( diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 73d5d824..d37a5f49 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -109,12 +109,12 @@ def safety_check(args, accelerator): sys.exit(1) if ( - args.flux_schedule_shift is not None - and args.flux_schedule_shift > 0 - and args.flux_schedule_auto_shift + args.flow_schedule_shift is not None + and args.flow_schedule_shift > 0 + and args.flow_schedule_auto_shift ): logger.error( - f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." + f"--flow_schedule_auto_shift cannot be combined with --flow_schedule_shift. Please set --flow_schedule_shift to 0 if you want to train with --flow_schedule_auto_shift." ) sys.exit(1) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 6de9c34f..6b006ac3 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -121,7 +121,7 @@ pack_latents, unpack_latents, get_mobius_guidance, - apply_flux_schedule_shift, + apply_flow_schedule_shift, ) is_optimi_available = False @@ -2416,28 +2416,28 @@ def train(self): if self.config.flow_matching: if not self.config.flux_fast_schedule and not any( [ - self.config.flux_use_beta_schedule, - self.config.flux_use_uniform_schedule, + self.config.flow_use_beta_schedule, + self.config.flow_use_uniform_schedule, ] ): # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF # also used by: https://github.com/XLabs-AI/x-flux/tree/main # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f sigmas = torch.sigmoid( - self.config.flow_matching_sigmoid_scale + self.config.flow_sigmoid_scale * torch.randn((bsz,), device=self.accelerator.device) ) - sigmas = apply_flux_schedule_shift( + sigmas = apply_flow_schedule_shift( self.config, self.noise_scheduler, sigmas, noise ) - elif self.config.flux_use_uniform_schedule: + elif self.config.flow_use_uniform_schedule: sigmas = torch.rand((bsz,), device=self.accelerator.device) - sigmas = apply_flux_schedule_shift( + sigmas = apply_flow_schedule_shift( self.config, self.noise_scheduler, sigmas, noise ) - elif self.config.flux_use_beta_schedule: - alpha = self.config.flux_beta_schedule_alpha - beta = self.config.flux_beta_schedule_beta + elif self.config.flow_use_beta_schedule: + alpha = self.config.flow_beta_schedule_alpha + beta = self.config.flow_beta_schedule_beta # Create a Beta distribution instance beta_dist = Beta(alpha, beta) @@ -2447,7 +2447,7 @@ def train(self): device=self.accelerator.device ) - sigmas = apply_flux_schedule_shift( + sigmas = apply_flow_schedule_shift( self.config, self.noise_scheduler, sigmas, noise ) else: diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 2e2ed5b1..28b5c206 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -38,16 +38,16 @@ def setUp(self): self.args.validation_using_datasets = False self.args.flow_matching_loss = "compatible" self.args.flux_fast_schedule = False - self.args.flux_schedule_auto_shift = False - self.args.flux_schedule_shift = None + self.args.flow_schedule_auto_shift = False + self.args.flow_schedule_shift = None self.args.flux_guidance_value = None self.args.flux_guidance_min = None self.args.flux_guidance_max = None - self.args.flux_use_beta_schedule = False - self.args.flux_beta_schedule_alpha = None - self.args.flux_beta_schedule_beta = None + self.args.flow_use_beta_schedule = False + self.args.flow_beta_schedule_alpha = None + self.args.flow_beta_schedule_beta = None self.args.flux_attention_masked_training = False - self.args.flux_use_uniform_schedule = False + self.args.flow_use_uniform_schedule = False self.args.flux_lora_target = None self.args.validation_guidance_skip_layers = None self.args.validation_seed = 1234 @@ -221,9 +221,9 @@ def test_sd3_schedule_info(self): output = sd3_schedule_info(self.args) self.assertIn("(no special parameters set)", output) - self.args.flux_schedule_auto_shift = True + self.args.flow_schedule_auto_shift = True output = sd3_schedule_info(self.args) - self.assertIn("flux_schedule_auto_shift", output) + self.assertIn("flow_schedule_auto_shift", output) def test_model_schedule_info(self): with patch( diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 3a036762..183a42ae 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -124,23 +124,39 @@ def test_stats_memory_used_none( @patch("accelerate.state.AcceleratorState", Mock()) @patch( "argparse.ArgumentParser.parse_args", - return_value=MagicMock( + return_value=Mock( torch_num_threads=2, train_batch_size=1, weight_dtype=torch.float32, + model_type="full", optimizer="adamw_bf16", + optimizer_config=None, max_train_steps=2, num_train_epochs=0, timestep_bias_portion=0, metadata_update_interval=100, gradient_accumulation_steps=1, + validation_resolution=1024, mixed_precision="bf16", report_to="none", output_dir="output_dir", - flux_schedule_shift=3, - flux_schedule_auto_shift=False, + logging_dir="logging_dir", + learning_rate=1, + flow_schedule_shift=3, + user_prompt_library=None, + flow_schedule_auto_shift=False, validation_guidance_skip_layers=None, + pretrained_model_name_or_path="some/path", + base_model_precision="no_change", gradient_checkpointing_interval=None, + # deprecated options + flux_beta_schedule_alpha=None, + flux_beta_schedule_beta=None, + flux_use_beta_schedule=None, + flux_use_uniform_schedule=None, + flux_schedule_shift=None, + flux_schedule_auto_shift=None, + flow_matching_sigmoid_scale=None, ), ) def test_misc_init(