Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explanation option to Chex shape asserts so that error messages can bring in lots of contextual information and make failures much more actionable. #198

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def _shape_matches(actual_shape: Sequence[int],
def assert_shape(
inputs: Union[Scalar, Union[Array, Sequence[Array]]],
expected_shapes: Union[_ai.TShapeMatcher,
Sequence[_ai.TShapeMatcher]]) -> None:
Sequence[_ai.TShapeMatcher]],
explanation: Optional[Union[str, Callable[[], str]]] = None) -> None:
"""Checks that the shape of all inputs matches specified ``expected_shapes``.

Valid usages include:
Expand All @@ -535,6 +536,8 @@ def assert_shape(
where the expected shape is a sequence of integer and `None` dimensions;
if all inputs have same shape, a single shape may be passed as
``expected_shapes``.
explanation: Additional message to give context when this assertion fails
(or a function/closure that returns such a message).

Raises:
AssertionError: If the lengths of ``inputs`` and ``expected_shapes`` do not
Expand Down Expand Up @@ -564,9 +567,19 @@ def assert_shape(
errors.append((idx, shape, _ai.format_shape_matcher(expected)))

if errors:
if callable(explanation):
try:
explanation: str = explanation()
except Exception as e: # pylint: disable=broad-except
explanation = ("[[`explanation` callback failed: " +
"\n".join(traceback.format_exception(
e.__class__, e, e.__traceback__, limit=4)) + "]]")
if not explanation:
explanation = ""
msg = "; ".join(
f"input {e[0]} has shape {e[1]} but expected {e[2]}" for e in errors)
raise AssertionError(f"Error in shape compatibility check: {msg}.")
raise AssertionError(
f"Error in shape compatibility check: {msg}. {explanation}")


@_static_assertion
Expand Down