You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TorchRec will fuse the backward&optimize procedure for performance, but in some cases when some parts are not differentiable, I have to manually calculate part of the gradients and manually apply the gradients.
For example, in the code snippet below, there are two losses, one can be directly calculated using nn.MSELoss, another has non-differentiable part and must be calculated manually.
Without fusion, there are 2 gradients accumulation and 1 optimization step
With fusion, there are 2 (gradient+optimize) fused step.
My question is: now I perform 2 optimize, will it affect the convergence?
What is the best practice for such situation?
# linear layer and loss functionmse_loss=nn.MSELoss()
# input data and labelsx=torch.rand(3 ,5)
labels=torch.randint(low=0, high=2, size=(3, 1))
# logitslogits=model(x)
# lossloss1=mse_loss(logits, labels)
loss1.backward() # this will perform a backward&optimize fusion, and parameters will be updated?# another loss with non-differentiable part, must be calculated manuallycustom_gradient=NonDifferentiableLogic(...)
torch.autograd.backward(logits, grad_tensors=custom_gradient) # this will perform another backward&optimize fusion, and parameters will be updated again?
The text was updated successfully, but these errors were encountered:
TorchRec will fuse the backward&optimize procedure for performance, but in some cases when some parts are not differentiable, I have to manually calculate part of the gradients and manually apply the gradients.
For example, in the code snippet below, there are two losses, one can be directly calculated using
nn.MSELoss
, another has non-differentiable part and must be calculated manually.My question is: now I perform 2 optimize, will it affect the convergence?
What is the best practice for such situation?
The text was updated successfully, but these errors were encountered: