Skip to content

Commit

Permalink
feat: arrow join methods (#558)
Browse files Browse the repository at this point in the history
* feat: pyarrow join methods

* fill_null include type

* add extra hypothesis test cause im paranoid

* fix typo in err msg

---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jul 20, 2024
1 parent fb80986 commit be4ecde
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 77 deletions.
34 changes: 23 additions & 11 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import get_pyarrow_parquet
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -208,7 +209,7 @@ def join(
self,
other: Self,
*,
how: Literal["inner"] = "inner",
how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner",
left_on: str | list[str] | None,
right_on: str | list[str] | None,
) -> Self:
Expand All @@ -217,24 +218,35 @@ def join(
if isinstance(right_on, str):
right_on = [right_on]

if how == "cross": # type: ignore[comparison-overlap]
raise NotImplementedError

if how == "anti": # type: ignore[comparison-overlap]
raise NotImplementedError
how_to_join_map = {
"anti": "left anti",
"semi": "left semi",
"inner": "inner",
"left": "left outer",
}

if how == "semi": # type: ignore[comparison-overlap]
raise NotImplementedError
if how == "cross":
plx = self.__narwhals_namespace__()
key_token = generate_unique_token(
n_bytes=8, columns=[*self.columns, *other.columns]
)

if how == "left": # type: ignore[comparison-overlap]
raise NotImplementedError
return self._from_native_dataframe(
self.with_columns(**{key_token: plx.lit(0, None)})._native_dataframe.join(
other.with_columns(**{key_token: plx.lit(0, None)})._native_dataframe,
keys=key_token,
right_keys=key_token,
join_type="inner",
right_suffix="_right",
),
).drop(key_token)

return self._from_native_dataframe(
self._native_dataframe.join(
other._native_dataframe,
keys=left_on,
right_keys=right_on,
join_type=how,
join_type=how_to_join_map[how],
right_suffix="_right",
),
)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def sample(
self, "sample", n=n, fraction=fraction, with_replacement=with_replacement
)

def fill_null(self: Self, value: Any) -> Self:
return reuse_series_implementation(self, "fill_null", value=value)

@property
def dt(self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def sample(
mask = np.random.choice(idx, size=n, replace=with_replacement)
return self._from_native_series(pc.take(ser, mask))

def fill_null(self: Self, value: Any) -> Self:
pa = get_pyarrow()
pc = get_pyarrow_compute()
ser = self._native_series
dtype = ser.type

return self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype)))

@property
def shape(self) -> tuple[int]:
return (len(self._native_series),)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.utils import create_native_series
from narwhals._pandas_like.utils import generate_unique_token
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_dataframe_comparand
Expand All @@ -23,6 +22,7 @@
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
29 changes: 0 additions & 29 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import secrets
from enum import Enum
from enum import auto
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -458,31 +457,3 @@ def int_dtype_mapper(dtype: Any) -> str:
if str(dtype).lower() != str(dtype): # pragma: no cover
return "Int64"
return "int64"


def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: no cover
"""Generates a unique token of specified n_bytes that is not present in the given list of columns.
Arguments:
n_bytes : The number of bytes to generate for the token.
columns : The list of columns to check for uniqueness.
Returns:
A unique token that is not present in the given list of columns.
Raises:
AssertionError: If a unique token cannot be generated after 100 attempts.
"""
counter = 0
while True:
token = secrets.token_hex(n_bytes)
if token not in columns:
return token

counter += 1
if counter > 100:
msg = (
"Internal Error: Narwhals was not able to generate a column name to perform cross "
"join operation"
)
raise AssertionError(msg)
29 changes: 29 additions & 0 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
import secrets
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand Down Expand Up @@ -325,3 +326,31 @@ def is_ordered_categorical(series: Series) -> bool:
return native_series.type.ordered # type: ignore[no-any-return]
# If it doesn't match any of the above, let's just play it safe and return False.
return False # pragma: no cover


def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: no cover
"""Generates a unique token of specified n_bytes that is not present in the given list of columns.
Arguments:
n_bytes : The number of bytes to generate for the token.
columns : The list of columns to check for uniqueness.
Returns:
A unique token that is not present in the given list of columns.
Raises:
AssertionError: If a unique token cannot be generated after 100 attempts.
"""
counter = 0
while True:
token = secrets.token_hex(n_bytes)
if token not in columns:
return token

counter += 1
if counter > 100:
msg = (
"Internal Error: Narwhals was not able to generate a column name to perform given "
"join operation"
)
raise AssertionError(msg)
7 changes: 1 addition & 6 deletions tests/expr_and_series/fill_null_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

Expand All @@ -12,10 +10,7 @@
}


