Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
max-de-rooij committed May 13, 2024
1 parent 63c02a4 commit bac031b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
3 changes: 2 additions & 1 deletion stringencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base.encoders import OneHotEncoder, LabelEncoder
from .base.base_encoder import BaseStringEncoder

__all__ = ['OneHotEncoder', 'LabelEncoder']
__all__ = ['BaseStringEncoder', 'OneHotEncoder', 'LabelEncoder']
2 changes: 1 addition & 1 deletion stringencoder/base/base_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ def __repr__(self):
return self.__str__()

def __eq__(self, other):
return self.__dict__ == other.__dict__
return self._table == other._table and np.all(self._mat == other._mat)

52 changes: 50 additions & 2 deletions test_encoders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,45 @@
from stringencoder import OneHotEncoder, LabelEncoder
from stringencoder import OneHotEncoder, LabelEncoder, BaseStringEncoder
import pytest
import numpy as np

class TestBaseExceptions:
def test_matrix_shape(self):
with pytest.raises(ValueError,match=r"Table length and matrix shape mismatch"):
_ = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [0, 0], [1, 1]]))

def test_duplicate_chars(self):
with pytest.raises(ValueError,match=r"Table has duplicate characters:.*"):
_ = BaseStringEncoder(table='abbc', mat=np.array([[1, 0], [0, 1], [1,1], [0, 0]]))

def test_duplicate_rows(self):
with pytest.raises(ValueError,match=r"Matrix has duplicate rows"):
_ = BaseStringEncoder(table='abcd', mat=np.array([[1, 0], [0, 1], [1,1], [1, 0]]))

def test_no_exceptions_duplicate_rows(self):
_ = BaseStringEncoder(table='abcd', mat=np.array([[1, 0], [0, 1], [1,1], [1, 0]]), force_unique_features=False)

class TestBaseAuxiliary:
def test_duplicate_chars(self):
encoder = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [1,1]]))
assert encoder._find_duplicate_chars('abbc') == {'a': 1, 'b': 2, 'c': 1}

def test_str_repr(self):
encoder = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [1,1]]))
assert str(encoder) == 'BaseStringEncoder{}' == repr(encoder)

def test_str_repr_kwargs(self):
encoder = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [1,1]]), extra_argument=False)
assert str(encoder) == 'BaseStringEncoder{\'extra_argument\': False}' == repr(encoder)

def test_eq(self):
encoder1 = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [1,1]]))
encoder2 = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [1,1]]))
assert encoder1 == encoder2

def test_eq_false(self):
encoder1 = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [1,1]]))
encoder2 = BaseStringEncoder(table='abc', mat=np.array([[1, 0], [0, 1], [0,0]]))
assert encoder1 != encoder2

class TestEncoders:
def test_onehot_encoder(self):
Expand Down Expand Up @@ -36,4 +75,13 @@ def test_label_encoder_decoder(self):
encoded = encoder.encode('abc')
decoded = encoder.decode(encoded)
assert decoded == 'abc'


class TestEncodersAuxiliary:

def test_str_repr_onehot(self):
encoder = OneHotEncoder(table='abc')
assert str(encoder) == 'OneHotEncoder{}' == repr(encoder)

def test_str_repr_label(self):
encoder = LabelEncoder(table='abc')
assert str(encoder) == 'LabelEncoder{}' == repr(encoder)

0 comments on commit bac031b

Please sign in to comment.