Skip to content

Commit

Permalink
move scaling forward
Browse files Browse the repository at this point in the history
  • Loading branch information
yl4579 authored Nov 24, 2023
1 parent f871208 commit cbc8916
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions train_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,6 @@ def main(config_path):
# SLM generator loss
optimizer.zero_grad()
loss_gen_lm.backward()
optimizer.step('bert_encoder')
optimizer.step('bert')
optimizer.step('predictor')
optimizer.step('diffusion')

# SLM discriminator loss
if d_loss_slm != 0:
optimizer.zero_grad()
d_loss_slm.backward(retain_graph=True)
optimizer.step('wd')

# compute the gradient norm
total_norm = {}
Expand Down Expand Up @@ -537,6 +527,17 @@ def main(config_path):
for p in model.diffusion.parameters():
if p.grad is not None:
p.grad *= slmadv_params.scale

optimizer.step('bert_encoder')
optimizer.step('bert')
optimizer.step('predictor')
optimizer.step('diffusion')

# SLM discriminator loss
if d_loss_slm != 0:
optimizer.zero_grad()
d_loss_slm.backward(retain_graph=True)
optimizer.step('wd')

iters = iters + 1

Expand Down

0 comments on commit cbc8916

Please sign in to comment.