Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added isnone() function #801

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import case, greatest, ifelse, least
from .conditional import case, greatest, ifelse, isnone, least
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
from .string import byte_hamming_distance
Expand All @@ -42,6 +42,7 @@
"greatest",
"ifelse",
"int_hash_64",
"isnone",
"least",
"length",
"literal",
Expand Down
86 changes: 64 additions & 22 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Union
from typing import Optional, Union

from sqlalchemy import ColumnElement
from sqlalchemy import case as sql_case
from sqlalchemy.sql.elements import BinaryExpression

from datachain.lib.utils import DataChainParamsError
from datachain.query.schema import Column
from datachain.sql.functions import conditional

from .func import ColT, Func

CaseT = Union[int, float, complex, bool, str]
CaseT = Union[int, float, complex, bool, str, Func]

shcheklein marked this conversation as resolved.
Show resolved Hide resolved

def greatest(*args: Union[ColT, float]) -> Func:
Expand Down Expand Up @@ -87,17 +88,19 @@ def least(*args: Union[ColT, float]) -> Func:
)


def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
def case(
*args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
) -> Func:
"""
Returns the case function that produces case expression which has a list of
conditions and corresponding results. Results can only be python primitives
like string, numbes or booleans. Result type is inferred from condition results.
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
like string, numbers or booleans. Result type is inferred from condition results.

Args:
args (tuple(BinaryExpression, value(str | int | float | complex | bool):
- Tuple of binary expression and values pair which corresponds to one
case condition - value
else_ (str | int | float | complex | bool): else value in case expression
args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))):
Tuple of condition and values pair.
else_ (str | int | float | complex | bool, Func): else value in case
expression.
shcheklein marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Func: A Func object that represents the case function.
Expand All @@ -111,45 +114,84 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
"""
supported_types = [int, float, complex, str, bool]

type_ = type(else_) if else_ else None
def _get_type(val):
if isinstance(val, Func):
# nested functions
return val.result_type
return type(val)

if not args:
raise DataChainParamsError("Missing statements")

type_ = _get_type(else_) if else_ is not None else None

for arg in args:
if type_ and not isinstance(arg[1], type_):
raise DataChainParamsError("Statement values must be of the same type")
type_ = type(arg[1])
arg_type = _get_type(arg[1])
if type_ and arg_type != type_:
raise DataChainParamsError(
f"Statement values must be of the same type, got {type_} and {arg_type}"
)
type_ = arg_type

if type_ not in supported_types:
raise DataChainParamsError(
f"Only python literals ({supported_types}) are supported for values"
)

kwargs = {"else_": else_}
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)

return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)


def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
def ifelse(
condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT
) -> Func:
"""
Returns the ifelse function that produces if expression which has a condition
and values for true and false outcome. Results can only be python primitives
like string, numbes or booleans. Result type is inferred from the values.
and values for true and false outcome. Results can be one of python primitives
like string, numbes or booleans, but can also be nested functions.
Result type is inferred from the values.
shcheklein marked this conversation as resolved.
Show resolved Hide resolved

Args:
condition: BinaryExpression - condition which is evaluated
if_val: (str | int | float | complex | bool): value for true condition outcome
else_val: (str | int | float | complex | bool): value for false condition
outcome
condition (ColumnElement, Func): Condition which is evaluated.
if_val (str | int | float | complex | bool, Func): Value for true
condition outcome.
else_val (str | int | float | complex | bool, Func): Value for false condition
outcome.

Returns:
Func: A Func object that represents the ifelse function.

Example:
```py
dc.mutate(
res=func.ifelse(C("num") > 0, "P", "N"),
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"),
)
```
"""
return case((condition, if_val), else_=else_val)


def isnone(col: Union[str, Column]) -> Func:
"""
Returns True if column value is None, otherwise False

Args:
col (str | Column): Column to check if it's None or not.
If a string is provided, it is assumed to be the name of the column.
shcheklein marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Func: A Func object that represents the conditional to check if column is None.

Example:
```py
dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"))
```
"""
from datachain import C

if isinstance(col, str):
# if string, it is assumed to be the name of the column
col = C(col)

return case((col == None, True), else_=False) # noqa: E711
19 changes: 13 additions & 6 deletions src/datachain/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .window import Window


ColT = Union[str, ColumnElement, "Func"]
ColT = Union[str, ColumnElement, "Func", tuple]


