Skip to content

Commit

Permalink
Simple LoRA Finetuning (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
rishab-partha authored Aug 21, 2024
1 parent 1d4e6fa commit c63006e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
77 changes: 77 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from composer.devices import DeviceGPU
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel
from peft import LoraConfig
from torchmetrics import MeanSquaredError
from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig

Expand Down Expand Up @@ -67,6 +68,8 @@ def stable_diffusion_2(
fsdp: bool = True,
clip_qkv: Optional[float] = None,
use_xformers: bool = True,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
):
"""Stable diffusion v2 training setup.
Expand Down Expand Up @@ -108,6 +111,8 @@ def stable_diffusion_2(
fsdp (bool): Whether to use FSDP. Defaults to True.
clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None.
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

Expand Down Expand Up @@ -215,6 +220,40 @@ def stable_diffusion_2(
mask_pad_tokens=mask_pad_tokens,
fsdp=fsdp,
)
if lora_rank is not None:
assert lora_alpha is not None
model.unet.requires_grad_(False)
for param in model.unet.parameters():
param.requires_grad_(False)

unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights='gaussian',
target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'],
)
model.unet.add_adapter(unet_lora_config)
model.unet._fsdp_wrap = True
if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None:
for attention in model.unet.mid_block.attentions:
attention._fsdp_wrap = True
for resnet in model.unet.mid_block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.up_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.down_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True

if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
if is_xformers_installed and use_xformers:
Expand Down Expand Up @@ -262,6 +301,8 @@ def stable_diffusion_xl(
fsdp: bool = True,
clip_qkv: Optional[float] = None,
use_xformers: bool = True,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
):
"""Stable diffusion 2 training setup + SDXL UNet and VAE.
Expand Down Expand Up @@ -315,6 +356,8 @@ def stable_diffusion_xl(
clip_qkv (float, optional): If not None, clip the qkv values to this value. Improves stability of training.
Default: ``None``.
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

Expand Down Expand Up @@ -481,6 +524,40 @@ def stable_diffusion_xl(
fsdp=fsdp,
sdxl=True,
)

if lora_rank is not None:
assert lora_alpha is not None
model.unet.requires_grad_(False)
for param in model.unet.parameters():
param.requires_grad_(False)

unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights='gaussian',
target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'],
)
model.unet.add_adapter(unet_lora_config)
model.unet._fsdp_wrap = True
if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None:
for attention in model.unet.mid_block.attentions:
attention._fsdp_wrap = True
for resnet in model.unet.mid_block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.up_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.down_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True
if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
if is_xformers_installed and use_xformers:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'diffusers[torch]==0.26.3', 'transformers[torch]==4.38.2', 'huggingface_hub==0.21.2', 'wandb==0.16.3',
'xformers==0.0.23.post1', 'triton==2.1.0', 'torchmetrics[image]==1.3.1', 'lpips==0.1.4', 'clean-fid==0.1.35',
'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', 'gradio==4.19.2',
'datasets==2.19.2'
'datasets==2.19.2', 'peft==0.12.0'
]

extras_require = {}
Expand Down

0 comments on commit c63006e

Please sign in to comment.