Skip to content

Commit

Permalink
Upgrade to Datafusion 43 (#905)
Browse files Browse the repository at this point in the history
* patch datafusion deps

* migrate from deprecated RuntimeEnv::new to RuntimeEnv::try_new

Ref: apache/datafusion#12566

* remove Arc from create_udf call

Ref: apache/datafusion#12489

* doc typo

* migrage new UnnestOptions API

Ref: https://github.com/apache/datafusion/pull/12836/files

* update API for logical expr Limit

Ref: apache/datafusion#12836

* remove logical expr CrossJoin

It was removed upstream.

Ref: apache/datafusion#13076

* update PyWindowUDF

Ref: apache/datafusion#12803

* migrate window functions lead and lag to udwf

Ref: apache/datafusion#12802

* migrate window functions rank, dense_rank, and percent_rank to udwf

Ref: apache/datafusion#12648

* convert window function cume_dist to udwf

Ref: apache/datafusion#12695

* convert window function ntile to udwf

Ref: apache/datafusion#12694

* clean up functions_window invocation

* Only one column was being passed to udwf

* Update to DF 43.0.0

* Update tests to look for string_view type

* String view is now the default type for strings

* Making a variety of adjustments in wrappers and unit tests to account for the switch from string to string_view as default

* Resolve errors in doc building

---------

Co-authored-by: Tim Saucer <[email protected]>
  • Loading branch information
Michael-J-Ward and timsaucer authored Nov 10, 2024
1 parent 4a6c4d1 commit 3c66201
Show file tree
Hide file tree
Showing 19 changed files with 338 additions and 338 deletions.
373 changes: 199 additions & 174 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ substrait = ["dep:datafusion-substrait"]
tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "42.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "42.0.0", optional = true }
datafusion-proto = { version = "42.0.0" }
datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "43.0.0", optional = true }
datafusion-proto = { version = "43.0.0" }
datafusion-functions-window-common = { version = "43.0.0" }
prost = "0.13" # keep in line with `datafusion-substrait`
uuid = { version = "1.11", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
Expand All @@ -58,4 +59,4 @@ crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
codegen-units = 1
codegen-units = 1
4 changes: 2 additions & 2 deletions examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def df_selection(col_name, col_type):
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
return F.round(col(col_name), lit(2)).alias(col_name)
elif col_type == pa.string():
elif col_type == pa.string() or col_type == pa.string_view():
return F.trim(col(col_name)).alias(col_name)
else:
return col(col_name)
Expand All @@ -43,7 +43,7 @@ def load_schema(col_name, col_type):
def expected_selection(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return F.trim(col(col_name)).cast(col_type).alias(col_name)
elif col_type == pa.string():
elif col_type == pa.string() or col_type == pa.string_view():
return F.trim(col(col_name)).alias(col_name)
else:
return col(col_name)
Expand Down
4 changes: 2 additions & 2 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
Column = expr_internal.Column
CreateMemoryTable = expr_internal.CreateMemoryTable
CreateView = expr_internal.CreateView
CrossJoin = expr_internal.CrossJoin
Distinct = expr_internal.Distinct
DropTable = expr_internal.DropTable
EmptyRelation = expr_internal.EmptyRelation
Expand Down Expand Up @@ -140,7 +139,6 @@
"Join",
"JoinType",
"JoinConstraint",
"CrossJoin",
"Union",
"Unnest",
"UnnestExpr",
Expand Down Expand Up @@ -376,6 +374,8 @@ def literal(value: Any) -> Expr:
``value`` must be a valid PyArrow scalar value or easily castable to one.
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
if not isinstance(value, pa.Scalar):
value = pa.scalar(value)
return Expr(expr_internal.Expr.literal(value))
Expand Down
11 changes: 8 additions & 3 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def decode(input: Expr, encoding: Expr) -> Expr:

def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
"""Converts each element to its text representation."""
return Expr(f.array_to_string(expr.expr, delimiter.expr))
return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string())))


def array_join(expr: Expr, delimiter: Expr) -> Expr:
Expand Down Expand Up @@ -1067,7 +1067,10 @@ def struct(*args: Expr) -> Expr:

def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
"""Returns a struct with the given names and arguments pairs."""
name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
name_pair_exprs = [
[Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]]
for pair in name_pairs
]

