From 012d08b1bcd74abbc05a9ef163e41c99bf0e6b2e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Jan 2025 16:39:45 +0800 Subject: [PATCH] Enable dreambooth lora finetune example on other devices (#10602) * enable dreambooth_lora on other devices Signed-off-by: jiqing-feng * enable xpu Signed-off-by: jiqing-feng * check cuda device before empty cache Signed-off-by: jiqing-feng * fix comment Signed-off-by: jiqing-feng * import free_memory Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- examples/dreambooth/train_dreambooth_lora.py | 19 +++++++++++-------- src/diffusers/training_utils.py | 2 ++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 8175b7614429..83a24b778083 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -54,7 +54,11 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + free_memory, +) from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, @@ -151,14 +155,14 @@ def log_validation( if args.validation_images is None: images = [] for _ in range(args.num_validation_images): - with torch.cuda.amp.autocast(): + with torch.amp.autocast(accelerator.device.type): image = pipeline(**pipeline_args, generator=generator).images[0] images.append(image) else: images = [] for image in args.validation_images: image = Image.open(image) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(accelerator.device.type): image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) @@ -177,7 +181,7 @@ def log_validation( ) del pipeline - torch.cuda.empty_cache() + free_memory() return images @@ -793,7 +797,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -829,8 +833,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt): tokenizer = None gc.collect() - torch.cuda.empty_cache() + free_memory() else: pre_computed_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 2474ed5c2114..082640f37a17 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -299,6 +299,8 @@ def free_memory(): torch.mps.empty_cache() elif is_torch_npu_available(): torch_npu.npu.empty_cache() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14