From 6b481595f096254817902d1dc0e1ead18e5610ca Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 28 Jul 2022 09:39:18 -0700 Subject: [PATCH] redefinition -> defaults kept in config Summary: This is an internal change in the config systen. It allows redefining a pluggable implementation with new default values. This is useful in notebooks / interactive use. For example, this now works. class A(ReplaceableBase): pass registry.register class B(A): i: int = 4 class C(Configurable): a: A a_class_type: str = "B" def __post_init__(self): run_auto_creation(self) expand_args_fields(C) registry.register class B(A): i: int = 5 c = C() assert c.a.i == 5 Reviewed By: shapovalov Differential Revision: D38219371 fbshipit-source-id: 72911a9bd3426d3359cf8802cc016fc7f6d7713b --- pytorch3d/implicitron/tools/config.py | 39 +++++++++++++++++++++++++-- tests/implicitron/test_config.py | 18 ++++++++----- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 79cda30a9..1605f8b44 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -887,6 +887,40 @@ def create(): return dataclasses.field(default_factory=create) +def _get_default_args_field_from_registry( + *, + base_class_wanted: Type[_X], + name: str, + _do_not_process: Tuple[type, ...] = (), + _hook: Optional[Callable[[DictConfig], None]] = None, +): + """ + Get a dataclass field which defaults to + get_default_args(registry.get(base_class_wanted, name)). + + This is used internally in place of get_default_args_field in + order that default values are updated if a class is redefined. + + Args: + base_class_wanted: As for registry.get. + name: As for registry.get. + _do_not_process: As for get_default_args + _hook: Function called on the result before returning. + + Returns: + function to return new DictConfig object + """ + + def create(): + C = registry.get(base_class_wanted=base_class_wanted, name=name) + args = get_default_args(C, _do_not_process=_do_not_process) + if _hook is not None: + _hook(args) + return args + + return dataclasses.field(default_factory=create) + + def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]: """ If a member is annotated as `type_`, and that should expanded in @@ -978,8 +1012,9 @@ def _process_member( setattr( some_class, args_name, - get_default_args_field( - derived_type, + _get_default_args_field_from_registry( + base_class_wanted=type_, + name=derived_type.__name__, _do_not_process=_do_not_process + (some_class,), _hook=hook_closed, ), diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 503f8ab54..590b9dea7 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -378,14 +378,20 @@ def get_color(self): with self.assertWarnsRegex( UserWarning, "New implementation of Grape is being chosen." ): - bowl = FruitBowl(**bowl_args) - self.assertIsInstance(bowl.main_fruit, Grape) + defaulted_bowl = FruitBowl() + self.assertIsInstance(defaulted_bowl.main_fruit, Grape) + self.assertEqual(defaulted_bowl.main_fruit.large, True) + self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green") + with self.assertWarnsRegex( + UserWarning, "New implementation of Grape is being chosen." + ): + args_bowl = FruitBowl(**bowl_args) + self.assertIsInstance(args_bowl.main_fruit, Grape) # Redefining the same class won't help with defaults because encoded in args - self.assertEqual(bowl.main_fruit.large, False) - + self.assertEqual(args_bowl.main_fruit.large, False) # But the override worked. - self.assertEqual(bowl.main_fruit.get_color(), "green") + self.assertEqual(args_bowl.main_fruit.get_color(), "green") # 2. Try redefining without the dataclass modifier # This relies on the fact that default creation processes the class. @@ -397,7 +403,7 @@ class Grape(Fruit): # noqa: F811 with self.assertWarnsRegex( UserWarning, "New implementation of Grape is being chosen." ): - bowl = FruitBowl(**bowl_args) + FruitBowl(**bowl_args) # 3. Adding a new class doesn't get picked up, because the first # get_default_args call has frozen FruitBowl. This is intrinsic to