diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index b1afa0b76b..8c895eb852 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -442,3 +442,12 @@ def test_edata_repr(): assert expected_str in repr(e) # avoid double delete!! edata_ptr.release() + + +def test_edata_equality_operator(): + e1 = amici.ExpData(1, 2, 3, [3]) + e2 = amici.ExpData(1, 2, 3, [3]) + assert e1 == e2 + # check that comparison with other types works + # this is not implemented by swig by default + assert e1 != 1 diff --git a/swig/edata.i b/swig/edata.i index 59dcb4fa8a..d8afffa43e 100644 --- a/swig/edata.i +++ b/swig/edata.i @@ -74,6 +74,9 @@ def _edata_repr(self: "ExpData"): %pythoncode %{ def __repr__(self): return _edata_repr(self) + +def __eq__(self, other): + return isinstance(other, self.__class__) and __eq__(self, other) %} }; %extend std::unique_ptr {