From 3e8a4455fd4de046225344dda9b516d0d013ee36 Mon Sep 17 00:00:00 2001 From: Watcharin Kriengwatana <37610745+himkwtn@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:08:51 -0700 Subject: [PATCH] ENH: #448 add keyword args to model.print `model.print` now respects kwargs of `builtins.print` --- pysindy/pysindy.py | 16 ++++++++-------- test/test_pysindy.py | 17 ++++------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 0c9071ea3..431a34fcd 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -351,7 +351,7 @@ def equations(self, precision=3): precision=precision, ) - def print(self, lhs=None, precision=3): + def print(self, lhs=None, precision=3, **kwargs): """Print the SINDy model equations. Parameters @@ -362,6 +362,8 @@ def print(self, lhs=None, precision=3): precision: int, optional (default 3) Precision to be used when printing out model coefficients. + + **kwargs: Additional keyword arguments passed to the builtin print function """ eqns = self.equations(precision) if sindy_pi_flag and isinstance(self.optimizer, SINDyPI): @@ -370,17 +372,15 @@ def print(self, lhs=None, precision=3): feature_names = self.feature_names for i, eqn in enumerate(eqns): if self.discrete_time: - names = "(" + feature_names[i] + ")" - print(names + "[k+1] = " + eqn) + names = f"({feature_names[i]})[k+1]" elif lhs is None: if not sindy_pi_flag or not isinstance(self.optimizer, SINDyPI): - names = "(" + feature_names[i] + ")" - print(names + "' = " + eqn) + names = f"({feature_names[i]})'" else: - names = feature_names[i] - print(names + " = " + eqn) + names = f"({feature_names[i]})" else: - print(lhs[i] + " = " + eqn) + names = f"{lhs[i]}" + print(f"{names} = {eqn}", **kwargs) def score(self, x, t=None, x_dot=None, u=None, metric=r2_score, **metric_kws): """ diff --git a/test/test_pysindy.py b/test/test_pysindy.py index c12c98179..31e31e637 100644 --- a/test/test_pysindy.py +++ b/test/test_pysindy.py @@ -490,7 +490,9 @@ def test_equations(data, capsys): model.print(precision=2) out, _ = capsys.readouterr() + assert len(out) > 0 + assert "(x0)' = " in out def test_print_discrete_time(data_discrete_time, capsys): @@ -500,20 +502,9 @@ def test_print_discrete_time(data_discrete_time, capsys): model.print() out, _ = capsys.readouterr() - assert len(out) > 0 - - -def test_print_discrete_time_multiple_trajectories( - data_discrete_time_multiple_trajectories, capsys -): - x = data_discrete_time_multiple_trajectories - model = SINDy(discrete_time=True) - model.fit(x) - - model.print() - out, _ = capsys.readouterr() - assert len(out) > 1 + assert len(out) > 0 + assert "(x0)[k+1] = " in out def test_differentiate(data_lorenz, data_multiple_trajectories):