From f1df80a2fce9dfcfa138550e2f615c8e8aba3687 Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Tue, 28 Jan 2025 21:54:18 -0800 Subject: [PATCH 1/2] add Cholesky decomp to _linear_fit --- hera_cal/nucal.py | 16 +++++++++++++++- hera_cal/tests/test_nucal.py | 2 ++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/hera_cal/nucal.py b/hera_cal/nucal.py index cfea09fa4..d5e6a8cbb 100644 --- a/hera_cal/nucal.py +++ b/hera_cal/nucal.py @@ -761,10 +761,11 @@ def _linear_fit(XTX, Xy, solver='lu_solve', alpha=1e-15, cached_input={}): # Assert that the method is valid assert solver in [ "lu_solve", + "cho_solve", "solve", "pinv", "lstsq", - ], "method must be one of {}".format(["lu_solve", "solve", "pinv", "lstsq"]) + ], "method must be one of {}".format(["lu_solve", "cho_solve", "solve", "pinv", "lstsq"]) # Assert that the regularization tolerance is non-negative assert alpha >= 0.0, "alpha must be non-negative." @@ -789,6 +790,19 @@ def _linear_fit(XTX, Xy, solver='lu_solve', alpha=1e-15, cached_input={}): # Save info cached_output = {'LU': L} + elif solver == "cho_solve": + # Factor XTX using scipy.linalg.cho_factor + if "c_and_lower" in cached_input: + c_and_lower = cached_input.get('c_and_lower') + else: + c_and_lower = linalg.cho_factor(XTX) + + # Solve the linear system of equations using scipy.linalg.cho_solve + beta = linalg.cho_solve(c_and_lower, Xy) + + # Save info + cached_output = {'c_and_lower': c_and_lower} + elif solver == "solve": # Solve the linear system of equations using np.linalg.solve beta = np.linalg.solve(XTX, Xy) diff --git a/hera_cal/tests/test_nucal.py b/hera_cal/tests/test_nucal.py index 4cbe4dd17..990d3822d 100644 --- a/hera_cal/tests/test_nucal.py +++ b/hera_cal/tests/test_nucal.py @@ -426,12 +426,14 @@ def test_linear_fit(): b2, _ = nucal._linear_fit(XTX, Xy, solver='solve') b3, _ = nucal._linear_fit(XTX, Xy, solver='lstsq') b4, cached_input = nucal._linear_fit(XTX, Xy, solver='pinv') + b5, _ = nucal._linear_fit(XTX, Xy, solver='cho_solve') assert cached_input.get('XTXinv') is not None # Show that all modes give the same result np.testing.assert_allclose(b1, b2, atol=1e-6) np.testing.assert_allclose(b1, b3, atol=1e-6) np.testing.assert_allclose(b1, b4, atol=1e-6) + np.testing.assert_allclose(b1, b5, atol=1e-6) # Test that the fit is correct model = np.dot(X, b4) From d16533643378ac93c9cd2a9ef5c54b57c4a3a7d2 Mon Sep 17 00:00:00 2001 From: Tyler Cox Date: Wed, 29 Jan 2025 10:09:43 -0800 Subject: [PATCH 2/2] add test with cache to _linear_fit --- hera_cal/tests/test_nucal.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hera_cal/tests/test_nucal.py b/hera_cal/tests/test_nucal.py index 990d3822d..638ca0103 100644 --- a/hera_cal/tests/test_nucal.py +++ b/hera_cal/tests/test_nucal.py @@ -425,9 +425,8 @@ def test_linear_fit(): np.testing.assert_allclose(b1, b1_cached) b2, _ = nucal._linear_fit(XTX, Xy, solver='solve') b3, _ = nucal._linear_fit(XTX, Xy, solver='lstsq') - b4, cached_input = nucal._linear_fit(XTX, Xy, solver='pinv') + b4, _ = nucal._linear_fit(XTX, Xy, solver='pinv') b5, _ = nucal._linear_fit(XTX, Xy, solver='cho_solve') - assert cached_input.get('XTXinv') is not None # Show that all modes give the same result np.testing.assert_allclose(b1, b2, atol=1e-6) @@ -447,6 +446,11 @@ def test_linear_fit(): with pytest.raises(AssertionError): b = nucal._linear_fit(XTX, Xy, alpha=-1) + for mode in ['lu_solve', 'cho_solve', 'pinv', 'lstsq', 'solve']: + b, cached_input = nucal._linear_fit(XTX, Xy, solver=mode) + b_cached, _ = nucal._linear_fit(XTX, Xy, solver=mode, cached_input=cached_input) + np.testing.assert_allclose(b, b_cached) + def test_compute_spectral_filters(): # Create a set of mock data to fit