diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 7bec9c799cae..7956efb4471e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -158,6 +158,9 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + if args.enable_vae_tiling: + pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024) + pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -597,6 +600,7 @@ def parse_args(input_args=None): help="Whether to offload the VAE and the text encoder to CPU when they are not used.", ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation") if input_args is not None: args = parser.parse_args(input_args)