Skip to content

Commit

Permalink
autoencoder_dc tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjy2003 committed Jan 9, 2025
1 parent 95c5ce4 commit 2643c74
Showing 1 changed file with 101 additions and 6 deletions.
107 changes: 101 additions & 6 deletions src/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,15 @@ def __init__(
self.use_tiling = False

# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_height = 1024
self.tile_sample_min_width = 1024

# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_height = 896
self.tile_sample_stride_width = 896

self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

def enable_tiling(
self,
Expand Down Expand Up @@ -515,6 +518,8 @@ def enable_tiling(
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

def disable_tiling(self) -> None:
r"""
Expand Down Expand Up @@ -606,11 +611,101 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
return (decoded,)
return DecoderOutput(sample=decoded)

def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
batch_size, num_channels, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, x.shape[2], self.tile_sample_stride_height):
row = []
for j in range(0, x.shape[3], self.tile_sample_stride_width):
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0:
tile = F.pad(tile, (0, (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio, 0, (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio))
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=3))

encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]

if not return_dict:
return (encoded,)
return EncoderOutput(latent=encoded)

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
batch_size, num_channels, height, width = z.shape

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=3))

decoded = torch.cat(result_rows, dim=2)

if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)

def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
encoded = self.encode(sample, return_dict=False)[0]
Expand Down

0 comments on commit 2643c74

Please sign in to comment.