class Func(Function):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _db_cols(self) -> Sequence[ColT]:
return (
[
col
if isinstance(col, (Func, BindParameter, Case, Comparator))
if isinstance(col, (Func, BindParameter, Case, Comparator, tuple))
else ColumnMeta.to_db_name(
col.name if isinstance(col, ColumnElement) else col
)
Expand Down Expand Up @@ -381,17 +381,24 @@ def get_column(
col_type = self.get_result_type(signals_schema)
sql_type = python_to_sql(col_type)

def get_col(col: ColT) -> ColT:
def get_col(col: ColT, string_as_literal=False) -> ColT:
# string_as_literal is used only for conditionals like `case()` where
# literals are nested inside ColT as we have tuples of condition - values
# and if user wants to set some case value as column, explicit `C("col")`
# syntax must be used to distinguish from literals
if isinstance(col, tuple):
return tuple(get_col(x, string_as_literal=True) for x in col)
if isinstance(col, Func):
return col.get_column(signals_schema, table=table)
if isinstance(col, str):
if isinstance(col, str) and not string_as_literal:
column = Column(col, sql_type)
column.table = table
return column
return col

cols = [get_col(col) for col in self._db_cols]
func_col = self.inner(*cols, *self.args, **self.kwargs)
kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
func_col = self.inner(*cols, *self.args, **kwargs)

if self.is_window:
if not self.window:
Expand Down Expand Up @@ -423,7 +430,7 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
return sql_to_python(col)

return signals_schema.get_column_type(
col.name if isinstance(col, ColumnElement) else col
col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down
19 changes: 18 additions & 1 deletion tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def test_case_not_same_result_types(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D"))
assert str(exc_info.value) == "Statement values must be of the same type"
assert str(exc_info.value) == (
"Statement values must be of the same type, got <class 'str'> amd <class 'int'>"
)


def test_case_wrong_result_type(warehouse):
Expand Down Expand Up @@ -124,3 +126,18 @@ def test_ifelse(warehouse, val, expected):
query = select(func.ifelse(val <= 3, "L", "H"))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)


@pytest.mark.parametrize(
"val,expected",
[
[None, True],
[func.literal("abcd"), False],
],
)
def test_isnone(warehouse, val, expected):
from datachain.func.conditional import isnone

query = select(isnone(val))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)
82 changes: 82 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
case,
ifelse,
int_hash_64,
isnone,
literal,
)
from datachain.func.random import rand
Expand All @@ -18,6 +19,7 @@
sqlite_byte_hamming_distance,
sqlite_int_hash_64,
)
from tests.utils import skip_if_not_sqlite


@pytest.fixture()
Expand Down Expand Up @@ -663,6 +665,59 @@ def test_case_mutate(dc, val, else_, type_):
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"val,else_,type_",
[
["A", "D", str],
[1, 2, int],
[1.5, 2.5, float],
[True, False, bool],
],
)
def test_nested_case_on_condition_mutate(dc, val, else_, type_):
res = dc.mutate(
test=case((case((C("num") < 2, True), else_=False), val), else_=else_)
)
assert list(res.order_by("test").collect("test")) == sorted(
[val, else_, else_, else_, else_]
)
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"v1,v2,v3,type_",
[
["A", "B", "C", str],
[1, 2, 3, int],
[1.5, 2.5, 3.5, float],
[False, True, True, bool],
],
)
def test_nested_case_on_value_mutate(dc, v1, v2, v3, type_):
res = dc.mutate(
test=case((C("num") < 4, case((C("num") < 2, v1), else_=v2)), else_=v3)
)
assert list(res.order_by("num").collect("test")) == sorted([v1, v2, v2, v3, v3])
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"v1,v2,v3,type_",
[
["A", "B", "C", str],
[1, 2, 3, int],
[1.5, 2.5, 3.5, float],
[False, True, True, bool],
],
)
def test_nested_case_on_else_mutate(dc, v1, v2, v3, type_):
res = dc.mutate(
test=case((C("num") < 3, v1), else_=case((C("num") < 4, v2), else_=v3))
)
assert list(res.order_by("num").collect("test")) == sorted([v1, v1, v2, v3, v3])
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"if_val,else_val,type_",
[
Expand All @@ -678,3 +733,30 @@ def test_ifelse_mutate(dc, if_val, else_val, type_):
[if_val, else_val, else_val, else_val, else_val]
)
assert res.schema["test"] == type_


@pytest.mark.parametrize("col", ["val", C("val")])
def test_isnone_mutate(col):
dc = DataChain.from_values(
num=list(range(1, 6)),
val=[None if i > 3 else "A" for i in range(1, 6)],
)

res = dc.mutate(test=isnone(col))
assert list(res.order_by("test").collect("test")) == sorted(
[False, False, False, True, True]
)
assert res.schema["test"] is bool


@pytest.mark.parametrize("col", [C("val"), "val"])
@skip_if_not_sqlite
def test_isnone_with_ifelse_mutate(col):
dc = DataChain.from_values(
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
num=list(range(1, 6)),
val=[None if i > 3 else "A" for i in range(1, 6)],
)

res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE"))
assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2
assert res.schema["test"] is str
Loading