From d096af72874af63f0260f79332c1af9318169221 Mon Sep 17 00:00:00 2001 From: tinyzqh Date: Thu, 31 Oct 2024 21:12:27 +0800 Subject: [PATCH] fix index of num_features_to_replace bug --- lop/algos/convGnT.py | 2 +- lop/algos/gnt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lop/algos/convGnT.py b/lop/algos/convGnT.py index 7a2e574..10caaeb 100644 --- a/lop/algos/convGnT.py +++ b/lop/algos/convGnT.py @@ -207,7 +207,7 @@ def update_optim_params(self, features_to_replace_input_indices, features_to_rep if self.opt_type == 'AdamGnT': for i in range(self.num_hidden_layers): # input weights - if num_features_to_replace == 0: + if num_features_to_replace[i] == 0: continue # input weights self.opt.state[self.net[i * 2].bias]['exp_avg'][features_to_replace_input_indices[i]] = 0.0 diff --git a/lop/algos/gnt.py b/lop/algos/gnt.py index 2804525..c999209 100644 --- a/lop/algos/gnt.py +++ b/lop/algos/gnt.py @@ -211,7 +211,7 @@ def update_optim_params(self, features_to_replace, num_features_to_replace): if self.opt_type == 'adam': for i in range(self.num_hidden_layers): # input weights - if num_features_to_replace == 0: + if num_features_to_replace[i] == 0: continue self.opt.state[self.net[i * 2].weight]['exp_avg'][features_to_replace[i], :] = 0.0 self.opt.state[self.net[i * 2].bias]['exp_avg'][features_to_replace[i]] = 0.0