Skip to content

Commit

Permalink
The ContingencySimilarity metric should be able to discretize conti…
Browse files Browse the repository at this point in the history
…nuous columns (#702)
  • Loading branch information
frances-h authored Jan 9, 2025
1 parent debf91f commit 799c4ad
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 18 deletions.
43 changes: 40 additions & 3 deletions sdmetrics/column_pairs/statistical/contingency_similarity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Contingency Similarity Metric."""

import pandas as pd

from sdmetrics.column_pairs.base import ColumnPairsMetric
from sdmetrics.goal import Goal
from sdmetrics.utils import discretize_column


class ContingencySimilarity(ColumnPairsMetric):
Expand All @@ -23,23 +26,57 @@ class ContingencySimilarity(ColumnPairsMetric):
min_value = 0.0
max_value = 1.0

@staticmethod
def _validate_inputs(real_data, synthetic_data, continuous_column_names, num_discrete_bins):
for data in [real_data, synthetic_data]:
if not isinstance(data, pd.DataFrame) or len(data.columns) != 2:
raise ValueError('The data must be a pandas DataFrame with two columns.')

if set(real_data.columns) != set(synthetic_data.columns):
raise ValueError('The columns in the real and synthetic data must match.')

if continuous_column_names is not None:
bad_continuous_columns = "' ,'".join([
column for column in continuous_column_names if column not in real_data.columns
])
if bad_continuous_columns:
raise ValueError(
f"Continuous column(s) '{bad_continuous_columns}' not found in the data."
)

if not isinstance(num_discrete_bins, int) or num_discrete_bins <= 0:
raise ValueError('`num_discrete_bins` must be an integer greater than zero.')

@classmethod
def compute(cls, real_data, synthetic_data):
def compute(cls, real_data, synthetic_data, continuous_column_names=None, num_discrete_bins=10):
"""Compare the contingency similarity of two discrete columns.
Args:
real_data (Union[numpy.ndarray, pandas.Series]):
real_data (pd.DataFrame):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.Series]):
synthetic_data (pd.DataFrame):
The values from the synthetic dataset.
continuous_column_names (list[str], optional):
The list of columns to discretize before running the metric. The column names in
this list should match the column names in the real and synthetic data. Defaults
to ``None``.
num_discrete_bins (int, optional):
The number of bins to create for the continuous columns. Defaults to 10.
Returns:
float:
The contingency similarity of the two columns.
"""
cls._validate_inputs(real_data, synthetic_data, continuous_column_names, num_discrete_bins)
columns = real_data.columns[:2]
real = real_data[columns]
synthetic = synthetic_data[columns]
if continuous_column_names is not None:
for column in continuous_column_names:
real[column], synthetic[column] = discretize_column(
real[column], synthetic[column], num_discrete_bins=num_discrete_bins
)

contingency_real = real.groupby(list(columns), dropna=False).size() / len(real)
contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
synthetic
Expand Down
5 changes: 2 additions & 3 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdmetrics.utils import (
discretize_column,
get_alternate_keys,
get_columns_from_metadata,
get_type_from_column_meta,
Expand Down Expand Up @@ -116,9 +117,7 @@ def discretize_table_data(real_data, synthetic_data, metadata):
real_col = pd.to_numeric(real_col)
synthetic_col = pd.to_numeric(synthetic_col)

bin_edges = np.histogram_bin_edges(real_col.dropna())
binned_real_col = np.digitize(real_col, bins=bin_edges)
binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges)
binned_real_col, binned_synthetic_col = discretize_column(real_col, synthetic_col)

binned_real[column_name] = binned_real_col
binned_synthetic[column_name] = binned_synthetic_col
Expand Down
22 changes: 22 additions & 0 deletions sdmetrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,28 @@ def is_datetime(data):
)


def discretize_column(real_column, synthetic_column, num_discrete_bins=10):
"""Discretize a real and synthetic column.
Args:
real_column (pd.Series):
The real column.
synthetic_column (pd.Series):
The synthetic column.
num_discrete_bins (int, optional):
The number of bins to create. Defaults to 10.
Returns:
tuple(pd.Series, pd.Series):
The discretized real and synthetic columns.
"""
bin_edges = np.histogram_bin_edges(real_column.dropna(), bins=num_discrete_bins)
bin_edges[0], bin_edges[-1] = -np.inf, np.inf
binned_real_column = np.digitize(real_column, bins=bin_edges)
binned_synthetic_column = np.digitize(synthetic_column, bins=bin_edges)
return binned_real_column, binned_synthetic_column


class HyperTransformer:
"""HyperTransformer class.
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/column_pairs/statistical/test_contingency_similarity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from unittest.mock import patch

import pandas as pd
Expand All @@ -7,6 +8,59 @@


