Skip to content

Commit

Permalink
Add option to use predefined aspect ratio buckets in the cropping tra…
Browse files Browse the repository at this point in the history
…nsform (#157)
  • Loading branch information
coryMosaicML authored Jul 26, 2024
1 parent ef74f2b commit adebf01
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 10 deletions.
34 changes: 25 additions & 9 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams
from diffusion.models.text_encoder import MultiTokenizer

Expand Down Expand Up @@ -45,6 +46,7 @@ class StreamingImageCaptionDataset(StreamingDataset):
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
Expand All @@ -63,6 +65,7 @@ def __init__(
transform: Optional[Callable] = None,
image_key: str = 'image',
caption_key: str = 'caption',
aspect_ratio_bucket_key: Optional[str] = None,
sdxl_conditioning: bool = False,
zero_dropped_captions: bool = False,
**streaming_kwargs,
Expand Down Expand Up @@ -90,6 +93,9 @@ def __init__(
self.caption_selection = caption_selection
self.image_key = image_key
self.caption_key = caption_key
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.zero_dropped_captions = zero_dropped_captions

self.tokenizer = tokenizer
Expand All @@ -107,7 +113,9 @@ def __getitem__(self, index):
orig_w, orig_h = img.size

# Image transforms
if self.crop is not None:
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
else:
crop_top, crop_left = 0, 0
Expand Down Expand Up @@ -179,6 +187,7 @@ def build_streaming_image_caption_dataloader(
transform: Optional[List[Callable]] = None,
image_key: str = 'image',
caption_key: str = 'caption',
aspect_ratio_bucket_key: Optional[str] = None,
crop_type: Optional[str] = 'square',
zero_dropped_captions: bool = True,
sdxl_conditioning: bool = False,
Expand Down Expand Up @@ -212,7 +221,8 @@ def build_streaming_image_caption_dataloader(
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
Default: ``'square'``.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``.
sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`.
Expand All @@ -225,12 +235,14 @@ def build_streaming_image_caption_dataloader(
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
if crop_type not in ['square', 'random', 'aspect_ratio']:
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
)
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
isinstance(resize_size[0], int)):
raise ValueError(
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
# Handle ``None`` kwargs
if streaming_kwargs is None:
streaming_kwargs = {}
Expand All @@ -246,7 +258,10 @@ def build_streaming_image_caption_dataloader(
elif crop_type == 'random':
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand All @@ -265,6 +280,7 @@ def build_streaming_image_caption_dataloader(
transform=transform,
image_key=image_key,
caption_key=caption_key,
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
batch_size=batch_size,
sdxl_conditioning=sdxl_conditioning,
zero_dropped_captions=zero_dropped_captions,
Expand Down
49 changes: 48 additions & 1 deletion diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Transforms for the training and eval dataset."""

import math
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -45,7 +46,7 @@ def __call__(self, img):
return img, c_top, c_left


class RandomCropAspectRatioTransorm:
class RandomCropAspectRatioTransform:
"""Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket.
Args:
Expand Down Expand Up @@ -111,3 +112,49 @@ def __call__(self, img):
c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width))
img = crop(img, c_top, c_left, height, width)
return img, c_top, c_left


class RandomCropBucketedAspectRatioTransform:
"""Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket.
This transform requires the desired aspect ratio bucket to be specified manually in the call to the transform.
Args:
resize_size (Tuple[Tuple[int, int], ...): A tuple of 2-tuple integers representing the aspect ratio buckets.
The format is ((height_bucket1, width_bucket1), (height_bucket2, width_bucket2), ...).
"""

def __init__(
self,
resize_size: Tuple[Tuple[int, int], ...],
):
self.height_buckets = torch.tensor([size[0] for size in resize_size])
self.width_buckets = torch.tensor([size[1] for size in resize_size])
self.aspect_ratio_buckets = self.height_buckets / self.width_buckets
self.log_aspect_ratio_buckets = torch.log(self.aspect_ratio_buckets)

def __call__(self, img, aspect_ratio):
orig_w, orig_h = img.size
orig_aspect_ratio = orig_h / orig_w
# Figure out target H/W given the input aspect ratio
bucket_ind = torch.abs(self.log_aspect_ratio_buckets - math.log(aspect_ratio)).argmin()
target_width, target_height = self.width_buckets[bucket_ind].item(), self.height_buckets[bucket_ind].item()
target_aspect_ratio = target_height / target_width

# Determine resize size
if orig_aspect_ratio > target_aspect_ratio:
# Resize width and crop height
w_scale = target_width / orig_w
resize_size = (round(w_scale * orig_h), target_width)
elif orig_aspect_ratio < target_aspect_ratio:
# Resize height and crop width
h_scale = target_height / orig_h
resize_size = (target_height, round(h_scale * orig_w))
else:
resize_size = (target_height, target_width)
img = transforms.functional.resize(img, resize_size, antialias=True)

# Crop based on aspect ratio
c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width))
img = crop(img, c_top, c_left, height, width)
return img, c_top, c_left

0 comments on commit adebf01

Please sign in to comment.