Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose klinen #386

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions etils/klinen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Klinen - Torch-like API for flax

Klinen is a small wrapper around `flax.linen`. The goal is to provide a
stateless, object-oriented, supporting auto-complete and type checking.

## Documentation

### Model creation

Model creation is similar to flax, except the modules should inherit from
`klinen` instead of `linen`:

```python
from flax import linen as nn
from kauldron import klinen as knn


class MLP(knn.Module):

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
return nn.Dense(32)(x)


class AutoEncoder(knn.Module):
encoder: knn.Module
decoder: knn.Module

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
return self.decoder(self.encoder(x))
```

* Inside `knn.Module`, any linen modules can be used.

### Model initialization / usage

To initialize the model, use `model.init_bind()` instead of `model.init()`. It
will return a copy of the module with bind parameters.

```python
model = AutoEncoder(
encoder=MLP(),
decoder=MLP(),
)
model = model.init_bind(rng, jnp.zeros((batch_size, 64)))

# After the model is initialized, it can be called directly
y = model(x)

# You can also call individual sub-modules
y = model.encoder(x)
```

### Randomness

`klinen` modules are stateless, this mean they are fully deterministic. Calling
the same model twice will always return the same result. If your model uses
randomness (e.g. `nn.Dropout`), the `rng` key has to be explicitly provided:

```python
model = model.with_rng(rng) # Replace the current rng.

y0 = model(x)
y1 = model(x)

assert jnp.allclose(y0, y1) # Calling the model twice give the same output

model = model.with_rng(rng2) # Set a new rng
```

Multiple values are accepted:

* `model.with_rng({'dropout': rng})`: Rng streams explicitly defined
* `model.with_rng(rng)`: Key distributed among streams (with
`rng.fold_in(stream_name)`)
* `model.with_rng()`: no-argument provided, split the current `rng` to get the
next one.

Calling `model(x)` before a key was provided with `.with_rng` will yield an
error the first time.

Currently, there's no guarantee that the encoder called in `model(x)` or
directly with `model.encoder(x)` have the same rng. This will be fixed in the
future.

### Training/eval mode

To disable determinism, models can use the `self.training` attribute:

```python
class MLP(knn.Module):

@nn.compact
def __call__(self, x):
x = nn.Dense(2)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
return x
```

By default, `model.training == True`. You can switch the model to eval mode
with `.eval()`

```python
model = model.eval() # Switch to eval mode
assert not model.training


model = model.train() # Switch back to train mode
assert model.training
```

### Parameters

You can access the flax parameters, either at the root level or for individual
modules.

```python
model.params
model.encoder.params
model.encoder.params['Dense_0'] # nn.Dense params defined inside `nn.compact`
```

### Jit, auto-diff

`knn.Module` are compatible with `jax.tree_utils` to map over the parameters.
This means modules can be used nativelly inside `jax.jit`:

```python
@jax.jit
def eval_step(model: knn.Model, x: jax.Array, y: jax.Array) -> jax.Array:
model = model.eval()
y_pred = model(x)
return loss(y_pred, y)
```

### Intermediate values

Often, it's very convenient to be able to store/access intermediate values
in the module tree. It is possible by annotating module fields as
`knn.Intermediate[T] = dataclasses.field(init=False)`.

```python
class Sequential(knn.Module):
childs: list[nn.Module]

tmp_values: knn.Intermediate[list[jax.Array]] = dataclasses.field(
init=False,
default_factory=list,
)

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
for child in childs:
x = child(x)
self.tmp_values.append(x)
return x
```

The intermediate values are reset at each call (calling `model()` twice
will create a new `tmp_values` list). Intermediate values are not bound to
the model object. Instead they need to be explicitly fetched:

```python
model = AutoEncoder(
encoder=Sequential([
nn.Dense(32),
nn.Dense(32),
]),
decoder=MLP(),
)
model = model.init_bind(rng, x)

y = model(x) # Standard call (no intermediate)

with model.capture_intermediates() as intermediates:
y = model(x)

# Convenience wrapper around `model.capture_intermediates()`
y, intermediates = model.call_with_intermediate(x)


# `intermediates` has the same structure as the `model`, but only sub-modules
# and `knn.Intermediate` fields are available.
tmp_values = intermediates.encoder.tmp_values
```
21 changes: 21 additions & 0 deletions etils/klinen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2023 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrapper aound `flax.linen.Module` to add torch-like API."""

from kauldron.klinen.intermediate import Intermediate
from kauldron.klinen.layers import Dense
from kauldron.klinen.layers import Dropout
from kauldron.klinen.layers import Sequential
from kauldron.klinen.module import Module
23 changes: 23 additions & 0 deletions etils/klinen/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collections."""


