diff --git a/portfolyo/tools/testing.py b/portfolyo/tools/testing.py index ed5c241..48da7eb 100644 --- a/portfolyo/tools/testing.py +++ b/portfolyo/tools/testing.py @@ -44,10 +44,12 @@ def assert_series_equal(left: pd.Series, right: pd.Series, *args, **kwargs): pd.testing.assert_series_equal(leftm, rightm, *args, **kwargs) # Units must be the same. + assert type(leftu) is type(rightu) if leftu is None: assert leftu is rightu elif isinstance(leftu, pint.Unit): # all values share the same unit, leftu is Unit - assert leftu == rightu + # Use `1*` to turn back into quantity. Ensures 'MWh' and 'MW*h' are the same. + assert 1 * leftu == 1 * rightu else: # each value has its own unit; leftu is Series pd.testing.assert_series_equal(leftu, rightu) diff --git a/tests/core/pfline/test_flat_helper.py b/tests/core/pfline/test_flat_helper.py index d629148..8f43ab5 100644 --- a/tests/core/pfline/test_flat_helper.py +++ b/tests/core/pfline/test_flat_helper.py @@ -69,7 +69,7 @@ def test_makedataframe_freqtz(freq, tz): @pytest.mark.parametrize("data,expected", TESTCASES_INPUTTYPES) -def test_makedataframe_inputtypes(data: Any, expected: pd.DataFrame): +def test_makedataframe_inputtypes(data: Any, expected: pd.DataFrame | type): """Test if dataframe can be created from various input types.""" if type(expected) is type and issubclass(expected, Exception): with pytest.raises(expected):