Skip to content

Commit

Permalink
(doc/fluxion/lora) add/convert docstrings to mkdocstrings format
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Feb 1, 2024
1 parent 8d4c734 commit 82264f7
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
9 changes: 8 additions & 1 deletion src/refiners/fluxion/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter

__all__ = ["Adapter"]
__all__ = [
"Adapter",
"Lora",
"LinearLora",
"Conv2dLora",
"LoraAdapter",
]
120 changes: 117 additions & 3 deletions src/refiners/fluxion/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@


class Lora(fl.Chain, ABC):
"""Low-rank approximation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
- `down`: initialized with a random normal distribution
- `up`: initialized with zeros
Note:
This layer is not meant to be used directly.
Instead, use one of its subclasses:
- [`LinearLora`][refiners.fluxion.adapters.lora.LinearLora]
- [`Conv2dLora`][refiners.fluxion.adapters.lora.Conv2dLora]
"""

def __init__(
self,
name: str,
Expand All @@ -18,11 +33,23 @@ def __init__(
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the LoRA layer.
Args:
name: The name of the LoRA.
rank: The rank of the LoRA.
scale: The scale of the LoRA.
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
self.name = name
self._rank = rank
self._scale = scale

super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
super().__init__(
*self.lora_layers(device=device, dtype=dtype),
fl.Multiply(scale),
)

normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)
Expand All @@ -31,26 +58,36 @@ def __init__(
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
"""Create the down and up layers of the LoRA.
Args:
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
...

@property
def down(self) -> fl.WeightedModule:
"""The down layer."""
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer

@property
def up(self) -> fl.WeightedModule:
"""The up layer."""
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer

@property
def rank(self) -> int:
"""The rank of the low-rank approximation."""
return self._rank

@property
def scale(self) -> float:
"""The scale of the low-rank approximation."""
return self._scale

@scale.setter
Expand Down Expand Up @@ -119,13 +156,24 @@ def auto_attach(
return LoraAdapter(layer, self), parent

def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
"""Load the weights of the LoRA.
Args:
down_weight: The down weight.
up_weight: The up weight.
"""
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))


class LinearLora(Lora):
"""Low-rank approximation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
"""

def __init__(
self,
name: str,
Expand All @@ -137,10 +185,27 @@ def __init__(
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the LoRA layer.
Args:
name: The name of the LoRA.
in_features: The number of input features.
out_features: The number of output features.
rank: The rank of the LoRA.
scale: The scale of the LoRA.
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
self.in_features = in_features
self.out_features = out_features

super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
super().__init__(
name,
rank=rank,
scale=scale,
device=device,
dtype=dtype,
)

@classmethod
def from_weights(
Expand Down Expand Up @@ -190,6 +255,11 @@ def is_compatible(self, layer: fl.WeightedModule, /) -> bool:


class Conv2dLora(Lora):
"""Low-rank approximation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
"""

def __init__(
self,
name: str,
Expand All @@ -204,13 +274,33 @@ def __init__(
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the LoRA layer.
Args:
name: The name of the LoRA.
in_channels: The number of input channels.
out_channels: The number of output channels.
rank: The rank of the LoRA.
scale: The scale of the LoRA.
kernel_size: The kernel size of the LoRA.
stride: The stride of the LoRA.
padding: The padding of the LoRA.
device: The device of the LoRA weights.
dtype: The dtype of the LoRA weights.
"""
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding

super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)
super().__init__(
name,
rank=rank,
scale=scale,
device=device,
dtype=dtype,
)

@classmethod
def from_weights(
Expand Down Expand Up @@ -279,20 +369,34 @@ def is_compatible(self, layer: fl.WeightedModule, /) -> bool:


class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
"""Adapter for LoRA layers.
This adapter simply sums the target layer with the given LoRA layers.
"""

def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
"""Initialize the adapter.
Args:
target: The target layer.
loras: The LoRA layers.
"""
with self.setup_adapter(target):
super().__init__(target, *loras)

@property
def names(self) -> list[str]:
"""The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora)]

@property
def loras(self) -> dict[str, Lora]:
"""The LoRA layers."""
return {lora.name: lora for lora in self.layers(Lora)}

@property
def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers."""
return {lora.name: lora.scale for lora in self.layers(Lora)}

@scales.setter
Expand All @@ -301,10 +405,20 @@ def scale(self, values: dict[str, float]) -> None:
self.loras[name].scale = value

def add_lora(self, lora: Lora, /) -> None:
"""Add a LoRA layer to the adapter.
Args:
lora: The LoRA layer to add.
"""
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora)

def remove_lora(self, name: str, /) -> Lora | None:
"""Remove a LoRA layer from the adapter.
Args:
name: The name of the LoRA layer to remove.
"""
if name in self.names:
lora = self.loras[name]
self.remove(lora)
Expand Down

0 comments on commit 82264f7

Please sign in to comment.