Skip to content

Commit

Permalink
Fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Oct 29, 2023
1 parent 92af19f commit e46539b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 1 deletion.
1 change: 1 addition & 0 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest

# TODO: Import these from py-polars/tests/unit/operations/test_inefficient_apply.py
MY_CONSTANT = 3
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
MY_LIST = [1, 2, 3]
Expand Down
134 changes: 133 additions & 1 deletion py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,139 @@
from polars.testing import assert_frame_equal, assert_series_equal
from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser
from polars.utils.various import in_terminal_that_supports_colour
from tests.test_udfs import MY_CONSTANT, MY_DICT, MY_LIST, NOOP_TEST_CASES, TEST_CASES

MY_CONSTANT = 3
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
MY_LIST = [1, 2, 3]


# column_name, function, expected_suggestion
TEST_CASES = [
# ---------------------------------------------
# numeric expr: math, comparison, logic ops
# ---------------------------------------------
("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'),
("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'),
("a", "lambda x: x & True", 'pl.col("a") & True'),
("a", "lambda x: x | False", 'pl.col("a") | False'),
("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'),
("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'),
("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'),
("a", "lambda x: x is None", 'pl.col("a") is None'),
("a", "lambda x: x is not None", 'pl.col("a") is not None'),
(
"a",
"lambda x: ((x * -x) ** x) * 1.0",
'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',
),
(
"a",
"lambda x: 1.0 * (x * (x**x))",
'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',
),
(
"a",
"lambda x: (x / x) + ((x * x) - x)",
'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',
),
(
"a",
"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",
'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',
),
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'),
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'),
(
"a",
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
),
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'),
("a", "lambda x: 0 + numpy.cbrt(x)", '0 + pl.col("a").cbrt()'),
("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'),
(
"a", # note: functions operate on consts
"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",
'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',
),
(
"a",
"lambda x: (float(x) * int(x)) // 2",
'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',
),
# ---------------------------------------------
# logical 'and/or' (validate nesting levels)
# ---------------------------------------------
(
"a",
"lambda x: x > 1 or (x == 1 and x == 2)",
'(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)',
),
(
"a",
"lambda x: (x > 1 or x == 1) and x == 2",
'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',
),
(
"a",
"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",
'(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))',
),
(
"a",
"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",
'(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)',
),
(
"a",
"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
),
# ---------------------------------------------
# string expr: case/cast ops
# ---------------------------------------------
("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'),
(
"b",
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
),
# ---------------------------------------------
# json expr: load/extract
# ---------------------------------------------
("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_extract()'),
# ---------------------------------------------
# map_dict
# ---------------------------------------------
("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'),
(
"a",
"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",
'(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)',
),
# ---------------------------------------------
# standard library datetime parsing
# ---------------------------------------------
(
"d",
'lambda x: datetime.strptime(x, "%Y-%m-%d")',
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
),
(
"d",
'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
),
]

NOOP_TEST_CASES = [
"lambda x: x",
"lambda x, y: x + y",
"lambda x: x[0] + 1",
"lambda x: MY_LIST[x]",
"lambda x: MY_DICT[1]",
'lambda x: "first" if x == 1 else "not first"',
]

EVAL_ENVIRONMENT = {
"np": numpy,
Expand Down

0 comments on commit e46539b

Please sign in to comment.