Skip to content

Commit

Permalink
include colum name in fit; update report names
Browse files Browse the repository at this point in the history
  • Loading branch information
rwedge committed Feb 20, 2025
1 parent 81a78ae commit a10d610
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions tests/benchmark/supported_dtypes_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,15 @@ def test_transformer(dtype, data, sdtype, transformer):

_transformer = transformer()
transformer_name = _transformer.get_name()
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f"{transformer_name}_FIT")
previous_transform_result, _ = get_previous_dtype_result(dtype, sdtype, f"{transformer_name}_TRANSFORM")
previous_reverse_result, _ = get_previous_dtype_result(dtype, sdtype, f"{transformer_name}_REVERSE")
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f"RDT_{transformer_name}_FIT")
previous_transform_result, _ = get_previous_dtype_result(dtype, sdtype, f"RDT_{transformer_name}_TRANSFORM")
previous_reverse_result, _ = get_previous_dtype_result(dtype, sdtype, f"RDT_{transformer_name}_REVERSE")
fit_result = False
transform_result = False
reverse_result = False

try:
_transformer.fit(data)
_transformer.fit(data, dtype)
fit_result = True
transformed_data = _transformer.transform(data)
transform_result = True
Expand All @@ -298,20 +298,20 @@ def test_transformer(dtype, data, sdtype, transformer):
save_results_to_json({
'dtype': dtype,
'sdtype': sdtype,
f"{transformer_name}_FIT": fit_result,
f"{transformer_name}_TRANSFORM": transform_result,
f"{transformer_name}_REVERSE": reverse_result,
f"RDT_{transformer_name}_FIT": fit_result,
f"RDT_{transformer_name}_TRANSFORM": transform_result,
f"RDT_{transformer_name}_REVERSE": reverse_result,
})

fit_assertion_message = f"{dtype} is no longer supported by '{transformer_name}_FIT'."
fit_assertion_message = f"{dtype} is no longer supported by 'RDT_{transformer_name}_FIT'."
if fit_result is False:
assert fit_result == previous_fit_result, fit_assertion_message

transform_assertion_message = f"{dtype} is no longer supported by '{transformer_name}_TRANSFORM'."
transform_assertion_message = f"{dtype} is no longer supported by 'RDT_{transformer_name}_TRANSFORM'."
if transform_result is False:
assert transform_result == previous_transform_result, transform_assertion_message

reverse_assertion_message = f"{dtype} is no longer supported by '{transformer_name}_REVERSE'."
reverse_assertion_message = f"{dtype} is no longer supported by 'RDT_{transformer_name}_REVERSE'."
if reverse_result is False:
assert reverse_result == previous_reverse_result, reverse_assertion_message

Expand Down Expand Up @@ -458,9 +458,9 @@ def test_fit_and_sample_single_column_constraints(constraint_name, constraint, d
metadata = _get_metadata_for_dtype_and_sdtype(dtype, sdtype)
synthesizer = GaussianCopulaSynthesizer(metadata)
sdtype = metadata.columns[dtype].get('sdtype')
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'{constraint_name}_FIT')
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'CONSTRAINT_{constraint_name}_FIT')
previous_sample_result, _ = get_previous_dtype_result(
dtype, sdtype, f'{constraint_name}_SAMPLE'
dtype, sdtype, f'CONSTRAINT_{constraint_name}_SAMPLE'
)

# Prepare the constraint and data
Expand Down Expand Up @@ -491,16 +491,16 @@ def test_fit_and_sample_single_column_constraints(constraint_name, constraint, d
save_results_to_json({
'dtype': dtype,
'sdtype': sdtype,
f'{constraint_name}_FIT': fit_result,
f'{constraint_name}_SAMPLE': sample_result,
f'CONSTRAINT_{constraint_name}_FIT': fit_result,
f'CONSTRAINT_{constraint_name}_SAMPLE': sample_result,
})
if fit_result is False:
fit_assertion_message = f"{dtype} is no longer supported by '{constraint_name}_FIT''."
fit_assertion_message = f"{dtype} is no longer supported by 'CONSTRAINT_{constraint_name}_FIT''."
assert fit_result == previous_fit_result, fit_assertion_message

if sample_result is False:
sample_assertion_message = (
f"{dtype} is no longer supported by '{constraint_name}_FIT''."
f"{dtype} is no longer supported by 'CONSTRAINT_{constraint_name}_FIT''."
)
assert sample_result == previous_sample_result, sample_assertion_message

Expand Down Expand Up @@ -542,9 +542,9 @@ def test_fit_and_sample_multi_column_constraints(constraint_name, constraint, dt
if (sdtype, dtype, constraint_name) not in EXCLUDED_CONSTRAINT_TESTS:
metadata = _get_metadata_for_dtype_and_sdtype(dtype, sdtype)
sdtype = metadata.columns[dtype].get('sdtype')
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'{constraint_name}_FIT')
previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'CONSTRAINT_{constraint_name}_FIT')
previous_sample_result, _ = get_previous_dtype_result(
dtype, sdtype, f'{constraint_name}_SAMPLE'
dtype, sdtype, f'CONSTRAINT_{constraint_name}_SAMPLE'
)

# Prepare constraints, data required and metadata
Expand Down Expand Up @@ -577,13 +577,13 @@ def test_fit_and_sample_multi_column_constraints(constraint_name, constraint, dt
save_results_to_json({
'dtype': dtype,
'sdtype': sdtype,
f'{constraint_name}_FIT': fit_result,
f'{constraint_name}_SAMPLE': sample_result,
f'CONSTRAINT_{constraint_name}_FIT': fit_result,
f'CONSTRAINT_{constraint_name}_SAMPLE': sample_result,
})
if fit_result is False:
fit_message = f"{dtype} failed during '{constraint_name}_FIT'."
fit_message = f"{dtype} failed during 'CONSTRAINT_{constraint_name}_FIT'."
assert fit_result == previous_fit_result, fit_message

if sample_result is False:
sample_msg = f"{dtype} failed during '{constraint_name}_SAMPLE'."
sample_msg = f"{dtype} failed during 'CONSTRAINT_{constraint_name}_SAMPLE'."
assert sample_result == previous_sample_result, sample_msg

0 comments on commit a10d610

Please sign in to comment.