diff --git a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py index 75a1db293..63c60e8cb 100644 --- a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py @@ -6,6 +6,7 @@ from mmrazor.registry import TASK_UTILS from mmrazor.utils import get_placeholder +from ...algorithms.base import BaseAlgorithm from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput, DefaultMMDemoInput, DefaultMMDetDemoInput, DefaultMMPoseDemoInput, DefaultMMRotateDemoInput, @@ -70,8 +71,12 @@ def get_default_demo_input_class(model, scope): def defaul_demo_inputs(model, input_shape, training=False, scope=None): """Get demo input according to a model and scope.""" - demo_input = get_default_demo_input_class(model, scope) - return demo_input().get_data(model, input_shape, training) + if isinstance(model, BaseAlgorithm): + return defaul_demo_inputs(model.architecture, input_shape, training, + scope) + else: + demo_input = get_default_demo_input_class(model, scope) + return demo_input().get_data(model, input_shape, training) @TASK_UTILS.register_module() diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index 8664f3a2d..ab0dfb4b5 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -51,7 +51,9 @@ def _get_data(self, model, input_shape=None, training=None): return data def _get_mm_data(self, model, input_shape, training=False): - return {'inputs': torch.rand(input_shape), 'data_samples': None} + data = {'inputs': torch.rand(input_shape), 'data_samples': None} + data = model.data_preprocessor(data, training) + return data @TASK_UTILS.register_module() @@ -132,7 +134,7 @@ def _get_mm_data(self, model, input_shape, training=False): from mmpose.models import TopdownPoseEstimator from .mmpose_demo_input import demo_mmpose_inputs - assert isinstance(model, TopdownPoseEstimator) + assert isinstance(model, TopdownPoseEstimator), f'{type(model)}' data = demo_mmpose_inputs(model, input_shape) return data diff --git a/projects/cores/hooks/prune_hook.py b/projects/cores/hooks/prune_hook.py index c6ca37137..52f30a602 100644 --- a/projects/cores/hooks/prune_hook.py +++ b/projects/cores/hooks/prune_hook.py @@ -83,8 +83,10 @@ def __init__(self, def before_run(self, runner) -> None: model = get_model_from_runner(runner) - self.origin_delta = self._evaluate(model)[self.delta_type] - print_log(f'get original {self.delta_type}: {self.origin_delta}') + original_resource = self._evaluate(model) + print_log(f'get original resource: {original_resource}') + + self.origin_delta = original_resource[self.delta_type] # save checkpoint