From c759fc560f84eaff3577afac0083a2a2f07b349f Mon Sep 17 00:00:00 2001 From: Dejan Kovachev Date: Fri, 31 Mar 2023 07:44:38 -0700 Subject: [PATCH] Hard population of registry system with pre_expand Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports. Reviewed By: bottler Differential Revision: D44504122 fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62 --- pytorch3d/implicitron/dataset/data_source.py | 25 +++++++-- pytorch3d/implicitron/models/generic_model.py | 53 +++++++++++-------- pytorch3d/implicitron/models/overfit_model.py | 25 +++++++++ pytorch3d/implicitron/tools/config.py | 7 +++ tests/implicitron/test_config.py | 34 ++++++++++++ 5 files changed, 117 insertions(+), 27 deletions(-) diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index fcc2ed207..7ea5fe6f9 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -13,13 +13,8 @@ ) from pytorch3d.renderer.cameras import CamerasBase -from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .dataset_map_provider import DatasetMap, DatasetMapProviderBase -from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa -from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa -from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa -from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa class DataSourceBase(ReplaceableBase): @@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] data_loader_map_provider: DataLoaderMapProviderBase data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider" + @classmethod + def pre_expand(cls) -> None: + # use try/finally to bypass cinder's lazy imports + try: + from .blender_dataset_map_provider import ( # noqa: F401 + BlenderDatasetMapProvider, + ) + from .json_index_dataset_map_provider import ( # noqa: F401 + JsonIndexDatasetMapProvider, + ) + from .json_index_dataset_map_provider_v2 import ( # noqa: F401 + JsonIndexDatasetMapProviderV2, + ) + from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa: F401 + from .rendered_mesh_dataset_map_provider import ( # noqa: F401 + RenderedMeshDatasetMapProvider, + ) + finally: + pass + def __post_init__(self): run_auto_creation(self) self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index b903814f6..6e336105f 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -20,23 +20,8 @@ ImplicitronRender, ) from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase -from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa - ResNetFeatureExtractor, -) from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase -from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa - IdrFeatureField, -) -from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa - NeRFormerImplicitFunction, -) -from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa - SRNHyperNetImplicitFunction, -) -from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa - VoxelGridImplicitFunction, -) from pytorch3d.implicitron.models.metrics import ( RegularizationMetricsBase, ViewMetricsBase, @@ -50,14 +35,7 @@ RendererOutput, RenderSamplingMode, ) -from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer # noqa -from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa - MultiPassEmissionAbsorptionRenderer, -) from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase -from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa - SignedDistanceFunctionRenderer, -) from pytorch3d.implicitron.models.utils import ( apply_chunked, @@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ] ) + @classmethod + def pre_expand(cls) -> None: + # use try/finally to bypass cinder's lazy imports + try: + from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa: F401, B950 + ResNetFeatureExtractor, + ) + from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950 + IdrFeatureField, + ) + from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950 + NeRFormerImplicitFunction, + ) + from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950 + SRNHyperNetImplicitFunction, + ) + from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa: F401, B950 + VoxelGridImplicitFunction, + ) + from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401 + LSTMRenderer, + ) + from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa + MultiPassEmissionAbsorptionRenderer, + ) + from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401 + SignedDistanceFunctionRenderer, + ) + finally: + pass + def __post_init__(self): if self.view_pooler_enabled: if self.image_feature_extractor_class_type is None: diff --git a/pytorch3d/implicitron/models/overfit_model.py b/pytorch3d/implicitron/models/overfit_model.py index b773f437c..52854d057 100644 --- a/pytorch3d/implicitron/models/overfit_model.py +++ b/pytorch3d/implicitron/models/overfit_model.py @@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13 ] ) + @classmethod + def pre_expand(cls) -> None: + # use try/finally to bypass cinder's lazy imports + try: + from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950 + IdrFeatureField, + ) + from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950 + NeuralRadianceFieldImplicitFunction, + ) + from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950 + SRNImplicitFunction, + ) + from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401 + LSTMRenderer, + ) + from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa: F401 + MultiPassEmissionAbsorptionRenderer, + ) + from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401 + SignedDistanceFunctionRenderer, + ) + finally: + pass + def __post_init__(self): # The attribute will be filled by run_auto_creation run_auto_creation(self) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index d20759831..0fb4012e6 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -185,6 +185,7 @@ def __post_init__(self): IMPL_SUFFIX: str = "_impl" TWEAK_SUFFIX: str = "_tweak_args" _DATACLASS_INIT: str = "__dataclass_own_init__" +PRE_EXPAND_NAME: str = "pre_expand" class ReplaceableBase: @@ -838,6 +839,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None In addition, if the class inherits torch.nn.Module, the generated __init__ will call torch.nn.Module's __init__ before doing anything else. + Before any transformation of the class, if the class has a classmethod called + `pre_expand`, it will be called with no arguments. + Note that although the *_args members are intended to have type DictConfig, they are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig in place of a dict, but not vice-versa. Allowing dict lets a class user specify @@ -858,6 +862,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None if _is_actually_dataclass(some_class): return some_class + if hasattr(some_class, PRE_EXPAND_NAME): + getattr(some_class, PRE_EXPAND_NAME)() + # The functions this class's run_auto_creation will run. creation_functions: List[str] = [] # The classes which this type knows about from the registry diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index e5ddec373..471451b45 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -10,6 +10,7 @@ from dataclasses import dataclass, field, is_dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from pytorch3d.implicitron.tools.config import ( @@ -805,6 +806,39 @@ def __post_init__(self): self.assertEqual(control_args, ["Orange", "Orange", True, True]) + def test_pre_expand(self): + # Check that the precreate method of a class is called once before + # when expand_args_fields is called on the class. + + class A(Configurable): + n: int = 9 + + @classmethod + def pre_expand(cls): + pass + + A.pre_expand = Mock() + expand_args_fields(A) + A.pre_expand.assert_called() + + def test_pre_expand_replaceable(self): + # Check that the precreate method of a class is called once before + # when expand_args_fields is called on the class. + + class A(ReplaceableBase): + pass + + @classmethod + def pre_expand(cls): + pass + + class A1(A): + n: 9 + + A.pre_expand = Mock() + expand_args_fields(A1) + A.pre_expand.assert_called() + @dataclass(eq=False) class MockDataclass: