Skip to content

Commit

Permalink
Enable dreambooth lora finetune example on other devices (#10602)
Browse files Browse the repository at this point in the history
* enable dreambooth_lora on other devices

Signed-off-by: jiqing-feng <[email protected]>

* enable xpu

Signed-off-by: jiqing-feng <[email protected]>

* check cuda device before empty cache

Signed-off-by: jiqing-feng <[email protected]>

* fix comment

Signed-off-by: jiqing-feng <[email protected]>

* import free_memory

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng authored Jan 21, 2025
1 parent 4ace7d0 commit 012d08b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
19 changes: 11 additions & 8 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
free_memory,
)
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
Expand Down Expand Up @@ -151,14 +155,14 @@ def log_validation(
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

Expand All @@ -177,7 +181,7 @@ def log_validation(
)

del pipeline
torch.cuda.empty_cache()
free_memory()

return images

Expand Down Expand Up @@ -793,7 +797,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
Expand Down Expand Up @@ -829,8 +833,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
tokenizer = None

gc.collect()
torch.cuda.empty_cache()
free_memory()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def free_memory():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
Expand Down

0 comments on commit 012d08b

Please sign in to comment.