From 5ff2393c72a2a678535ac1c31779684552f18189 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 23 Oct 2023 11:04:35 -0700 Subject: [PATCH] further a guess to what to convert image to depending on number of channels --- denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | 7 ++++++- denoising_diffusion_pytorch/version.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 81ef4ad14..96c340945 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -873,7 +873,7 @@ def __init__( amp = False, mixed_precision_type = 'fp16', split_batches = True, - convert_image_to = 'RGB', + convert_image_to = None, calculate_fid = True, inception_block_idx = 2048, max_grad_norm = 1., @@ -895,6 +895,11 @@ def __init__( self.channels = diffusion_model.channels is_ddim_sampling = diffusion_model.is_ddim_sampling + # default convert_image_to depending on channels + + if not exists(convert_image_to): + convert_image_to = {1: 'L', 3: 'RGB', 4: 'RGBA'}.get(self.channels) + # sampling and training hyperparameters assert has_int_squareroot(num_samples), 'number of samples must have an integer square root' diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index be0b61075..6c08d428e 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.9.3' +__version__ = '1.9.4'