diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 0fcbe2000ce7..15d048f0106e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -934,6 +934,7 @@ class DreamBoothDataset(Dataset): def __init__( self, + args, instance_data_root, instance_prompt, class_prompt, @@ -943,10 +944,8 @@ def __init__( class_num=None, size=1024, repeats=1, - center_crop=False, ): self.size = size - self.center_crop = center_crop self.instance_prompt = instance_prompt self.custom_instance_prompts = None @@ -1035,11 +1034,11 @@ def __init__( # flip image = train_flip(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + y1 = max(0, int(round((image.height - self.size) / 2.0))) + x1 = max(0, int(round((image.width - self.size) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params(image, (self.size, self.size)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -1875,6 +1874,7 @@ def load_model_hook(models, input_dir): # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( + args=args, instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, train_text_encoder_ti=args.train_text_encoder_ti, @@ -1884,7 +1884,6 @@ def load_model_hook(models, input_dir): class_num=args.num_class_images, size=args.resolution, repeats=args.repeats, - center_crop=args.center_crop, ) train_dataloader = torch.utils.data.DataLoader(