# flatten
name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
Expand Down Expand Up @@ -1424,7 +1427,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
return Expr(
f.array_sort(
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
array.expr,
Expr.literal(pa.scalar(desc, type=pa.string())).expr,
Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr,
)
)

Expand Down
1 change: 1 addition & 0 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def udaf(
which this UDAF is used. The following examples are all valid.
.. code-block:: python
import pyarrow as pa
import pyarrow.compute as pc
Expand Down
16 changes: 12 additions & 4 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ def test_limit(test_ctx):

plan = plan.to_variant()
assert isinstance(plan, Limit)
assert plan.skip() == 0
# TODO: Upstream now has expressions for skip and fetch
# REF: https://github.com/apache/datafusion/pull/12836
# assert plan.skip() == 0

df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
plan = df.logical_plan()

plan = plan.to_variant()
assert isinstance(plan, Limit)
assert plan.skip() == 5
# TODO: Upstream now has expressions for skip and fetch
# REF: https://github.com/apache/datafusion/pull/12836
# assert plan.skip() == 5


def test_aggregate_query(test_ctx):
Expand Down Expand Up @@ -126,7 +130,10 @@ def test_relational_expr(test_ctx):
ctx = SessionContext()

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
[
pa.array([1, 2, 3]),
pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], name="batch_array")
Expand All @@ -141,7 +148,8 @@ def test_relational_expr(test_ctx):
assert df.filter(col("b") == "beta").count() == 1
assert df.filter(col("b") != "beta").count() == 2

assert df.filter(col("a") == "beta").count() == 0
with pytest.raises(Exception):
df.filter(col("a") == "beta").count()


