Skip to content

Commit

Permalink
Fix problem in generator training and reward adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
WayneJin0918 committed Jun 10, 2024
1 parent 33973cc commit dfb6649
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 35 deletions.
14 changes: 8 additions & 6 deletions semilearn/algorithms/srfixmatch/fixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, args, net_builder, tb_log=None, logger=None):

self.criterion = torch.nn.MSELoss()

self.max_reward = -float('inf')
def init(self, T, p_cutoff, hard_label=True):
self.T = T
self.p_cutoff = p_cutoff
Expand Down Expand Up @@ -148,7 +149,7 @@ def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
# Generate pseudo labels using the generator (your pseudo-labeling process)
self.rewarder.train()
self.generator.train()
generated_label = self.generator(feats_x_lb.detach()).detach()
generated_label = self.generator(feats_x_lb.detach())
generated_label=generated_label.long()
# Convert generated pseudo labels and true labels to tensors
real_labels_tensor = y_lb.cuda(self.gpu)
Expand All @@ -157,16 +158,17 @@ def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
filtered_pseudo_labels = pseudo_label.long()
filtered_feats_x_ulb_w = feats_x_ulb_w.detach()
rewarder = self.rewarder.eval()
max_reward = -float('inf')

reward = self.rewarder(feats_x_ulb_w.detach(), pseudo_label.long())
reward = reward.mean()
max_reward = torch.where(reward > max_reward, reward, max_reward)
filtered_pseudo_labels = torch.where(reward > max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
self.max_reward = torch.where(reward > self.max_reward, reward, self.max_reward)
filtered_pseudo_labels = torch.where(reward > self.max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > self.max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
if self.it % self.N_k == 0 and self.it > self.start_timing:
self.max_reward = -float('inf')
self.rewarder.train()
self.generator.train()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1)).detach()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1))
generated_label=generated_label.long()
reward = self.rewarder(filtered_feats_x_ulb_w, generated_label.squeeze(1))
generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes)
Expand Down
14 changes: 8 additions & 6 deletions semilearn/algorithms/srflexmatch/srflexmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, args, net_builder, tb_log=None, logger=None):

self.criterion = torch.nn.MSELoss()

self.max_reward = -float('inf')
def init(self, T, p_cutoff, hard_label=True, thresh_warmup=True):
self.T = T
self.p_cutoff = p_cutoff
Expand Down Expand Up @@ -154,7 +155,7 @@ def train_step(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s):
# Generate pseudo labels using the generator (your pseudo-labeling process)
self.rewarder.train()
self.generator.train()
generated_label = self.generator(feats_x_lb.detach()).detach()
generated_label = self.generator(feats_x_lb.detach())
generated_label=generated_label.long()
# Convert generated pseudo labels and true labels to tensors
real_labels_tensor = y_lb.cuda(self.gpu)
Expand All @@ -163,16 +164,17 @@ def train_step(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s):
filtered_pseudo_labels = pseudo_label.long()
filtered_feats_x_ulb_w = feats_x_ulb_w.detach()
rewarder = self.rewarder.eval()
max_reward = -float('inf')

reward = self.rewarder(feats_x_ulb_w.detach(), pseudo_label.long())
reward = reward.mean()
max_reward = torch.where(reward > max_reward, reward, max_reward)
filtered_pseudo_labels = torch.where(reward > max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
self.max_reward = torch.where(reward > self.max_reward, reward, self.max_reward)
filtered_pseudo_labels = torch.where(reward > self.max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > self.max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
if self.it % self.N_k == 0 and self.it > self.start_timing:
self.max_reward = -float('inf')
self.rewarder.train()
self.generator.train()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1)).detach()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1))
generated_label=generated_label.long()
reward = self.rewarder(filtered_feats_x_ulb_w, generated_label.squeeze(1))
generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes)
Expand Down
14 changes: 8 additions & 6 deletions semilearn/algorithms/srfreematch/srfreematch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, args, net_builder, tb_log=None, logger=None):

self.criterion = torch.nn.MSELoss()

self.max_reward = -float('inf')
def init(self, T, hard_label=True, ema_p=0.999, use_quantile=True, clip_thresh=False):
self.T = T
self.use_hard_label = hard_label
Expand Down Expand Up @@ -159,7 +160,7 @@ def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
# Generate pseudo labels using the generator (your pseudo-labeling process)
self.rewarder.train()
self.generator.train()
generated_label = self.generator(feats_x_lb.detach()).detach()
generated_label = self.generator(feats_x_lb.detach())
generated_label=generated_label.long()
# Convert generated pseudo labels and true labels to tensors
real_labels_tensor = y_lb.cuda(self.gpu)
Expand All @@ -168,16 +169,17 @@ def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
filtered_pseudo_labels = pseudo_label.long()
filtered_feats_x_ulb_w = feats_x_ulb_w.detach()
rewarder = self.rewarder.eval()
max_reward = -float('inf')

reward = self.rewarder(feats_x_ulb_w.detach(), pseudo_label.long())
reward = reward.mean()
max_reward = torch.where(reward > max_reward, reward, max_reward)
filtered_pseudo_labels = torch.where(reward > max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
self.max_reward = torch.where(reward > self.max_reward, reward, self.max_reward)
filtered_pseudo_labels = torch.where(reward > self.max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > self.max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
if self.it % self.N_k == 0 and self.it > self.start_timing:
self.max_reward = -float('inf')
self.rewarder.train()
self.generator.train()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1)).detach()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1))
generated_label=generated_label.long()
reward = self.rewarder(filtered_feats_x_ulb_w, generated_label.squeeze(1))
generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes)
Expand Down
23 changes: 12 additions & 11 deletions semilearn/algorithms/srpseudolabel/srpseudolabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, args, net_builder, tb_log=None, logger=None, **kwargs):

self.criterion = torch.nn.MSELoss()

self.max_reward = -float('inf')
def init(self, p_cutoff, unsup_warm_up=0.4):
self.p_cutoff = p_cutoff
self.unsup_warm_up = unsup_warm_up
Expand Down Expand Up @@ -135,30 +136,30 @@ def train_step(self, x_lb, y_lb, x_ulb_w):
# Generate pseudo labels using the generator (your pseudo-labeling process)
self.rewarder.train()
self.generator.train()
generated_label = self.generator(feats_x_lb.detach()).detach()
generated_label = self.generator(feats_x_lb.detach())
generated_label=generated_label.long()
# Convert generated pseudo labels and true labels to tensors
real_labels_tensor = y_lb.cuda(self.gpu)
reward = self.rewarder(feats_x_lb.detach(),generated_label.squeeze(1))
if self.it >= self.start_timing:
filtered_pseudo_labels = pseudo_label.long()
filtered_feats_x_ulb_w = feats_x_ulb.detach()
filtered_feats_x_ulb = feats_x_ulb.detach()
rewarder = self.rewarder.eval()
max_reward = -float('inf')

reward = self.rewarder(feats_x_ulb.detach(), pseudo_label.long())
reward = reward.mean()
max_reward = torch.where(reward > max_reward, reward, max_reward)
filtered_pseudo_labels = torch.where(reward > max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > max_reward, feats_x_ulb.detach(), filtered_feats_x_ulb_w)
self.max_reward = torch.where(reward > self.max_reward, reward, self.max_reward)
filtered_pseudo_labels = torch.where(reward > self.max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb = torch.where(reward > self.max_reward, feats_x_ulb.detach(), filtered_feats_x_ulb)
if self.it % self.N_k == 0 and self.it > self.start_timing:
self.max_reward = -float('inf')
self.rewarder.train()
self.generator.train()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1)).detach()
generated_label = self.generator(filtered_feats_x_ulb.squeeze(1))
generated_label=generated_label.long()
reward = self.rewarder(filtered_feats_x_ulb_w, generated_label.squeeze(1))
num_classes=self.num_classes if self.task_type == 'cls' else self.range
generated_label = F.one_hot(generated_label.squeeze(1), num_classes=num_classes)
filtered_pseudo_labels= F.one_hot(filtered_pseudo_labels.long(), num_classes=num_classes)
reward = self.rewarder(filtered_feats_x_ulb, generated_label.squeeze(1))
generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes)
filtered_pseudo_labels= F.one_hot(filtered_pseudo_labels.long(), num_classes=self.num_classes)
cosine_similarity_score = cosine_similarity_n(generated_label.float(), filtered_pseudo_labels.float())
generator_loss = self.criterion(reward, torch.ones_like(reward).cuda(self.gpu))
rewarder_loss = self.criterion(reward, cosine_similarity_score)
Expand Down
14 changes: 8 additions & 6 deletions semilearn/algorithms/srsoftmatch/srsoftmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, args, net_builder, tb_log=None, logger=None):

self.criterion = torch.nn.MSELoss()

self.max_reward = -float('inf')
def init(self, T, hard_label=True, dist_align=True, dist_uniform=True, ema_p=0.999, n_sigma=2, per_class=False):
self.T = T
self.use_hard_label = hard_label
Expand Down Expand Up @@ -158,7 +159,7 @@ def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
# Generate pseudo labels using the generator (your pseudo-labeling process)
self.rewarder.train()
self.generator.train()
generated_label = self.generator(feats_x_lb.detach()).detach()
generated_label = self.generator(feats_x_lb.detach())
generated_label=generated_label.long()
# Convert generated pseudo labels and true labels to tensors
real_labels_tensor = y_lb.cuda(self.gpu)
Expand All @@ -167,16 +168,17 @@ def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
filtered_pseudo_labels = pseudo_label.long()
filtered_feats_x_ulb_w = feats_x_ulb_w.detach()
rewarder = self.rewarder.eval()
max_reward = -float('inf')

reward = self.rewarder(feats_x_ulb_w.detach(), pseudo_label.long())
reward = reward.mean()
max_reward = torch.where(reward > max_reward, reward, max_reward)
filtered_pseudo_labels = torch.where(reward > max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
self.max_reward = torch.where(reward > self.max_reward, reward, self.max_reward)
filtered_pseudo_labels = torch.where(reward > self.max_reward, pseudo_label.detach(), filtered_pseudo_labels)
filtered_feats_x_ulb_w = torch.where(reward > self.max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w)
if self.it % self.N_k == 0 and self.it > self.start_timing:
self.max_reward = -float('inf')
self.rewarder.train()
self.generator.train()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1)).detach()
generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1))
generated_label=generated_label.long()
reward = self.rewarder(filtered_feats_x_ulb_w, generated_label.squeeze(1))
generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes)
Expand Down

0 comments on commit dfb6649

Please sign in to comment.