Skip to content

Commit

Permalink
check_dtype=False
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 6, 2025
1 parent 806dfa6 commit 233bde5
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions tests/unit/single_table/data_augmentation/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
@pytest.fixture
def real_training_data():
return pd.DataFrame({
'target': pd.Series([1, 0, 0], dtype='int64'),
'numerical': pd.Series([1, 2, 3], dtype='int64'),
'target': [1, 0, 0],
'numerical': [1, 2, 3],
'categorical': ['a', 'b', 'b'],
'boolean': [True, False, True],
'datetime': pd.to_datetime(['2021-01-01', '2021-01-02', '2021-01-03']),
Expand All @@ -28,8 +28,8 @@ def real_training_data():
@pytest.fixture
def synthetic_data():
return pd.DataFrame({
'target': pd.Series([0, 1, 0], dtype='int64'),
'numerical': pd.Series([2, 2, 3], dtype='int64'),
'target': [0, 1, 0],
'numerical': [2, 2, 3],
'categorical': ['b', 'a', 'b'],
'boolean': [False, True, False],
'datetime': pd.to_datetime(['2021-01-25', '2021-01-02', '2021-01-03']),
Expand All @@ -39,8 +39,8 @@ def synthetic_data():
@pytest.fixture
def real_validation_data():
return pd.DataFrame({
'target': pd.Series([1, 0, 0], dtype='int64'),
'numerical': pd.Series([3, 3, 3], dtype='int64'),
'target': [1, 0, 0],
'numerical': [3, 3, 3],
'categorical': ['a', 'b', 'b'],
'boolean': [True, False, True],
'datetime': pd.to_datetime(['2021-01-01', '2021-01-12', '2021-01-03']),
Expand Down Expand Up @@ -219,26 +219,26 @@ def test__transform_preprocess(self, real_training_data, synthetic_data, real_va
# Assert
expected_transformed = {
'real_training_data': pd.DataFrame({
'target': pd.Series([1, 0, 0], dtype='int64'),
'numerical': pd.Series([1, 2, 3], dtype='int64'),
'target': [1, 0, 0],
'numerical': [1, 2, 3],
'categorical': pd.Categorical(['a', 'b', 'b']),
'boolean': pd.Categorical([True, False, True]),
'datetime': pd.to_numeric(
pd.to_datetime(['2021-01-01', '2021-01-02', '2021-01-03'])
),
}),
'synthetic_data': pd.DataFrame({
'target': pd.Series([0, 1, 0], dtype='int64'),
'numerical': pd.Series([2, 2, 3], dtype='int64'),
'target': [0, 1, 0],
'numerical': [2, 2, 3],
'categorical': pd.Categorical(['b', 'a', 'b']),
'boolean': pd.Categorical([False, True, False]),
'datetime': pd.to_numeric(
pd.to_datetime(['2021-01-25', '2021-01-02', '2021-01-03'])
),
}),
'real_validation_data': pd.DataFrame({
'target': pd.Series([1, 0, 0], dtype='int64'),
'numerical': pd.Series([3, 3, 3], dtype='int64'),
'target': [1, 0, 0],
'numerical': [3, 3, 3],
'categorical': pd.Categorical(['a', 'b', 'b']),
'boolean': pd.Categorical([True, False, True]),
'datetime': pd.to_numeric(
Expand All @@ -247,7 +247,9 @@ def test__transform_preprocess(self, real_training_data, synthetic_data, real_va
}),
}
for table_name, table in transformed.items():
pd.testing.assert_frame_equal(table, expected_transformed[table_name])
pd.testing.assert_frame_equal(
table, expected_transformed[table_name], check_dtype=False
)

def test__fit_transform(
self, real_training_data, synthetic_data, real_validation_data, metadata
Expand Down

0 comments on commit 233bde5

Please sign in to comment.