class Collection:
"""Flax collection constants."""

PARAMS = 'params'
INTERMEDIATES = 'intermediates'
DROPOUT = 'dropout'
162 changes: 162 additions & 0 deletions etils/klinen/intermediate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2023 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Intermediate utils."""

from __future__ import annotations

import dataclasses
import typing
from typing import Any, Optional, TypeVar

from etils import edc
from kauldron.klinen.collections import Collection
from typing_extensions import Annotated

if typing.TYPE_CHECKING:
from kauldron.klinen import module as module_lib


_T = TypeVar('_T')
_SelfT = TypeVar('_SelfT')

_IS_INTERMEDIATE = object()

if typing.TYPE_CHECKING:
# TODO(b/254514368): Remove hack
class _IntermediateMeta(type):

def __getitem__(cls, value):
return value

class Intermediate(metaclass=_IntermediateMeta):
pass

else:
Intermediate = Annotated[_T, _IS_INTERMEDIATE] # pytype: disable=invalid-typevar


KEY_PREFIX = '_attribute__'


@dataclasses.dataclass
class IntermediateDescriptor:
"""Descriptor to read-write individual contextvar."""

field: dataclasses.Field[Any]
objtype: type[Any] = dataclasses.field(init=False)
attribute_name: str = dataclasses.field(init=False)

@classmethod
def from_field(
cls, field: dataclasses.Field[Any], hint: edc.helpers.Hint
) -> IntermediateDescriptor:
if field.init:
raise ValueError(
'`knn.Intermediate[T]` fields should be'
f' `dataclasses.field(init=False)` for `{hint}`'
)
return cls(field)

@property
def _collection_name(self) -> str:
"""Name of the attribute in the `.sow('intermediates', name)` collection."""
return f'{KEY_PREFIX}{self.attribute_name}'

@property
def _default(self) -> Any:
"""Default value."""
if self.field.default is not dataclasses.MISSING:
default = self.field.default
elif self.field.default_factory is not dataclasses.MISSING:
default = self.field.default_factory()
else:
raise AttributeError(
f'{self.objtype.__name__!r} object cannot access intermediate'
f' attribute {self.attribute_name!r}. Attribute was not set during'
' the call.'
)
return default

def __set_name__(self, objtype: type[module_lib.Module], name: str) -> None:
"""Bind the descriptor to the class (PEP 487)."""
self.objtype = objtype
self.attribute_name = name

def __get__(
self,
obj: Optional[module_lib.Module],
objtype: Optional[type[module_lib.Module]] = None,
):
"""`x = module.my_intermediate`."""

if obj is None:
return self

if not obj.scope:
raise AttributeError(
f'Intermediate field `{objtype.__name__}.{self.attribute_name}`'
' can only be accessed inside module functions. Use '
'`model.capture_intermediate()` to access the intermediate values.'
)

if not obj.scope.has_variable(
Collection.INTERMEDIATES, self._collection_name
):
obj.sow(
Collection.INTERMEDIATES,
self._collection_name,
self._default,
reduce_fn=_replace_previous_value,
)
return obj.get_variable(Collection.INTERMEDIATES, self._collection_name)

def __set__(self, obj: module_lib.Module, value: Any) -> None:
"""`module.my_intermediate = x`."""

if not obj.scope:
if not hasattr(obj, '_kd_init_finished'):
# No-op during `__init__`.
# This is to support `dataclasses.field(default=...)`
return
raise AttributeError(
f'Intermediate field `{type(obj).__name__}.{self.attribute_name}`'
' can only be set inside module functions.'
)
obj.sow(
Collection.INTERMEDIATES,
self._collection_name,
value,
reduce_fn=_replace_previous_value,
)


def _replace_previous_value(old: Any, new: _T) -> _T:
"""Merge function for `.sow` which always overwrite the value."""
del old
return new


def setup_cls(cls: type[module_lib.Module]) -> None:
"""Wraps `Intermediate[T]` fields in `IntermediateDescriptor` descriptors."""
# Replace fields annotated with `Intermediate[T]` by their descriptor
edc.helpers.wrap_new(
cls,
descriptor_infos=[
edc.helpers.DescriptorInfo(
annotation=Intermediate,
descriptor_fn=IntermediateDescriptor.from_field,
)
],
)
Loading