diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index c3f9c32360fa..1abdc570397b 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -19,6 +19,7 @@ import pytest +# TODO: Import these from py-polars/tests/unit/operations/test_inefficient_apply.py MY_CONSTANT = 3 MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} MY_LIST = [1, 2, 3] diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index be5440b1c854..d6d18b3963c9 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -15,7 +15,139 @@ from polars.testing import assert_frame_equal, assert_series_equal from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser from polars.utils.various import in_terminal_that_supports_colour -from tests.test_udfs import MY_CONSTANT, MY_DICT, MY_LIST, NOOP_TEST_CASES, TEST_CASES + +MY_CONSTANT = 3 +MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} +MY_LIST = [1, 2, 3] + + +# column_name, function, expected_suggestion +TEST_CASES = [ + # --------------------------------------------- + # numeric expr: math, comparison, logic ops + # --------------------------------------------- + ("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'), + ("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'), + ("a", "lambda x: x & True", 'pl.col("a") & True'), + ("a", "lambda x: x | False", 'pl.col("a") | False'), + ("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'), + ("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'), + ("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'), + ("a", "lambda x: x is None", 'pl.col("a") is None'), + ("a", "lambda x: x is not None", 'pl.col("a") is not None'), + ( + "a", + "lambda x: ((x * -x) ** x) * 1.0", + '((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0', + ), + ( + "a", + "lambda x: 1.0 * (x * (x**x))", + '1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))', + ), + ( + "a", + "lambda x: (x / x) + ((x * x) - x)", + '(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))', + ), + ( + "a", + "lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))", + '(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))', + ), + ("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'), + ("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'), + ( + "a", + "lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0", + 'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)', + ), + ("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'), + ("a", "lambda x: 0 + numpy.cbrt(x)", '0 + pl.col("a").cbrt()'), + ("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'), + ( + "a", # note: functions operate on consts + "lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)", + '(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)', + ), + ( + "a", + "lambda x: (float(x) * int(x)) // 2", + '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', + ), + # --------------------------------------------- + # logical 'and/or' (validate nesting levels) + # --------------------------------------------- + ( + "a", + "lambda x: x > 1 or (x == 1 and x == 2)", + '(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)', + ), + ( + "a", + "lambda x: (x > 1 or x == 1) and x == 2", + '((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)', + ), + ( + "a", + "lambda x: x > 2 or x != 3 and x not in (0, 1, 4)", + '(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))', + ), + ( + "a", + "lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3", + '(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)', + ), + ( + "a", + "lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3", + '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', + ), + # --------------------------------------------- + # string expr: case/cast ops + # --------------------------------------------- + ("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), + ( + "b", + 'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()', + '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', + ), + # --------------------------------------------- + # json expr: load/extract + # --------------------------------------------- + ("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_extract()'), + # --------------------------------------------- + # map_dict + # --------------------------------------------- + ("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'), + ( + "a", + "lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]", + '(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)', + ), + # --------------------------------------------- + # standard library datetime parsing + # --------------------------------------------- + ( + "d", + 'lambda x: datetime.strptime(x, "%Y-%m-%d")', + 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', + ), + ( + "d", + 'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")', + 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', + ), +] + +NOOP_TEST_CASES = [ + "lambda x: x", + "lambda x, y: x + y", + "lambda x: x[0] + 1", + "lambda x: MY_LIST[x]", + "lambda x: MY_DICT[1]", + 'lambda x: "first" if x == 1 else "not first"', +] EVAL_ENVIRONMENT = { "np": numpy,