Skip to content

Commit

Permalink
Merge branch 'main' into dreambooth-lora-flux-exploration
Browse files Browse the repository at this point in the history
  • Loading branch information
linoytsaban authored Oct 7, 2024
2 parents 9a83f27 + 1287822 commit f110e4e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
10 changes: 9 additions & 1 deletion examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--upcast_vae",
action="store_true",
help="Whether or not to upcast vae to fp32",
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
weight_dtype = torch.bfloat16

# Move vae, transformer and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=torch.float32)
if args.upcast_vae:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def load_sub_model(
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""

Expand Down Expand Up @@ -670,6 +671,7 @@ def load_sub_model(
loading_kwargs["offload_folder"] = offload_folder
loading_kwargs["offload_state_dict"] = offload_state_dict
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors

if from_flax:
loading_kwargs["from_flax"] = True
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def load_module(name, value):
variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
Expand Down
25 changes: 17 additions & 8 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,9 @@ def __getattr__(cls, key):
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Args:
Compares a library version to some requirement using a given operation.
Args:
library_or_version (`str` or `packaging.version.Version`):
A library name or a version to check.
operation (`str`):
Expand All @@ -688,8 +689,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
def is_torch_version(operation: str, version: str):
"""
Args:
Compares the current PyTorch version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
Expand All @@ -700,8 +702,9 @@ def is_torch_version(operation: str, version: str):

def is_transformers_version(operation: str, version: str):
"""
Args:
Compares the current Transformers version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
Expand All @@ -714,8 +717,9 @@ def is_transformers_version(operation: str, version: str):

def is_accelerate_version(operation: str, version: str):
"""
Args:
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
Expand All @@ -728,8 +732,9 @@ def is_accelerate_version(operation: str, version: str):

def is_peft_version(operation: str, version: str):
"""
Args:
Compares the current PEFT version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
Expand All @@ -742,8 +747,9 @@ def is_peft_version(operation: str, version: str):

def is_k_diffusion_version(operation: str, version: str):
"""
Args:
Compares the current k-diffusion version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
Expand All @@ -756,8 +762,9 @@ def is_k_diffusion_version(operation: str, version: str):

def get_objects_from_module(module):
"""
Args:
Returns a dict of object names and values in a module, while skipping private/internal objects
Args:
module (ModuleType):
Module to extract the objects from.
Expand All @@ -775,7 +782,9 @@ def get_objects_from_module(module):


class OptionalDependencyNotAvailable(BaseException):
"""An error indicating that an optional dependency of Diffusers was not found in the environment."""
"""
An error indicating that an optional dependency of Diffusers was not found in the environment.
"""


class _LazyModule(ModuleType):
Expand Down

0 comments on commit f110e4e

Please sign in to comment.