From 2643c74ef61f124948b06c9f8b692238bffe5c55 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 04:36:40 -0800 Subject: [PATCH 1/8] autoencoder_dc tiling --- .../models/autoencoders/autoencoder_dc.py | 107 +++++++++++++++++- 1 file changed, 101 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 109e37c23e1b..89465d704cb0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -479,12 +479,15 @@ def __init__( self.use_tiling = False # The minimal tile height and width for spatial tiling to be used - self.tile_sample_min_height = 512 - self.tile_sample_min_width = 512 + self.tile_sample_min_height = 1024 + self.tile_sample_min_width = 1024 # The minimal distance between two spatial tiles - self.tile_sample_stride_height = 448 - self.tile_sample_stride_width = 448 + self.tile_sample_stride_height = 896 + self.tile_sample_stride_width = 896 + + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio def enable_tiling( self, @@ -515,6 +518,8 @@ def enable_tiling( self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio def disable_tiling(self) -> None: r""" @@ -606,11 +611,101 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor: - raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.") + batch_size, num_channels, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, x.shape[2], self.tile_sample_stride_height): + row = [] + for j in range(0, x.shape[3], self.tile_sample_stride_width): + tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0: + tile = F.pad(tile, (0, (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio, 0, (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio)) + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + + if not return_dict: + return (encoded,) + return EncoderOutput(latent=encoded) def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.") + batch_size, num_channels, height, width = z.shape + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + decoded = torch.cat(result_rows, dim=2) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor: encoded = self.encode(sample, return_dict=False)[0] From c0b1ca5300f964fcb5d91e856a8e7b041400087d Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 19:06:54 -0800 Subject: [PATCH 2/8] add tiling and slicing support in SANA pipelines --- src/diffusers/pipelines/sana/pipeline_sana.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index afc2f74c9e8f..8b318597c12d 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -218,6 +218,35 @@ def __init__( ) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def encode_prompt( self, prompt: Union[str, List[str]], From d80dea51e46320d93bb17e779b26d320d60eb665 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 19:12:48 -0800 Subject: [PATCH 3/8] create variables for padding length because the line becomes too long --- src/diffusers/models/autoencoders/autoencoder_dc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 89465d704cb0..7f6aeef4d8fd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -643,7 +643,9 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tenso for j in range(0, x.shape[3], self.tile_sample_stride_width): tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0: - tile = F.pad(tile, (0, (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio, 0, (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio)) + pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio + pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio + tile = F.pad(tile, (0, pad_w, 0, pad_h)) tile = self.encoder(tile) row.append(tile) rows.append(row) From c3b9a8ef81553a78763482ad7b82e5ced4e489d9 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 19:35:23 -0800 Subject: [PATCH 4/8] add tiling and slicing support in pag SANA pipelines --- .../pipelines/pag/pipeline_pag_sana.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index f363a1a557bc..2cdc1c70cdcc 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -183,6 +183,35 @@ def __init__( pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()), ) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def encode_prompt( self, prompt: Union[str, List[str]], From 2fad76292eae31c3bc1932ad2f530949658cffa7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 11 Jan 2025 02:20:21 +0100 Subject: [PATCH 5/8] revert changes to tile size --- src/diffusers/models/autoencoders/autoencoder_dc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 7f6aeef4d8fd..58bddb981d1b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -479,12 +479,12 @@ def __init__( self.use_tiling = False # The minimal tile height and width for spatial tiling to be used - self.tile_sample_min_height = 1024 - self.tile_sample_min_width = 1024 + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 # The minimal distance between two spatial tiles - self.tile_sample_stride_height = 896 - self.tile_sample_stride_width = 896 + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio From d2f2d2bca74867bd38ed62cd76b7b71f5722832d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 11 Jan 2025 02:21:32 +0100 Subject: [PATCH 6/8] make style --- src/diffusers/models/autoencoders/autoencoder_dc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 58bddb981d1b..1e6a26dddca8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -627,7 +627,7 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tenso batch_size, num_channels, height, width = x.shape latent_height = height // self.spatial_compression_ratio latent_width = width // self.spatial_compression_ratio - + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio @@ -642,7 +642,10 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tenso row = [] for j in range(0, x.shape[3], self.tile_sample_stride_width): tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] - if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0: + if ( + tile.shape[2] % self.spatial_compression_ratio != 0 + or tile.shape[3] % self.spatial_compression_ratio != 0 + ): pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio tile = F.pad(tile, (0, pad_w, 0, pad_h)) From 477937e6fae932144dfa5959a198ac354334bd41 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 11 Jan 2025 02:28:01 +0100 Subject: [PATCH 7/8] add vae tiling test --- tests/pipelines/sana/test_sana.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index 21de4e04437a..7109a700403c 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -254,6 +254,36 @@ def test_attention_slicing_forward_pass( "Attention slicing should not affect the inference results", ) + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + # TODO(aryan): Create a dummy gemma model with smol vocab size @unittest.skip( "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." From bba83a4ab8848763f73f7232fa4be5c425182243 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Wed, 15 Jan 2025 21:53:04 -0800 Subject: [PATCH 8/8] fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16 --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef26e..967ebf8649ba 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -899,7 +899,7 @@ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, valu scores = torch.matmul(key.transpose(-1, -2), query) scores = scores.to(dtype=torch.float32) scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps) - hidden_states = torch.matmul(value, scores) + hidden_states = torch.matmul(value, scores.to(value.dtype)) return hidden_states def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: