Skip to content

Commit

Permalink
ENH: #448 add keyword args to model.print
Browse files Browse the repository at this point in the history
`model.print` now respects kwargs of `builtins.print`
  • Loading branch information
himkwtn authored Aug 6, 2024
1 parent 889ac6d commit 3e8a445
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 21 deletions.
16 changes: 8 additions & 8 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def equations(self, precision=3):
precision=precision,
)

def print(self, lhs=None, precision=3):
def print(self, lhs=None, precision=3, **kwargs):
"""Print the SINDy model equations.
Parameters
Expand All @@ -362,6 +362,8 @@ def print(self, lhs=None, precision=3):
precision: int, optional (default 3)
Precision to be used when printing out model coefficients.
**kwargs: Additional keyword arguments passed to the builtin print function
"""
eqns = self.equations(precision)
if sindy_pi_flag and isinstance(self.optimizer, SINDyPI):
Expand All @@ -370,17 +372,15 @@ def print(self, lhs=None, precision=3):
feature_names = self.feature_names
for i, eqn in enumerate(eqns):
if self.discrete_time:
names = "(" + feature_names[i] + ")"
print(names + "[k+1] = " + eqn)
names = f"({feature_names[i]})[k+1]"
elif lhs is None:
if not sindy_pi_flag or not isinstance(self.optimizer, SINDyPI):
names = "(" + feature_names[i] + ")"
print(names + "' = " + eqn)
names = f"({feature_names[i]})'"
else:
names = feature_names[i]
print(names + " = " + eqn)
names = f"({feature_names[i]})"
else:
print(lhs[i] + " = " + eqn)
names = f"{lhs[i]}"
print(f"{names} = {eqn}", **kwargs)

def score(self, x, t=None, x_dot=None, u=None, metric=r2_score, **metric_kws):
"""
Expand Down
17 changes: 4 additions & 13 deletions test/test_pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def test_equations(data, capsys):
model.print(precision=2)

out, _ = capsys.readouterr()

assert len(out) > 0
assert "(x0)' = " in out


def test_print_discrete_time(data_discrete_time, capsys):
Expand All @@ -500,20 +502,9 @@ def test_print_discrete_time(data_discrete_time, capsys):
model.print()

out, _ = capsys.readouterr()
assert len(out) > 0


def test_print_discrete_time_multiple_trajectories(
data_discrete_time_multiple_trajectories, capsys
):
x = data_discrete_time_multiple_trajectories
model = SINDy(discrete_time=True)
model.fit(x)

model.print()

out, _ = capsys.readouterr()
assert len(out) > 1
assert len(out) > 0
assert "(x0)[k+1] = " in out


def test_differentiate(data_lorenz, data_multiple_trajectories):
Expand Down

0 comments on commit 3e8a445

Please sign in to comment.