diff --git a/etils/klinen/README.md b/etils/klinen/README.md new file mode 100644 index 00000000..ffe763f4 --- /dev/null +++ b/etils/klinen/README.md @@ -0,0 +1,186 @@ +# Klinen - Torch-like API for flax + +Klinen is a small wrapper around `flax.linen`. The goal is to provide a +stateless, object-oriented, supporting auto-complete and type checking. + +## Documentation + +### Model creation + +Model creation is similar to flax, except the modules should inherit from +`klinen` instead of `linen`: + +```python +from flax import linen as nn +from kauldron import klinen as knn + + +class MLP(knn.Module): + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + return nn.Dense(32)(x) + + +class AutoEncoder(knn.Module): + encoder: knn.Module + decoder: knn.Module + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + return self.decoder(self.encoder(x)) +``` + +* Inside `knn.Module`, any linen modules can be used. + +### Model initialization / usage + +To initialize the model, use `model.init_bind()` instead of `model.init()`. It +will return a copy of the module with bind parameters. + +```python +model = AutoEncoder( + encoder=MLP(), + decoder=MLP(), +) +model = model.init_bind(rng, jnp.zeros((batch_size, 64))) + +# After the model is initialized, it can be called directly +y = model(x) + +# You can also call individual sub-modules +y = model.encoder(x) +``` + +### Randomness + +`klinen` modules are stateless, this mean they are fully deterministic. Calling +the same model twice will always return the same result. If your model uses +randomness (e.g. `nn.Dropout`), the `rng` key has to be explicitly provided: + +```python +model = model.with_rng(rng) # Replace the current rng. + +y0 = model(x) +y1 = model(x) + +assert jnp.allclose(y0, y1) # Calling the model twice give the same output + +model = model.with_rng(rng2) # Set a new rng +``` + +Multiple values are accepted: + +* `model.with_rng({'dropout': rng})`: Rng streams explicitly defined +* `model.with_rng(rng)`: Key distributed among streams (with + `rng.fold_in(stream_name)`) +* `model.with_rng()`: no-argument provided, split the current `rng` to get the + next one. + +Calling `model(x)` before a key was provided with `.with_rng` will yield an +error the first time. + +Currently, there's no guarantee that the encoder called in `model(x)` or +directly with `model.encoder(x)` have the same rng. This will be fixed in the +future. + +### Training/eval mode + +To disable determinism, models can use the `self.training` attribute: + +```python +class MLP(knn.Module): + + @nn.compact + def __call__(self, x): + x = nn.Dense(2)(x) + x = nn.Dropout(0.5)(x, deterministic=not self.training) + return x +``` + +By default, `model.training == True`. You can switch the model to eval mode +with `.eval()` + +```python +model = model.eval() # Switch to eval mode +assert not model.training + + +model = model.train() # Switch back to train mode +assert model.training +``` + +### Parameters + +You can access the flax parameters, either at the root level or for individual +modules. + +```python +model.params +model.encoder.params +model.encoder.params['Dense_0'] # nn.Dense params defined inside `nn.compact` +``` + +### Jit, auto-diff + +`knn.Module` are compatible with `jax.tree_utils` to map over the parameters. +This means modules can be used nativelly inside `jax.jit`: + +```python +@jax.jit +def eval_step(model: knn.Model, x: jax.Array, y: jax.Array) -> jax.Array: + model = model.eval() + y_pred = model(x) + return loss(y_pred, y) +``` + +### Intermediate values + +Often, it's very convenient to be able to store/access intermediate values +in the module tree. It is possible by annotating module fields as +`knn.Intermediate[T] = dataclasses.field(init=False)`. + +```python +class Sequential(knn.Module): + childs: list[nn.Module] + + tmp_values: knn.Intermediate[list[jax.Array]] = dataclasses.field( + init=False, + default_factory=list, + ) + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + for child in childs: + x = child(x) + self.tmp_values.append(x) + return x +``` + +The intermediate values are reset at each call (calling `model()` twice +will create a new `tmp_values` list). Intermediate values are not bound to +the model object. Instead they need to be explicitly fetched: + +```python +model = AutoEncoder( + encoder=Sequential([ + nn.Dense(32), + nn.Dense(32), + ]), + decoder=MLP(), +) +model = model.init_bind(rng, x) + +y = model(x) # Standard call (no intermediate) + +with model.capture_intermediates() as intermediates: + y = model(x) + +# Convenience wrapper around `model.capture_intermediates()` +y, intermediates = model.call_with_intermediate(x) + + +# `intermediates` has the same structure as the `model`, but only sub-modules +# and `knn.Intermediate` fields are available. +tmp_values = intermediates.encoder.tmp_values +``` diff --git a/etils/klinen/__init__.py b/etils/klinen/__init__.py new file mode 100644 index 00000000..568f1664 --- /dev/null +++ b/etils/klinen/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper aound `flax.linen.Module` to add torch-like API.""" + +from kauldron.klinen.intermediate import Intermediate +from kauldron.klinen.layers import Dense +from kauldron.klinen.layers import Dropout +from kauldron.klinen.layers import Sequential +from kauldron.klinen.module import Module diff --git a/etils/klinen/collections.py b/etils/klinen/collections.py new file mode 100644 index 00000000..d8169159 --- /dev/null +++ b/etils/klinen/collections.py @@ -0,0 +1,23 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collections.""" + + +class Collection: + """Flax collection constants.""" + + PARAMS = 'params' + INTERMEDIATES = 'intermediates' + DROPOUT = 'dropout' diff --git a/etils/klinen/intermediate.py b/etils/klinen/intermediate.py new file mode 100644 index 00000000..f628d02f --- /dev/null +++ b/etils/klinen/intermediate.py @@ -0,0 +1,162 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Intermediate utils.""" + +from __future__ import annotations + +import dataclasses +import typing +from typing import Any, Optional, TypeVar + +from etils import edc +from kauldron.klinen.collections import Collection +from typing_extensions import Annotated + +if typing.TYPE_CHECKING: + from kauldron.klinen import module as module_lib + + +_T = TypeVar('_T') +_SelfT = TypeVar('_SelfT') + +_IS_INTERMEDIATE = object() + +if typing.TYPE_CHECKING: + # TODO(b/254514368): Remove hack + class _IntermediateMeta(type): + + def __getitem__(cls, value): + return value + + class Intermediate(metaclass=_IntermediateMeta): + pass + +else: + Intermediate = Annotated[_T, _IS_INTERMEDIATE] # pytype: disable=invalid-typevar + + +KEY_PREFIX = '_attribute__' + + +@dataclasses.dataclass +class IntermediateDescriptor: + """Descriptor to read-write individual contextvar.""" + + field: dataclasses.Field[Any] + objtype: type[Any] = dataclasses.field(init=False) + attribute_name: str = dataclasses.field(init=False) + + @classmethod + def from_field( + cls, field: dataclasses.Field[Any], hint: edc.helpers.Hint + ) -> IntermediateDescriptor: + if field.init: + raise ValueError( + '`knn.Intermediate[T]` fields should be' + f' `dataclasses.field(init=False)` for `{hint}`' + ) + return cls(field) + + @property + def _collection_name(self) -> str: + """Name of the attribute in the `.sow('intermediates', name)` collection.""" + return f'{KEY_PREFIX}{self.attribute_name}' + + @property + def _default(self) -> Any: + """Default value.""" + if self.field.default is not dataclasses.MISSING: + default = self.field.default + elif self.field.default_factory is not dataclasses.MISSING: + default = self.field.default_factory() + else: + raise AttributeError( + f'{self.objtype.__name__!r} object cannot access intermediate' + f' attribute {self.attribute_name!r}. Attribute was not set during' + ' the call.' + ) + return default + + def __set_name__(self, objtype: type[module_lib.Module], name: str) -> None: + """Bind the descriptor to the class (PEP 487).""" + self.objtype = objtype + self.attribute_name = name + + def __get__( + self, + obj: Optional[module_lib.Module], + objtype: Optional[type[module_lib.Module]] = None, + ): + """`x = module.my_intermediate`.""" + + if obj is None: + return self + + if not obj.scope: + raise AttributeError( + f'Intermediate field `{objtype.__name__}.{self.attribute_name}`' + ' can only be accessed inside module functions. Use ' + '`model.capture_intermediate()` to access the intermediate values.' + ) + + if not obj.scope.has_variable( + Collection.INTERMEDIATES, self._collection_name + ): + obj.sow( + Collection.INTERMEDIATES, + self._collection_name, + self._default, + reduce_fn=_replace_previous_value, + ) + return obj.get_variable(Collection.INTERMEDIATES, self._collection_name) + + def __set__(self, obj: module_lib.Module, value: Any) -> None: + """`module.my_intermediate = x`.""" + + if not obj.scope: + if not hasattr(obj, '_kd_init_finished'): + # No-op during `__init__`. + # This is to support `dataclasses.field(default=...)` + return + raise AttributeError( + f'Intermediate field `{type(obj).__name__}.{self.attribute_name}`' + ' can only be set inside module functions.' + ) + obj.sow( + Collection.INTERMEDIATES, + self._collection_name, + value, + reduce_fn=_replace_previous_value, + ) + + +def _replace_previous_value(old: Any, new: _T) -> _T: + """Merge function for `.sow` which always overwrite the value.""" + del old + return new + + +def setup_cls(cls: type[module_lib.Module]) -> None: + """Wraps `Intermediate[T]` fields in `IntermediateDescriptor` descriptors.""" + # Replace fields annotated with `Intermediate[T]` by their descriptor + edc.helpers.wrap_new( + cls, + descriptor_infos=[ + edc.helpers.DescriptorInfo( + annotation=Intermediate, + descriptor_fn=IntermediateDescriptor.from_field, + ) + ], + ) diff --git a/etils/klinen/intermediate_proxy.py b/etils/klinen/intermediate_proxy.py new file mode 100644 index 00000000..eaee27fc --- /dev/null +++ b/etils/klinen/intermediate_proxy.py @@ -0,0 +1,358 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Intermediate proxy output.""" + +from __future__ import annotations + +import dataclasses +import functools +from typing import Any, Optional, TypeAlias, TypeVar + +from etils import epy +import flax +import jax +from kauldron.klinen import intermediate +from kauldron.klinen import module as module_lib +from kauldron.klinen import traverse + +_FnT = TypeVar('_FnT') +_T = TypeVar('_T') + + +# TODO(epot): Better support for `nn.compact` (`inter['Dense_0'].tmp`) +# TODO(epot): Fix non-tree modules (e.g. `Sequential([nn.relu])`) + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class _ModuleProxy: + """Proxy of a single module exposing the captured intermediates.""" + _cls_name: str + # Fields: `xxx: Intermediate[T]` + _intermediate_fields: dict[str, intermediate.IntermediateDescriptor] + # Fields: Nested modules (`xxx: nn.Module`) + _module_attributes: dict[str, Any] + # Flax name and parent + _name: str + _future_parent: traverse.Future[Optional[_ModuleProxy]] + + # Intermediate values + # Accessible as `model.intermediate_val` + _intermediate_attributes: dict[str, Any] = dataclasses.field( + default_factory=dict + ) + # Accessible as `model['intermediate_val']` + _intermediate_items: dict[str, Any] = dataclasses.field(default_factory=dict) + + @classmethod + def from_module( + cls, + module: module_lib.Module, + *, + attributes: dict[str, Any], + name: Optional[str], + future_parent: traverse.Future[Optional[_ModuleProxy]], + cache: _Cache, + ) -> _ModuleProxy: + """Create the proxy.""" + assert id(module) not in cache + proxy = cls( + _cls_name=type(module).__name__, + _intermediate_fields=_get_intermediate_fields(module), + _module_attributes=attributes, + _name=name, + _future_parent=future_parent, + ) + cache[id(module)] = proxy + return proxy + + def __getattr__(self, name: str) -> Any: + # Accessing a child module + if name in self._module_attributes: + return self._module_attributes[name] + + if name not in self._intermediate_fields: + raise AttributeError( + f'No attribute {self._cls_name}.{name}. Only `Intermediate[]` and' + ' sub-module attributes are available.' + ) + # Accessing an intermediate value + return self._intermediate_attributes[name] + + def __getitem__(self, key: str) -> Any: + return self._intermediate_items[key] + + @property + def _attribute_names(self) -> list[str]: + """List of defined attribute names.""" + return list(self._module_attributes) + list(self._intermediate_attributes) + + def __repr__(self) -> str: + content = {k: getattr(self, k) for k in self._attribute_names} + # extra_values are the `.sow` and uncaptured values. + if self._intermediate_items: + content['__getitem__'] = self._intermediate_items + return epy.Lines.make_block( + self._cls_name, + content=content, + ) + + def __dir__(self) -> list[str]: + """Available attributes.""" + return self._attribute_names + + @property + def _parent(self) -> Optional[_ModuleProxy]: + """Returns the parent.""" + if self._future_parent is not None: + return self._future_parent.value + return None + + @property + def _parent_names(self) -> list[str]: + """List of path (excluding the first).""" + + parent_names = [] + parent = self + while parent._parent is not None: # pylint: disable=protected-access # pytype: disable=attribute-error + parent_names.append(parent._name) # pylint: disable=protected-access + parent = parent._parent # pylint: disable=protected-access + + return list(reversed(parent_names)) + + def _set_intermediate(self, intermediate_dict: dict[str, Any]) -> None: + """Assign the intermediate values to `self`.""" + # Get the inner-most dict + values = intermediate_dict + for name in self._parent_names: + values = values.get(name, {}) + + # Pop all intermediates + for name, descriptor in self._intermediate_fields.items(): + # Intermediate was set, pop + if descriptor._collection_name in values: # pylint: disable=protected-access + value = values.pop(descriptor._collection_name) # pylint: disable=protected-access + # Intermediate not set, but default exists + elif descriptor.field.default is not dataclasses.MISSING: + value = descriptor._default # pylint: disable=protected-access + # Internediate not set and missing + else: + continue + self._intermediate_attributes[name] = value + + # Eventually other intermediates (from `.sow()`) + self._intermediate_items = dict(values) + values.clear() + + # Eventually clear the dict + all_dicts = {} + values = intermediate_dict + for name in self._parent_names: + all_dicts[name] = values + values = values.get(name, {}) + + for name, dict_ in reversed(all_dicts.items()): + if not dict_.get(name, None): # Empty dict, pop + dict_.pop(name, None) + + def tree_flatten( + self, + ) -> tuple[list[Any], _ModuleProxy]: + """`jax.tree_utils` support.""" + flat_values = [ + self._module_attributes, + self._intermediate_attributes, + self._intermediate_items, + ] + return (flat_values, self) # pytype: disable=bad-return-type + + @classmethod + def tree_unflatten( + cls, + metadata: _ModuleProxy, + array_field_values: list[Any], + ) -> _ModuleProxy: + """`jax.tree_utils` support.""" + [ + module_attributes, + intermediate_attributes, + intermediate_items, + ] = array_field_values + return dataclasses.replace( + metadata, + _module_attributes=module_attributes, + _intermediate_attributes=intermediate_attributes, + _intermediate_items=intermediate_items, + ) + + +def _get_intermediate_fields( + module: module_lib.Module, +) -> dict[str, intermediate.IntermediateDescriptor]: + """Extract only `Intermediate` fields.""" + intermediate_fields = {} + for cls in type(module).mro(): + for name, value in cls.__dict__.items(): + if isinstance(value, intermediate.IntermediateDescriptor): + intermediate_fields[name] = value + return intermediate_fields + + +@jax.tree_util.register_pytree_node_class +class ModuleIntermediateProxy: + """Module-like object which contain the intermediate values.""" + _proxy: _ModuleProxy + + def __init__(self, module: module_lib.Module): + # While `_finalized` is `False`, the proxy cannot be used by the user. + self._finalized: bool = False + self._module = module + self._intermediate_dict: Optional[flax.core.scope.FrozenVariableDict] = None + + def _bind( + self, + intermediate_dict: flax.core.scope.FrozenVariableDict, + module: module_lib.Module, + ) -> None: + """Bind the intermediate context to self.""" + if self._intermediate_dict is not None: + # Nested `context.set_in_call()`. Should not be possible. + raise RuntimeError('Intermediate context already set.') + if self._module is not module: + raise ValueError( + 'Intermediate capture and call instances do not match:' + f' {self._module.name} ({self._module(module).__name__}) vs' + f' {module.name} ({type(module).__name__})' + ) + self._intermediate_dict = intermediate_dict + + def _finalize(self) -> None: + """Create the nested intermediate proxies.""" + assert not self._finalized + self._finalized = True + cache: dict[int, _ModuleProxy] = {} + # Traverse tree twice: + # 1. To create the proxies + # 2. To set the intermediate values (can't be done in `1` as futures not + # yet created) + self._proxy = traverse.recursive_replace( # pytype: disable=annotation-type-mismatch + self._module, + replace_fn=functools.partial(_ModuleProxy.from_module, cache=cache), + ) + intermediate_dict = self._intermediate_dict + assert intermediate_dict is not None + intermediate_dict = flax.core.unfreeze(intermediate_dict) + traverse.recursive_replace( + self._module, + replace_fn=functools.partial( + _set_intermediate_values, + cache=cache, + intermediate_dict=intermediate_dict, + ), + ) + assert not intermediate_dict # Intermediate dict should have been cleared + self._intermediate_dict = None + self._module = None + + def __getattr__(self, name): + if not self._finalized: + raise AttributeError( + 'Cannot access intermediate values from within the' + ' `capture_intermediates()` contextmanager.' + ) + else: + return getattr(self._proxy, name) + + def __getitem__(self, key: str) -> Any: + if not self._finalized: + raise KeyError( + 'Cannot access intermediate values from within the' + ' `capture_intermediates()` contextmanager.' + ) + else: + return self._proxy.__getitem__(key) + + def __repr__(self) -> str: + if not self._finalized: + return f'{type(self).__name__}()' + else: + return self._proxy.__repr__() + + def __dir__(self) -> list[str]: + """List attributes for Colab support.""" + if not self._finalized: + return [] + else: + return self._proxy.__dir__() + + def tree_flatten( + self, + ) -> tuple[list[Any], ModuleIntermediateProxy]: + """`jax.tree_utils` support.""" + if not self._finalized: + raise ValueError( + 'Cannot pass intermediates to tree_utils inside the' + ' `capture_intermediates`' + ) + + # TODO(epot): This does not support model sharing (as tree_utils will + # duplicate the node.) + return ([self._proxy], self) # pytype: disable=bad-return-type + + @classmethod + def tree_unflatten( + cls, + metadata: ModuleIntermediateProxy, + array_field_values: list[Any], + ) -> ModuleIntermediateProxy: + (proxy,) = array_field_values + self = cls(metadata._module) # pylint: disable=protected-access + self._finalized = True + self._proxy = proxy # pylint: disable=protected-access + return self + + +def _without_kwargs(fn: _FnT) -> _FnT: + """Do not forward kwargs.""" + + @functools.wraps(fn) + def decorated( + module: _T, + *, + attributes: dict[str, Any], + name: Optional[str], + future_parent: traverse.Future[module_lib.Module], + **kwargs, + ) -> _T: + """.""" + del attributes, name, future_parent + fn(module, **kwargs) + + return decorated + + +@_without_kwargs +def _set_intermediate_values( + module: module_lib.Module, + *, + cache: _Cache, + intermediate_dict: dict[str, Any], +) -> module_lib.Module: + """.""" + proxy = cache[id(module)] + proxy._set_intermediate(intermediate_dict) # pylint: disable=protected-access + return module + +_Cache: TypeAlias = dict[int, _ModuleProxy] diff --git a/etils/klinen/layers.py b/etils/klinen/layers.py new file mode 100644 index 00000000..a2714342 --- /dev/null +++ b/etils/klinen/layers.py @@ -0,0 +1,35 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flax layers.""" + +from flax import linen as nn +from kauldron.klinen import module as knn + +# TODO(epot): Merge this with flax + + +class Dense(nn.Dense, knn.Module): # pytype: disable=signature-mismatch + pass + + +class Sequential(nn.Sequential, knn.Module): # pytype: disable=signature-mismatch + pass + + +class Dropout(nn.Dropout, knn.Module): # pytype: disable=signature-mismatch + + @nn.compact + def __call__(self, x): + return super().__call__(x, deterministic=not self.training) diff --git a/etils/klinen/module.py b/etils/klinen/module.py new file mode 100644 index 00000000..95785f83 --- /dev/null +++ b/etils/klinen/module.py @@ -0,0 +1,533 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base Module class.""" + +from __future__ import annotations + +from collections.abc import Callable +import contextlib +import dataclasses +import functools +import typing +from typing import Any, Iterator, Optional, TypeVar + +from etils import edc +from etils import enp +from etils.etree import jax as etree +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from kauldron import random +from kauldron.klinen import intermediate +from kauldron.klinen import intermediate_proxy +from kauldron.klinen import traverse +from kauldron.klinen.collections import Collection +import numpy as np + +_FnT = TypeVar('_FnT', bound=Callable) +_SelfT = TypeVar('_SelfT') + + +def _bind_only_method(fn: _FnT) -> _FnT: + """Validate the method is only called after `init_bind()`.""" + + @functools.wraps(fn) + def new_fn(self: Module, *args, **kwargs): + if not self._is_bind: # pylint: disable=protected-access + raise ValueError( + f'Cannot call {fn.__qualname__} before calling .init_bind()' + ) + return fn(self, *args, **kwargs) + + return new_fn + + +def _skip_wrap_call(fn: _FnT) -> _FnT: + """Skip the flax function auto-wrapping.""" + # flax wrap all method inside `_call_wrapped_method`. Do not wrap + # `_call_wrapped_method` to avoid infinite recursion + fn.method_handler_wrapped = True + return fn + + +@edc.dataclass +@dataclasses.dataclass(frozen=True, kw_only=True) +class _ModuleState: + """Module state. + + Only the root module has a state. The childs modules only uses the + root state. + + Attributes: + params: Bounded variables + streams: Rng stream names + rngs: Current rngs + training: Whether model is in training or evaluation mode. + tree_params_only: Whether only the tree are mapped over + """ + + params: Optional[flax.core.scope.FrozenVariableDict] + streams: tuple[str, ...] + rngs: dict[str, random.PRNGKey] + training: bool = True + tree_params_only: bool = False + + def replace(self: _SelfT, **kwargs: Any) -> _SelfT: + return dataclasses.replace(self, **kwargs) + + +@edc.dataclass +@dataclasses.dataclass +class _Context: + """Global context. + + Attributes: + in_call_state: `_ModuleState` of the top level `y = model(x)` call. + capture_proxy: If set, the intermediate values are forwarded to this proxy. + """ + + in_call_state: edc.ContextVar[Optional[_ModuleState]] = None + capture_proxy: edc.ContextVar[ + Optional[intermediate_proxy.ModuleIntermediateProxy] + ] = None + + @contextlib.contextmanager + def set_in_call_state(self, module: Module) -> Iterator[None]: + self.in_call_state = module._kd_state # pylint: disable=protected-access + try: + yield + finally: + self.in_call_state = None + + +context = _Context() + + +class Module(nn.Module): # pytype: disable=invalid-function-definition + """Base Module class.""" + + _: dataclasses.KW_ONLY # Required to allow sub-classing + # TODO(epot): Should be hidden from the public API + # Fields for auto-complete/type-checking, but ignored by `@dataclass` + _kd_state: Optional[_ModuleState] = dataclasses.field( + repr=False, + compare=False, + hash=False, + default=None, + ) + + if typing.TYPE_CHECKING: + # Set by `traverse.recursive_set_parent` + _kd_name: str = dataclasses.field(init=False) + _kd_future_parent: Optional[traverse.Future[Module]] = dataclasses.field( + init=False + ) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + intermediate.setup_cls(cls) + jax.tree_util.register_pytree_node_class(cls) + + def __post_init__(self, **kwargs): + super().__post_init__(**kwargs) + object.__setattr__(self, '_kd_init_finished', True) + + def init_bind( + self: _SelfT, + rng: jax.random.KeyArray, + *args, + streams: tuple[str, ...] = (Collection.DROPOUT,), + **kwargs, + ) -> _SelfT: + """Initialize the module, returning a binded version.""" + # Note: Only top-level has a state. Childs recover the state from the + # parent/scope. + self._kd_state = _ModuleState( + params=None, + streams=streams, + rngs={}, + ) + + # Normalize the args/kwargs + # Is it possible to have non-array kwargs ? + args, kwargs = etree.spec_like((args, kwargs), ignore_other=False) + args, kwargs = etree.map(_as_empty, (args, kwargs)) + + # Set the childs `_kd_parent` + self = self._replace_state() # pylint: disable=self-cls-assignment + + # Generate the rngs for initialization + rngs = _normalize_rngs(rng, streams=streams, add_params=True) + + # Initialize the state + with context.set_in_call_state(self): + variables = flax.core.freeze(self.init( + rngs, + *args, + **kwargs, + mutable=(Collection.PARAMS, Collection.INTERMEDIATES), + )) + + return self._replace_state(params=variables.get(Collection.PARAMS, {})) + + @_bind_only_method + def with_rng( + self: _SelfT, + rng: int + | jax.random.KeyArray + | dict[str, jax.random.KeyArray] + | None = None, + ) -> _SelfT: + """Replace the rngs keys. + + Can be called: + + * `model = model.with_rng()`: Replace key with next key + * `model = model.with_rng(0)`: Create a key from the seed. + * `model = model.with_rng(key)`: Key distributed among streams + * `model = model.with_rng({'dropout': key})`: streams explicitly defined + + Args: + rng: Random key. + + Returns: + The updated model with the next key. + """ + # When rng is None, auto-increment rng + if rng is None: + # TODO(epot): When LazyRng, should instead increment rng counter to match + # flax behavior, but difficult because we do not have access to flax + # internal counter + rng = {name: k.next() for name, k in self.rngs.items()} + if isinstance(rng, int): + rng = random.PRNGKey(rng) + return self._replace_state( + rngs=_normalize_rngs(rng, streams=self._root_state.streams) + ) + + @property + @_bind_only_method + def rngs(self) -> dict[str, random.PRNGKey]: + """Returns `dict[str, PRNGKey]` mapping key to.""" + if self._kd_state is None: + rngs = self._root_state.rngs + # Fold-in the random info + rngs = { + k: flax.core.scope.LazyRng.create(rng, *self._kd_parent_names) + for k, rng in rngs.items() + } + return rngs + else: + return self._kd_state.rngs # pytype: disable=bad-return-type + + @_bind_only_method + def train(self: _SelfT) -> _SelfT: + """Switch mode to training.""" + return self._replace_state(training=True) + + @_bind_only_method + def eval(self: _SelfT) -> _SelfT: + """Switch mode to evaluation (disable dropout,...).""" + return self._replace_state(training=False) + + @property + @_bind_only_method + def training(self) -> bool: + """Returns `True` if mode is training.""" + return self._root_state.training + + @property + @_bind_only_method + def params(self) -> flax.core.scope.FrozenVariableDict: + """Model weights.""" + params = self._root_state.params + for name in self._kd_parent_names: + params = params.get(name, {}) # pytype: disable=attribute-error + return params # pytype: disable=bad-return-type + + @_bind_only_method + def param_tree_on(self: _SelfT) -> _SelfT: + """Makes `tree_utils` only act on params.""" + return self._replace_state(tree_params_only=True) + + @_bind_only_method + def param_tree_off(self: _SelfT) -> _SelfT: + """Makes `tree_utils` act on everything.""" + return self._replace_state(tree_params_only=False) + + def call_with_intermediates( + self: _SelfT, *args: Any, **kwargs: Any + ) -> tuple[Any, _SelfT]: + """Call the module with intermediates. + + Wrapper around `__call__` which also return the intermediate values: + + ``` + y = model(x) + + y, intermediates = model.call_with_intermediates(x) + ``` + + The intermediate values have the same structure as the model. + + Args: + *args: Arguments forwarded to `module.__call__` + **kwargs: Arguments forwarded to `module.__call__` + + Returns: + `module.__call__` output + Intermediate values. + """ + with self.capture_intermediates() as intermediates: + return self(*args, **kwargs), intermediates + + # ========== Internal methods ========== + + @property + @_skip_wrap_call + def _kd_parent(self) -> Optional[Module]: + """Returns the parent.""" + if self._kd_future_parent is not None: + return self._kd_future_parent.value + return None + + @property + @_skip_wrap_call + def _kd_parent_names(self) -> list[str]: + """List of path (excluding the first).""" + + parent_names = [] + parent = self + while parent._kd_parent is not None: # pylint: disable=protected-access # pytype: disable=attribute-error + parent_names.append(parent._kd_name) # pylint: disable=protected-access + parent = parent._kd_parent # pylint: disable=protected-access + + return list(reversed(parent_names)) + + # Return type should be Optional[_ModuleState] but we would then loose + # auto-complete. + @property + @_skip_wrap_call + @_bind_only_method + def _root_state(self) -> _ModuleState: + """Returns the root parent state.""" + # Inside a call context, we directly get the state + # Indeed, modules defined inside `nn.compact` are not availabe + if context.in_call_state: + return context.in_call_state + parent = self + while parent._kd_parent is not None: # pylint: disable=protected-access # pytype: disable=attribute-error + parent = parent._kd_parent # pylint: disable=protected-access + return parent._kd_state # pylint: disable=protected-access # pytype: disable=bad-return-type + + @property + @_skip_wrap_call + def _is_bind(self) -> bool: + return hasattr(self, '_kd_name') or bool(context.in_call_state) + + @_skip_wrap_call + def _replace_state(self: _SelfT, **kwargs) -> _SelfT: + """Recursivelly update all the childs parents.""" + # TODO(epot): Support attributes defined in `.setup()` + if self._kd_state is None: + new_state_kwargs = dict( + params=self.params, + rngs=self.rngs, + ) + new_state_kwargs.update(kwargs) + new_state = self._root_state.replace(**new_state_kwargs) + else: + new_state = self._kd_state.replace(**kwargs) # pytype: disable=attribute-error + + # First update the state + new_self = dataclasses.replace(self, _kd_state=new_state) + + # Recursivelly update the modules to link to the new parents + new_self = traverse.recursive_set_parent(new_self) + + return new_self + + @contextlib.contextmanager + def capture_intermediates(self: _SelfT) -> Iterator[_SelfT]: + """Track the intermediate values. + + Note that this function isn't meant to be called directly but instead + through `y, intermediates = model.call_and_capture(x)`. + + Usage: + + ```python + with model.capture_intermediates() as intermediates: + y = model(x) # Model set `model.xxx` + + # After the contextmanager end, `intermediates` contain the captured + # intermediate values. + intermediates.xxx + ``` + + Yields: + The module proxy containing the intermediate values + + Raises: + RuntimeError: If contextmanager are nested. + """ + + proxy = intermediate_proxy.ModuleIntermediateProxy(self) + if context.capture_proxy: + raise RuntimeError('`capture_intermediates()` calls cannot be nested.') + try: + context.capture_proxy = proxy + yield proxy # pytype: disable=bad-return-type + finally: + proxy._finalize() # pylint: disable=protected-access + context.capture_proxy = None + + @_bind_only_method + def tree_flatten( + self, + ) -> tuple[list[flax.core.scope.FrozenVariableDict], Module]: + """`jax.tree_utils` support.""" + if not self._kd_state: + self = self._replace_state() # Detach the child module # pylint: disable=self-cls-assignment + if self._kd_state.tree_params_only: + vals = [self._kd_state.params] + else: + vals = [self._kd_state.params, self._kd_state.rngs] + return (vals, self) # pytype: disable=bad-return-type + + @classmethod + def tree_unflatten( + cls: type[_SelfT], + metadata: Module, + array_field_values: list[flax.core.scope.FrozenVariableDict], + ) -> _SelfT: + assert metadata._kd_state # pylint: disable=protected-access + if metadata._kd_state.tree_params_only: # pylint: disable=protected-access + (params,) = array_field_values + return metadata._replace_state(params=params) # pylint: disable=protected-access + else: + (params, rngs) = array_field_values + return metadata._replace_state(params=params, rngs=rngs) # pylint: disable=protected-access + + @_skip_wrap_call + def _call_wrapped_method(self, fn, args, kwargs): + """All function calls (`__call__`,...).""" + + # No-op for `Module` functions + # TODO(epot): Better heuristic ? + if fn.__module__ == 'kauldron.klinen.module': + return super()._call_wrapped_method(fn, args, kwargs) + + # In-call set: Use default flax behavior + if context.in_call_state: + return super()._call_wrapped_method(fn, args, kwargs) + + if not self._is_bind: + try: + return super()._call_wrapped_method(fn, args, kwargs) + except flax.errors.CallCompactUnboundModuleError: # pylint: disable=try-except-raise + raise + else: + raise RuntimeError("Calling without scope didn't raise unbound error") + + state = self._kd_state # pylint: disable=protected-access + + if not state: + # No scope: binding not called (module non-initialized) + if self._root_state is None: + # Call original method, to raise flax `CallCompactUnboundModuleError` + try: + return super()._call_wrapped_method(fn, args, kwargs) + except flax.errors.CallCompactUnboundModuleError: # pylint: disable=try-except-raise + raise + else: + raise RuntimeError("Calling without scope didn't raise unbound error") + else: + # Detach module: e.g. `model.encoder(x)` + return getattr(self._replace_state(), fn.__name__)(*args, **kwargs) + elif state.params is None: + # Should never happens in the `model.init()` function + raise RuntimeError('Module not initialized.') + else: + # Top-level bind call + with context.set_in_call_state(self): + y, variables = self.apply( + {Collection.PARAMS: state.params}, + rngs=state.rngs, + method=getattr(self, fn.__name__), + *args, + **kwargs, + mutable=(Collection.INTERMEDIATES,), + ) + if context.capture_proxy: + context.capture_proxy._bind( # pylint: disable=protected-access + module=self, + intermediate_dict=variables.get(Collection.INTERMEDIATES, {}), + ) + return y + + raise RuntimeError('Should have returned before.') + + if not typing.TYPE_CHECKING: + + @_skip_wrap_call + def __getattr__(self, name: str) -> Any: + maybe_descriptor = getattr(type(self), name, None) + if isinstance(maybe_descriptor, intermediate.IntermediateDescriptor): + # If `__get__` raise `AttributeError`, getattr will be called so + # explicitly call `__get__` a second time to propagate the error. + maybe_descriptor.__get__(self, type(self)) + else: # Default flax behavior + super().__getattr__(name) + + @_skip_wrap_call + def __setattr__(self, name: str, value: Any) -> None: + maybe_descriptor = getattr(type(self), name, None) + if isinstance(maybe_descriptor, intermediate.IntermediateDescriptor): + # Bypass flax setattr to use the descriptor + object.__setattr__(self, name, value) + else: # Default flax behavior + super().__setattr__(name, value) + + +def _as_empty(arr: enp.ArraySpec) -> jax.Array: + """Create empty array.""" + # Downcast float64 to float32 to avoid Jax warning. Or better that the + # user is aware of it ? + dtype = np.float32 if arr.dtype == np.float64 else arr.dtype + return jnp.empty(shape=arr.shape, dtype=dtype) + + +def _normalize_rngs( + rng: jax.random.KeyArray | dict[str, jax.random.KeyArray], + streams: list[str] | tuple[str, ...], + add_params: bool = False, +) -> dict[str, jax.random.KeyArray]: + """Normalize the rngs keys.""" + rng = jax.tree_util.tree_map(random.PRNGKey, rng) + if isinstance(rng, dict): + return rng + elif isinstance(rng, random.PRNGKey): + # Could we collect the streams from the childs modules ? Difficult as + # some modules are only available inside `__call__`. + rngs = {name: rng.fold_in(name) for name in streams} + if add_params: + # Do not `fold_in('params')` so `.init_bind(key)` is consistent with + # `.init(key)` + rngs[Collection.PARAMS] = rng + return rngs + else: + raise TypeError(f'Unexpected key {rng}') diff --git a/etils/klinen/module_rng_test.py b/etils/klinen/module_rng_test.py new file mode 100644 index 00000000..4a59b759 --- /dev/null +++ b/etils/klinen/module_rng_test.py @@ -0,0 +1,72 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test.""" + +import chex +from flax import linen as nn +from jax import numpy as jnp +from kauldron import klinen as knn +from kauldron import random +import numpy as np + + +class AutoEncoder(knn.Module): + encoder: nn.Module + decoder: nn.Module + + @nn.compact + def __call__(self, x): + return self.decoder(self.encoder(x)) + + +def test_init(): + key = random.PRNGKey(0) + + m0 = knn.Dense(3) + m1 = nn.Dense(3) + + x = jnp.zeros((2,)) + + p0 = m0.init_bind(key, x).params + p1 = m1.init(key, x)['params'] + + chex.assert_trees_all_close(p0, p1) # klinen and linen have same params + + +def test_randomness(): + model = AutoEncoder( + encoder=knn.Sequential([ + knn.Dropout(0.5), + knn.Dense(32), + ]), + decoder=knn.Sequential([ + knn.Dropout(0.5), + knn.Dense(32), + ]), + ) + + key = random.PRNGKey(0) + + x = jnp.ones((5,)) + + model = model.init_bind(key.fold_in('init'), x) + + model = model.with_rng({'dropout': key.fold_in('dropout')}) + + # Calling the model directly or indirectly should yield the same result + y0 = model(x) + y1 = model.decoder(model.encoder(x)) + + np.testing.assert_allclose(y0, y1) diff --git a/etils/klinen/module_test.py b/etils/klinen/module_test.py new file mode 100644 index 00000000..ed837fa4 --- /dev/null +++ b/etils/klinen/module_test.py @@ -0,0 +1,308 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test.""" + +from __future__ import annotations + +from collections.abc import Callable +import dataclasses + +import chex +from etils.array_types import f32 +import flax +from flax import linen as nn +import jax +import jax.numpy as jnp +from kauldron import klinen as knn +from kauldron import random as krandom +import numpy as np +import pytest + +_IN_SHAPE = (3, 2) + +# TODO(epot): Test when mixing `knn` inside `nn` modules (both attribute and +# nested attribute (e.g. Sequential(childs=))) +# TODO(epot): Test when having `nn` module as attribute of `knn` + + +class Nested(knn.Module): + """Nested module.""" + + child: knn.Module + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + return self.child(x) + + +class ModelRoot(knn.Module): + """Root model.""" + + # Module can be: + # * Root + # * Attribute + # * Nested attribute + # * Compact + child: knn.Module + childs: list[knn.Module] + hidden: Callable[[], knn.Module] + nested: Nested + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + x = self.child(x) + for child in self.childs: + x = child(x) + x = self.hidden()(x) + x = self.nested(x) + return x + + +class DenseAndDropout(knn.Module): + """Simple model.""" + + @nn.compact + def __call__(self, x): + x = nn.Dense(2)(x) + x = nn.Dropout(0.1)(x, deterministic=not self.training) + return x + + +class WithIntermediate(knn.Module): + """Simple model.""" + + tmp_val: knn.Intermediate[jax.Array] = dataclasses.field(init=False) + tmp_list: knn.Intermediate[list[jax.Array]] = dataclasses.field( + init=False, default_factory=list + ) + + @nn.compact + def __call__(self, x): + self.tmp_list.append(x) + + if self.name != 'childs_1': # Sharded intermediates are shared between call + # TODO(epot): Why the explicitly given `name='shared'` is lost here. Flax + # bug ? + with pytest.raises( + AttributeError, match='Attribute was not set during the call' + ): + _ = self.tmp_val + + # Overwrite tmp_val + self.tmp_val = x + self.tmp_val = nn.Dense(2)(self.tmp_val) + + self.tmp_list.append(x) + return self.tmp_val + + +def _make_model(module_cls: type[knn.Module]) -> ModelRoot: + """Model factory.""" + shared = module_cls(name='shared') + model_raw = ModelRoot( + child=module_cls(), + childs=[ + module_cls(), + shared, + shared, + ], + hidden=lambda: module_cls(), # pylint: disable=unnecessary-lambda + nested=Nested(module_cls()), + ) + + rng = jax.random.PRNGKey(0) + + return model_raw.init_bind(rng, f32[(*_IN_SHAPE,)]) + + +def test_non_bind(): + model_raw = DenseAndDropout() + + with pytest.raises(ValueError, match='before calling'): + _ = model_raw.params # Unbind, function not available + + with pytest.raises(flax.errors.CallCompactUnboundModuleError): + model_raw(jnp.ones(_IN_SHAPE)) + + +def test_train_mode(): + model = _make_model(DenseAndDropout) + + model = model.with_rng(0) + + model_train = model + assert model.training + assert model.child.training + assert model.childs[0].training + assert model.childs[1].training + assert model.childs[2] is model.childs[1] + assert model.nested.training + assert model.nested.child.training + + x = jnp.ones(_IN_SHAPE) + y = model(x) + assert not np.allclose(y, x) # Dropout applied + assert not np.allclose(y[0], y[1]) # Batch have different dropout + y2 = model(x) + np.testing.assert_allclose(y, y2) # Calling model twice yield same result + + model = model.eval() + assert not model.child.training + assert not model.childs[0].training + assert not model.childs[1].training + assert model.childs[2] is model.childs[1] + assert model.childs[2] is not model_train.childs[1] + assert not model.nested.training + assert not model.nested.child.training + + # Dropout disabled: No-op + y = model(x) + assert not np.allclose(y, y2) # Dropout disabled + # In eval, same example yield same result + np.testing.assert_allclose(y[0], y[1]) + + nested = model.nested.train() + # Model isn't mutated + assert not model.child.training + assert not model.childs[0].training + assert not model.childs[1].training + assert not model.nested.training + assert not model.nested.child.training + # But nested is updated + assert nested.training + assert nested.child.training + + +def test_rng(): + x = jnp.ones(_IN_SHAPE) + model = _make_model(DenseAndDropout) + + with pytest.raises(flax.errors.InvalidRngError): + model(x) # No rng by default + + assert model.rngs == {} # pylint: disable=g-explicit-bool-comparison + assert model.child.rngs == {} # pylint: disable=g-explicit-bool-comparison + assert model.childs[0].rngs == {} # pylint: disable=g-explicit-bool-comparison + + key = krandom.PRNGKey(0) + x = jnp.ones(_IN_SHAPE) + + model = model.with_rng(key) + + y = model(x) + y2 = model(x) + np.testing.assert_allclose(y, y2) + + # Test with_rng + model = model.with_rng() # Next key + y = model(x) + assert not jnp.allclose(y, y2) # Old pred is different from the new key + np.testing.assert_allclose(y, model(x)) + + # rng is constant between train/eval + jax.tree_util.tree_map( + np.testing.assert_allclose, model.rngs, model.eval().rngs + ) + + # Key can be explicitly passed + model = model.with_rng({'dropout': key}) + y = model(x) + y2 = model(x) + np.testing.assert_allclose(y, y2) + + +def test_jit(): + @jax.jit + def fn(model, x): + y = model(x) + return y + + x = jnp.ones(_IN_SHAPE) + model = _make_model(DenseAndDropout) + model = model.with_rng(0) + + y = fn(model, x) + y2 = model(x) + np.testing.assert_allclose(y, y2, atol=1e-6) + + new_y = fn(model.with_rng(), x) + new_y2 = model.with_rng()(x) + np.testing.assert_allclose(new_y, new_y2, atol=1e-6) + + # Calling jit with new rng should yield new results + assert not np.allclose(new_y, y, atol=1e-6) + + +def test_param(): + model = _make_model(DenseAndDropout) + model = model.with_rng(0) + + assert isinstance(model.params, flax.core.FrozenDict) + assert isinstance(model.child.params, flax.core.FrozenDict) + + # Values are as expected + chex.assert_trees_all_close( + model.nested.params, + flax.core.FrozenDict({'child': model.nested.child.params}), + ) + + +def test_intermediate(): + model = _make_model(WithIntermediate) + + x = jnp.ones(_IN_SHAPE) + + y = model(x) # Model call works without intermediate + + with pytest.raises( + AttributeError, match='can only be accessed inside module functions' + ): + _ = model.child.tmp_val + + with pytest.raises( + AttributeError, match='can only be accessed inside module functions' + ): + _ = model.child.tmp_list + + with model.capture_intermediates() as intermediates: + y2 = model(x) + + np.testing.assert_allclose(y, y2) + # Last called captured value should match output + np.testing.assert_allclose(y, intermediates.nested.child.tmp_val) + # But not the first one + _assert_not_all_close(y, intermediates.child.tmp_val) + + assert len(intermediates.child.tmp_list) == 2 + assert len(intermediates.childs[0].tmp_list) == 2 + assert len(intermediates.childs[1].tmp_list) == 4 # Shared called twice + + # Calling intermediate twice should reset the values + with model.capture_intermediates() as intermediates: + y2 = model(x) + + np.testing.assert_allclose(y, y2) + # Last called captured value should match output + np.testing.assert_allclose(y, intermediates.nested.child.tmp_val) + # But not the first one + _assert_not_all_close(y, intermediates.child.tmp_val) + + assert len(intermediates.child.tmp_list) == 2 + assert len(intermediates.childs[0].tmp_list) == 2 + assert len(intermediates.childs[1].tmp_list) == 4 # Shared called twice + + +def _assert_not_all_close(x, y): + assert not np.allclose(x, y) diff --git a/etils/klinen/traverse.py b/etils/klinen/traverse.py new file mode 100644 index 00000000..007ba0c0 --- /dev/null +++ b/etils/klinen/traverse.py @@ -0,0 +1,154 @@ +# Copyright 2023 The etils Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module helper.""" + +from __future__ import annotations + +from collections.abc import Callable +import dataclasses +import functools +from typing import Any, Generic, Optional, TypeVar + +from flax import linen as nn +from kauldron.klinen import module as module_lib + +_T = TypeVar('_T') + + +class Future(Generic[_T]): + """Value wrapper which is set later.""" + + value: _T + + +def recursive_set_parent(module: _T) -> _T: + """Reccursivelly assign the parent to the modules. + + Ater this function is called, all the `model.childs[0].child` will + have their `_kd_name` (e.g. `'childs_3'`) and `_kd_future_parent` set. + + Args: + module: The module to traverse. + + Returns: + A copy of the module and sub-modules with parent set. + """ + return recursive_replace(module, replace_fn=_replace_module) + + +def _replace_module( + module: _T, + *, + attributes: dict[str, Any], + name: Optional[str], + future_parent: Future[module_lib.Module], +) -> _T: + """Replace the module with updated name and parent fields.""" + new_module = dataclasses.replace(module, **attributes) + assert new_module is not module + new_module._kd_name = name # pylint: disable=protected-access + new_module._kd_future_parent = ( # pylint: disable=protected-access + future_parent + ) + return new_module + + +def recursive_replace( + module: _T, + *, + name: Optional[str] = None, + future_parent: Optional[Future[module_lib.Module]] = None, + replace_fn: Callable[..., _T], +) -> _T: + """Reccursivelly traverse the modules attribute and replace them. + + Args: + module: The module to traverse. + name: Module name + future_parent: Module parent + replace_fn: Function to replace the module with updated attributes. + + Returns: + A copy of the module and sub-modules with parent set. + """ + self_parent = Future() + + # Cache the modules shared across fields. + # Note: This does not support cicles (module.child is module) + cache: dict[int, module_lib.Module] = {} + + # TODO(epot): Could have more optimized implementation (skip non-module + # fields) + attributes = { + f.name: getattr(module, f.name) + for f in dataclasses.fields(module) + if f.init + } + module_attributes = {} + for k, v in attributes.items(): + is_module_field = Future() + is_module_field.value = False + # `_map_over_modules_in_tree` implementation could likely be + # optimized to be applied using list/dict comprehension. + v = nn.module._map_over_modules_in_tree( # pylint: disable=protected-access + functools.partial( + _set_module_name_and_parent, + future_parent=self_parent, + cache=cache, + replace_fn=replace_fn, + is_module_field=is_module_field, + ), + # Need to wrap inside dict, otherwise, flax do not infer attribute name. + {k: v}, + ) + if is_module_field.value: + module_attributes[k] = v[k] + + new_module = replace_fn( + module, + attributes=module_attributes, + name=name, + future_parent=future_parent, + ) + + self_parent.value = new_module + return new_module + + +def _set_module_name_and_parent( + prefix: str, + leaf: _T, + *, + future_parent: Future[module_lib.Module], + cache: dict[int, module_lib.Module], + replace_fn: Callable[..., _T], + is_module_field: Future[bool], +) -> _T: + """Set the `_kd_name` and `_kd_future_parent` attribute.""" + if isinstance(leaf, module_lib.Module): + is_module_field.value = True + id_ = id(leaf) + if id_ in cache: # Already created + return cache[id_] + leaf = recursive_replace( + leaf, + name=prefix.removeprefix('_'), + future_parent=future_parent, + replace_fn=replace_fn, + ) + cache[id_] = leaf + return leaf + else: + return leaf