Skip to content

Commit

Permalink
remove del column_relationship after warning
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jan 4, 2024
1 parent 8c1dd0c commit de2c268
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _detect_multi_column_transformers(self):
"""
result = {}
if self.metadata.column_relationships:
for relationship in self.metadata.column_relationships:
for relationship in self.metadata._valid_column_relationships:
column_names = tuple(relationship['column_names'])
relationship_type = relationship['type']
if relationship_type == 'address':
Expand Down
25 changes: 14 additions & 11 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def _append_error(self, errors, method, *args, **kwargs):
except InvalidMetadataError as e:
errors.append(e)

def _validate_column_relationship(self, relationship, index):
def _validate_column_relationship(self, relationship):
"""Validate a column relationship.
Verify that a column relationship has a valid relationship type, has
Expand All @@ -625,8 +625,6 @@ def _validate_column_relationship(self, relationship, index):
Args:
relationship (dict):
Column relationship to validate.
index (int):
Index of the relationship in the list of relationships.
Raises:
- ``InvalidMetadataError`` if relationship is invalid
Expand All @@ -653,13 +651,15 @@ def _validate_column_relationship(self, relationship, index):
}
try:
self._COLUMN_RELATIONSHIP_TYPES[relationship_type](columns_to_sdtypes)

except ImportError:
warnings.warn(
f"The metadata contains a column relationship of type '{relationship_type}'. "
f'which requires the {relationship_type} add-on.'
'This relationship will be ignored. For higher quality data in this'
' relationship, please inquire about the SDV Enterprise tier.')
del self.column_relationships[index]
' relationship, please inquire about the SDV Enterprise tier.'
)
raise ImportError

except Exception as e:
errors.append(str(e))
Expand Down Expand Up @@ -702,13 +702,16 @@ def _validate_all_column_relationships(self, column_relationships):

# Validate each individual relationship
errors = []
self._valid_column_relationships = deepcopy(column_relationships)
for idx, relationship in enumerate(column_relationships):
self._append_error(
errors,
self._validate_column_relationship,
relationship,
idx
)
try:
self._append_error(
errors,
self._validate_column_relationship,
relationship,
)
except ImportError:
self._valid_column_relationships.pop(idx)

if errors:
raise InvalidMetadataError(
Expand Down
1 change: 1 addition & 0 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test__detect_multi_column_transformers_with_address(self, transfomers_mock):
}
]
})
metadata.validate()
dp = DataProcessor(SingleTableMetadata())
dp.metadata = metadata
dp._locales = ['en_US', 'en_GB']
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ def test__validate_column_relationship(self):
}

# Run
instance._validate_column_relationship(relationship, 0)
instance._validate_column_relationship(relationship)

# Assert
expected_columns_to_sdtypes = {
Expand Down Expand Up @@ -1712,7 +1712,7 @@ def test__validate_column_relationship_bad_relationship_type(self):
"Must be one of ['mock_relationship']."
)
with pytest.raises(InvalidMetadataError, match=msg):
instance._validate_column_relationship(relationship, 0)
instance._validate_column_relationship(relationship)

def test__validate_column_relationship_bad_columns(self):
"""Test validation fails for invalid columns."""
Expand Down Expand Up @@ -1744,7 +1744,7 @@ def validation_side_effect(*args, **kwargs):
"Columns ['a', 'b'] have unsupported sdtype."
)
with pytest.raises(InvalidMetadataError, match=err_msg):
instance._validate_column_relationship(relationship, 0)
instance._validate_column_relationship(relationship)

# Assert
expected_columns_to_sdtypes = {
Expand Down Expand Up @@ -1774,8 +1774,8 @@ def test__validate_all_column_relationships(self):

# Assert
mock_validate_relationship.assert_has_calls([
call(relationship_one, 0),
call(relationship_two, 1)
call(relationship_one),
call(relationship_two)
])

def test__validate_all_column_relationships_invalid_relationship_structure(self):
Expand Down Expand Up @@ -1817,7 +1817,7 @@ def test__validate_all_column_relationships_repeated_column(self):
def test__validate_all_column_relationships_bad_relationship(self):
"""Test validation fails if individual relationship validation fails."""
# Setup
def mock_relationship_validate(relationship, idx):
def mock_relationship_validate(relationship):
raise InvalidMetadataError(
f"Error in '{relationship['type']}' relationship."
)
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def test_set_address_columns_warning(self):
"""Test ``set_address_columns`` method when the synthesizer has been fitted."""
# Setup
synthesizer = BaseSingleTableSynthesizer(SingleTableMetadata())
synthesizer._check_address_columns = Mock()
synthesizer._data_processor.set_address_transformer = Mock()

# Run and Assert
expected_message = re.escape(
Expand Down

0 comments on commit de2c268

Please sign in to comment.