Skip to content

Commit

Permalink
Fix _get_likelihoods not generating likelihood values (#1720)
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h authored Jan 3, 2024
1 parent d987178 commit 64e8df2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
29 changes: 26 additions & 3 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,23 @@ def _estimate_num_columns(self):

return columns_per_table

def preprocess(self, data):
"""Transform the raw data to numerical space.
Args:
data (dict):
Dictionary mapping each table name to a ``pandas.DataFrame``.
Returns:
dict:
A dictionary with the preprocessed data.
"""
processed_data = super().preprocess(data)
for _, synthesizer in self._table_synthesizers.items():
synthesizer.reset_sampling()

return processed_data

def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):
"""Generate the extension columns for this child table.
Expand Down Expand Up @@ -507,18 +524,24 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
A DataFrame of the likelihood of each parent id.
"""
likelihoods = {}
table_rows = table_rows.copy()

data_processor = self._table_synthesizers[table_name]._data_processor
table_rows = data_processor.transform(table_rows)
transformed = data_processor.transform(table_rows)
if transformed.index.name:
table_rows = table_rows.set_index(transformed.index.name)

table_rows = pd.concat(
[transformed, table_rows.drop(columns=transformed.columns)],
axis=1
)
for parent_id, row in parent_rows.iterrows():
parameters = self._extract_parameters(row, table_name, foreign_key)
table_meta = self._table_synthesizers[table_name].get_metadata()
synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name])
synthesizer._set_parameters(parameters)
try:
with np.random.Generator(np.random.get_state()[1]):
likelihoods[parent_id] = synthesizer._get_likelihood(table_rows)
likelihoods[parent_id] = synthesizer._get_likelihood(table_rows)

except (AttributeError, np.linalg.LinAlgError):
likelihoods[parent_id] = None
Expand Down
30 changes: 30 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,3 +1177,33 @@ def test_get_learned_distributions_error_msg(self):
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synth.get_learned_distributions(table_name='guests')

def test__get_likelihoods(self):
"""Test ``_get_likelihoods`` generates likelihoods for parents."""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
hmasynthesizer = HMASynthesizer(metadata)
hmasynthesizer.fit(data)

sampled_data = {}
sampled_data['characters'] = hmasynthesizer._sample_rows(
hmasynthesizer._table_synthesizers['characters'],
len(data['characters'])
)
hmasynthesizer._sample_children('characters', sampled_data)

# Run
likelihoods = hmasynthesizer._get_likelihoods(
sampled_data['character_families'],
sampled_data['characters'].set_index('character_id'),
'character_families',
'character_id'
)

# Assert
not_nan_cols = [1, 3, 6]
nan_cols = [2, 4, 5, 7]
assert set(likelihoods.columns) == {1, 2, 3, 4, 5, 6, 7}
assert len(likelihoods) == len(sampled_data['character_families'])
assert not any(likelihoods[not_nan_cols].isna().any())
assert all(likelihoods[nan_cols].isna())

0 comments on commit 64e8df2

Please sign in to comment.