diff --git a/tests/baseline.py b/tests/baseline.py index a60dcfc60..135069e7f 100644 --- a/tests/baseline.py +++ b/tests/baseline.py @@ -44,6 +44,74 @@ def from_dict(cls, data: Dict): def num_function_variants_without_coverage(self): return self.coverage.num_function_variants_without_coverage() + def validate_against(self, expected): + errors = [] + + if self.registry.extension_count < expected.registry.extension_count: + errors.append( + f"Extension count mismatch: expected {expected.registry.extension_count}, got {self.registry.extension_count}" + ) + if self.registry.dependency_count < expected.registry.dependency_count: + errors.append( + f"Dependency count mismatch: expected {expected.registry.dependency_count}, got {self.registry.dependency_count}" + ) + if self.registry.function_count < expected.registry.function_count: + errors.append( + f"Function count mismatch: expected {expected.registry.function_count}, got {self.registry.function_count}" + ) + if ( + self.registry.num_aggregate_functions + < expected.registry.num_aggregate_functions + ): + errors.append( + f"Aggregate function count mismatch: expected {expected.registry.num_aggregate_functions}, got {self.registry.num_aggregate_functions}" + ) + if self.registry.num_scalar_functions < expected.registry.num_scalar_functions: + errors.append( + f"Scalar function count mismatch: expected {expected.registry.num_scalar_functions}, got {self.registry.num_scalar_functions}" + ) + if self.registry.num_window_functions < expected.registry.num_window_functions: + errors.append( + f"Window function count mismatch: expected {expected.registry.num_window_functions}, got {self.registry.num_window_functions}" + ) + if ( + self.registry.num_function_overloads + < expected.registry.num_function_overloads + ): + errors.append( + f"Function overload count mismatch: expected {expected.registry.num_function_overloads}, got {self.registry.num_function_overloads}" + ) + + if self.coverage.total_test_count < expected.coverage.total_test_count: + errors.append( + f"Total test count mismatch: expected {expected.coverage.total_test_count}, got {self.coverage.total_test_count}" + ) + if ( + self.coverage.num_function_variants + < expected.coverage.num_function_variants + ): + errors.append( + f"Total function variants mismatch: expected {expected.coverage.num_function_variants}, got {self.coverage.num_function_variants}" + ) + if ( + self.coverage.num_covered_function_variants + < expected.coverage.num_covered_function_variants + ): + errors.append( + f"Covered function variants mismatch: expected {expected.coverage.num_covered_function_variants}, got {self.coverage.num_covered_function_variants}" + ) + + expected_coverage_gap = expected.num_function_variants_without_coverage() + actual_coverage_gap = self.num_function_variants_without_coverage() + if actual_coverage_gap > expected_coverage_gap: + errors.append( + f"Coverage gap too large: {actual_coverage_gap} function variants with no tests, " + f"out of {self.coverage.num_function_variants} total function variants. " + f"New functions should be added along with test cases that illustrate their behavior." + ) + + return errors + def read_baseline_file(file_path: str) -> Baseline: with open(file_path, "r") as file: diff --git a/tests/cases/comparison/is_false.test b/tests/cases/comparison/is_false.test index ea2d78b2d..6cfd337d6 100644 --- a/tests/cases/comparison/is_false.test +++ b/tests/cases/comparison/is_false.test @@ -1,7 +1,7 @@ ### SUBSTRAIT_SCALAR_TEST: v1.0 ### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml' -# basic is_false +# basic: Basic examples without any special cases is_false(true::bool) = false::bool is_false(false::bool) = true::bool is_false(null::bool) = false::bool diff --git a/tests/cases/comparison/is_not_false.test b/tests/cases/comparison/is_not_false.test index a19bb53de..202b84318 100644 --- a/tests/cases/comparison/is_not_false.test +++ b/tests/cases/comparison/is_not_false.test @@ -1,7 +1,7 @@ ### SUBSTRAIT_SCALAR_TEST: v1.0 ### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml' -# basic is_not_false +# basic: Basic examples without any special cases is_not_false(true::bool) = true::bool is_not_false(false::bool) = false::bool is_not_false(null::bool) = true::bool diff --git a/tests/cases/comparison/is_not_true b/tests/cases/comparison/is_not_true.test similarity index 80% rename from tests/cases/comparison/is_not_true rename to tests/cases/comparison/is_not_true.test index 997735bb1..58b42781a 100644 --- a/tests/cases/comparison/is_not_true +++ b/tests/cases/comparison/is_not_true.test @@ -1,7 +1,7 @@ ### SUBSTRAIT_SCALAR_TEST: v1.0 ### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml' -# basic is_not_true +# basic: Basic examples without any special cases is_not_true(true::bool) = false::bool is_not_true(false::bool) = true::bool is_not_true(null::bool) = true::bool diff --git a/tests/cases/comparison/is_true.test b/tests/cases/comparison/is_true.test index cb1f30170..4bddaf2d8 100644 --- a/tests/cases/comparison/is_true.test +++ b/tests/cases/comparison/is_true.test @@ -1,7 +1,7 @@ ### SUBSTRAIT_SCALAR_TEST: v1.0 ### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml' -# basic is_true +# basic: Basic examples without any special cases is_true(true::bool) = true::bool is_true(false::bool) = false::bool is_true(null::bool) = false::bool diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 95fff93e1..064a67eaa 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -10,72 +10,6 @@ from tests.coverage.extensions import Extension -def compare_baselines(expected, actual): - errors = [] - - if actual.registry.extension_count < expected.registry.extension_count: - errors.append( - f"Extension count mismatch: expected {expected.registry.extension_count}, got {actual.registry.extension_count}" - ) - if actual.registry.dependency_count < expected.registry.dependency_count: - errors.append( - f"Dependency count mismatch: expected {expected.registry.dependency_count}, got {actual.registry.dependency_count}" - ) - if actual.registry.function_count < expected.registry.function_count: - errors.append( - f"Function count mismatch: expected {expected.registry.function_count}, got {actual.registry.function_count}" - ) - if ( - actual.registry.num_aggregate_functions - < expected.registry.num_aggregate_functions - ): - errors.append( - f"Aggregate function count mismatch: expected {expected.registry.num_aggregate_functions}, got {actual.registry.num_aggregate_functions}" - ) - if actual.registry.num_scalar_functions < expected.registry.num_scalar_functions: - errors.append( - f"Scalar function count mismatch: expected {expected.registry.num_scalar_functions}, got {actual.registry.num_scalar_functions}" - ) - if actual.registry.num_window_functions < expected.registry.num_window_functions: - errors.append( - f"Window function count mismatch: expected {expected.registry.num_window_functions}, got {actual.registry.num_window_functions}" - ) - if ( - actual.registry.num_function_overloads - < expected.registry.num_function_overloads - ): - errors.append( - f"Function overload count mismatch: expected {expected.registry.num_function_overloads}, got {actual.registry.num_function_overloads}" - ) - - if actual.coverage.total_test_count < expected.coverage.total_test_count: - errors.append( - f"Total test count mismatch: expected {expected.coverage.total_test_count}, got {actual.coverage.total_test_count}" - ) - if actual.coverage.num_function_variants < expected.coverage.num_function_variants: - errors.append( - f"Total function variants mismatch: expected {expected.coverage.num_function_variants}, got {actual.coverage.num_function_variants}" - ) - if ( - actual.coverage.num_covered_function_variants - < expected.coverage.num_covered_function_variants - ): - errors.append( - f"Covered function variants mismatch: expected {expected.coverage.num_covered_function_variants}, got {actual.coverage.num_covered_function_variants}" - ) - - expected_coverage_gap = expected.num_function_variants_without_coverage() - actual_coverage_gap = actual.num_function_variants_without_coverage() - if actual_coverage_gap > expected_coverage_gap: - errors.append( - f"Coverage gap too large: {actual_coverage_gap} function variants with no tests, " - f"out of {actual.coverage.num_function_variants} total function variants. " - f"New functions should be added along with test cases that illustrate their behavior." - ) - - return errors - - # NOTE: this test is run as part of pre-commit hook def test_substrait_extension_coverage(): script_dir = os.path.dirname(os.path.abspath(__file__)) @@ -92,7 +26,7 @@ def test_substrait_extension_coverage(): ), f"{coverage.num_tests_with_no_matching_function} tests with no matching function" actual_baseline = generate_baseline(registry, coverage) - errors = compare_baselines(baseline, actual_baseline) + errors = actual_baseline.validate_against(baseline) assert not errors, ( "\n".join(errors) + f"The baseline file does not match the current test coverage. "