Skip to content

Commit

Permalink
Add list_cat, list_concat, list_repeat (#942)
Browse files Browse the repository at this point in the history
* Add list_cat, list_concat

* Add list_repeat

* docs: add examples for list_cat, list_concat, and list_repeat functions

* Amend list_repeat code example - literal

* Amend list_ to array_ in documentation
  • Loading branch information
kosiew authored Nov 12, 2024
1 parent e3e55b7 commit 53cdb11
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/source/user-guide/common-operations/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,35 @@ This function returns an integer indicating the total number of elements in the
In this example, the `num_elements` column will contain `3` for both rows.

To concatenate two arrays, you can use the function :py:func:`datafusion.functions.array_cat` or :py:func:`datafusion.functions.array_concat`.
These functions return a new array that is the concatenation of the input arrays.

.. ipython:: python
from datafusion import SessionContext, col
from datafusion.functions import array_cat, array_concat
ctx = SessionContext()
df = ctx.from_pydict({"a": [[1, 2, 3]], "b": [[4, 5, 6]]})
df.select(array_cat(col("a"), col("b")).alias("concatenated_array"))
In this example, the `concatenated_array` column will contain `[1, 2, 3, 4, 5, 6]`.

To repeat the elements of an array a specified number of times, you can use the function :py:func:`datafusion.functions.array_repeat`.
This function returns a new array with the elements repeated.

.. ipython:: python
from datafusion import SessionContext, col, literal
from datafusion.functions import array_repeat
ctx = SessionContext()
df = ctx.from_pydict({"a": [[1, 2, 3]]})
df.select(array_repeat(col("a"), literal(2)).alias("repeated_array"))
In this example, the `repeated_array` column will contain `[[1, 2, 3], [1, 2, 3]]`.


Structs
-------

Expand Down
27 changes: 27 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
"length",
"levenshtein",
"list_append",
"list_cat",
"list_concat",
"list_dims",
"list_distinct",
"list_element",
Expand All @@ -162,6 +164,7 @@
"list_prepend",
"list_push_back",
"list_push_front",
"list_repeat",
"list_remove",
"list_remove_all",
"list_remove_n",
Expand Down Expand Up @@ -1145,6 +1148,22 @@ def array_distinct(array: Expr) -> Expr:
return Expr(f.array_distinct(array.expr))


def list_cat(*args: Expr) -> Expr:
"""Concatenates the input arrays.
This is an alias for :py:func:`array_concat`, :py:func:`array_cat`.
"""
return array_concat(*args)


def list_concat(*args: Expr) -> Expr:
"""Concatenates the input arrays.
This is an alias for :py:func:`array_concat`, :py:func:`array_cat`.
"""
return array_concat(*args)


def list_distinct(array: Expr) -> Expr:
"""Returns distinct values from the array after removing duplicates.
Expand Down Expand Up @@ -1369,6 +1388,14 @@ def array_repeat(element: Expr, count: Expr) -> Expr:
return Expr(f.array_repeat(element.expr, count.expr))


def list_repeat(element: Expr, count: Expr) -> Expr:
"""Returns an array containing ``element`` ``count`` times.
This is an alias for :py:func:`array_repeat`.
"""
return array_repeat(element, count)


def array_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr:
"""Replaces the first occurrence of ``from_val`` with ``to_val``."""
return Expr(f.array_replace(array.expr, from_val.expr, to_val.expr))
Expand Down
12 changes: 12 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ def py_flatten(arr):
lambda col: f.array_cat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
],
[
lambda col: f.list_cat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
],
[
lambda col: f.list_concat(col, col),
lambda data: [np.concatenate([arr, arr]) for arr in data],
],
[
lambda col: f.array_dims(col),
lambda data: [[len(r)] for r in data],
Expand Down Expand Up @@ -439,6 +447,10 @@ def py_flatten(arr):
lambda col: f.array_repeat(col, literal(2)),
lambda data: [[arr] * 2 for arr in data],
],
[
lambda col: f.list_repeat(col, literal(2)),
lambda data: [[arr] * 2 for arr in data],
],
[
lambda col: f.array_replace(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
Expand Down

0 comments on commit 53cdb11

Please sign in to comment.