class TestContingencySimilarity:
def test__validate_inputs(self):
"""Test the ``_validate_inputs`` method."""
# Setup
bad_data = pd.Series(range(5))
real_data = pd.DataFrame({'col1': range(10), 'col2': range(10, 20)})
bad_synthetic_data = pd.DataFrame({'bad_column': range(10), 'col2': range(10)})
synthetic_data = pd.DataFrame({'col1': range(5), 'col2': range(5)})
bad_continous_columns = ['col1', 'missing_col']
bad_num_discrete_bins = -1

# Run and Assert
expected_bad_data = re.escape('The data must be a pandas DataFrame with two columns.')
with pytest.raises(ValueError, match=expected_bad_data):
ContingencySimilarity._validate_inputs(
real_data=bad_data,
synthetic_data=bad_data,
continuous_column_names=None,
num_discrete_bins=10,
)

expected_mismatch_columns_error = re.escape(
'The columns in the real and synthetic data must match.'
)
with pytest.raises(ValueError, match=expected_mismatch_columns_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=bad_synthetic_data,
continuous_column_names=None,
num_discrete_bins=10,
)

expected_bad_continous_column_error = re.escape(
"Continuous column(s) 'missing_col' not found in the data."
)
with pytest.raises(ValueError, match=expected_bad_continous_column_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=synthetic_data,
continuous_column_names=bad_continous_columns,
num_discrete_bins=10,
)

expected_bad_num_discrete_bins_error = re.escape(
'`num_discrete_bins` must be an integer greater than zero.'
)
with pytest.raises(ValueError, match=expected_bad_num_discrete_bins_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=synthetic_data,
continuous_column_names=['col1'],
num_discrete_bins=bad_num_discrete_bins,
)

def test_compute(self):
"""Test the ``compute`` method.
Expand All @@ -32,6 +86,22 @@ def test_compute(self):
# Assert
assert result == expected_score

def test_compute_with_discretization(self):
"""Test the ``compute`` method with continuous columns."""
# Setup
real_data = pd.DataFrame({'col1': [1.0, 2.4, 2.6, 0.8], 'col2': [1, 2, 3, 4]})
synthetic_data = pd.DataFrame({'col1': [1.0, 1.8, 2.6, 1.0], 'col2': [2, 3, 7, -10]})
expected_score = 0.25

# Run
metric = ContingencySimilarity()
result = metric.compute(
real_data, synthetic_data, continuous_column_names=['col2'], num_discrete_bins=4
)

# Assert
assert result == expected_score

@patch('sdmetrics.column_pairs.statistical.contingency_similarity.ColumnPairsMetric.normalize')
def test_normalize(self, normalize_mock):
"""Test the ``normalize`` method.
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/reports/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ def test_discretize_table_data():

# Assert
expected_real = pd.DataFrame({
'col1': [1, 6, 11],
'col1': [1, 6, 10],
'col2': ['a', 'b', 'c'],
'col3': [2, 1, 11],
'col3': [2, 1, 10],
'col4': [True, False, True],
'col5': [10, 1, 11],
'col5': [10, 1, 10],
})
expected_synth = pd.DataFrame({
'col1': [11, 1, 11],
'col1': [10, 1, 10],
'col2': ['c', 'a', 'c'],
'col3': [11, 0, 5],
'col3': [10, 1, 5],
'col4': [False, False, True],
'col5': [10, 5, 11],
'col5': [10, 5, 10],
})

pd.testing.assert_frame_equal(discretized_real, expected_real)
Expand Down Expand Up @@ -193,18 +193,18 @@ def test_discretize_table_data_new_metadata():

# Assert
expected_real = pd.DataFrame({
'col1': [1, 6, 11],
'col1': [1, 6, 10],
'col2': ['a', 'b', 'c'],
'col3': [2, 1, 11],
'col3': [2, 1, 10],
'col4': [True, False, True],
'col5': [10, 1, 11],
'col5': [10, 1, 10],
})
expected_synth = pd.DataFrame({
'col1': [11, 1, 11],
'col1': [10, 1, 10],
'col2': ['c', 'a', 'c'],
'col3': [11, 0, 5],
'col3': [10, 1, 5],
'col4': [False, False, True],
'col5': [10, 5, 11],
'col5': [10, 5, 10],
})

pd.testing.assert_frame_equal(discretized_real, expected_real)
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sdmetrics.utils import (
HyperTransformer,
discretize_column,
get_alternate_keys,
get_cardinality_distribution,
get_columns_from_metadata,
Expand Down Expand Up @@ -54,6 +55,21 @@ def test_get_missing_percentage():
assert percentage_nan == 28.57


def test_discretize_column():
"""Test the ``discretize_column`` method."""
# Setup
real = pd.Series(range(10))
synthetic = pd.Series([-10] + list(range(1, 9)) + [20])
num_bins = 5

# Run
binned_real, binned_synthetic = discretize_column(real, synthetic, num_discrete_bins=num_bins)

# Assert
np.testing.assert_array_equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5], binned_real)
np.testing.assert_array_equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5], binned_synthetic)


def test_get_columns_from_metadata():
"""Test the ``get_columns_from_metadata`` method with current metadata format.
Expand Down

0 comments on commit 799c4ad

Please sign in to comment.