diff --git a/README.md b/README.md index 09328da..37f3333 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,14 @@ Please download the pre-trained models from the following links and save them to ```bash MODEL_NAME='styleganinv_ffhq256' IMAGE_LIST='examples/test.list' -python invert.py $MODEL_NAME $IMAGE_LIST +python invert.py --model_name $MODEL_NAME --image_list $IMAGE_LIST +``` + +```bash +MODEL_NAME='styleganinv_ffhq256' +IMAGE_DIR='examples/images/' + +python invert.py --model_name $MODEL_NAME --image_dir $IMAGE_DIR ``` **NOTE:** We find that 100 iterations are good enough for inverting an image, which takes about 8s (on P40). But users can always use more iterations (much slower) for a more precise reconstruction. @@ -47,7 +54,7 @@ python invert.py $MODEL_NAME $IMAGE_LIST MODEL_NAME='styleganinv_ffhq256' TARGET_LIST='examples/target.list' CONTEXT_LIST='examples/context.list' -python diffuse.py $MODEL_NAME $TARGET_LIST $CONTEXT_LIST +python diffuse.py --model_name $MODEL_NAME --target_list $TARGET_LIST --context_list $CONTEXT_LIST ``` NOTE: The diffusion process is highly similar to image inversion. The main difference is that only the target patch is used to compute loss for **masked** optimization. @@ -57,7 +64,7 @@ NOTE: The diffusion process is highly similar to image inversion. The main diffe ```bash SRC_DIR='results/inversion/test' DST_DIR='results/inversion/test' -python interpolate.py $MODEL_NAME $SRC_DIR $DST_DIR +python interpolate.py --model_name $MODEL_NAME --src_dir $SRC_DIR --dst_dir $DST_DIR ``` ### Manipulation @@ -65,7 +72,7 @@ python interpolate.py $MODEL_NAME $SRC_DIR $DST_DIR ```bash IMAGE_DIR='results/inversion/test' BOUNDARY='boundaries/expression.npy' -python manipulate.py $MODEL_NAME $IMAGE_DIR $BOUNDARY +python manipulate.py --model_name $MODEL_NAME --image_dir $IMAGE_DIR --boundary_path $BOUNDARY ``` **NOTE:** Boundaries are obtained using [InterFaceGAN](https://github.com/genforce/interfacegan). @@ -75,7 +82,7 @@ python manipulate.py $MODEL_NAME $IMAGE_DIR $BOUNDARY ```bash STYLE_DIR='results/inversion/test' CONTENT_DIR='results/inversion/test' -python mix_style.py $MODEL_NAME $STYLE_DIR $CONTENT_DIR +python mix_style.py --model_name $MODEL_NAME --style_dir $STYLE_DIR --content_dir $CONTENT_DIR ``` ## BibTeX diff --git a/diffuse.py b/diffuse.py index fc1e37e..542e5aa 100644 --- a/diffuse.py +++ b/diffuse.py @@ -24,10 +24,10 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('target_list', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--target_list', type=str, help='List of target images to diffuse from.') - parser.add_argument('context_list', type=str, + parser.add_argument('--context_list', type=str, help='List of context images to diffuse to.') parser.add_argument('-o', '--output_dir', type=str, default='', help='Directory to save the results. If not specified, ' diff --git a/examples/test.list b/examples/test.list index 46fd527..e3879b8 100644 --- a/examples/test.list +++ b/examples/test.list @@ -16,4 +16,4 @@ examples/000015.png examples/000016.png examples/000017.png examples/000018.png -examples/000019.png +examples/000019.png \ No newline at end of file diff --git a/interpolate.py b/interpolate.py index 882dfcf..bbe430f 100644 --- a/interpolate.py +++ b/interpolate.py @@ -23,11 +23,11 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('src_dir', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--src_dir', type=str, help='Source directory, which includes original images, ' 'inverted codes, and image list.') - parser.add_argument('dst_dir', type=str, + parser.add_argument('--dst_dir', type=str, help='Target directory, which includes original images, ' 'inverted codes, and image list.') parser.add_argument('-o', '--output_dir', type=str, default='', diff --git a/invert.py b/invert.py index 169e6d9..389e5b6 100644 --- a/invert.py +++ b/invert.py @@ -21,9 +21,12 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('image_list', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--image_list', type=str, default = '', help='List of images to invert.') + + parser.add_argument('--test_dir', type=str, default = '', + help='directory of images to invert.') parser.add_argument('-o', '--output_dir', type=str, default='', help='Directory to save the results. If not specified, ' '`./results/inversion/${IMAGE_LIST}` ' @@ -38,6 +41,11 @@ def parse_args(): parser.add_argument('--loss_weight_feat', type=float, default=5e-5, help='The perceptual loss scale for optimization. ' '(default: 5e-5)') + + parser.add_argument('--loss_weight_ssim', type=float, default=1.0, + help='The perceptual loss scale for optimization. ' + '(default: 1)') + parser.add_argument('--loss_weight_enc', type=float, default=2.0, help='The encoder loss scale for optimization.' '(default: 2.0)') @@ -52,9 +60,24 @@ def main(): """Main function.""" args = parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id - assert os.path.exists(args.image_list) - image_list_name = os.path.splitext(os.path.basename(args.image_list))[0] + if args.image_list != '' and args.test_dir == '': + assert os.path.exists(args.image_list) + image_list_name = os.path.splitext(os.path.basename(args.image_list))[0] + elif args.test_dir != '' and args.image_list == '' : + assert os.path.exists(args.test_dir) + image_list_name = os.path.splitext(os.path.basename(args.test_dir))[0] + else: + raise Exception("Use either --image_list or --test_dir. Using both arguments at the same time not supported.") + + + MODEL_DIR = os.path.join('models', 'pretrain') + os.makedirs(MODEL_DIR, exist_ok=True) + if(all(x not in os.listdir(MODEL_DIR) for x in ["styleganinv_ffhq256_encoder.pth" , "styleganinv_ffhq256_generator.pth" , "vgg16.pth"])): + raise Exception("styleganinv_ffhq256_encoder.pth , styleganinv_ffhq256_generator.pth and vgg16.pth missing") + output_dir = args.output_dir or f'results/inversion/{image_list_name}' + if not os.path.exists(output_dir): + os.makedirs(output_dir) logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger') logger.info(f'Loading model.') @@ -65,15 +88,27 @@ def main(): reconstruction_loss_weight=1.0, perceptual_loss_weight=args.loss_weight_feat, regularization_loss_weight=args.loss_weight_enc, + loss_weight_ssim = args.loss_weight_ssim, logger=logger) image_size = inverter.G.resolution # Load image list. logger.info(f'Loading image list.') image_list = [] - with open(args.image_list, 'r') as f: - for line in f: - image_list.append(line.strip()) + if args.image_list !='': + + with open(args.image_list, 'r') as f: + for line in f: + image_list.append(line.strip()) + + if args.test_dir !='': + for root, dirs, files in os.walk(args.test_dir): + for file in files: + image_list.append(file) + + + #print(len(image_list)) + logger.info(f'loaded {len(image_list)} images') # Initialize visualizer. save_interval = args.num_iterations // args.num_results @@ -90,10 +125,15 @@ def main(): logger.info(f'Start inversion.') latent_codes = [] for img_idx in tqdm(range(len(image_list)), leave=False): - image_path = image_list[img_idx] - image_name = os.path.splitext(os.path.basename(image_path))[0] + if args.image_list !='': + image_path = image_list[img_idx] + image_name = os.path.splitext(os.path.basename(image_path))[0] + elif args.test_dir !='': + image_path = os.path.join( args.test_dir, image_list[img_idx]) + image_name = os.path.splitext(os.path.basename(image_list[img_idx]))[0] + image = resize_image(load_image(image_path), (image_size, image_size)) - code, viz_results = inverter.easy_invert(image, num_viz=args.num_results) + code, viz_results , ssim_loss = inverter.easy_invert(np.array(image), num_viz=args.num_results) latent_codes.append(code) save_image(f'{output_dir}/{image_name}_ori.png', image) save_image(f'{output_dir}/{image_name}_enc.png', viz_results[1]) @@ -103,6 +143,7 @@ def main(): for viz_idx, viz_img in enumerate(viz_results[1:]): visualizer.set_cell(img_idx, viz_idx + 2, image=viz_img) + # Save results. os.system(f'cp {args.image_list} {output_dir}/image_list.txt') np.save(f'{output_dir}/inverted_codes.npy', diff --git a/manipulate.py b/manipulate.py index c4a70dd..a3ed33f 100644 --- a/manipulate.py +++ b/manipulate.py @@ -20,11 +20,11 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('image_dir', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--image_dir', type=str, help='Image directory, which includes original images, ' 'inverted codes, and image list.') - parser.add_argument('boundary_path', type=str, + parser.add_argument('--boundary_path', type=str, help='Path to the boundary for semantic manipulation.') parser.add_argument('-o', '--output_dir', type=str, default='', help='Directory to save the results. If not specified, ' diff --git a/mix_style.py b/mix_style.py index 1c2cdfc..d04a40a 100644 --- a/mix_style.py +++ b/mix_style.py @@ -23,11 +23,11 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('style_dir', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--style_dir', type=str, help='Style directory, which includes original images, ' 'inverted codes, and image list.') - parser.add_argument('content_dir', type=str, + parser.add_argument('--content_dir', type=str, help='Content directory, which includes original images, ' 'inverted codes, and image list.') parser.add_argument('-o', '--output_dir', type=str, default='', diff --git a/models/base_module.py b/models/base_module.py index 04a9105..5e84ac1 100644 --- a/models/base_module.py +++ b/models/base_module.py @@ -1,300 +1,300 @@ -# python 3.7 -"""Contains the base class for modules in a GAN model. - -Commonly, GAN consists of two components, i.e., generator and discriminator. -In practice, however, more modules can be added, such as encoder. -""" - -import os.path -import sys -import logging -import numpy as np - -import torch - -from . import model_settings - -__all__ = ['BaseModule'] - -DTYPE_NAME_TO_TORCH_TENSOR_TYPE = { - 'float16': torch.HalfTensor, - 'float32': torch.FloatTensor, - 'float64': torch.DoubleTensor, - 'int8': torch.CharTensor, - 'int16': torch.ShortTensor, - 'int32': torch.IntTensor, - 'int64': torch.LongTensor, - 'uint8': torch.ByteTensor, - 'bool': torch.BoolTensor, -} - - -def get_temp_logger(logger_name='logger'): - """Gets a temporary logger. - - This logger will print all levels of messages onto the screen. - - Args: - logger_name: Name of the logger. - - Returns: - A `logging.Logger`. - - Raises: - ValueError: If the input `logger_name` is empty. - """ - if not logger_name: - raise ValueError(f'Input `logger_name` should not be empty!') - - logger = logging.getLogger(logger_name) - if not logger.hasHandlers(): - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') - sh = logging.StreamHandler(stream=sys.stdout) - sh.setLevel(logging.DEBUG) - sh.setFormatter(formatter) - logger.addHandler(sh) - - return logger - - -class BaseModule(object): - """Base class for modules in GANs, like generator and discriminator. - - NOTE: The module should be defined with pytorch, and used for inference only. - """ - - def __init__(self, model_name, module_name, logger=None): - """Initializes with specific settings. - - The GAN model should be first registered in `model_settings.py` with proper - settings. Among them, some attributes are necessary, including: - - (1) resolution: Resolution of the synthesis. - (2) image_channels: Number of channels of the synthesis. (default: 3) - (3) channel_order: Channel order of the raw synthesis. (default: `RGB`) - (4) min_val: Minimum value of the raw synthesis. (default -1.0) - (5) max_val: Maximum value of the raw synthesis. (default 1.0) - - Args: - model_name: Name with which the GAN model is registered. - module_name: Name of the module, like `generator` or `discriminator`. - logger: Logger for recording log messages. If set as `None`, a default - logger, which prints messages from all levels onto the screen, will be - created. (default: None) - - Raises: - AttributeError: If some necessary attributes are missing. - """ - self.model_name = model_name - self.module_name = module_name - self.logger = logger or get_temp_logger(model_name) - - # Parse settings. - for key, val in model_settings.MODEL_POOL[model_name].items(): - setattr(self, key, val) - self.use_cuda = model_settings.USE_CUDA and torch.cuda.is_available() - self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE - self.ram_size = model_settings.MAX_IMAGES_ON_RAM - self.net = None - self.run_device = 'cuda' if self.use_cuda else 'cpu' - self.cpu_device = 'cpu' - - # Check necessary settings. - self.check_attr('gan_type') # Should be specified in derived classes. - self.check_attr('resolution') - self.image_channels = getattr(self, 'image_channels', 3) - assert self.image_channels in [1, 3] - self.channel_order = getattr(self, 'channel_order', 'RGB').upper() - assert self.channel_order in ['RGB', 'BGR'] - self.min_val = getattr(self, 'min_val', -1.0) - self.max_val = getattr(self, 'max_val', 1.0) - - # Get paths. - self.weight_path = model_settings.get_weight_path( - f'{model_name}_{module_name}') - - # Build graph and load pre-trained weights. - self.logger.info(f'Build network for module `{self.module_name}` in ' - f'model `{self.model_name}`.') - self.model_specific_vars = [] - self.build() - if os.path.isfile(self.weight_path): - self.load() - else: - self.logger.warning(f'No pre-trained weights will be loaded!') - - # Change to inference mode and GPU mode if needed. - assert self.net - self.net.eval().to(self.run_device) - - def check_attr(self, attr_name): - """Checks the existence of a particular attribute. - - Args: - attr_name: Name of the attribute to check. - - Raises: - AttributeError: If the target attribute is missing. - """ - if not hasattr(self, attr_name): - raise AttributeError(f'Field `{attr_name}` is missing for ' - f'module `{self.module_name}` in ' - f'model `{self.model_name}`!') - - def build(self): - """Builds the graph.""" - raise NotImplementedError(f'Should be implemented in derived class!') - - def load(self): - """Loads pre-trained weights.""" - self.logger.info(f'Loading pytorch weights from `{self.weight_path}`.') - state_dict = torch.load(self.weight_path) - for var_name in self.model_specific_vars: - state_dict[var_name] = self.net.state_dict()[var_name] - self.net.load_state_dict(state_dict) - self.logger.info(f'Successfully loaded!') - - def to_tensor(self, array): - """Converts a `numpy.ndarray` to `torch.Tensor` on running device. - - Args: - array: The input array to convert. - - Returns: - A `torch.Tensor` whose dtype is determined by that of the input array. - - Raises: - ValueError: If the array is with neither `torch.Tensor` type nor - `numpy.ndarray` type. - """ - dtype = type(array) - if isinstance(array, torch.Tensor): - tensor = array - elif isinstance(array, np.ndarray): - tensor_type = DTYPE_NAME_TO_TORCH_TENSOR_TYPE[array.dtype.name] - tensor = torch.from_numpy(array).type(tensor_type) - else: - raise ValueError(f'Unsupported input type `{dtype}`!') - tensor = tensor.to(self.run_device) - return tensor - - def get_value(self, tensor): - """Gets value of a `torch.Tensor`. - - Args: - tensor: The input tensor to get value from. - - Returns: - A `numpy.ndarray`. - - Raises: - ValueError: If the tensor is with neither `torch.Tensor` type nor - `numpy.ndarray` type. - """ - dtype = type(tensor) - if isinstance(tensor, np.ndarray): - return tensor - if isinstance(tensor, torch.Tensor): - return tensor.to(self.cpu_device).detach().numpy() - raise ValueError(f'Unsupported input type `{dtype}`!') - - def get_ont_hot_labels(self, num, labels=None): - """Gets ont-hot labels for conditional generation. - - Args: - num: Number of labels to generate. - labels: Input labels as reference to generate one-hot labels. If set as - `None`, label `0` will be used by default. (default: None) - - Returns: - Returns `None` if `self.label_size` is 0, otherwise, a `numpy.ndarray` - with shape [num, self.label_size] and dtype `np.float32`. - """ - self.check_attr('label_size') - if self.label_size == 0: - return None - - if labels is None: - labels = 0 - labels = np.array(labels).reshape(-1) - if labels.size == 1: - labels = np.tile(labels, (num,)) - assert labels.shape == (num,) - for label in labels: - if label >= self.label_size or label < 0: - raise ValueError(f'Label should be smaller than {self.label_size}, ' - f'but {label} is received!') - - one_hot = np.zeros((num, self.label_size), dtype=np.int32) - one_hot[np.arange(num), labels] = 1 - return one_hot - - def get_batch_inputs(self, inputs, batch_size=None): - """Gets inputs within mini-batch. - - This function yields at most `self.batch_size` inputs at a time. - - Args: - inputs: Input data to form mini-batch. - batch_size: Batch size. If not specified, `self.batch_size` will be used. - (default: None) - """ - total_num = inputs.shape[0] - batch_size = batch_size or self.batch_size - for i in range(0, total_num, batch_size): - yield inputs[i:i + batch_size] - - def batch_run(self, inputs, run_fn): - """Runs model with mini-batch. - - This function splits the inputs into mini-batches, run the model with each - mini-batch, and then concatenate the outputs from all mini-batches together. - - NOTE: The output of `run_fn` can only be `numpy.ndarray` or a dictionary - whose values are all `numpy.ndarray`. - - Args: - inputs: The input samples to run with. - run_fn: A callable function. - - Returns: - Same type as the output of `run_fn`. - - Raises: - ValueError: If the output type of `run_fn` is not supported. - """ - if inputs.shape[0] > self.ram_size: - self.logger.warning(f'Number of inputs on RAM is larger than ' - f'{self.ram_size}. Please use ' - f'`self.get_batch_inputs()` to split the inputs! ' - f'Otherwise, it may encounter OOM problem!') - - results = {} - temp_key = '__temp_key__' - for batch_inputs in self.get_batch_inputs(inputs): - batch_outputs = run_fn(batch_inputs) - if isinstance(batch_outputs, dict): - for key, val in batch_outputs.items(): - if not isinstance(val, np.ndarray): - raise ValueError(f'Each item of the model output should be with ' - f'type `numpy.ndarray`, but type `{type(val)}` is ' - f'received for key `{key}`!') - if key not in results: - results[key] = [val] - else: - results[key].append(val) - elif isinstance(batch_outputs, np.ndarray): - if temp_key not in results: - results[temp_key] = [batch_outputs] - else: - results[temp_key].append(batch_outputs) - else: - raise ValueError(f'The model output can only be with type ' - f'`numpy.ndarray`, or a dictionary of ' - f'`numpy.ndarray`, but type `{type(batch_outputs)}` ' - f'is received!') - - for key, val in results.items(): - results[key] = np.concatenate(val, axis=0) - return results if temp_key not in results else results[temp_key] +# python 3.7 +"""Contains the base class for modules in a GAN model. + +Commonly, GAN consists of two components, i.e., generator and discriminator. +In practice, however, more modules can be added, such as encoder. +""" + +import os.path +import sys +import logging +import numpy as np + +import torch + +from . import model_settings + +__all__ = ['BaseModule'] + +DTYPE_NAME_TO_TORCH_TENSOR_TYPE = { + 'float16': torch.HalfTensor, + 'float32': torch.FloatTensor, + 'float64': torch.DoubleTensor, + 'int8': torch.CharTensor, + 'int16': torch.ShortTensor, + 'int32': torch.IntTensor, + 'int64': torch.LongTensor, + 'uint8': torch.ByteTensor, + 'bool': torch.BoolTensor, +} + + +def get_temp_logger(logger_name='logger'): + """Gets a temporary logger. + + This logger will print all levels of messages onto the screen. + + Args: + logger_name: Name of the logger. + + Returns: + A `logging.Logger`. + + Raises: + ValueError: If the input `logger_name` is empty. + """ + if not logger_name: + raise ValueError(f'Input `logger_name` should not be empty!') + + logger = logging.getLogger(logger_name) + if not logger.hasHandlers(): + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') + sh = logging.StreamHandler(stream=sys.stdout) + sh.setLevel(logging.DEBUG) + sh.setFormatter(formatter) + logger.addHandler(sh) + + return logger + + +class BaseModule(object): + """Base class for modules in GANs, like generator and discriminator. + + NOTE: The module should be defined with pytorch, and used for inference only. + """ + + def __init__(self, model_name, module_name, logger=None): + """Initializes with specific settings. + + The GAN model should be first registered in `model_settings.py` with proper + settings. Among them, some attributes are necessary, including: + + (1) resolution: Resolution of the synthesis. + (2) image_channels: Number of channels of the synthesis. (default: 3) + (3) channel_order: Channel order of the raw synthesis. (default: `RGB`) + (4) min_val: Minimum value of the raw synthesis. (default -1.0) + (5) max_val: Maximum value of the raw synthesis. (default 1.0) + + Args: + model_name: Name with which the GAN model is registered. + module_name: Name of the module, like `generator` or `discriminator`. + logger: Logger for recording log messages. If set as `None`, a default + logger, which prints messages from all levels onto the screen, will be + created. (default: None) + + Raises: + AttributeError: If some necessary attributes are missing. + """ + self.model_name = model_name + self.module_name = module_name + self.logger = logger or get_temp_logger(model_name) + + # Parse settings. + for key, val in model_settings.MODEL_POOL[model_name].items(): + setattr(self, key, val) + self.use_cuda = model_settings.USE_CUDA and torch.cuda.is_available() + self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE + self.ram_size = model_settings.MAX_IMAGES_ON_RAM + self.net = None + self.run_device = 'cuda' if self.use_cuda else 'cpu' + self.cpu_device = 'cpu' + + # Check necessary settings. + self.check_attr('gan_type') # Should be specified in derived classes. + self.check_attr('resolution') + self.image_channels = getattr(self, 'image_channels', 3) + assert self.image_channels in [1, 3] + self.channel_order = getattr(self, 'channel_order', 'RGB').upper() + assert self.channel_order in ['RGB', 'BGR'] + self.min_val = getattr(self, 'min_val', -1.0) + self.max_val = getattr(self, 'max_val', 1.0) + + # Get paths. + self.weight_path = model_settings.get_weight_path( + f'{model_name}_{module_name}') + + # Build graph and load pre-trained weights. + self.logger.info(f'Build network for module `{self.module_name}` in ' + f'model `{self.model_name}`.') + self.model_specific_vars = [] + self.build() + if os.path.isfile(self.weight_path): + self.load() + else: + self.logger.warning(f'No pre-trained weights will be loaded!') + + # Change to inference mode and GPU mode if needed. + assert self.net + self.net.eval().to(self.run_device) + + def check_attr(self, attr_name): + """Checks the existence of a particular attribute. + + Args: + attr_name: Name of the attribute to check. + + Raises: + AttributeError: If the target attribute is missing. + """ + if not hasattr(self, attr_name): + raise AttributeError(f'Field `{attr_name}` is missing for ' + f'module `{self.module_name}` in ' + f'model `{self.model_name}`!') + + def build(self): + """Builds the graph.""" + raise NotImplementedError(f'Should be implemented in derived class!') + + def load(self): + """Loads pre-trained weights.""" + self.logger.info(f'Loading pytorch weights from `{self.weight_path}`.') + state_dict = torch.load(self.weight_path) + for var_name in self.model_specific_vars: + state_dict[var_name] = self.net.state_dict()[var_name] + self.net.load_state_dict(state_dict) + self.logger.info(f'Successfully loaded!') + + def to_tensor(self, array): + """Converts a `numpy.ndarray` to `torch.Tensor` on running device. + + Args: + array: The input array to convert. + + Returns: + A `torch.Tensor` whose dtype is determined by that of the input array. + + Raises: + ValueError: If the array is with neither `torch.Tensor` type nor + `numpy.ndarray` type. + """ + dtype = type(array) + if isinstance(array, torch.Tensor): + tensor = array + elif isinstance(array, np.ndarray): + tensor_type = DTYPE_NAME_TO_TORCH_TENSOR_TYPE[array.dtype.name] + tensor = torch.from_numpy(array).type(tensor_type) + else: + raise ValueError(f'Unsupported input type `{dtype}`!') + tensor = tensor.to(self.run_device) + return tensor + + def get_value(self, tensor): + """Gets value of a `torch.Tensor`. + + Args: + tensor: The input tensor to get value from. + + Returns: + A `numpy.ndarray`. + + Raises: + ValueError: If the tensor is with neither `torch.Tensor` type nor + `numpy.ndarray` type. + """ + dtype = type(tensor) + if isinstance(tensor, np.ndarray): + return tensor + if isinstance(tensor, torch.Tensor): + return tensor.to(self.cpu_device).detach().numpy() + raise ValueError(f'Unsupported input type `{dtype}`!') + + def get_ont_hot_labels(self, num, labels=None): + """Gets ont-hot labels for conditional generation. + + Args: + num: Number of labels to generate. + labels: Input labels as reference to generate one-hot labels. If set as + `None`, label `0` will be used by default. (default: None) + + Returns: + Returns `None` if `self.label_size` is 0, otherwise, a `numpy.ndarray` + with shape [num, self.label_size] and dtype `np.float32`. + """ + self.check_attr('label_size') + if self.label_size == 0: + return None + + if labels is None: + labels = 0 + labels = np.array(labels).reshape(-1) + if labels.size == 1: + labels = np.tile(labels, (num,)) + assert labels.shape == (num,) + for label in labels: + if label >= self.label_size or label < 0: + raise ValueError(f'Label should be smaller than {self.label_size}, ' + f'but {label} is received!') + + one_hot = np.zeros((num, self.label_size), dtype=np.int32) + one_hot[np.arange(num), labels] = 1 + return one_hot + + def get_batch_inputs(self, inputs, batch_size=None): + """Gets inputs within mini-batch. + + This function yields at most `self.batch_size` inputs at a time. + + Args: + inputs: Input data to form mini-batch. + batch_size: Batch size. If not specified, `self.batch_size` will be used. + (default: None) + """ + total_num = inputs.shape[0] + batch_size = batch_size or self.batch_size + for i in range(0, total_num, batch_size): + yield inputs[i:i + batch_size] + + def batch_run(self, inputs, run_fn): + """Runs model with mini-batch. + + This function splits the inputs into mini-batches, run the model with each + mini-batch, and then concatenate the outputs from all mini-batches together. + + NOTE: The output of `run_fn` can only be `numpy.ndarray` or a dictionary + whose values are all `numpy.ndarray`. + + Args: + inputs: The input samples to run with. + run_fn: A callable function. + + Returns: + Same type as the output of `run_fn`. + + Raises: + ValueError: If the output type of `run_fn` is not supported. + """ + if inputs.shape[0] > self.ram_size: + self.logger.warning(f'Number of inputs on RAM is larger than ' + f'{self.ram_size}. Please use ' + f'`self.get_batch_inputs()` to split the inputs! ' + f'Otherwise, it may encounter OOM problem!') + + results = {} + temp_key = '__temp_key__' + for batch_inputs in self.get_batch_inputs(inputs): + batch_outputs = run_fn(batch_inputs) + if isinstance(batch_outputs, dict): + for key, val in batch_outputs.items(): + if not isinstance(val, np.ndarray): + raise ValueError(f'Each item of the model output should be with ' + f'type `numpy.ndarray`, but type `{type(val)}` is ' + f'received for key `{key}`!') + if key not in results: + results[key] = [val] + else: + results[key].append(val) + elif isinstance(batch_outputs, np.ndarray): + if temp_key not in results: + results[temp_key] = [batch_outputs] + else: + results[temp_key].append(batch_outputs) + else: + raise ValueError(f'The model output can only be with type ' + f'`numpy.ndarray`, or a dictionary of ' + f'`numpy.ndarray`, but type `{type(batch_outputs)}` ' + f'is received!') + + for key, val in results.items(): + results[key] = np.concatenate(val, axis=0) + return results if temp_key not in results else results[temp_key] diff --git a/models/model_settings.py b/models/model_settings.py index b6bc732..42f554c 100644 --- a/models/model_settings.py +++ b/models/model_settings.py @@ -14,6 +14,12 @@ 'final_tanh': True, 'use_bn': True, }, + 'styleganinv_ffhq256B': { + 'resolution': 256, + 'repeat_w': False, + 'final_tanh': True, + 'use_bn': True, + }, 'styleganinv_bedroom256': { 'resolution': 256, 'repeat_w': False, diff --git a/models/stylegan_generator.py b/models/stylegan_generator.py index 7089e0c..9ba39b3 100644 --- a/models/stylegan_generator.py +++ b/models/stylegan_generator.py @@ -138,6 +138,21 @@ def preprocess(self, latent_codes, latent_space_type='z', **kwargs): return latent_codes.astype(np.float32) + + def synthesizeImages(self,latent_codes,latent_space_type='Z'): + zs = latent_codes + zs = zs.to(self.run_device) + ws = self.net.mapping(zs) + ws = ws.to(self.run_device) + wps = self.net.truncation(ws) + wps = wps.to(self.run_device) + + images = self.net.synthesis(wps) + images = images.to(self.run_device) + + return images + + def _synthesize(self, latent_codes, latent_space_type='z', diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py new file mode 100644 index 0000000..738e803 --- /dev/null +++ b/pytorch_ssim/__init__.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/utils/inverter.py b/utils/inverter.py index dc029d2..017a9ea 100644 --- a/utils/inverter.py +++ b/utils/inverter.py @@ -1,327 +1,340 @@ -# python 3.7 -"""Utility functions to invert a given image back to a latent code.""" - -from tqdm import tqdm -import cv2 -import numpy as np - -import torch - -from models.stylegan_generator import StyleGANGenerator -from models.stylegan_encoder import StyleGANEncoder -from models.perceptual_model import PerceptualModel - -__all__ = ['StyleGANInverter'] - - -def _softplus(x): - """Implements the softplus function.""" - return torch.nn.functional.softplus(x, beta=1, threshold=10000) - -def _get_tensor_value(tensor): - """Gets the value of a torch Tensor.""" - return tensor.cpu().detach().numpy() - - -class StyleGANInverter(object): - """Defines the class for StyleGAN inversion. - - Even having the encoder, the output latent code is not good enough to recover - the target image satisfyingly. To this end, this class optimize the latent - code based on gradient descent algorithm. In the optimization process, - following loss functions will be considered: - - (1) Pixel-wise reconstruction loss. (required) - (2) Perceptual loss. (optional, but recommended) - (3) Regularization loss from encoder. (optional, but recommended for in-domain - inversion) - - NOTE: The encoder can be missing for inversion, in which case the latent code - will be randomly initialized and the regularization loss will be ignored. - """ - - def __init__(self, - model_name, - learning_rate=1e-2, - iteration=100, - reconstruction_loss_weight=1.0, - perceptual_loss_weight=5e-5, - regularization_loss_weight=2.0, - logger=None): - """Initializes the inverter. - - NOTE: Only Adam optimizer is supported in the optimization process. - - Args: - model_name: Name of the model on which the inverted is based. The model - should be first registered in `models/model_settings.py`. - logger: Logger to record the log message. - learning_rate: Learning rate for optimization. (default: 1e-2) - iteration: Number of iterations for optimization. (default: 100) - reconstruction_loss_weight: Weight for reconstruction loss. Should always - be a positive number. (default: 1.0) - perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual - loss. (default: 5e-5) - regularization_loss_weight: Weight for regularization loss from encoder. - This is essential for in-domain inversion. However, this loss will - automatically ignored if the generative model does not include a valid - encoder. 0 disables regularization loss. (default: 2.0) - """ - self.logger = logger - self.model_name = model_name - self.gan_type = 'stylegan' - - self.G = StyleGANGenerator(self.model_name, self.logger) - self.E = StyleGANEncoder(self.model_name, self.logger) - self.F = PerceptualModel(min_val=self.G.min_val, max_val=self.G.max_val) - self.encode_dim = [self.G.num_layers, self.G.w_space_dim] - self.run_device = self.G.run_device - assert list(self.encode_dim) == list(self.E.encode_dim) - - assert self.G.gan_type == self.gan_type - assert self.E.gan_type == self.gan_type - - self.learning_rate = learning_rate - self.iteration = iteration - self.loss_pix_weight = reconstruction_loss_weight - self.loss_feat_weight = perceptual_loss_weight - self.loss_reg_weight = regularization_loss_weight - assert self.loss_pix_weight > 0 - - - def preprocess(self, image): - """Preprocesses a single image. - - This function assumes the input numpy array is with shape [height, width, - channel], channel order `RGB`, and pixel range [0, 255]. - - The returned image is with shape [channel, new_height, new_width], where - `new_height` and `new_width` are specified by the given generative model. - The channel order of returned image is also specified by the generative - model. The pixel range is shifted to [min_val, max_val], where `min_val` and - `max_val` are also specified by the generative model. - """ - if not isinstance(image, np.ndarray): - raise ValueError(f'Input image should be with type `numpy.ndarray`!') - if image.dtype != np.uint8: - raise ValueError(f'Input image should be with dtype `numpy.uint8`!') - - if image.ndim != 3 or image.shape[2] not in [1, 3]: - raise ValueError(f'Input should be with shape [height, width, channel], ' - f'where channel equals to 1 or 3!\n' - f'But {image.shape} is received!') - if image.shape[2] == 1 and self.G.image_channels == 3: - image = np.tile(image, (1, 1, 3)) - if image.shape[2] != self.G.image_channels: - raise ValueError(f'Number of channels of input image, which is ' - f'{image.shape[2]}, is not supported by the current ' - f'inverter, which requires {self.G.image_channels} ' - f'channels!') - - if self.G.image_channels == 3 and self.G.channel_order == 'BGR': - image = image[:, :, ::-1] - if image.shape[1:3] != [self.G.resolution, self.G.resolution]: - image = cv2.resize(image, (self.G.resolution, self.G.resolution)) - image = image.astype(np.float32) - image = image / 255.0 * (self.G.max_val - self.G.min_val) + self.G.min_val - image = image.astype(np.float32).transpose(2, 0, 1) - - return image - - def get_init_code(self, image): - """Gets initial latent codes as the start point for optimization. - - The input image is assumed to have already been preprocessed, meaning to - have shape [self.G.image_channels, self.G.resolution, self.G.resolution], - channel order `self.G.channel_order`, and pixel range [self.G.min_val, - self.G.max_val]. - """ - x = image[np.newaxis] - x = self.G.to_tensor(x.astype(np.float32)) - z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim)) - return z.astype(np.float32) - - def invert(self, image, num_viz=0): - """Inverts the given image to a latent code. - - Basically, this function is based on gradient descent algorithm. - - Args: - image: Target image to invert, which is assumed to have already been - preprocessed. - num_viz: Number of intermediate outputs to visualize. (default: 0) - - Returns: - A two-element tuple. First one is the inverted code. Second one is a list - of intermediate results, where first image is the input image, second - one is the reconstructed result from the initial latent code, remainings - are from the optimization process every `self.iteration // num_viz` - steps. - """ - x = image[np.newaxis] - x = self.G.to_tensor(x.astype(np.float32)) - x.requires_grad = False - init_z = self.get_init_code(image) - z = torch.Tensor(init_z).to(self.run_device) - z.requires_grad = True - - optimizer = torch.optim.Adam([z], lr=self.learning_rate) - - viz_results = [] - viz_results.append(self.G.postprocess(_get_tensor_value(x))[0]) - x_init_inv = self.G.net.synthesis(z) - viz_results.append(self.G.postprocess(_get_tensor_value(x_init_inv))[0]) - pbar = tqdm(range(1, self.iteration + 1), leave=True) - for step in pbar: - loss = 0.0 - - # Reconstruction loss. - x_rec = self.G.net.synthesis(z) - loss_pix = torch.mean((x - x_rec) ** 2) - loss = loss + loss_pix * self.loss_pix_weight - log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}' - - # Perceptual loss. - if self.loss_feat_weight: - x_feat = self.F.net(x) - x_rec_feat = self.F.net(x_rec) - loss_feat = torch.mean((x_feat - x_rec_feat) ** 2) - loss = loss + loss_feat * self.loss_feat_weight - log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}' - - # Regularization loss. - if self.loss_reg_weight: - z_rec = self.E.net(x_rec).view(1, *self.encode_dim) - loss_reg = torch.mean((z - z_rec) ** 2) - loss = loss + loss_reg * self.loss_reg_weight - log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}' - - log_message += f', loss: {_get_tensor_value(loss):.3f}' - pbar.set_description_str(log_message) - if self.logger: - self.logger.debug(f'Step: {step:05d}, ' - f'lr: {self.learning_rate:.2e}, ' - f'{log_message}') - - # Do optimization. - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if num_viz > 0 and step % (self.iteration // num_viz) == 0: - viz_results.append(self.G.postprocess(_get_tensor_value(x_rec))[0]) - - return _get_tensor_value(z), viz_results - - def easy_invert(self, image, num_viz=0): - """Wraps functions `preprocess()` and `invert()` together.""" - return self.invert(self.preprocess(image), num_viz) - - def diffuse(self, - target, - context, - center_x, - center_y, - crop_x, - crop_y, - num_viz=0): - """Diffuses the target image to a context image. - - Basically, this function is a motified version of `self.invert()`. More - concretely, the encoder regularizer is removed from the objectives and the - reconstruction loss is computed from the masked region. - - Args: - target: Target image (foreground). - context: Context image (background). - center_x: The x-coordinate of the crop center. - center_y: The y-coordinate of the crop center. - crop_x: The crop size along the x-axis. - crop_y: The crop size along the y-axis. - num_viz: Number of intermediate outputs to visualize. (default: 0) - - Returns: - A two-element tuple. First one is the inverted code. Second one is a list - of intermediate results, where first image is the direct copy-paste - image, second one is the reconstructed result from the initial latent - code, remainings are from the optimization process every - `self.iteration // num_viz` steps. - """ - image_shape = (self.G.image_channels, self.G.resolution, self.G.resolution) - mask = np.zeros((1, *image_shape), dtype=np.float32) - xx = center_x - crop_x // 2 - yy = center_y - crop_y // 2 - mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0 - - target = target[np.newaxis] - if context.ndim == 3: - context = self.preprocess(context)[np.newaxis] - else: - contexts = [] - for i in range(context.shape[0]): - contexts.append(self.preprocess(context[i])) - context = np.asarray(contexts) - x = target * mask + context * (1 - mask) - x = self.G.to_tensor(x.astype(np.float32)) - x.requires_grad = False - mask = self.G.to_tensor(mask.astype(np.float32)) - mask.requires_grad = False - - init_z = _get_tensor_value(self.E.net(x).view(-1, *self.encode_dim)) - init_z = init_z.astype(np.float32) - z = torch.Tensor(init_z).to(self.run_device) - z.requires_grad = True - - optimizer = torch.optim.Adam([z], lr=self.learning_rate) - - copy_and_paste = self.G.postprocess(_get_tensor_value(x)) - x_init_inv = self.G.net.synthesis(z) - encoder_out = self.G.postprocess(_get_tensor_value(x_init_inv)) - viz_results = {} - for it in range(context.shape[0]): - viz_results[it] = [] - viz_results[it].append(copy_and_paste[it]) - viz_results[it].append(encoder_out[it]) - - pbar = tqdm(range(1, self.iteration + 1), leave=True) - for step in pbar: - loss = 0.0 - - # Reconstruction loss. - x_rec = self.G.net.synthesis(z) - loss_pix = torch.mean(((x - x_rec) * mask) ** 2, dim=[1, 2, 3]) - loss = loss + loss_pix * self.loss_pix_weight - log_message = f'loss_pix: {np.mean(_get_tensor_value(loss_pix)):.3f}' - - # Perceptual loss. - if self.loss_feat_weight: - x_feat = self.F.net(x * mask) - x_rec_feat = self.F.net(x_rec * mask) - loss_feat = torch.mean((x_feat - x_rec_feat) ** 2, dim=[1, 2, 3]) - loss = loss + loss_feat * self.loss_feat_weight - log_message += f', loss_feat: {np.mean(_get_tensor_value(loss_feat)):.3f}' - - log_message += f', loss: {np.mean(_get_tensor_value(loss)):.3f}' - pbar.set_description_str(log_message) - if self.logger: - self.logger.debug(f'Step: {step:05d}, ' - f'lr: {self.learning_rate:.2e}, ' - f'{log_message}') - - # Do optimization. - optimizer.zero_grad() - loss.backward(torch.ones_like(loss)) - optimizer.step() - - if num_viz > 0 and step % (self.iteration // num_viz) == 0: - rec_res = self.G.postprocess(_get_tensor_value(x_rec)) - for it in range(rec_res.shape[0]): - viz_results[it].append(rec_res[it]) - - return _get_tensor_value(z), viz_results - - def easy_diffuse(self, target, context, *args, **kwargs): - """Wraps functions `preprocess()` and `diffuse()` together.""" - return self.diffuse(self.preprocess(target), - context, - *args, **kwargs) +# python 3.7 +"""Utility functions to invert a given image back to a latent code.""" + +from tqdm import tqdm +import cv2 +import numpy as np + +import torch + +from models.stylegan_generator import StyleGANGenerator +from models.stylegan_encoder import StyleGANEncoder +from models.perceptual_model import PerceptualModel +import pytorch_ssim + +__all__ = ['StyleGANInverter'] + + +def _softplus(x): + """Implements the softplus function.""" + return torch.nn.functional.softplus(x, beta=1, threshold=10000) + +def _get_tensor_value(tensor): + """Gets the value of a torch Tensor.""" + return tensor.cpu().detach().numpy() + + +class StyleGANInverter(object): + """Defines the class for StyleGAN inversion. + + Even having the encoder, the output latent code is not good enough to recover + the target image satisfyingly. To this end, this class optimize the latent + code based on gradient descent algorithm. In the optimization process, + following loss functions will be considered: + + (1) Pixel-wise reconstruction loss. (required) + (2) Perceptual loss. (optional, but recommended) + (3) Regularization loss from encoder. (optional, but recommended for in-domain + inversion) + + NOTE: The encoder can be missing for inversion, in which case the latent code + will be randomly initialized and the regularization loss will be ignored. + """ + + def __init__(self, + model_name, + learning_rate=1e-2, + iteration=100, + reconstruction_loss_weight=1.0, + perceptual_loss_weight=5e-5, + regularization_loss_weight=2.0, + loss_weight_ssim = 1.0, + logger=None): + """Initializes the inverter. + + NOTE: Only Adam optimizer is supported in the optimization process. + + Args: + model_name: Name of the model on which the inverted is based. The model + should be first registered in `models/model_settings.py`. + logger: Logger to record the log message. + learning_rate: Learning rate for optimization. (default: 1e-2) + iteration: Number of iterations for optimization. (default: 100) + reconstruction_loss_weight: Weight for reconstruction loss. Should always + be a positive number. (default: 1.0) + perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual + loss. (default: 5e-5) + regularization_loss_weight: Weight for regularization loss from encoder. + This is essential for in-domain inversion. However, this loss will + automatically ignored if the generative model does not include a valid + encoder. 0 disables regularization loss. (default: 2.0) + """ + self.logger = logger + self.model_name = model_name + self.gan_type = 'stylegan' + + self.G = StyleGANGenerator(self.model_name, self.logger) + self.E = StyleGANEncoder(self.model_name, self.logger) + self.F = PerceptualModel(min_val=self.G.min_val, max_val=self.G.max_val) + self.encode_dim = [self.G.num_layers, self.G.w_space_dim] + self.run_device = self.G.run_device + assert list(self.encode_dim) == list(self.E.encode_dim) + + assert self.G.gan_type == self.gan_type + assert self.E.gan_type == self.gan_type + + self.learning_rate = learning_rate + self.iteration = iteration + self.loss_pix_weight = reconstruction_loss_weight + self.loss_feat_weight = perceptual_loss_weight + self.loss_reg_weight = regularization_loss_weight + self.loss_weight_ssim = loss_weight_ssim + assert self.loss_pix_weight > 0 + + + def preprocess(self, image): + """Preprocesses a single image. + + This function assumes the input numpy array is with shape [height, width, + channel], channel order `RGB`, and pixel range [0, 255]. + + The returned image is with shape [channel, new_height, new_width], where + `new_height` and `new_width` are specified by the given generative model. + The channel order of returned image is also specified by the generative + model. The pixel range is shifted to [min_val, max_val], where `min_val` and + `max_val` are also specified by the generative model. + """ + if not isinstance(image, np.ndarray): + raise ValueError(f'Input image should be with type `numpy.ndarray`!') + if image.dtype != np.uint8: + raise ValueError(f'Input image should be with dtype `numpy.uint8`!') + + if image.ndim != 3 or image.shape[2] not in [1, 3]: + raise ValueError(f'Input should be with shape [height, width, channel], ' + f'where channel equals to 1 or 3!\n' + f'But {image.shape} is received!') + if image.shape[2] == 1 and self.G.image_channels == 3: + image = np.tile(image, (1, 1, 3)) + if image.shape[2] != self.G.image_channels: + raise ValueError(f'Number of channels of input image, which is ' + f'{image.shape[2]}, is not supported by the current ' + f'inverter, which requires {self.G.image_channels} ' + f'channels!') + + if self.G.image_channels == 3 and self.G.channel_order == 'BGR': + image = image[:, :, ::-1] + if image.shape[1:3] != [self.G.resolution, self.G.resolution]: + image = cv2.resize(image, (self.G.resolution, self.G.resolution)) + image = image.astype(np.float32) + image = image / 255.0 * (self.G.max_val - self.G.min_val) + self.G.min_val + image = image.astype(np.float32).transpose(2, 0, 1) + + return image + + def get_init_code(self, image): + """Gets initial latent codes as the start point for optimization. + + The input image is assumed to have already been preprocessed, meaning to + have shape [self.G.image_channels, self.G.resolution, self.G.resolution], + channel order `self.G.channel_order`, and pixel range [self.G.min_val, + self.G.max_val]. + """ + x = image[np.newaxis] + x = self.G.to_tensor(x.astype(np.float32)) + z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim)) + return z.astype(np.float32) + + def invert(self, image, num_viz=0): + """Inverts the given image to a latent code. + + Basically, this function is based on gradient descent algorithm. + + Args: + image: Target image to invert, which is assumed to have already been + preprocessed. + num_viz: Number of intermediate outputs to visualize. (default: 0) + + Returns: + A two-element tuple. First one is the inverted code. Second one is a list + of intermediate results, where first image is the input image, second + one is the reconstructed result from the initial latent code, remainings + are from the optimization process every `self.iteration // num_viz` + steps. + """ + x = image[np.newaxis] + x = self.G.to_tensor(x.astype(np.float32)) + x.requires_grad = False + init_z = self.get_init_code(image) + z = torch.Tensor(init_z).to(self.run_device) + z.requires_grad = True + + optimizer = torch.optim.Adam([z], lr=self.learning_rate) + + viz_results = [] + viz_results.append(self.G.postprocess(_get_tensor_value(x))[0]) + x_init_inv = self.G.net.synthesis(z) + viz_results.append(self.G.postprocess(_get_tensor_value(x_init_inv))[0]) + pbar = tqdm(range(1, self.iteration + 1), leave=True) + for step in pbar: + loss = 0.0 + + # Reconstruction loss. + x_rec = self.G.net.synthesis(z) + loss_pix = torch.mean((x - x_rec) ** 2) + loss = loss + loss_pix * self.loss_pix_weight + log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}' + + # SSIM loss. + ssim_loss = pytorch_ssim.SSIM() + x_rec = self.G.net.synthesis(z) + ssim_out = -ssim_loss(x, x_rec) + + loss = loss + ssim_out * self.loss_weight_ssim + log_message += f', loss_ssim: {(- ssim_out.item()):.3f}' + + # Perceptual loss. + if self.loss_feat_weight: + x_feat = self.F.net(x) + x_rec_feat = self.F.net(x_rec) + loss_feat = torch.mean((x_feat - x_rec_feat) ** 2) + loss = loss + loss_feat * self.loss_feat_weight + log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}' + + # Regularization loss. + if self.loss_reg_weight: + z_rec = self.E.net(x_rec).view(1, *self.encode_dim) + loss_reg = torch.mean((z - z_rec) ** 2) + loss = loss + loss_reg * self.loss_reg_weight + log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}' + + + + log_message += f', loss: {_get_tensor_value(loss):.3f}' + pbar.set_description_str(log_message) + if self.logger: + self.logger.debug(f'Step: {step:05d}, ' + f'lr: {self.learning_rate:.2e}, ' + f'{log_message}') + + # Do optimization. + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if num_viz > 0 and step % (self.iteration // num_viz) == 0: + viz_results.append(self.G.postprocess(_get_tensor_value(x_rec))[0]) + + return _get_tensor_value(z), viz_results, - ssim_out.item() + + def easy_invert(self, image, num_viz=0): + """Wraps functions `preprocess()` and `invert()` together.""" + return self.invert(self.preprocess(image), num_viz) + + def diffuse(self, + target, + context, + center_x, + center_y, + crop_x, + crop_y, + num_viz=0): + """Diffuses the target image to a context image. + + Basically, this function is a motified version of `self.invert()`. More + concretely, the encoder regularizer is removed from the objectives and the + reconstruction loss is computed from the masked region. + + Args: + target: Target image (foreground). + context: Context image (background). + center_x: The x-coordinate of the crop center. + center_y: The y-coordinate of the crop center. + crop_x: The crop size along the x-axis. + crop_y: The crop size along the y-axis. + num_viz: Number of intermediate outputs to visualize. (default: 0) + + Returns: + A two-element tuple. First one is the inverted code. Second one is a list + of intermediate results, where first image is the direct copy-paste + image, second one is the reconstructed result from the initial latent + code, remainings are from the optimization process every + `self.iteration // num_viz` steps. + """ + image_shape = (self.G.image_channels, self.G.resolution, self.G.resolution) + mask = np.zeros((1, *image_shape), dtype=np.float32) + xx = center_x - crop_x // 2 + yy = center_y - crop_y // 2 + mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0 + + target = target[np.newaxis] + if context.ndim == 3: + context = self.preprocess(context)[np.newaxis] + else: + contexts = [] + for i in range(context.shape[0]): + contexts.append(self.preprocess(context[i])) + context = np.asarray(contexts) + x = target * mask + context * (1 - mask) + x = self.G.to_tensor(x.astype(np.float32)) + x.requires_grad = False + mask = self.G.to_tensor(mask.astype(np.float32)) + mask.requires_grad = False + + init_z = _get_tensor_value(self.E.net(x).view(-1, *self.encode_dim)) + init_z = init_z.astype(np.float32) + z = torch.Tensor(init_z).to(self.run_device) + z.requires_grad = True + + optimizer = torch.optim.Adam([z], lr=self.learning_rate) + + copy_and_paste = self.G.postprocess(_get_tensor_value(x)) + x_init_inv = self.G.net.synthesis(z) + encoder_out = self.G.postprocess(_get_tensor_value(x_init_inv)) + viz_results = {} + for it in range(context.shape[0]): + viz_results[it] = [] + viz_results[it].append(copy_and_paste[it]) + viz_results[it].append(encoder_out[it]) + + pbar = tqdm(range(1, self.iteration + 1), leave=True) + for step in pbar: + loss = 0.0 + + # Reconstruction loss. + x_rec = self.G.net.synthesis(z) + loss_pix = torch.mean(((x - x_rec) * mask) ** 2, dim=[1, 2, 3]) + loss = loss + loss_pix * self.loss_pix_weight + log_message = f'loss_pix: {np.mean(_get_tensor_value(loss_pix)):.3f}' + + # Perceptual loss. + if self.loss_feat_weight: + x_feat = self.F.net(x * mask) + x_rec_feat = self.F.net(x_rec * mask) + loss_feat = torch.mean((x_feat - x_rec_feat) ** 2, dim=[1, 2, 3]) + loss = loss + loss_feat * self.loss_feat_weight + log_message += f', loss_feat: {np.mean(_get_tensor_value(loss_feat)):.3f}' + + log_message += f', loss: {np.mean(_get_tensor_value(loss)):.3f}' + pbar.set_description_str(log_message) + if self.logger: + self.logger.debug(f'Step: {step:05d}, ' + f'lr: {self.learning_rate:.2e}, ' + f'{log_message}') + + # Do optimization. + optimizer.zero_grad() + loss.backward(torch.ones_like(loss)) + optimizer.step() + + if num_viz > 0 and step % (self.iteration // num_viz) == 0: + rec_res = self.G.postprocess(_get_tensor_value(x_rec)) + for it in range(rec_res.shape[0]): + viz_results[it].append(rec_res[it]) + + return _get_tensor_value(z), viz_results + + def easy_diffuse(self, target, context, *args, **kwargs): + """Wraps functions `preprocess()` and `diffuse()` together.""" + return self.diffuse(self.preprocess(target), + context, + *args, **kwargs)