Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After running the train.py, error "Unable to solve the normal equations in LinearFeatureBaseline. The matrix X^T*X (with X the design matrix) is not full-rank, regardless of the regularization (maximum regularization: 1.0)." occurs. #72

Open
bianca-li-bupt opened this issue Mar 1, 2023 · 0 comments

Comments

@bianca-li-bupt
Copy link

bianca-li-bupt commented Mar 1, 2023

The solution is as below.
In baseline.py line 60:

  1. Replace the 'torch.lstsq' with 'torch.linalg.lstsq'. Because the former function was removed.
  2. Output coeffs is an instance now, not a tensor. So use coeffs.solution referring to the value.
for _ in range(5):
           try:
               coeffs= torch.linalg.lstsq(XT_y, XT_X + reg_coeff * self._eye, driver='gelsy')
               # coeffs,_ = torch.lstsq(XT_y, XT_X + reg_coeff * self._eye)
               # An extra round of increasing regularization eliminated
               # inf or nan in the least-squares solution most of the time

               if torch.isnan(coeffs.solution).any() or torch.isinf(coeffs.solution).any():
                   raise RuntimeError

               break
           except RuntimeError:
               reg_coeff *= 10
       else:
           raise RuntimeError('Unable to solve the normal equations in '
               '`LinearFeatureBaseline`. The matrix X^T*X (with X the design '
               'matrix) is not full-rank, regardless of the regularization '
               '(maximum regularization: {0}).'.format(reg_coeff))
       self.weight.copy_(coeffs.solution.flatten())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant