Skip to content

Commit

Permalink
key concept docs
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Feb 1, 2024
1 parent ab63a01 commit a737643
Show file tree
Hide file tree
Showing 11 changed files with 277 additions and 22 deletions.
File renamed without changes
File renamed without changes
44 changes: 26 additions & 18 deletions docs/concepts/adapters.md → docs/concepts/adapter/index.md
Original file line number Diff line number Diff line change
@@ -1,45 +1,47 @@
# What is an Adapter in Refiners? A technical overview
# Adapter

An Adapter is a Chain that replaces a Module (the target) in another Chain (the parent). Typically the target will become a child of the adapter.
Adapters are the final and most high-level abstraction in Refiners. They are the concept of adaptation turned into code.

An Adapter is [generally](#higher-level-adapters) a Chain that replaces a Module (the target) in another Chain (the parent). Typically the target will become a child of the adapter.

In code terms, `Adapter` is a generic mixin. Adapters subclass `type(parent)` and `Adapter[type(target)]`. For instance, if you adapt a Conv2d in a Sum, the definition of the Adapter could look like:

```py
class MyAdapter(Sum, Adapter[Conv2d]):
class MyAdapter(fl.Sum, fl.Adapter[fl.Conv2d]):
...
```

## A simple example: adapting a Linear

Let us take a simple example to see how this works. Consider this model:

![before](assets/adapters/linear-before.png)
![before](linear-before.png)

In pseudo-code, it could look like this:
In code, it could look like this:

```py
my_model = MyModel(Chain(Linear(), Chain(...)))
my_model = MyModel(fl.Chain(fl.Linear(), fl.Chain(...)))
```

Suppose we want to adapt the Linear to sum its output with the result of another chain. We can define and initialize an adapter like this:

```py
class MyAdapter(Sum, Adapter[Linear]):
def __init__(self, target: Linear) -> None:
class MyAdapter(fl.Sum, fl.Adapter[fl.Linear]):
def __init__(self, target: fl.Linear) -> None:
with self.setup_adapter(target):
super().__init__(Chain(...), target)
super().__init__(fl.Chain(...), target)

# Find the target and its parent in the chain.
# For simplicity let us assume it is the only Linear.
for target, parent in my_model.walk(Linear):
for target, parent in my_model.walk(fl.Linear):
break

adapter = MyAdapter(target)
```

The result is now this:

![ejected](assets/adapters/linear-ejected.png)
![ejected](linear-ejected.png)

Note that the original chain is unmodified. You can still run inference on it as if the adapter did not exist. To use the adapter, you must inject it into the chain:

Expand All @@ -49,7 +51,7 @@ adapter.inject(parent)

The result will be:

![injected](assets/adapters/linear-injected.png)
![injected](linear-injected.png)

Now if you run inference it will go through the Adapter. You can go back to the previous situation by calling `adapter.eject()`.

Expand All @@ -65,18 +67,18 @@ Starting from the same model as earlier, let us assume we want to:
This Adapter that will perform a `structural_copy` of part of its target, which means it will duplicate all Chain nodes but keep pointers to the same `WeightedModule`s, and hence not use extra GPU memory.

```py
class MyAdapter(Chain, Adapter[Chain]):
def __init__(self, target: Linear) -> None:
class MyAdapter(fl.Chain, fl.Adapter[fl.Chain]):
def __init__(self, target: fl.Linear) -> None:
with self.setup_adapter(target):
new_b = Chain(target, target.Chain.Chain_2.structural_copy())
new_b = fl.Chain(target, target.Chain.Chain_2.structural_copy())
super().__init__(new_b, target.Linear)

adapter = MyAdapter(my_model.Chain_1) # Chain A in the diagram
```

We end up with this:

![chain-ejected](assets/adapters/chain-ejected.png)
![chain-ejected](chain-ejected.png)

We can now inject it into the original graph. It is not even needed to pass the parent this time, since Chains know their parents.

Expand All @@ -86,12 +88,18 @@ adapter.inject()

We obtain this:

![chain-injected](assets/adapters/chain-injected.png)
![chain-injected](chain-injected.png)

Note that the Linear is in the Chain twice now, but that does not matter as long as you really want it to be the same Linear layer with the same weights.

As before, we can call eject the adapter to go back to the original model.

## A real-world example: LoraAdapter

A popular example of adaptation is [LoRA](https://arxiv.org/abs/2106.09685). You can check out [how we implement it in Refiners](../src/refiners/fluxion/adapters/lora.py).
A popular example of adaptation is [LoRA](https://arxiv.org/abs/2106.09685). You can check out [how we implement it in Refiners](https://github.com/finegrain-ai/refiners/blob/main/src/refiners/fluxion/adapters/lora.py).

## Higher-level adapters

If you use Refiners, you will find Adapters that go beyond the simple definition given at the top of this page. Some adapters inject multiple smaller adapters in models, others implement helper methods to be used by their caller...

From a bird's eye view, you can just consider Adapters as things you inject models to adapt them, and that can be ejected to return the model to its original state. You will get a better feel for what is an adapter and how to leverage them by actually using the framework.
File renamed without changes
File renamed without changes
File renamed without changes
116 changes: 116 additions & 0 deletions docs/concepts/chain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Chain


When we say models are implemented in a declarative way in Refiners, what this means in practice is they are implemented as Chains. `Chain` is a Python class to implement trees of modules. It is a subclass of Refiners' `Module`, which is in turn a subclass of PyTorch's `Module`. All inner nodes of a Chain are subclasses of `Chain`, and leaf nodes are subclasses of Refiners' `Module`.

## A first example

To give you an idea of how it looks, let us take an example similar to the one from the PyTorch paper:

```py
class BasicModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 128, 3)
self.linear_1 = nn.Linear(128, 40)
self.linear_2 = nn.Linear(40, 10)

def forward(self, x):
t1 = self.conv(x)
t2 = nn.functional.relu(t1)
t3 = self.linear_1(t2)
t4 = self.linear_2(t3)
return nn.functional.softmax(t4)
```

Here is how we could implement the same model in Refiners:

```py
class BasicModel(fl.Chain):
def __init__(self):
super().__init__(
fl.Conv2d(1, 128, 3),
fl.ReLU(),
fl.Linear(128, 40),
fl.Linear(40, 10),
fl.Lambda(torch.nn.functional.softmax),
)
```

> **Note** - We often use the namespace `fl` which means `fluxion`, which is the name of the part of Refiners that implements basic layers.
As of writing, Refiners does not include a `Softmax` layer by default, but as you can see you can easily call arbitrary code using `fl.Lambda`. Alternatively, if you just wanted to write `Softmax()`, you could implement it like this:

```py
class Softmax(fl.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softmax(x)
```

> Note that we use type hints here. All of Refiners' codebase is typed, which makes it a pleasure to use if your downstream code is typed too.
## Inspecting and manipulating

Let us instantiate the `BasicModel` we just defined and inspect its representation in a Python REPL:

```
>>> m = BasicModel()
>>> m
(CHAIN) BasicModel()
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
├── ReLU()
├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1
├── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2
└── Softmax()
```

The children of a `Chain` are stored in a dictionary and can be accessed by name or index. When layers of the same type appear in the Chain, distinct suffixed keys are automatically generated.


```
>>> m[0]
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
>>> m.Conv2d
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
>>> m[3]
Linear(in_features=40, out_features=10, device=cpu, dtype=float32)
>>> m.Linear_2
Linear(in_features=40, out_features=10, device=cpu, dtype=float32)
```

The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to wrap the two `Linear`s in a subchain. Here is how I could do it:


```py
m.insert_after_type(fl.ReLU, fl.Chain(m.pop(2), m.pop(2)))
```

Did it work? Let's see:

```
>>> m
(CHAIN) BasicModel()
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
├── ReLU()
├── (CHAIN)
│ ├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1
│ └── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2
└── Softmax()
```

## Accessing and iterating

There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named `walk`. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do:

```py
for x in m.layers(fl.WeightedModule):
print(x)
```

It prints:

```
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
Linear(in_features=128, out_features=40, device=cpu, dtype=float32)
Linear(in_features=40, out_features=10, device=cpu, dtype=float32
```
110 changes: 110 additions & 0 deletions docs/concepts/context.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Context

## Motivation: avoiding "props drilling"

Chains are a powerful tool to represent computational graphs, but they are not always convenient.

Many adapters add extra input to the model. For instance, ControlNet and T2i-Adapter require a guide (condition image), inpainting adapters require a mask, Latent Consistency Models use a condition scale embedding, other adapters may leverage time or context embeddings... Those inputs are often passed by the user in a high-level format (numbers, text...) and converted to embeddings by the model before being consumed in downstream layers.

Managing this extra input is inconvenient. Typically, you would add them to the inputs and outputs of each layer somehow. But if you add them as channels or concatenate them you get composability issues, and if you try to pass them as extra arguments you end up needing to deal with them in layers that should not be concerned with their presence.

The same kind of having to pass extra contextual information up and down a tree exists in other fields, and in particular in JavaScript frameworks that deal with a Virtual DOM such as React, where it is called "props drilling". To make it easier to manage, the [Context API](https://react.dev/learn/passing-data-deeply-with-context) was introduced, and we went with a similar idea in Refiners.

## A simple example

Here is an example of how contexts work:


```py
from refiners.fluxion.context import Contexts

class MyProvider(fl.Chain):
def init_context(self) -> Contexts:
return {"my context": {"my key": None}}

m = MyProvider(
fl.Chain(
fl.Sum(
fl.UseContext("my context", "my key"),
fl.Lambda(lambda: 2),
),
fl.SetContext("my context", "my key"),
),
fl.Chain(
fl.UseContext("my context", "my key"),
fl.Lambda(print),
),
)

m.set_context("my context", {"my key": 4})
m() # prints 6
```

As you can see, to use the context, you define it by subclassing any `Chain` and defining `init_context`. You can set the context with the `set_context` method or the `SetContext` layer, and you can access it anywhere down the provider's tree with `UseContext`.

## Simplifying complex models with Context

Another use of the context is simplifying complex models, in particular those with long-range nested skip connections.

To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:

```py
square = fl.Lambda(lambda x: x ** 2)
sqrt = fl.Lambda(lambda x: x ** 0.5)

m1 = fl.Chain(
fl.Residual(
square,
fl.Residual(
square,
fl.Residual(
square,
),
sqrt,
),
sqrt,
),
sqrt,
)
```

You can see two problems here:

- nesting is increasing 1 lever with every residual, in a real case this would become unreadable;
- you could not isolate the part that computes the squares (similar to down blocks in a U-Net) from the part that computes the square roots (similar to up blocks in a U-Net).

Let us solve those two issues using the context:

```py
from refiners.fluxion.context import Contexts

class MyModel(fl.Chain):
def init_context(self) -> Contexts:
return {"mymodel": {"residuals": []}}

push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x))

class ApplyResidual(fl.Sum):
def __init__(self):
super().__init__(
fl.Identity(),
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()),
)

squares = fl.Chain(x for _ in range(3) for x in (push_residual, square))
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt))
m2 = MyModel(squares, sqrts)
```

As you can see, despite `squares` and `sqrts` being completely independent chains, they can access the same context due to being nested under the same provider.

Does it work?

```
>>> m1(2.0)
2.5547711633552384
>>> m2(2.0)
2.5547711633552384
```

Yes!✨
1 change: 0 additions & 1 deletion docs/concepts/index.md

This file was deleted.

23 changes: 22 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
# Refiners - Docs

WIP
## Why Refiners?

PyTorch is a great framework to implement deep learning models, widely adopted in academia and industry around the globe. A core design principle of PyTorch is that users write *imperative* Python code that manipulates Tensors[^1]. This code can be organized in Modules, which are just Python classes whose constructors typically initialize parameters and load weights, and which implement a `forward` method that computes the forward pass. Dealing with reconstructing an inference graph, backpropagation and so on are left to the framework.

This approach works very well in general, as demonstrated by the popularity of PyTorch. However, the growing importance of the Adaptation pattern is challenging it.

Adaptation is the idea of *patching* existing powerful models to implement new capabilities. Those models are called foundation models; they are typically trained from scratch on amounts of data inaccessible to most individuals, small companies or research labs, and exhibit emergent properties. Examples of such models are LLMs (GPT, LLaMa, Mistral), image generation models (Stable Diffusion, Muse), vision models (BLIP-2, LLaVA 1.5, Fuyu-8B) but also models trained on more specific tasks such as embedding extraction (CLIP, DINOv2) or image segmentation (SAM).

Adaptation of foundational models can take many forms. One of the simplest but most powerful derives from fine-tuning: re-training a subset of the weights of the model on a specific task, then distributing only those weights. Add to this a trick to significantly reduce the size of the fine-tuned weights and you get LoRA[^2], which is probably the most well-known adaptation method. However, adaptation can go beyond that and change the shape of the model or its inputs.

There are several approaches to patch the code of a foundation model implemented in typical PyTorch imperative style to support adaptation, including:

- Just duplicate the original code base and edit it in place unconditionally. This approach is often adopted by researchers today.
- Change the original code base to optionally support the adapter. This approach is often used by frameworks and libraries built on top of PyTorch and works well for a single adapter. However, as you start adding support for multiple adapters to the same foundational module the cyclomatic complexity explodes and the code becomes hard to maintain and error-prone. The end result is that adapters typically do not compose well.
- Change the original code to abstract adaptation by adding ad-hoc hooks everywhere. This approach has the advantage of keeping the foundational model independent from its adapter, but it makes the code extremely non-linear and hard to reason about - so-called "spaghetti code".

As believers in adaptation, none of those approaches was appealing to us, so we designed Refiners as a better option. Refiners is a micro-framework built on top of PyTorch which does away with its imperative style. In Refiners, models are implemented in a *declarative* way instead, which makes them by nature easier to manipulate and patch.

Now that you know *why* we do that, you can check out [*how*](/concepts/chain/). It's not that hard, we promise!

[^1]: Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library.
[^2]: Hu et al., 2022. LoRA: Low-Rank Adaptation of Large Language Models.
5 changes: 3 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ nav:
- Guides:
- guides/index.md
- Key Concepts:
- concepts/index.md
- concepts/adapters.md
- concepts/chain.md
- concepts/context.md
- concepts/adapter/index.md
- API Reference:
- index.md
extra:
Expand Down

0 comments on commit a737643

Please sign in to comment.