diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py index 00d615815..ce27d8293 100644 --- a/tests/test_models/test_algorithms/test_prune_algorithm.py +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -11,6 +11,9 @@ from mmrazor.models.algorithms.pruning.ite_prune_algorithm import ( ItePruneAlgorithm, ItePruneConfigManager) from mmrazor.registry import MODELS +from projects.group_fisher.modules.group_fisher_algorthm import \ + GroupFisherAlgorithm +from projects.group_fisher.modules.group_fisher_ops import GroupFisherConv2d from ...utils.set_dist_env import SetDistEnv @@ -262,3 +265,63 @@ def test_resume(self): print(algorithm2.mutator.current_choices) self.assertDictEqual(algorithm.mutator.current_choices, algorithm2.mutator.current_choices) + + +class TestGroupFisherPruneAlgorithm(TestItePruneAlgorithm): + + def test_group_fisher_prune(self): + data = self.fake_cifar_data() + + MUTATOR_CONFIG = dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict(type='GroupFisherChannelUnit')) + + epoch = 2 + interval = 1 + delta = 'flops' + + algorithm = GroupFisherAlgorithm( + MODEL_CFG, + pruning=True, + mutator=MUTATOR_CONFIG, + delta=delta, + interval=interval, + save_ckpt_delta_thr=[1.1]).to(DEVICE) + mutator = algorithm.mutator + + ckpt_path = os.path.dirname(__file__) + f'/{delta}_0.99.pth' + + fake_cfg_path = os.path.dirname(__file__) + '/cfg.py' + self.gen_fake_cfg(fake_cfg_path) + self.assertTrue(os.path.exists(fake_cfg_path)) + + message_hub = MessageHub.get_current_instance() + cfg_str = open(fake_cfg_path).read() + message_hub.update_info('cfg', cfg_str) + + for e in range(epoch): + for ite in range(10): + self._set_epoch_ite(e, ite, epoch) + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + self.gen_fake_grad(mutator) + self.assertEqual(delta, algorithm.delta) + self.assertEqual(interval, algorithm.interval) + self.assertTrue(os.path.exists(ckpt_path)) + os.remove(ckpt_path) + os.remove(fake_cfg_path) + self.assertTrue(not os.path.exists(ckpt_path)) + self.assertTrue(not os.path.exists(fake_cfg_path)) + + def gen_fake_grad(self, mutator): + for unit in mutator.mutable_units: + for channel in unit.input_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + module.recorded_grad = module.recorded_input + + def gen_fake_cfg(self, fake_cfg_path): + with open(fake_cfg_path, 'a', encoding='utf-8') as cfg: + cfg.write(f'work_dir = \'{os.path.dirname(__file__)}\'') + cfg.write('\n') diff --git a/tests/test_projects/__init__.py b/tests/test_projects/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_projects/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_projects/test_expand/__init__.py b/tests/test_projects/test_expand/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_projects/test_expand/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_projects/test_expand/test_expand.py b/tests/test_projects/test_expand/test_expand.py new file mode 100644 index 000000000..3af0ae7a5 --- /dev/null +++ b/tests/test_projects/test_expand/test_expand.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import SimpleMutableChannel +from mmrazor.models.mutators import ChannelMutator +from projects.cores.expandable_ops.ops import ExpandLinear +from projects.cores.expandable_ops.unit import (ExpandableUnit, expand_model, + expand_static_model) +from ...data.models import MultiConcatModel, SingleLineModel + + +class TestExpand(unittest.TestCase): + + def test_expand(self): + x = torch.rand([1, 3, 224, 224]) + model = MultiConcatModel() + print(model) + mutator = ChannelMutator[ExpandableUnit]( + channel_unit_cfg=ExpandableUnit) + mutator.prepare_from_supernet(model) + print(mutator.choice_template) + print(model) + y1 = model(x) + + for unit in mutator.mutable_units: + unit.expand(10) + print(unit.mutable_channel.mask.shape) + expand_model(model, zero=True) + print(model) + y2 = model(x) + self.assertTrue((y1 - y2).abs().max() < 1e-3) + + def test_expand_static_model(self): + x = torch.rand([1, 3, 224, 224]) + model = SingleLineModel() + y1 = model(x) + expand_static_model(model, divisor=4) + y2 = model(x) + print(y1.reshape([-1])[:5]) + print(y2.reshape([-1])[:5]) + self.assertTrue((y1 - y2).abs().max() < 1e-3) + + def test_ExpandConv2d(self): + linear = ExpandLinear(3, 3) + mutable_in = SimpleMutableChannel(3) + mutable_out = SimpleMutableChannel(3) + linear.register_mutable_attr('in_channels', mutable_in) + linear.register_mutable_attr('out_channels', mutable_out) + + print(linear.weight) + + mutable_in.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0]) + mutable_out.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0]) + linear_ex = linear.expand(zero=True) + print(linear_ex.weight)