From 68bc5429ef4d78ac6fad967ff2fef6b4de52dc02 Mon Sep 17 00:00:00 2001
From: Jonas Breuling <jonas.breuling@inm.uni-stuttgart.de>
Date: Wed, 12 Jun 2024 17:55:48 +0200
Subject: [PATCH] Working on Radau's error estimate.

---
 .github/workflows/main.yml        |  2 +-
 scipy_dae/integrate/_dae/radau.py | 61 ++++++++++++-------------------
 2 files changed, 24 insertions(+), 39 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index e02b855..ed9f688 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -23,7 +23,7 @@ jobs:
     - name: Update pip
       run: pip install --upgrade pip
 
-    - name: Install dependencies and scipy_dae (scipy_dae)
+    - name: Install dependencies and scipy_dae
       run: pip install -e .
 
     - name: Install test dependencies
diff --git a/scipy_dae/integrate/_dae/radau.py b/scipy_dae/integrate/_dae/radau.py
index eaf3531..0a39142 100644
--- a/scipy_dae/integrate/_dae/radau.py
+++ b/scipy_dae/integrate/_dae/radau.py
@@ -518,55 +518,40 @@ def _step_impl(self):
                 # # Fabien (5.59) and (5.60)
                 # LU_real = self.lu(MU_REAL / h * Jyp + Jy)
                 LU_real_ODE_mass = LU_real
+                # LU_real_ODE_mass = self.lu(MU_REAL / h * Jyp + Jy)
                 # err_ODE_mass = h * MU_REAL * Jyp @ ((yp - f) + Z.T.dot(E * MU_REAL))
                 # error_ODE_mass = self.solve_lu(LU_real_ODE_mass, err_ODE_mass) / (MU_REAL * h)
+                # error_ODE_mass = self.solve_lu(
+                #     LU_real_ODE_mass, 
+                #     Jyp @ ((yp - f) + Z.T.dot(E) / h),
+                # )
                 error_ODE_mass = self.solve_lu(
-                    LU_real_ODE_mass, 
-                    Jyp @ ((yp - f) + Z.T.dot(E) / h),
+                    LU_real, 
+                    # # (yp - f) + Jyp @ Z.T @ E / h,
+                    # yp + Jyp @ Z.T @ E / h,
+                    Jyp @ (yp + Z.T @ E / h), # TODO: Why is this a good estimate?
                 )
+                # error_ODE_mass = (yp + Z.T @ E / h) / (MU_REAL / h)
                 error = error_ODE_mass
                 
                 # ###############
                 # # Fabien (5.65)
                 # ###############
-                # b0_hat = 0.02
-                # # b0_hat = 1 / MU_REAL # TODO: I prefer this...
+                # # b0_hat = 0.02
+                # b0_hat = 1 / MU_REAL # TODO: I prefer this...
                 # gamma_hat = 1 / MU_REAL
-                # yp_hat_new = (y_new - (y + h * b_hat @ Yp + h * b0_hat * yp)) / (h * gamma_hat)
-                # yp_hat_new = MU_REAL * (b - b_hat) @ Yp - b0_hat * yp
-                # # yp_hat_new = MU_REAL * ((b - b_hat) @ Yp - b0_hat * yp)
-                # # # for b0_hat = 1 / MU_REAL this reduces to
-                # # yp_hat_new = MU_REAL * (b - b_hat) @ Yp - yp
-                # # yp_hat_new *= h
-                # # print(f"yp_new:     {yp_new}")
-                # # print(f"yp_hat_new: {yp_hat_new}")
-                # # y_hat_new = y + h * b_hat @ Yp + h * b0_hat * yp + h * gamma_hat * yp_hat_new
-                # # y_hat_new = y + h * b_hat @ Yp + h * b0_hat * yp + h * gamma_hat * yp_hat_new
-                # # error = y_hat_new - y_new
-
-                # error = np.zeros_like(y_new)
-
-                # # F = self.fun(t_new, y_new, yp_hat_new)
-                # F = self.fun(t_new, y_new, MU_REAL * (b - b_hat) @ Yp - b0_hat * yp)
-                # # error = np.linalg.solve(Jy + Jyp / (h * gamma_hat), -F)
-                # error = self.solve_lu(LU_real, -F)
-
-                # yp_hat_new = MU_REAL * (b - b_hat) @ Yp - b0_hat * yp
+                # # yp_hat_new = (y_new - (y + h * b_hat @ Yp + h * b0_hat * yp)) / (h * gamma_hat)
+                # # yp_hat_new = (1 / gamma_hat) * ((b - b_hat) @ Yp - b0_hat * yp)
+                # yp_hat_new = (1 / gamma_hat) * ((b - b_hat) @ Yp - b0_hat * yp)
+                # # yp_hat_new = (Z.T.dot(E) / h - yp)
                 # F = self.fun(t_new, y_new, yp_hat_new)
-                # error = self.solve_lu(LU_real, -F) #* (h * gamma_hat)
-                # # y1 = Y((ride_data.s-1)*ride_data.n+1:ride_data.s*ride_data.n,1);
-                # # t1 = t0+h;
-                # # yp1 = ride_data.mu(1)*(kron(ride_data.v',eye(ride_data.n))*Yp - ride_data.b0*yp0);
-                # # g = feval(IDEFUN,y1,yp1,t1);
-                # # ride_data.nfun = ride_data.nfun + 1;
-                # # r = -(ride_data.U\(ride_data.L\(ride_data.P*g)));
-
-                # LU_real = self.lu(MU_REAL / h * Jyp + Jy)
-
-                # error = -self.solve_lu(LU_real, self.fun(t_new, y_new, yp_hat_new))
-                # print(f"error: {error}")
-                # # error = -self.solve_lu(LU_real, self.fun(t_new, y_hat_new, yp_hat_new))
-
+                # # LU_real = self.lu(MU_REAL / h * Jyp + Jy)
+                # error = self.solve_lu(LU_real, -F)
+                # print(f"yp_new:         {yp_new}")
+                # print(f"yp_hat_new:     {yp_hat_new}")
+                # print(f"error_ODE_mass: {error_ODE_mass}")
+                # print(f"error:          {error}")
+                # print(f"")
                 
                 error_norm = norm(error / scale)