diff --git a/iopaint/model/brushnet/brushnet_wrapper.py b/iopaint/model/brushnet/brushnet_wrapper.py index c7343d2e..eb3098e1 100644 --- a/iopaint/model/brushnet/brushnet_wrapper.py +++ b/iopaint/model/brushnet/brushnet_wrapper.py @@ -15,8 +15,12 @@ ) from .brushnet import BrushNetModel from .brushnet_unet_forward import brushnet_unet_forward -from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \ - UpBlock2D_forward +from .unet_2d_blocks import ( + CrossAttnDownBlock2D_forward, + DownBlock2D_forward, + CrossAttnUpBlock2D_forward, + UpBlock2D_forward, +) from ...schema import InpaintRequest, ModelType @@ -26,6 +30,7 @@ class BrushNetWrapper(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from .pipeline_brushnet import StableDiffusionBrushNetPipeline + self.model_info = kwargs["model_info"] self.brushnet_method = kwargs["brushnet_method"] @@ -52,7 +57,9 @@ def init_model(self, device: torch.device, **kwargs): ) logger.info(f"Loading BrushNet model from {self.brushnet_method}") - brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype) + brushnet = BrushNetModel.from_pretrained( + self.brushnet_method, torch_dtype=torch_dtype + ) if self.model_info.is_single_file_diffusers: if self.model_info.model_type == ModelType.DIFFUSERS_SD: @@ -64,7 +71,7 @@ def init_model(self, device: torch.device, **kwargs): self.model_id_or_path, torch_dtype=torch_dtype, load_safety_checker=not disable_nsfw_checker, - original_config_file=get_config_files()['v1'], + original_config_file=get_config_files()["v1"], brushnet=brushnet, **model_kwargs, ) @@ -94,31 +101,42 @@ def init_model(self, device: torch.device, **kwargs): self.callback = kwargs.pop("callback", None) # Monkey patch the forward method of the UNet to use the brushnet_unet_forward method - self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__) + self.model.unet.forward = brushnet_unet_forward.__get__( + self.model.unet, self.model.unet.__class__ + ) for down_block in self.model.brushnet.down_blocks: - down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__) + down_block.forward = DownBlock2D_forward.__get__( + down_block, down_block.__class__ + ) for up_block in self.model.brushnet.up_blocks: up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__) # Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward for down_block in self.model.unet.down_blocks: if down_block.__class__.__name__ == "CrossAttnDownBlock2D": - down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__) + down_block.forward = CrossAttnDownBlock2D_forward.__get__( + down_block, down_block.__class__ + ) else: - down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__) + down_block.forward = DownBlock2D_forward.__get__( + down_block, down_block.__class__ + ) for up_block in self.model.unet.up_blocks: if up_block.__class__.__name__ == "CrossAttnUpBlock2D": - up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__) + up_block.forward = CrossAttnUpBlock2D_forward.__get__( + up_block, up_block.__class__ + ) else: - up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__) + up_block.forward = UpBlock2D_forward.__get__( + up_block, up_block.__class__ + ) def switch_brushnet_method(self, new_method: str): self.brushnet_method = new_method brushnet = BrushNetModel.from_pretrained( new_method, - resume_download=True, local_files_only=self.local_files_only, torch_dtype=self.torch_dtype, ).to(self.model.device)