diff --git a/docs/source/user-guide/common-operations/expressions.rst b/docs/source/user-guide/common-operations/expressions.rst index b2a83c89..e94e1a6b 100644 --- a/docs/source/user-guide/common-operations/expressions.rst +++ b/docs/source/user-guide/common-operations/expressions.rst @@ -110,6 +110,35 @@ This function returns an integer indicating the total number of elements in the In this example, the `num_elements` column will contain `3` for both rows. +To concatenate two arrays, you can use the function :py:func:`datafusion.functions.array_cat` or :py:func:`datafusion.functions.array_concat`. +These functions return a new array that is the concatenation of the input arrays. + +.. ipython:: python + + from datafusion import SessionContext, col + from datafusion.functions import array_cat, array_concat + + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3]], "b": [[4, 5, 6]]}) + df.select(array_cat(col("a"), col("b")).alias("concatenated_array")) + +In this example, the `concatenated_array` column will contain `[1, 2, 3, 4, 5, 6]`. + +To repeat the elements of an array a specified number of times, you can use the function :py:func:`datafusion.functions.array_repeat`. +This function returns a new array with the elements repeated. + +.. ipython:: python + + from datafusion import SessionContext, col, literal + from datafusion.functions import array_repeat + + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3]]}) + df.select(array_repeat(col("a"), literal(2)).alias("repeated_array")) + +In this example, the `repeated_array` column will contain `[[1, 2, 3], [1, 2, 3]]`. + + Structs ------- diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 5a2eab56..88ea7280 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -147,6 +147,8 @@ "length", "levenshtein", "list_append", + "list_cat", + "list_concat", "list_dims", "list_distinct", "list_element", @@ -162,6 +164,7 @@ "list_prepend", "list_push_back", "list_push_front", + "list_repeat", "list_remove", "list_remove_all", "list_remove_n", @@ -1145,6 +1148,22 @@ def array_distinct(array: Expr) -> Expr: return Expr(f.array_distinct(array.expr)) +def list_cat(*args: Expr) -> Expr: + """Concatenates the input arrays. + + This is an alias for :py:func:`array_concat`, :py:func:`array_cat`. + """ + return array_concat(*args) + + +def list_concat(*args: Expr) -> Expr: + """Concatenates the input arrays. + + This is an alias for :py:func:`array_concat`, :py:func:`array_cat`. + """ + return array_concat(*args) + + def list_distinct(array: Expr) -> Expr: """Returns distinct values from the array after removing duplicates. @@ -1369,6 +1388,14 @@ def array_repeat(element: Expr, count: Expr) -> Expr: return Expr(f.array_repeat(element.expr, count.expr)) +def list_repeat(element: Expr, count: Expr) -> Expr: + """Returns an array containing ``element`` ``count`` times. + + This is an alias for :py:func:`array_repeat`. + """ + return array_repeat(element, count) + + def array_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: """Replaces the first occurrence of ``from_val`` with ``to_val``.""" return Expr(f.array_replace(array.expr, from_val.expr, to_val.expr)) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index b3a5a065..c14cfc2d 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -291,6 +291,14 @@ def py_flatten(arr): lambda col: f.array_cat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], ], + [ + lambda col: f.list_cat(col, col), + lambda data: [np.concatenate([arr, arr]) for arr in data], + ], + [ + lambda col: f.list_concat(col, col), + lambda data: [np.concatenate([arr, arr]) for arr in data], + ], [ lambda col: f.array_dims(col), lambda data: [[len(r)] for r in data], @@ -439,6 +447,10 @@ def py_flatten(arr): lambda col: f.array_repeat(col, literal(2)), lambda data: [[arr] * 2 for arr in data], ], + [ + lambda col: f.list_repeat(col, literal(2)), + lambda data: [[arr] * 2 for arr in data], + ], [ lambda col: f.array_replace(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],