Skip to content

Commit

Permalink
Merge pull request #1051 from lsst/tickets/DM-45680
Browse files Browse the repository at this point in the history
DM-45680: Allow boolean columns to be used in query 'where'
  • Loading branch information
dhirving authored Aug 12, 2024
2 parents 8e57637 + 7aa86be commit a0830e1
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 4 deletions.
4 changes: 4 additions & 0 deletions doc/changes/DM-45680.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix an issue where boolean metadata columns (like `exposure.can_see_sky` and
`exposure.has_simulated`) were not usable in `where` clauses for Registry query
functions. These column names can now be used as a boolean expression, for
example `where="exposure.can_see_sky` or `where="NOT exposure.can_see_sky"`.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[An
# Docstring inherited.
return self.expect_scalar(expression.operand).desc()

def visit_boolean_wrapper(
self, value: qt.ColumnExpression, flags: PredicateVisitFlags
) -> sqlalchemy.ColumnElement[bool]:
return self.expect_scalar(value)

def visit_comparison(
self,
a: qt.ColumnExpression,
Expand Down
45 changes: 44 additions & 1 deletion python/lsst/daf/butler/queries/_expression_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .tree import (
BinaryExpression,
ColumnExpression,
ColumnReference,
ComparisonOperator,
LiteralValue,
Predicate,
Expand Down Expand Up @@ -158,6 +159,16 @@ def visitBinaryOp(
return Predicate.is_null(rhs.value)
case ["!=", _Null(), _ColExpr() as rhs]:
return Predicate.is_null(rhs.value).logical_not()
# Boolean columns can be null, but will have been converted to
# Predicate, so we need additional cases.
case ["=" | "!=", Predicate() as pred, _Null()] | ["=" | "!=", _Null(), Predicate() as pred]:
column_ref = _get_boolean_column_reference(pred)
if column_ref is not None:
match operator:
case "=":
return Predicate.is_null(column_ref)
case "!=":
return Predicate.is_null(column_ref).logical_not()

# Handle arithmetic operations
case [("+" | "-" | "*" | "/" | "%") as op, _ColExpr() as lhs, _ColExpr() as rhs]:
Expand Down Expand Up @@ -198,7 +209,23 @@ def visitIdentifier(self, name: str, node: Node) -> _VisitorResult:
if categorizeConstant(name) == ExpressionConstant.NULL:
return _Null()

return _ColExpr(interpret_identifier(self.context, name))
column_expression = interpret_identifier(self.context, name)
if column_expression.column_type == "bool":
# Expression-handling code (in this file and elsewhere) expects
# boolean-valued expressions to be represented as Predicate, not a
# ColumnExpression.

# We should only be getting direct references to a column, not a
# more complicated expression.
# (Anything more complicated should be a Predicate already.)
assert (
column_expression.expression_type == "dataset_field"
or column_expression.expression_type == "dimension_field"
or column_expression.expression_type == "dimension_key"
)
return Predicate.from_bool_expression(column_expression)
else:
return _ColExpr(column_expression)

def visitNumericLiteral(self, value: str, node: Node) -> _VisitorResult:
numeric: int | float
Expand Down Expand Up @@ -303,3 +330,19 @@ def _convert_in_clause_to_predicate(lhs: ColumnExpression, rhs: _VisitorResult,
return Predicate.is_null(lhs)
case _:
raise InvalidQueryError(f"Invalid IN expression: '{node!s}")


def _get_boolean_column_reference(predicate: Predicate) -> ColumnReference | None:
"""Unwrap a predicate to recover the boolean ColumnReference it contains.
Returns `None` if this Predicate contains anything other than a single
boolean ColumnReference operand.
This undoes the ColumnReference to Predicate conversion that occurs in
visitIdentifier for boolean columns.
"""
if len(predicate.operands) == 1 and len(predicate.operands[0]) == 1:
predicate_leaf = predicate.operands[0][0]
if predicate_leaf.predicate_type == "boolean_wrapper":
return predicate_leaf.operand

return None
62 changes: 61 additions & 1 deletion python/lsst/daf/butler/queries/expression_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ def desc(self) -> tree.Reversed:
"""
return tree.Reversed(operand=self._expression)

def as_boolean(self) -> tree.Predicate:
"""If this scalar expression is a boolean, convert it to a `Predicate`
so it can be used as a boolean expression.
Raises
------
InvalidQueryError
If this expression is not a boolean.
Returns
-------
predicate : `Predicate`
This expression converted to a `Predicate`.
"""
expr = self._expression
raise InvalidQueryError(
f"Expression '{expr}' with type"
f" '{expr.column_type}' can't be used directly as a boolean value."
" Use a comparison operator like '>' or '==' instead."
)

def __eq__(self, other: object) -> tree.Predicate: # type: ignore[override]
return self._make_comparison(other, "==")

Expand Down Expand Up @@ -233,6 +254,42 @@ def _expression(self) -> tree.ColumnExpression:
return self._expr


class BooleanScalarExpressionProxy(ScalarExpressionProxy):
"""A `ScalarExpressionProxy` representing a boolean column. You should
call `as_boolean()` on this object to convert it to an instance of
`Predicate` before attempting to use it.
Parameters
----------
expression : `.tree.ColumnReference`
Boolean column reference that backs this proxy.
"""

# This is a hack/work-around to make static typing work when referencing
# dimension record metadata boolean columns. From the perspective of
# typing, anything boolean should be a `Predicate`, but the type system has
# no way of knowing whether a given column is a bool or some other type.

def __init__(self, expression: tree.ColumnReference) -> None:
if expression.column_type != "bool":
raise ValueError(f"Expression is a {expression.column_type}, not a 'bool': {expression}")
self._boolean_expression = expression

@property
def is_null(self) -> tree.Predicate:
return ResolvedScalarExpressionProxy(self._boolean_expression).is_null

def as_boolean(self) -> tree.Predicate:
return tree.Predicate.from_bool_expression(self._boolean_expression)

@property
def _expression(self) -> tree.ColumnExpression:
raise InvalidQueryError(
f"Boolean expression '{self._boolean_expression}' can't be used directly in other expressions."
" Call the 'as_boolean()' method to convert it to a Predicate instead."
)


class TimespanProxy(ExpressionProxy):
"""An `ExpressionProxy` specialized for timespan columns and literals.
Expand Down Expand Up @@ -350,7 +407,10 @@ def __getattr__(self, field: str) -> ScalarExpressionProxy:
expression = tree.DimensionFieldReference(element=self._element, field=field)
except InvalidQueryError:
raise AttributeError(field)
return ResolvedScalarExpressionProxy(expression)
if expression.column_type == "bool":
return BooleanScalarExpressionProxy(expression)
else:
return ResolvedScalarExpressionProxy(expression)

@property
def region(self) -> RegionProxy:
Expand Down
43 changes: 42 additions & 1 deletion python/lsst/daf/butler/queries/tree/_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ._base import QueryTreeBase
from ._column_expression import (
ColumnExpression,
ColumnReference,
is_one_datetime_and_one_ingest_date,
is_one_timespan_and_one_datetime,
)
Expand Down Expand Up @@ -155,6 +156,26 @@ def from_bool(cls, value: bool) -> Predicate:
#
return cls.model_construct(operands=() if value else ((),))

@classmethod
def from_bool_expression(cls, value: ColumnReference) -> Predicate:
"""Construct a predicate that wraps a boolean ColumnReference, taking
on the value of the underlying ColumnReference.
Parameters
----------
value : `ColumnExpression`
Boolean-valued expression to convert to Predicate.
Returns
-------
predicate : `Predicate`
Predicate representing the expression.
"""
if value.column_type != "bool":
raise ValueError(f"ColumnExpression must have column type 'bool', not '{value.column_type}'")

return cls._from_leaf(BooleanWrapper(operand=value))

@classmethod
def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate:
"""Construct a predicate representing a binary comparison between
Expand Down Expand Up @@ -412,6 +433,26 @@ def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlag
return visitor._visit_logical_not(self.operand, flags)


class BooleanWrapper(PredicateLeafBase):
"""Pass-through to a pre-existing boolean column expression."""

predicate_type: Literal["boolean_wrapper"] = "boolean_wrapper"

operand: ColumnReference
"""Wrapped expression that will be used as the value for this predicate."""

def gather_required_columns(self, columns: ColumnSet) -> None:
# Docstring inherited.
self.operand.gather_required_columns(columns)

def __str__(self) -> str:
return f"{self.operand}"

def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
# Docstring inherited.
return visitor.visit_boolean_wrapper(self.operand, flags)


@final
class IsNull(PredicateLeafBase):
"""A boolean column expression that tests whether its operand is NULL."""
Expand Down Expand Up @@ -639,7 +680,7 @@ def _validate_column_types(self) -> InQuery:
return self


LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery
LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery | BooleanWrapper
PredicateLeaf: TypeAlias = Annotated[
LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type")
]
Expand Down
24 changes: 24 additions & 0 deletions python/lsst/daf/butler/queries/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,25 @@ class PredicateVisitor(Generic[_A, _O, _L]):
visit method arguments.
"""

@abstractmethod
def visit_boolean_wrapper(self, value: tree.ColumnExpression, flags: PredicateVisitFlags) -> _L:
"""Visit a boolean-valued column expression.
Parameters
----------
value : `tree.ColumnExpression`
Column expression, guaranteed to have `column_type == "bool"`.
flags : `PredicateVisitFlags`
Information about where this leaf appears in the larger predicate
tree.
Returns
-------
result : `object`
Implementation-defined.
"""
raise NotImplementedError()

@abstractmethod
def visit_comparison(
self,
Expand Down Expand Up @@ -448,6 +467,11 @@ class SimplePredicateVisitor(
return a replacement `Predicate` to construct a new tree.
"""

def visit_boolean_wrapper(
self, value: tree.ColumnExpression, flags: PredicateVisitFlags
) -> tree.Predicate | None:
return None

def visit_comparison(
self,
a: tree.ColumnExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,12 @@ def visitIdentifier(self, name: str, node: Node) -> VisitorResult:
if column == timespan_database_representation.TimespanDatabaseRepresentation.NAME
else element.RecordClass.fields.standard[column].getPythonType()
)
return ColumnExpression.reference(tag, dtype)
if dtype is bool:
# ColumnExpression is for non-boolean columns only. Booleans
# are represented as Predicate.
return Predicate.reference(tag)
else:
return ColumnExpression.reference(tag, dtype)
else:
tag = DimensionKeyColumnTag(element.name)
assert isinstance(element, Dimension)
Expand Down
Loading

0 comments on commit a0830e1

Please sign in to comment.