Skip to content

Commit

Permalink
huggingface_hub deprecate resume_download
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanster committed Nov 23, 2024
1 parent 36e65f0 commit b58c333
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions iopaint/model/brushnet/brushnet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
)
from .brushnet import BrushNetModel
from .brushnet_unet_forward import brushnet_unet_forward
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
UpBlock2D_forward
from .unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
from ...schema import InpaintRequest, ModelType


Expand All @@ -26,6 +30,7 @@ class BrushNetWrapper(DiffusionInpaintModel):

def init_model(self, device: torch.device, **kwargs):
from .pipeline_brushnet import StableDiffusionBrushNetPipeline

self.model_info = kwargs["model_info"]
self.brushnet_method = kwargs["brushnet_method"]

Expand All @@ -52,7 +57,9 @@ def init_model(self, device: torch.device, **kwargs):
)

logger.info(f"Loading BrushNet model from {self.brushnet_method}")
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
brushnet = BrushNetModel.from_pretrained(
self.brushnet_method, torch_dtype=torch_dtype
)

if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
Expand All @@ -64,7 +71,7 @@ def init_model(self, device: torch.device, **kwargs):
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker,
original_config_file=get_config_files()['v1'],
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
**model_kwargs,
)
Expand Down Expand Up @@ -94,31 +101,42 @@ def init_model(self, device: torch.device, **kwargs):
self.callback = kwargs.pop("callback", None)

# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
self.model.unet.forward = brushnet_unet_forward.__get__(
self.model.unet, self.model.unet.__class__
)

for down_block in self.model.brushnet.down_blocks:
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in self.model.brushnet.up_blocks:
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)

# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in self.model.unet.down_blocks:
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)

for up_block in self.model.unet.up_blocks:
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)

def switch_brushnet_method(self, new_method: str):
self.brushnet_method = new_method
brushnet = BrushNetModel.from_pretrained(
new_method,
resume_download=True,
local_files_only=self.local_files_only,
torch_dtype=self.torch_dtype,
).to(self.model.device)
Expand Down

0 comments on commit b58c333

Please sign in to comment.