Skip to content

Commit

Permalink
feat: cross join (#347)
Browse files Browse the repository at this point in the history
* cross join

* cross join test

* fix example from narwhalify

* lazy join docstring

* cudf, modin, and min pandas

* random token as key

* fix issue with non-str col names

* raise if keys are passed for cross join

* explicit expected values

* raise after 100 iterations

* add hypothesis test

* remove unnecessary xfail

* Revert "remove unnecessary xfail"

This reverts commit 09c33bc.

* remove floats from cross-join

---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jul 1, 2024
1 parent 838135c commit 02ea8da
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 71 deletions.
53 changes: 50 additions & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.utils import flatten
from narwhals.utils import parse_version

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -279,15 +280,61 @@ def join(
self,
other: Self,
*,
how: Literal["left", "inner", "outer"] = "inner",
left_on: str | list[str],
right_on: str | list[str],
how: Literal["left", "inner", "outer", "cross"] = "inner",
left_on: str | list[str] | None = None,
right_on: str | list[str] | None = None,
) -> Self:
if isinstance(left_on, str):
left_on = [left_on]
if isinstance(right_on, str):
right_on = [right_on]

if how == "cross":
if self._implementation in {"modin", "cudf"} or (
self._implementation == "pandas"
and (pd := get_pandas()) is not None
and parse_version(pd.__version__) < parse_version("1.4.0")
):

def generate_unique_token(
n_bytes: int, columns: list[str]
) -> str: # pragma: no cover
import secrets

counter = 0
while True:
token = secrets.token_hex(n_bytes)
if token not in columns:
return token

counter += 1
if counter > 100: # pragma: no cover
msg = (
"Internal Error: Narwhals was not able to generate a column name to perform cross "
"join operation"
)
raise AssertionError(msg)

key_token = generate_unique_token(8, self.columns)

return self._from_dataframe(
self._dataframe.assign(**{key_token: 0}).merge(
other._dataframe.assign(**{key_token: 0}),
how="inner",
left_on=key_token,
right_on=key_token,
suffixes=("", "_right"),
),
).drop(key_token)
else:
return self._from_dataframe(
self._dataframe.merge(
other._dataframe,
how="cross",
suffixes=("", "_right"),
),
)

return self._from_dataframe(
self._dataframe.merge(
other._dataframe,
Expand Down
153 changes: 88 additions & 65 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,19 @@ def join(
self,
other: Self,
*,
how: Literal["inner"] = "inner",
left_on: str | list[str],
right_on: str | list[str],
how: Literal["inner", "cross"] = "inner",
left_on: str | list[str] | None = None,
right_on: str | list[str] | None = None,
) -> Self:
if how != "inner":
raise NotImplementedError("Only inner joins are supported for now")
_supported_joins = {"inner", "cross"}
if how not in _supported_joins:
msg = f"Only the following join stragies are supported: {_supported_joins}"
raise NotImplementedError(msg)

if how == "cross" and (left_on or right_on):
msg = "Can not pass left_on, right_on for cross join"
raise ValueError(msg)

validate_same_library([self, other])
return self._from_dataframe(
self._dataframe.join(
Expand Down Expand Up @@ -1440,21 +1447,20 @@ def join(
self,
other: Self,
*,
how: Literal["inner"] = "inner",
left_on: str | list[str],
right_on: str | list[str],
how: Literal["inner", "cross"] = "inner",
left_on: str | list[str] | None = None,
right_on: str | list[str] | None = None,
) -> Self:
r"""
Join in SQL-like fashion.
Arguments:
other: DataFrame to join with.
how: {'inner'}
Join strategy.
how: Join strategy.
* *inner*: Returns rows that have matching values in both
tables
* *inner*: Returns rows that have matching values in both tables
* *cross*: Returns the Cartesian product of rows from both tables
left_on: Name(s) of the left join column(s).
Expand All @@ -1464,30 +1470,39 @@ def join(
A new joined DataFrame
Examples:
>>> import polars as pl
>>> import narwhals as nw
>>> df_pl = pl.DataFrame(
... {
... "foo": [1, 2, 3],
... "bar": [6.0, 7.0, 8.0],
... "ham": ["a", "b", "c"],
... }
... )
>>> other_df_pl = pl.DataFrame(
... {
... "apple": ["x", "y", "z"],
... "ham": ["a", "b", "d"],
... }
... )
>>> df = nw.from_native(df_pl, eager_only=True)
>>> other_df = nw.from_native(other_df_pl, eager_only=True)
>>> dframe = df.join(other_df, left_on="ham", right_on="ham")
>>> dframe
┌───────────────────────────────────────────────┐
| Narwhals DataFrame |
| Use `narwhals.to_native` to see native output |
└───────────────────────────────────────────────┘
>>> nw.to_native(dframe)
>>> import pandas as pd
>>> import polars as pl
>>> data = {
... "foo": [1, 2, 3],
... "bar": [6.0, 7.0, 8.0],
... "ham": ["a", "b", "c"],
... }
>>> data_other = {
... "apple": ["x", "y", "z"],
... "ham": ["a", "b", "d"],
... }
>>> df_pd = pd.DataFrame(data)
>>> other_pd = pd.DataFrame(data_other)
>>> df_pl = pl.DataFrame(data)
>>> other_pl = pl.DataFrame(data_other)
Let's define a dataframe-agnostic function in which we join over "ham" column:
>>> @nw.narwhalify
... def join_on_ham(df, other):
... return df.join(other, left_on="ham", right_on="ham")
We can now pass either pandas or Polars to the function:
>>> join_on_ham(df_pd, other_pd)
foo bar ham apple
0 1 6.0 a x
1 2 7.0 b y
>>> join_on_ham(df_pl, other_pl)
shape: (2, 4)
┌─────┬─────┬─────┬───────┐
│ foo ┆ bar ┆ ham ┆ apple │
Expand Down Expand Up @@ -2843,21 +2858,20 @@ def join(
self,
other: Self,
*,
how: Literal["inner"] = "inner",
left_on: str | list[str],
right_on: str | list[str],
how: Literal["inner", "cross"] = "inner",
left_on: str | list[str] | None = None,
right_on: str | list[str] | None = None,
) -> Self:
r"""
Add a join operation to the Logical Plan.
Arguments:
other: Lazy DataFrame to join with.
how: {'inner'}
Join strategy.
how: Join strategy.
* *inner*: Returns rows that have matching values in both
tables
* *inner*: Returns rows that have matching values in both tables
* *cross*: Returns the Cartesian product of rows from both tables
left_on: Join column of the left DataFrame.
Expand All @@ -2867,30 +2881,39 @@ def join(
A new joined LazyFrame
Examples:
>>> import polars as pl
>>> import narwhals as nw
>>> lf_pl = pl.LazyFrame(
... {
... "foo": [1, 2, 3],
... "bar": [6.0, 7.0, 8.0],
... "ham": ["a", "b", "c"],
... }
... )
>>> other_lf_pl = pl.LazyFrame(
... {
... "apple": ["x", "y", "z"],
... "ham": ["a", "b", "d"],
... }
... )
>>> lf = nw.from_native(lf_pl)
>>> other_lf = nw.from_native(other_lf_pl)
>>> lframe = lf.join(other_lf, left_on="ham", right_on="ham").collect()
>>> lframe
┌───────────────────────────────────────────────┐
| Narwhals DataFrame |
| Use `narwhals.to_native` to see native output |
└───────────────────────────────────────────────┘
>>> nw.to_native(lframe)
>>> import pandas as pd
>>> import polars as pl
>>> data = {
... "foo": [1, 2, 3],
... "bar": [6.0, 7.0, 8.0],
... "ham": ["a", "b", "c"],
... }
>>> data_other = {
... "apple": ["x", "y", "z"],
... "ham": ["a", "b", "d"],
... }
>>> df_pd = pd.DataFrame(data)
>>> other_pd = pd.DataFrame(data_other)
>>> df_pl = pl.LazyFrame(data)
>>> other_pl = pl.LazyFrame(data_other)
Let's define a dataframe-agnostic function in which we join over "ham" column:
>>> @nw.narwhalify
... def join_on_ham(df, other):
... return df.join(other, left_on="ham", right_on="ham")
We can now pass either pandas or Polars to the function:
>>> join_on_ham(df_pd, other_pd)
foo bar ham apple
0 1 6.0 a x
1 2 7.0 b y
>>> join_on_ham(df_pl, other_pl).collect()
shape: (2, 4)
┌─────┬─────┬─────┬───────┐
│ foo ┆ bar ┆ ham ┆ apple │
Expand Down
30 changes: 30 additions & 0 deletions tests/frame/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,36 @@ def test_join(df_raw: Any) -> None:
compare_dicts(result_native, expected)


@pytest.mark.parametrize(
("df_raw", "expected"),
[
(
df_polars,
{"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]},
),
(
df_lazy,
{"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]},
),
(
df_pandas,
{"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]},
),
(
df_mpd,
{"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]},
),
],
)
def test_cross_join(df_raw: Any, expected: dict[str, list[Any]]) -> None:
df = nw.from_native(df_raw).select("a")
result = df.join(df, how="cross") # type: ignore[arg-type]
compare_dicts(result, expected)

with pytest.raises(ValueError, match="Can not pass left_on, right_on for cross join"):
df.join(df, how="cross", left_on="a") # type: ignore[arg-type]


@pytest.mark.parametrize(
"df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow]
)
Expand Down
48 changes: 45 additions & 3 deletions tests/hypothesis/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd
import polars as pl
import pytest
from hypothesis import example
from hypothesis import given
from hypothesis import strategies as st
from pandas.testing import assert_frame_equal
Expand All @@ -14,7 +13,6 @@
pl_version = parse_version(pl.__version__)


@example([0, 0, 0], [0, 0, 0], [0.0, 0.0, -0.0], ["c"]) # type: ignore[misc]
@given(
st.lists(
st.integers(min_value=-9223372036854775807, max_value=9223372036854775807),
Expand All @@ -38,8 +36,8 @@
unique=True,
),
) # type: ignore[misc]
@pytest.mark.skipif(pl_version < parse_version("0.20.13"), reason="0.0 == -0.0")
@pytest.mark.slow()
@pytest.mark.xfail(pl_version < parse_version("0.20.13"), reason="0.0 == -0.0")
def test_join( # pragma: no cover
integers: st.SearchStrategy[list[int]],
other_integers: st.SearchStrategy[list[int]],
Expand Down Expand Up @@ -71,3 +69,47 @@ def test_join( # pragma: no cover
)

assert_frame_equal(dframe_pd1, dframe_pd2)


@given(
st.lists(
st.integers(min_value=-9223372036854775807, max_value=9223372036854775807),
min_size=3,
max_size=3,
),
st.lists(
st.integers(min_value=-9223372036854775807, max_value=9223372036854775807),
min_size=3,
max_size=3,
),
) # type: ignore[misc]
@pytest.mark.slow()
def test_cross_join( # pragma: no cover
integers: st.SearchStrategy[list[int]],
other_integers: st.SearchStrategy[list[int]],
) -> None:
data = {"a": integers, "b": other_integers}

df_polars = pl.DataFrame(data)
df_polars2 = pl.DataFrame(data)
df_pl = nw.DataFrame(df_polars)
other_pl = nw.DataFrame(df_polars2)
dframe_pl = df_pl.join(other_pl, how="cross")

df_pandas = pd.DataFrame(data)
df_pandas2 = pd.DataFrame(data)
df_pd = nw.DataFrame(df_pandas)
other_pd = nw.DataFrame(df_pandas2)
dframe_pd = df_pd.join(other_pd, how="cross")

dframe_pd1 = nw.to_native(dframe_pl).to_pandas()
dframe_pd1 = dframe_pd1.sort_values(
by=dframe_pd1.columns.to_list(), ignore_index=True
)

dframe_pd2 = nw.to_native(dframe_pd)
dframe_pd2 = dframe_pd2.sort_values(
by=dframe_pd2.columns.to_list(), ignore_index=True
)

assert_frame_equal(dframe_pd1, dframe_pd2)

0 comments on commit 02ea8da

Please sign in to comment.