Skip to content

Commit

Permalink
redefinition -> defaults kept in config
Browse files Browse the repository at this point in the history
Summary:
This is an internal change in the config systen. It allows redefining a pluggable implementation with new default values. This is useful in notebooks / interactive use. For example, this now works.

        class A(ReplaceableBase):
            pass

        registry.register
        class B(A):
            i: int = 4

        class C(Configurable):
            a: A
            a_class_type: str = "B"

            def __post_init__(self):
                run_auto_creation(self)

        expand_args_fields(C)

        registry.register
        class B(A):
            i: int = 5

        c = C()

        assert c.a.i == 5

Reviewed By: shapovalov

Differential Revision: D38219371

fbshipit-source-id: 72911a9bd3426d3359cf8802cc016fc7f6d7713b
  • Loading branch information
bottler authored and facebook-github-bot committed Jul 28, 2022
1 parent cb49550 commit 6b48159
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
39 changes: 37 additions & 2 deletions pytorch3d/implicitron/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,40 @@ def create():
return dataclasses.field(default_factory=create)


def _get_default_args_field_from_registry(
*,
base_class_wanted: Type[_X],
name: str,
_do_not_process: Tuple[type, ...] = (),
_hook: Optional[Callable[[DictConfig], None]] = None,
):
"""
Get a dataclass field which defaults to
get_default_args(registry.get(base_class_wanted, name)).
This is used internally in place of get_default_args_field in
order that default values are updated if a class is redefined.
Args:
base_class_wanted: As for registry.get.
name: As for registry.get.
_do_not_process: As for get_default_args
_hook: Function called on the result before returning.
Returns:
function to return new DictConfig object
"""

def create():
C = registry.get(base_class_wanted=base_class_wanted, name=name)
args = get_default_args(C, _do_not_process=_do_not_process)
if _hook is not None:
_hook(args)
return args

return dataclasses.field(default_factory=create)


def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
"""
If a member is annotated as `type_`, and that should expanded in
Expand Down Expand Up @@ -978,8 +1012,9 @@ def _process_member(
setattr(
some_class,
args_name,
get_default_args_field(
derived_type,
_get_default_args_field_from_registry(
base_class_wanted=type_,
name=derived_type.__name__,
_do_not_process=_do_not_process + (some_class,),
_hook=hook_closed,
),
Expand Down
18 changes: 12 additions & 6 deletions tests/implicitron/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,20 @@ def get_color(self):
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
self.assertIsInstance(bowl.main_fruit, Grape)
defaulted_bowl = FruitBowl()
self.assertIsInstance(defaulted_bowl.main_fruit, Grape)
self.assertEqual(defaulted_bowl.main_fruit.large, True)
self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green")

with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
args_bowl = FruitBowl(**bowl_args)
self.assertIsInstance(args_bowl.main_fruit, Grape)
# Redefining the same class won't help with defaults because encoded in args
self.assertEqual(bowl.main_fruit.large, False)

self.assertEqual(args_bowl.main_fruit.large, False)
# But the override worked.
self.assertEqual(bowl.main_fruit.get_color(), "green")
self.assertEqual(args_bowl.main_fruit.get_color(), "green")

# 2. Try redefining without the dataclass modifier
# This relies on the fact that default creation processes the class.
Expand All @@ -397,7 +403,7 @@ class Grape(Fruit): # noqa: F811
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
FruitBowl(**bowl_args)

# 3. Adding a new class doesn't get picked up, because the first
# get_default_args call has frozen FruitBowl. This is intrinsic to
Expand Down

0 comments on commit 6b48159

Please sign in to comment.