-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
277 additions
and
22 deletions.
There are no files selected for viewing
File renamed without changes
File renamed without changes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes
File renamed without changes
File renamed without changes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Chain | ||
|
||
|
||
When we say models are implemented in a declarative way in Refiners, what this means in practice is they are implemented as Chains. `Chain` is a Python class to implement trees of modules. It is a subclass of Refiners' `Module`, which is in turn a subclass of PyTorch's `Module`. All inner nodes of a Chain are subclasses of `Chain`, and leaf nodes are subclasses of Refiners' `Module`. | ||
|
||
## A first example | ||
|
||
To give you an idea of how it looks, let us take an example similar to the one from the PyTorch paper: | ||
|
||
```py | ||
class BasicModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = nn.Conv2d(1, 128, 3) | ||
self.linear_1 = nn.Linear(128, 40) | ||
self.linear_2 = nn.Linear(40, 10) | ||
|
||
def forward(self, x): | ||
t1 = self.conv(x) | ||
t2 = nn.functional.relu(t1) | ||
t3 = self.linear_1(t2) | ||
t4 = self.linear_2(t3) | ||
return nn.functional.softmax(t4) | ||
``` | ||
|
||
Here is how we could implement the same model in Refiners: | ||
|
||
```py | ||
class BasicModel(fl.Chain): | ||
def __init__(self): | ||
super().__init__( | ||
fl.Conv2d(1, 128, 3), | ||
fl.ReLU(), | ||
fl.Linear(128, 40), | ||
fl.Linear(40, 10), | ||
fl.Lambda(torch.nn.functional.softmax), | ||
) | ||
``` | ||
|
||
> **Note** - We often use the namespace `fl` which means `fluxion`, which is the name of the part of Refiners that implements basic layers. | ||
As of writing, Refiners does not include a `Softmax` layer by default, but as you can see you can easily call arbitrary code using `fl.Lambda`. Alternatively, if you just wanted to write `Softmax()`, you could implement it like this: | ||
|
||
```py | ||
class Softmax(fl.Module): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return torch.nn.functional.softmax(x) | ||
``` | ||
|
||
> Note that we use type hints here. All of Refiners' codebase is typed, which makes it a pleasure to use if your downstream code is typed too. | ||
## Inspecting and manipulating | ||
|
||
Let us instantiate the `BasicModel` we just defined and inspect its representation in a Python REPL: | ||
|
||
``` | ||
>>> m = BasicModel() | ||
>>> m | ||
(CHAIN) BasicModel() | ||
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) | ||
├── ReLU() | ||
├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1 | ||
├── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2 | ||
└── Softmax() | ||
``` | ||
|
||
The children of a `Chain` are stored in a dictionary and can be accessed by name or index. When layers of the same type appear in the Chain, distinct suffixed keys are automatically generated. | ||
|
||
|
||
``` | ||
>>> m[0] | ||
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) | ||
>>> m.Conv2d | ||
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) | ||
>>> m[3] | ||
Linear(in_features=40, out_features=10, device=cpu, dtype=float32) | ||
>>> m.Linear_2 | ||
Linear(in_features=40, out_features=10, device=cpu, dtype=float32) | ||
``` | ||
|
||
The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to wrap the two `Linear`s in a subchain. Here is how I could do it: | ||
|
||
|
||
```py | ||
m.insert_after_type(fl.ReLU, fl.Chain(m.pop(2), m.pop(2))) | ||
``` | ||
|
||
Did it work? Let's see: | ||
|
||
``` | ||
>>> m | ||
(CHAIN) BasicModel() | ||
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) | ||
├── ReLU() | ||
├── (CHAIN) | ||
│ ├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1 | ||
│ └── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2 | ||
└── Softmax() | ||
``` | ||
|
||
## Accessing and iterating | ||
|
||
There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named `walk`. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do: | ||
|
||
```py | ||
for x in m.layers(fl.WeightedModule): | ||
print(x) | ||
``` | ||
|
||
It prints: | ||
|
||
``` | ||
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) | ||
Linear(in_features=128, out_features=40, device=cpu, dtype=float32) | ||
Linear(in_features=40, out_features=10, device=cpu, dtype=float32 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Context | ||
|
||
## Motivation: avoiding "props drilling" | ||
|
||
Chains are a powerful tool to represent computational graphs, but they are not always convenient. | ||
|
||
Many adapters add extra input to the model. For instance, ControlNet and T2i-Adapter require a guide (condition image), inpainting adapters require a mask, Latent Consistency Models use a condition scale embedding, other adapters may leverage time or context embeddings... Those inputs are often passed by the user in a high-level format (numbers, text...) and converted to embeddings by the model before being consumed in downstream layers. | ||
|
||
Managing this extra input is inconvenient. Typically, you would add them to the inputs and outputs of each layer somehow. But if you add them as channels or concatenate them you get composability issues, and if you try to pass them as extra arguments you end up needing to deal with them in layers that should not be concerned with their presence. | ||
|
||
The same kind of having to pass extra contextual information up and down a tree exists in other fields, and in particular in JavaScript frameworks that deal with a Virtual DOM such as React, where it is called "props drilling". To make it easier to manage, the [Context API](https://react.dev/learn/passing-data-deeply-with-context) was introduced, and we went with a similar idea in Refiners. | ||
|
||
## A simple example | ||
|
||
Here is an example of how contexts work: | ||
|
||
|
||
```py | ||
from refiners.fluxion.context import Contexts | ||
|
||
class MyProvider(fl.Chain): | ||
def init_context(self) -> Contexts: | ||
return {"my context": {"my key": None}} | ||
|
||
m = MyProvider( | ||
fl.Chain( | ||
fl.Sum( | ||
fl.UseContext("my context", "my key"), | ||
fl.Lambda(lambda: 2), | ||
), | ||
fl.SetContext("my context", "my key"), | ||
), | ||
fl.Chain( | ||
fl.UseContext("my context", "my key"), | ||
fl.Lambda(print), | ||
), | ||
) | ||
|
||
m.set_context("my context", {"my key": 4}) | ||
m() # prints 6 | ||
``` | ||
|
||
As you can see, to use the context, you define it by subclassing any `Chain` and defining `init_context`. You can set the context with the `set_context` method or the `SetContext` layer, and you can access it anywhere down the provider's tree with `UseContext`. | ||
|
||
## Simplifying complex models with Context | ||
|
||
Another use of the context is simplifying complex models, in particular those with long-range nested skip connections. | ||
|
||
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net: | ||
|
||
```py | ||
square = fl.Lambda(lambda x: x ** 2) | ||
sqrt = fl.Lambda(lambda x: x ** 0.5) | ||
|
||
m1 = fl.Chain( | ||
fl.Residual( | ||
square, | ||
fl.Residual( | ||
square, | ||
fl.Residual( | ||
square, | ||
), | ||
sqrt, | ||
), | ||
sqrt, | ||
), | ||
sqrt, | ||
) | ||
``` | ||
|
||
You can see two problems here: | ||
|
||
- nesting is increasing 1 lever with every residual, in a real case this would become unreadable; | ||
- you could not isolate the part that computes the squares (similar to down blocks in a U-Net) from the part that computes the square roots (similar to up blocks in a U-Net). | ||
|
||
Let us solve those two issues using the context: | ||
|
||
```py | ||
from refiners.fluxion.context import Contexts | ||
|
||
class MyModel(fl.Chain): | ||
def init_context(self) -> Contexts: | ||
return {"mymodel": {"residuals": []}} | ||
|
||
push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x)) | ||
|
||
class ApplyResidual(fl.Sum): | ||
def __init__(self): | ||
super().__init__( | ||
fl.Identity(), | ||
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()), | ||
) | ||
|
||
squares = fl.Chain(x for _ in range(3) for x in (push_residual, square)) | ||
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt)) | ||
m2 = MyModel(squares, sqrts) | ||
``` | ||
|
||
As you can see, despite `squares` and `sqrts` being completely independent chains, they can access the same context due to being nested under the same provider. | ||
|
||
Does it work? | ||
|
||
``` | ||
>>> m1(2.0) | ||
2.5547711633552384 | ||
>>> m2(2.0) | ||
2.5547711633552384 | ||
``` | ||
|
||
Yes!✨ |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,24 @@ | ||
# Refiners - Docs | ||
|
||
WIP | ||
## Why Refiners? | ||
|
||
PyTorch is a great framework to implement deep learning models, widely adopted in academia and industry around the globe. A core design principle of PyTorch is that users write *imperative* Python code that manipulates Tensors[^1]. This code can be organized in Modules, which are just Python classes whose constructors typically initialize parameters and load weights, and which implement a `forward` method that computes the forward pass. Dealing with reconstructing an inference graph, backpropagation and so on are left to the framework. | ||
|
||
This approach works very well in general, as demonstrated by the popularity of PyTorch. However, the growing importance of the Adaptation pattern is challenging it. | ||
|
||
Adaptation is the idea of *patching* existing powerful models to implement new capabilities. Those models are called foundation models; they are typically trained from scratch on amounts of data inaccessible to most individuals, small companies or research labs, and exhibit emergent properties. Examples of such models are LLMs (GPT, LLaMa, Mistral), image generation models (Stable Diffusion, Muse), vision models (BLIP-2, LLaVA 1.5, Fuyu-8B) but also models trained on more specific tasks such as embedding extraction (CLIP, DINOv2) or image segmentation (SAM). | ||
|
||
Adaptation of foundational models can take many forms. One of the simplest but most powerful derives from fine-tuning: re-training a subset of the weights of the model on a specific task, then distributing only those weights. Add to this a trick to significantly reduce the size of the fine-tuned weights and you get LoRA[^2], which is probably the most well-known adaptation method. However, adaptation can go beyond that and change the shape of the model or its inputs. | ||
|
||
There are several approaches to patch the code of a foundation model implemented in typical PyTorch imperative style to support adaptation, including: | ||
|
||
- Just duplicate the original code base and edit it in place unconditionally. This approach is often adopted by researchers today. | ||
- Change the original code base to optionally support the adapter. This approach is often used by frameworks and libraries built on top of PyTorch and works well for a single adapter. However, as you start adding support for multiple adapters to the same foundational module the cyclomatic complexity explodes and the code becomes hard to maintain and error-prone. The end result is that adapters typically do not compose well. | ||
- Change the original code to abstract adaptation by adding ad-hoc hooks everywhere. This approach has the advantage of keeping the foundational model independent from its adapter, but it makes the code extremely non-linear and hard to reason about - so-called "spaghetti code". | ||
|
||
As believers in adaptation, none of those approaches was appealing to us, so we designed Refiners as a better option. Refiners is a micro-framework built on top of PyTorch which does away with its imperative style. In Refiners, models are implemented in a *declarative* way instead, which makes them by nature easier to manipulate and patch. | ||
|
||
Now that you know *why* we do that, you can check out [*how*](/concepts/chain/). It's not that hard, we promise! | ||
|
||
[^1]: Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library. | ||
[^2]: Hu et al., 2022. LoRA: Low-Rank Adaptation of Large Language Models. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters