Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor autoencoders] feat: introduce autoencoders module #6129

Merged
merged 7 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/api/models/asymmetricautoencoderkl.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3)

## AsymmetricAutoencoderKL

[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
[[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL

## AutoencoderKLOutput

[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput

## DecoderOutput

[[autodoc]] models.vae.DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
2 changes: 1 addition & 1 deletion docs/source/en/api/models/autoencoder_tiny.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ image

## AutoencoderTinyOutput

[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
[[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput
4 changes: 2 additions & 2 deletions docs/source/en/api/models/autoencoderkl.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url)

## AutoencoderKLOutput

[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput

## DecoderOutput

[[autodoc]] models.vae.DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

## FlaxAutoencoderKL

Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_consistency_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from tqdm import tqdm

from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
from diffusers.models.vae import Encoder


args = ArgumentParser()
Expand Down
22 changes: 12 additions & 10 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
Expand Down Expand Up @@ -58,11 +58,13 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .controlnetxs import ControlNetXSModel
from .dual_transformer_2d import DualTransformer2DModel
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@
import torch.nn.functional as F
from torch import nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers import ConsistencyDecoderScheduler
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from ..utils.torch_utils import randn_tensor
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...schedulers import ConsistencyDecoderScheduler
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_utils import ModelMixin
from .unet_2d import UNet2DModel
from ..modeling_utils import ModelMixin
from ..unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder


Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
self.use_slicing = False
self.use_tiling = False

# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand All @@ -162,23 +162,23 @@ def enable_tiling(self, use_tiling: bool = True):
"""
self.use_tiling = use_tiling

# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)

# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
Expand Down Expand Up @@ -333,14 +333,14 @@ def decode(

return DecoderOutput(sample=x_0)

# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import torch
import torch.nn as nn

from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
from .unet_2d_blocks import (
from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
from ..unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnetxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .attention_processor import (
AttentionProcessor,
)
from .autoencoder_kl import AutoencoderKL
from .autoencoders import AutoencoderKL
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput
from ...utils.accelerate_utils import apply_forward_hook

Expand Down