Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran committed Jan 17, 2025
1 parent fa23172 commit b694f60
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 71 deletions.
68 changes: 68 additions & 0 deletions tests/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,74 @@ def from_dict(cls, data: Dict):
def num_function_variants_without_coverage(self):
return self.coverage.num_function_variants_without_coverage()

def validate_against(self, expected):
errors = []

if self.registry.extension_count < expected.registry.extension_count:
errors.append(
f"Extension count mismatch: expected {expected.registry.extension_count}, got {self.registry.extension_count}"
)
if self.registry.dependency_count < expected.registry.dependency_count:
errors.append(
f"Dependency count mismatch: expected {expected.registry.dependency_count}, got {self.registry.dependency_count}"
)
if self.registry.function_count < expected.registry.function_count:
errors.append(
f"Function count mismatch: expected {expected.registry.function_count}, got {self.registry.function_count}"
)
if (
self.registry.num_aggregate_functions
< expected.registry.num_aggregate_functions
):
errors.append(
f"Aggregate function count mismatch: expected {expected.registry.num_aggregate_functions}, got {self.registry.num_aggregate_functions}"
)
if self.registry.num_scalar_functions < expected.registry.num_scalar_functions:
errors.append(
f"Scalar function count mismatch: expected {expected.registry.num_scalar_functions}, got {self.registry.num_scalar_functions}"
)
if self.registry.num_window_functions < expected.registry.num_window_functions:
errors.append(
f"Window function count mismatch: expected {expected.registry.num_window_functions}, got {self.registry.num_window_functions}"
)
if (
self.registry.num_function_overloads
< expected.registry.num_function_overloads
):
errors.append(
f"Function overload count mismatch: expected {expected.registry.num_function_overloads}, got {self.registry.num_function_overloads}"
)

if self.coverage.total_test_count < expected.coverage.total_test_count:
errors.append(
f"Total test count mismatch: expected {expected.coverage.total_test_count}, got {self.coverage.total_test_count}"
)
if (
self.coverage.num_function_variants
< expected.coverage.num_function_variants
):
errors.append(
f"Total function variants mismatch: expected {expected.coverage.num_function_variants}, got {self.coverage.num_function_variants}"
)
if (
self.coverage.num_covered_function_variants
< expected.coverage.num_covered_function_variants
):
errors.append(
f"Covered function variants mismatch: expected {expected.coverage.num_covered_function_variants}, got {self.coverage.num_covered_function_variants}"
)

expected_coverage_gap = expected.num_function_variants_without_coverage()
actual_coverage_gap = self.num_function_variants_without_coverage()
if actual_coverage_gap > expected_coverage_gap:
errors.append(
f"Coverage gap too large: {actual_coverage_gap} function variants with no tests, "
f"out of {self.coverage.num_function_variants} total function variants. "
f"New functions should be added along with test cases that illustrate their behavior."
)

return errors


def read_baseline_file(file_path: str) -> Baseline:
with open(file_path, "r") as file:
Expand Down
2 changes: 1 addition & 1 deletion tests/cases/comparison/is_false.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### SUBSTRAIT_SCALAR_TEST: v1.0
### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml'

# basic is_false
# basic: Basic examples without any special cases
is_false(true::bool) = false::bool
is_false(false::bool) = true::bool
is_false(null::bool) = false::bool
2 changes: 1 addition & 1 deletion tests/cases/comparison/is_not_false.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### SUBSTRAIT_SCALAR_TEST: v1.0
### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml'

# basic is_not_false
# basic: Basic examples without any special cases
is_not_false(true::bool) = true::bool
is_not_false(false::bool) = false::bool
is_not_false(null::bool) = true::bool
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### SUBSTRAIT_SCALAR_TEST: v1.0
### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml'

# basic is_not_true
# basic: Basic examples without any special cases
is_not_true(true::bool) = false::bool
is_not_true(false::bool) = true::bool
is_not_true(null::bool) = true::bool
2 changes: 1 addition & 1 deletion tests/cases/comparison/is_true.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### SUBSTRAIT_SCALAR_TEST: v1.0
### SUBSTRAIT_INCLUDE: '/extensions/functions_comparison.yaml'

# basic is_true
# basic: Basic examples without any special cases
is_true(true::bool) = true::bool
is_true(false::bool) = false::bool
is_true(null::bool) = false::bool
68 changes: 1 addition & 67 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,72 +10,6 @@
from tests.coverage.extensions import Extension


def compare_baselines(expected, actual):
errors = []

if actual.registry.extension_count < expected.registry.extension_count:
errors.append(
f"Extension count mismatch: expected {expected.registry.extension_count}, got {actual.registry.extension_count}"
)
if actual.registry.dependency_count < expected.registry.dependency_count:
errors.append(
f"Dependency count mismatch: expected {expected.registry.dependency_count}, got {actual.registry.dependency_count}"
)
if actual.registry.function_count < expected.registry.function_count:
errors.append(
f"Function count mismatch: expected {expected.registry.function_count}, got {actual.registry.function_count}"
)
if (
actual.registry.num_aggregate_functions
< expected.registry.num_aggregate_functions
):
errors.append(
f"Aggregate function count mismatch: expected {expected.registry.num_aggregate_functions}, got {actual.registry.num_aggregate_functions}"
)
if actual.registry.num_scalar_functions < expected.registry.num_scalar_functions:
errors.append(
f"Scalar function count mismatch: expected {expected.registry.num_scalar_functions}, got {actual.registry.num_scalar_functions}"
)
if actual.registry.num_window_functions < expected.registry.num_window_functions:
errors.append(
f"Window function count mismatch: expected {expected.registry.num_window_functions}, got {actual.registry.num_window_functions}"
)
if (
actual.registry.num_function_overloads
< expected.registry.num_function_overloads
):
errors.append(
f"Function overload count mismatch: expected {expected.registry.num_function_overloads}, got {actual.registry.num_function_overloads}"
)

if actual.coverage.total_test_count < expected.coverage.total_test_count:
errors.append(
f"Total test count mismatch: expected {expected.coverage.total_test_count}, got {actual.coverage.total_test_count}"
)
if actual.coverage.num_function_variants < expected.coverage.num_function_variants:
errors.append(
f"Total function variants mismatch: expected {expected.coverage.num_function_variants}, got {actual.coverage.num_function_variants}"
)
if (
actual.coverage.num_covered_function_variants
< expected.coverage.num_covered_function_variants
):
errors.append(
f"Covered function variants mismatch: expected {expected.coverage.num_covered_function_variants}, got {actual.coverage.num_covered_function_variants}"
)

expected_coverage_gap = expected.num_function_variants_without_coverage()
actual_coverage_gap = actual.num_function_variants_without_coverage()
if actual_coverage_gap > expected_coverage_gap:
errors.append(
f"Coverage gap too large: {actual_coverage_gap} function variants with no tests, "
f"out of {actual.coverage.num_function_variants} total function variants. "
f"New functions should be added along with test cases that illustrate their behavior."
)

return errors


# NOTE: this test is run as part of pre-commit hook
def test_substrait_extension_coverage():
script_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -92,7 +26,7 @@ def test_substrait_extension_coverage():
), f"{coverage.num_tests_with_no_matching_function} tests with no matching function"

actual_baseline = generate_baseline(registry, coverage)
errors = compare_baselines(baseline, actual_baseline)
errors = actual_baseline.validate_against(baseline)
assert not errors, (
"\n".join(errors)
+ f"The baseline file does not match the current test coverage. "
Expand Down

0 comments on commit b694f60

Please sign in to comment.