def test_expr_to_variant():
Expand Down
67 changes: 47 additions & 20 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def df():
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array(["Hello", "World", "!"]),
pa.array(["Hello", "World", "!"], type=pa.string_view()),
pa.array([4, 5, 6]),
pa.array(["hello ", " world ", " !"]),
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
pa.array(
[
datetime(2022, 12, 31),
Expand Down Expand Up @@ -88,16 +88,18 @@ def test_literal(df):
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([1] * 3)
assert result.column(1) == pa.array(["1"] * 3)
assert result.column(2) == pa.array(["OK"] * 3)
assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
assert result.column(3) == pa.array([3.14] * 3)
assert result.column(4) == pa.array([True] * 3)
assert result.column(5) == pa.array([b"hello world"] * 3)


def test_lit_arith(df):
"""Test literals with arithmetic operations"""
df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!")))
df = df.select(
literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
)
result = df.collect()
assert len(result) == 1
result = result[0]
Expand Down Expand Up @@ -600,21 +602,33 @@ def test_array_function_obj_tests(stmt, py_expr):
f.ascii(column("a")),
pa.array([72, 87, 33], type=pa.int32()),
), # H = 72; W = 87; ! = 33
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
(
f.bit_length(column("a").cast(pa.string())),
pa.array([40, 40, 8], type=pa.int32()),
),
(
f.btrim(literal(" World ")),
pa.array(["World", "World", "World"], type=pa.string_view()),
),
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
(
f.concat_ws("-", column("a"), literal("test")),
pa.array(["Hello-test", "World-test", "!-test"]),
),
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
(
f.concat(column("a").cast(pa.string()), literal("?")),
pa.array(["Hello?", "World?", "!?"]),
),
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
(
f.ltrim(column("c")),
pa.array(["hello ", "world ", "!"], type=pa.string_view()),
),
(
f.md5(column("a")),
pa.array(
Expand All @@ -640,19 +654,25 @@ def test_array_function_obj_tests(stmt, py_expr):
f.rpad(column("a"), literal(8)),
pa.array(["Hello ", "World ", "! "]),
),
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
(
f.rtrim(column("c")),
pa.array(["hello", " world", " !"], type=pa.string_view()),
),
(
f.split_part(column("a"), literal("l"), literal(1)),
pa.array(["He", "Wor", "!"]),
),
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
(
f.substr(column("a"), literal(3)),
pa.array(["llo", "rld", ""], type=pa.string_view()),
),
(
f.translate(column("a"), literal("or"), literal("ld")),
pa.array(["Helll", "Wldld", "!"]),
),
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
(f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
(
Expand Down Expand Up @@ -794,9 +814,9 @@ def test_temporal_functions(df):
f.date_trunc(literal("month"), column("d")),
f.datetrunc(literal("day"), column("d")),
f.date_bin(
literal("15 minutes"),
literal("15 minutes").cast(pa.string()),
column("d"),
literal("2001-01-01 00:02:30"),
literal("2001-01-01 00:02:30").cast(pa.string()),
),
f.from_unixtime(literal(1673383974)),
f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
Expand Down Expand Up @@ -858,8 +878,8 @@ def test_case(df):
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([10, 8, 8])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
assert result.column(2) == pa.array(["Hola", "Mundo", None])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())


def test_when_with_no_base(df):
Expand All @@ -877,8 +897,10 @@ def test_when_with_no_base(df):
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([4, 5, 6])
assert result.column(1) == pa.array(["too small", "just right", "too big"])
assert result.column(2) == pa.array(["Hello", None, None])
assert result.column(1) == pa.array(
["too small", "just right", "too big"], type=pa.string_view()
)
assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())


def test_regr_funcs_sql(df):
Expand Down Expand Up @@ -1021,8 +1043,13 @@ def test_regr_funcs_df(func, expected):

def test_binary_string_functions(df):
df = df.select(
f.encode(column("a"), literal("base64")),
f.decode(f.encode(column("a"), literal("base64")), literal("base64")),
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
f.decode(
f.encode(
column("a").cast(pa.string()), literal("base64").cast(pa.string())
),
literal("base64").cast(pa.string()),
),
)
result = df.collect()
assert len(result) == 1
Expand Down
2 changes: 0 additions & 2 deletions python/tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
Join,
JoinType,
JoinConstraint,
CrossJoin,
Union,
Like,
ILike,
Expand Down Expand Up @@ -129,7 +128,6 @@ def test_class_module_is_datafusion():
Join,
JoinType,
JoinConstraint,
CrossJoin,
Union,
Like,
ILike,
Expand Down
7 changes: 7 additions & 0 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,13 @@ def test_simple_select(ctx, tmp_path, arr):
batches = ctx.sql("SELECT a AS tt FROM t").collect()
result = batches[0].column(0)

# In DF 43.0.0 we now default to having BinaryView and StringView
# so the array that is saved to the parquet is slightly different
# than the array read. Convert to values for comparison.
if isinstance(result, pa.BinaryViewArray) or isinstance(result, pa.StringViewArray):
arr = arr.tolist()
result = result.tolist()

np.testing.assert_equal(result, arr)


Expand Down
2 changes: 1 addition & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ impl PySessionContext {
} else {
RuntimeConfig::default()
};
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let runtime = Arc::new(RuntimeEnv::try_new(runtime_config)?);
let session_state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
Expand Down
8 changes: 6 additions & 2 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ impl PyDataFrame {

#[pyo3(signature = (column, preserve_nulls=true))]
fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult<Self> {
let unnest_options = UnnestOptions { preserve_nulls };
// TODO: expose RecursionUnnestOptions
// REF: https://github.com/apache/datafusion/pull/11577
let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
let df = self
.df
.as_ref()
Expand All @@ -413,7 +415,9 @@ impl PyDataFrame {

#[pyo3(signature = (columns, preserve_nulls=true))]
fn unnest_columns(&self, columns: Vec<String>, preserve_nulls: bool) -> PyResult<Self> {
let unnest_options = UnnestOptions { preserve_nulls };
// TODO: expose RecursionUnnestOptions
// REF: https://github.com/apache/datafusion/pull/11577
let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self
.df
Expand Down
2 changes: 0 additions & 2 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub mod column;
pub mod conditional_expr;
pub mod create_memory_table;
pub mod create_view;
pub mod cross_join;
pub mod distinct;
pub mod drop_table;
pub mod empty_relation;
Expand Down Expand Up @@ -775,7 +774,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<join::PyJoin>()?;
m.add_class::<join::PyJoinType>()?;
m.add_class::<join::PyJoinConstraint>()?;
m.add_class::<cross_join::PyCrossJoin>()?;
m.add_class::<union::PyUnion>()?;
m.add_class::<unnest::PyUnnest>()?;
m.add_class::<unnest_expr::PyUnnestExpr>()?;
Expand Down
Loading

0 comments on commit 3c66201

Please sign in to comment.