Skip to content

Commit

Permalink
Merge pull request #987 from HERA-Team/cho_solve
Browse files Browse the repository at this point in the history
Add Cholesky decomposition to `nucal._linear_fit`
  • Loading branch information
tyler-a-cox authored Jan 29, 2025
2 parents 0d72e19 + d165336 commit 6655c23
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
16 changes: 15 additions & 1 deletion hera_cal/nucal.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,10 +761,11 @@ def _linear_fit(XTX, Xy, solver='lu_solve', alpha=1e-15, cached_input={}):
# Assert that the method is valid
assert solver in [
"lu_solve",
"cho_solve",
"solve",
"pinv",
"lstsq",
], "method must be one of {}".format(["lu_solve", "solve", "pinv", "lstsq"])
], "method must be one of {}".format(["lu_solve", "cho_solve", "solve", "pinv", "lstsq"])

# Assert that the regularization tolerance is non-negative
assert alpha >= 0.0, "alpha must be non-negative."
Expand All @@ -789,6 +790,19 @@ def _linear_fit(XTX, Xy, solver='lu_solve', alpha=1e-15, cached_input={}):
# Save info
cached_output = {'LU': L}

elif solver == "cho_solve":
# Factor XTX using scipy.linalg.cho_factor
if "c_and_lower" in cached_input:
c_and_lower = cached_input.get('c_and_lower')
else:
c_and_lower = linalg.cho_factor(XTX)

# Solve the linear system of equations using scipy.linalg.cho_solve
beta = linalg.cho_solve(c_and_lower, Xy)

# Save info
cached_output = {'c_and_lower': c_and_lower}

elif solver == "solve":
# Solve the linear system of equations using np.linalg.solve
beta = np.linalg.solve(XTX, Xy)
Expand Down
10 changes: 8 additions & 2 deletions hera_cal/tests/test_nucal.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,14 @@ def test_linear_fit():
np.testing.assert_allclose(b1, b1_cached)
b2, _ = nucal._linear_fit(XTX, Xy, solver='solve')
b3, _ = nucal._linear_fit(XTX, Xy, solver='lstsq')
b4, cached_input = nucal._linear_fit(XTX, Xy, solver='pinv')
assert cached_input.get('XTXinv') is not None
b4, _ = nucal._linear_fit(XTX, Xy, solver='pinv')
b5, _ = nucal._linear_fit(XTX, Xy, solver='cho_solve')

# Show that all modes give the same result
np.testing.assert_allclose(b1, b2, atol=1e-6)
np.testing.assert_allclose(b1, b3, atol=1e-6)
np.testing.assert_allclose(b1, b4, atol=1e-6)
np.testing.assert_allclose(b1, b5, atol=1e-6)

# Test that the fit is correct
model = np.dot(X, b4)
Expand All @@ -445,6 +446,11 @@ def test_linear_fit():
with pytest.raises(AssertionError):
b = nucal._linear_fit(XTX, Xy, alpha=-1)

for mode in ['lu_solve', 'cho_solve', 'pinv', 'lstsq', 'solve']:
b, cached_input = nucal._linear_fit(XTX, Xy, solver=mode)
b_cached, _ = nucal._linear_fit(XTX, Xy, solver=mode, cached_input=cached_input)
np.testing.assert_allclose(b, b_cached)


def test_compute_spectral_filters():
# Create a set of mock data to fit
Expand Down

0 comments on commit 6655c23

Please sign in to comment.