From 36da4489c90099ef23507caa07db9f88338f9ab3 Mon Sep 17 00:00:00 2001 From: liukai Date: Mon, 30 Jan 2023 15:23:59 +0800 Subject: [PATCH] add args min_channel --- projects/group_fisher/modules/group_fisher_channel_mutator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/projects/group_fisher/modules/group_fisher_channel_mutator.py b/projects/group_fisher/modules/group_fisher_channel_mutator.py index bb9255dcf..8db1403c2 100644 --- a/projects/group_fisher/modules/group_fisher_channel_mutator.py +++ b/projects/group_fisher/modules/group_fisher_channel_mutator.py @@ -37,10 +37,12 @@ def __init__(self, demo_input=(1, 3, 224, 224), tracer_type='FxTracer'), min_ratio=0.0, + min_channel=0, **kwargs) -> None: super().__init__(channel_unit_cfg, parse_cfg, **kwargs) self.mutable_units: List[GroupFisherChannelUnit] self.min_ratio = min_ratio + self.min_channel = min_channel def start_record_info(self) -> None: """Start recording the related information.""" @@ -64,7 +66,7 @@ def try_prune(self) -> None: min_unit = self.mutable_units[0] for unit in self.mutable_units: if unit.mutable_channel.activated_channels > max( - 20, (unit.num_channels * self.min_ratio)): + self.min_channel, (unit.num_channels * self.min_ratio), 0): imp = unit.importance() if imp.isnan().any(): if dist.get_rank() == 0: