From b0c8973834717f8f52ea5384a8c31de3a88f4d59 Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:36:07 -0700 Subject: [PATCH] [Sana 4K] Add vae tiling option to avoid OOM (#10583) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: J石页 --- examples/dreambooth/train_dreambooth_lora_sana.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 7bec9c799cae..7956efb4471e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -158,6 +158,9 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + if args.enable_vae_tiling: + pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024) + pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -597,6 +600,7 @@ def parse_args(input_args=None): help="Whether to offload the VAE and the text encoder to CPU when they are not used.", ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation") if input_args is not None: args = parser.parse_args(input_args)