From 5d90dfa6a965b5382c96db3c34393c067b6756e4 Mon Sep 17 00:00:00 2001 From: charchit7 Date: Wed, 9 Oct 2024 00:03:13 +0530 Subject: [PATCH] refactor image_processor file --- src/diffusers/image_processor.py | 308 +++++++++++++++++++++++++------ 1 file changed, 255 insertions(+), 53 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index d58bd9e3e3758..280048bd2379f 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -38,16 +38,44 @@ PipelineDepthInput = PipelineImageInput -def is_valid_image(image): +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a numpy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) def is_valid_image_imagelist(images): - # check if the image input is one of the supported formats for image and image list: - # it can be either one of below 3 - # (1) a 4d pytorch tensor or numpy array, - # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor - # (3) a list of valid image + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: return True elif is_valid_image(images): @@ -103,8 +131,16 @@ def __init__( @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: - """ + r""" Convert a numpy image or a batch of images to a PIL image. + + Args: + images (`np.ndarray`): + The image array to convert to PIL format. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images. """ if images.ndim == 3: images = images[None, ...] @@ -119,8 +155,16 @@ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: @staticmethod def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: - """ + r""" Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`PIL.Image.Image` or `List[PIL.Image.Image]`): + The PIL image or list of images to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. """ if not isinstance(images, list): images = [images] @@ -131,8 +175,16 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd @staticmethod def numpy_to_pt(images: np.ndarray) -> torch.Tensor: - """ + r""" Convert a NumPy image to a PyTorch tensor. + + Args: + images (`np.ndarray`): + The NumPy image array to convert to PyTorch format. + + Returns: + `torch.Tensor`: + A PyTorch tensor representation of the images. """ if images.ndim == 3: images = images[..., None] @@ -142,30 +194,62 @@ def numpy_to_pt(images: np.ndarray) -> torch.Tensor: @staticmethod def pt_to_numpy(images: torch.Tensor) -> np.ndarray: - """ + r""" Convert a PyTorch tensor to a NumPy image. + + Args: + images (`torch.Tensor`): + The PyTorch tensor to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @staticmethod def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """ + r""" Normalize an image array to [-1,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to normalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The normalized image array. """ return 2.0 * images - 1.0 @staticmethod def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """ + r""" Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The denormalized image array. """ return (images / 2 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: - """ + r""" Converts a PIL image to RGB format. + + Args: + image (`PIL.Image.Image`): + The PIL image to convert to RGB. + + Returns: + `PIL.Image.Image`: + The RGB-converted PIL image. """ image = image.convert("RGB") @@ -173,8 +257,16 @@ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: @staticmethod def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: - """ - Converts a PIL image to grayscale format. + r""" + Converts a given PIL image to grayscale. + + Args: + image (`PIL.Image.Image`): + The input image to convert. + + Returns: + `PIL.Image.Image`: + The image converted to grayscale. """ image = image.convert("L") @@ -182,8 +274,16 @@ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: @staticmethod def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: - """ + r""" Applies Gaussian blur to an image. + + Args: + image (`PIL.Image.Image`): + The PIL image to convert to grayscale. + + Returns: + `PIL.Image.Image`: + The grayscale-converted PIL image. """ image = image.filter(ImageFilter.GaussianBlur(blur_factor)) @@ -191,7 +291,7 @@ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: @staticmethod def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): - """ + r""" Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. @@ -285,14 +385,21 @@ def _resize_and_fill( width: int, height: int, ) -> PIL.Image.Image: - """ + r""" Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. Args: - image: The image to resize. - width: The width to resize the image to. - height: The height to resize the image to. + image (`PIL.Image.Image`): + The image to resize and fill. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and filled image. """ ratio = width / height @@ -330,14 +437,21 @@ def _resize_and_crop( width: int, height: int, ) -> PIL.Image.Image: - """ + r""" Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Args: - image: The image to resize. - width: The width to resize the image to. - height: The height to resize the image to. + image (`PIL.Image.Image`): + The image to resize and crop. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and cropped image. """ ratio = width / height src_ratio = image.width / image.height @@ -429,19 +543,23 @@ def get_default_height_width( height: Optional[int] = None, width: Optional[int] = None, ) -> Tuple[int, int]: - """ - This function return the height and width that are downscaled to the next integer multiple of - `vae_scale_factor`. + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. Args: - image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): - The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have - shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should - have shape `[batch, channel, height, width]`. - height (`int`, *optional*, defaults to `None`): - The height in preprocessed image. If `None`, will use the height of `image` input. - width (`int`, *optional*`, defaults to `None`): - The width in preprocessed. If `None`, will use the width of the `image` input. + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, numpy array, or PyTorch tensor. If it is a numpy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. """ if height is None: @@ -478,13 +596,13 @@ def preprocess( Preprocess the image input. Args: - image (`pipeline_image_input`): + image (`PipelineImageInput`): The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. - height (`int`, *optional*, defaults to `None`): + height (`int`, *optional*): The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. - width (`int`, *optional*`, defaults to `None`): + width (`int`, *optional*): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within @@ -496,6 +614,10 @@ def preprocess( supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) @@ -655,8 +777,22 @@ def apply_overlay( image: PIL.Image.Image, crop_coords: Optional[Tuple[int, int, int, int]] = None, ) -> PIL.Image.Image: - """ - overlay the inpaint output to the original image + r""" + Applies an overlay of the mask and the inpainted image on the original image. + + Args: + mask (`PIL.Image.Image`): + The mask image that highlights regions to overlay. + init_image (`PIL.Image.Image`): + The original image to which the overlay is applied. + image (`PIL.Image.Image`): + The image to overlay onto the original. + crop_coords (`Tuple[int, int, int, int]`, *optional*): + Coordinates to crop the image. If provided, the image will be cropped accordingly. + + Returns: + `PIL.Image.Image`: + The final image with the overlay applied. """ width, height = image.width, image.height @@ -713,8 +849,16 @@ def __init__( @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: - """ - Convert a NumPy image or a batch of images to a PIL image. + r""" + Convert a NumPy image or a batch of images to a list of PIL images. + + Args: + images (`np.ndarray`): + The input NumPy array of images, which can be a single image or a batch. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images converted from the input NumPy array. """ if images.ndim == 3: images = images[None, ...] @@ -729,8 +873,16 @@ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: @staticmethod def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: - """ + r""" Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`Union[List[PIL.Image.Image], PIL.Image.Image]`): + The input image or list of images to be converted. + + Returns: + `np.ndarray`: + A NumPy array of the converted images. """ if not isinstance(images, list): images = [images] @@ -741,18 +893,30 @@ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> @staticmethod def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """ - Args: - image: RGB-like depth image + r""" + Convert an RGB-like depth image to a depth map. - Returns: depth map + Args: + image (`Union[np.ndarray, torch.Tensor]`): + The RGB-like depth image to convert. + Returns: + `Union[np.ndarray, torch.Tensor]`: + The corresponding depth map. """ return image[:, :, 1] * 2**8 + image[:, :, 2] def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: - """ - Convert a NumPy depth image or a batch of images to a PIL image. + r""" + Convert a NumPy depth image or a batch of images to a list of PIL images. + + Args: + images (`np.ndarray`): + The input NumPy array of depth images, which can be a single image or a batch. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images converted from the input NumPy depth images. """ if images.ndim == 3: images = images[None, ...] @@ -833,8 +997,24 @@ def preprocess( width: Optional[int] = None, target_res: Optional[int] = None, ) -> torch.Tensor: - """ - Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. + r""" + Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors. + + Args: + rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`): + The RGB input image, which can be a single image or a batch. + depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`): + The depth input image, which can be a single image or a batch. + height (`Optional[int]`, *optional*, defaults to `None`): + The desired height of the processed image. If `None`, defaults to the height of the input image. + width (`Optional[int]`, *optional*, defaults to `None`): + The desired width of the processed image. If `None`, defaults to the width of the input image. + target_res (`Optional[int]`, *optional*, defaults to `None`): + Target resolution for resizing the images. If specified, overrides height and width. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing the processed RGB and depth images as PyTorch tensors. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) @@ -1072,7 +1252,17 @@ def __init__( @staticmethod def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: - """Returns binned height and width.""" + r""" + Returns the binned height and width based on the aspect ratio. + + Args: + height (`int`): The height of the image. + width (`int`): The width of the image. + ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width). + + Returns: + `Tuple[int, int]`: The closest binned height and width. + """ ar = float(height / width) closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) default_hw = ratios[closest_ratio] @@ -1080,6 +1270,18 @@ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[in @staticmethod def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: + r""" + Resizes and crops a tensor of images to the specified dimensions. + + Args: + samples (`torch.Tensor`): A tensor of shape (N, C, H, W) where N is the batch size, + C is the number of channels, H is the height, and W is the width. + new_width (`int`): The desired width of the output images. + new_height (`int`): The desired height of the output images. + + Returns: + `torch.Tensor`: A tensor containing the resized and cropped images. + """ orig_height, orig_width = samples.shape[2], samples.shape[3] # Check if resizing is needed