def test_fill_null(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_fill_null(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)

result = df.with_columns(nw.col("a", "b", "c").fill_null(99))
Expand Down
37 changes: 8 additions & 29 deletions tests/frame/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,11 @@ def test_inner_join_single_key(constructor: Any) -> None:
compare_dicts(result, expected)


def test_cross_join(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_cross_join(constructor: Any) -> None:
data = {"a": [1, 3, 2]}
df = nw.from_native(constructor(data))
result = df.join(df, how="cross") # type: ignore[arg-type]

expected = {"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]}
result = df.join(df, how="cross").sort("a", "a_right") # type: ignore[arg-type]
expected = {"a": [1, 1, 1, 2, 2, 2, 3, 3, 3], "a_right": [1, 2, 3, 1, 2, 3, 1, 2, 3]}
compare_dicts(result, expected)

with pytest.raises(ValueError, match="Can not pass left_on, right_on for cross join"):
Expand All @@ -71,15 +67,11 @@ def test_cross_join_non_pandas() -> None:
],
)
def test_anti_join(
request: Any,
constructor: Any,
join_key: list[str],
filter_expr: nw.Expr,
expected: dict[str, list[Any]],
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data))
other = df.filter(filter_expr)
Expand All @@ -96,15 +88,11 @@ def test_anti_join(
],
)
def test_semi_join(
request: Any,
constructor: Any,
join_key: list[str],
filter_expr: nw.Expr,
expected: dict[str, list[Any]],
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data))
other = df.filter(filter_expr)
Expand All @@ -127,10 +115,7 @@ def test_join_not_implemented(constructor: Any, how: str) -> None:


@pytest.mark.filterwarnings("ignore:the default coalesce behavior")
def test_left_join(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_left_join(constructor: Any) -> None:
data_left = {"a": [1.0, 2, 3], "b": [4.0, 5, 6]}
data_right = {"a": [1.0, 2, 3], "c": [4.0, 5, 7]}
df_left = nw.from_native(constructor(data_left), eager_only=True)
Expand All @@ -143,10 +128,7 @@ def test_left_join(request: Any, constructor: Any) -> None:


@pytest.mark.filterwarnings("ignore: the default coalesce behavior")
def test_left_join_multiple_column(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_left_join_multiple_column(constructor: Any) -> None:
data_left = {"a": [1, 2, 3], "b": [4, 5, 6]}
data_right = {"a": [1, 2, 3], "c": [4, 5, 6]}
df_left = nw.from_native(constructor(data_left), eager_only=True)
Expand All @@ -157,12 +139,9 @@ def test_left_join_multiple_column(request: Any, constructor: Any) -> None:


@pytest.mark.filterwarnings("ignore: the default coalesce behavior")
def test_left_join_overlapping_column(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

data_left = {"a": [1, 2, 3], "b": [4, 5, 6], "d": [1, 4, 2]}
data_right = {"a": [1, 2, 3], "c": [4, 5, 6], "d": [1, 4, 2]}
def test_left_join_overlapping_column(constructor: Any) -> None:
data_left = {"a": [1.0, 2, 3], "b": [4.0, 5, 6], "d": [1.0, 4, 2]}
data_right = {"a": [1.0, 2, 3], "c": [4.0, 5, 6], "d": [1.0, 4, 2]}
df_left = nw.from_native(constructor(data_left), eager_only=True)
df_right = nw.from_native(constructor(data_right), eager_only=True)
result = df_left.join(df_right, left_on="b", right_on="c", how="left")
Expand Down
16 changes: 16 additions & 0 deletions tests/hypothesis/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from hypothesis import assume
from hypothesis import given
Expand Down Expand Up @@ -163,3 +164,18 @@ def test_left_join( # pragma: no cover
)
).select(pl.all().fill_null(float("nan")))
compare_dicts(result_pd.to_dict(as_series=False), result_pl.to_dict(as_series=False))
# For PyArrow, insert an extra sort, as the order of rows isn't guaranteed
result_pa = (
nw.from_native(pa.table(data_left), eager_only=True)
.join(
nw.from_native(pa.table(data_right), eager_only=True),
how="left",
left_on=left_key,
right_on=right_key,
)
.select(nw.all().cast(nw.Float64).fill_null(float("nan")))
.pipe(lambda df: df.sort(df.columns))
)
compare_dicts(
result_pa, result_pd.pipe(lambda df: df.sort(df.columns)).to_dict(as_series=False)
)
1 change: 0 additions & 1 deletion utils/check_backend_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"DataFrame.pipe",
"DataFrame.unique",
"Series.drop_nulls",
"Series.fill_null",
"Series.from_iterable",
"Series.is_between",
"Series.is_duplicated",
Expand Down

0 comments on commit be4ecde

Please sign in to comment.