Skip to content

Commit

Permalink
Internal changes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 450380384
  • Loading branch information
stompchicken authored and ChexDev committed May 23, 2022
1 parent 2b994fa commit 1b05864
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
39 changes: 36 additions & 3 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from absl import logging
import jax
import tree


FrozenInstanceError = dataclasses.FrozenInstanceError
Expand Down Expand Up @@ -62,7 +63,8 @@ def new_init(self, *orig_args, **orig_kwargs):
all_kwargs = dict(*orig_args, **orig_kwargs)
unknown_kwargs = set(all_kwargs.keys()) - all_fields
if unknown_kwargs:
raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.")
raise ValueError(
f"__init__() got unexpected keyword arguments: {unknown_kwargs}.")

# Pass only arguments corresponding to fields with `init=True`.
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
Expand Down Expand Up @@ -91,7 +93,7 @@ def dataclass(
order=False,
unsafe_hash=False,
frozen=False,
mappable_dataclass=True, # pylint: disable=redefined-outer-name
mappable_dataclass=False, # pylint: disable=redefined-outer-name
):
"""JAX-friendly wrapper for :py:func:`dataclasses.dataclass`.
Expand Down Expand Up @@ -185,7 +187,7 @@ def __call__(self, cls):
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
return dcls(**dict(zip(dcls.__dataclass_fields__.keys(), args)))

def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
Expand All @@ -202,6 +204,8 @@ def _getstate(self):
def _setstate(self, state):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
if not class_self.mappable_dataclass:
register_dataclass_type_with_dm_tree(dcls)
class_self.registered = True
self.__dict__.update(state)

Expand All @@ -213,6 +217,8 @@ def _setstate(self, state):
def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
if not class_self.mappable_dataclass:
register_dataclass_type_with_dm_tree(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

Expand Down Expand Up @@ -246,3 +252,30 @@ def register_dataclass_type_with_jax_tree_util(data_class):
nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten)
except ValueError:
logging.info("%s is already registered as JAX PyTree node.", data_class)


def register_dataclass_type_with_dm_tree(data_class):
"""Register an existing dataclass with dm_tree node registry.
This will mean that functions in dm_tree will operate over fields of the
dataclass.
Args:
data_class: A class created using dataclasses.dataclass. It must be
constructable from keyword arguments corresponding to the members exposed
in instance.__dict__.
"""

def to_iterable(d):
keys, values = jax.util.unzip2(sorted(d.__dict__.items()))
return values, keys, keys

def from_iterable(keys, values):
return data_class(**dict(zip(keys, values)))

try:
tree.register_node(data_class, to_iterable, from_iterable)
except ValueError:
logging.log_first_n(logging.INFO,
"%s is already registered as dm_tree node.", 1,
data_class)
11 changes: 8 additions & 3 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,17 @@ class SimpleDataclass:
b: int = 2

SimpleDataclass(a=1, b=3)
with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'):
with self.assertRaisesRegex((ValueError, TypeError),
'.*unexpected keyword argument.*'):
SimpleDataclass(a=1, b=3, c=4)

def test_tuple_conversion(self):
@parameterized.named_parameters(
('non_mappable', False),
('mappable', True),
)
def test_tuple_conversion(self, mappable):

@chex_dataclass()
@chex_dataclass(mappable_dataclass=mappable)
class SimpleDataclass:
b: int
a: int
Expand Down

0 comments on commit 1b05864

Please sign in to comment.