Skip to content

Commit

Permalink
feat: arrow string namespace(s) (#551)
Browse files Browse the repository at this point in the history
* feat: arrow str namespace(s)

* Update narwhals/_arrow/series.py

---------

Co-authored-by: Marco Edward Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jul 20, 2024
1 parent be4ecde commit 4bf5bd2
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 43 deletions.
34 changes: 34 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ class ArrowExprStringNamespace:
def __init__(self, expr: ArrowExpr) -> None:
self._expr = expr

def starts_with(self, prefix: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._expr,
"str",
"starts_with",
prefix,
)

def ends_with(self, suffix: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._expr,
"str",
"ends_with",
suffix,
)

def contains(self, pattern: str, *, literal: bool) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._expr, "str", "contains", pattern, literal=literal
)

def slice(self, offset: int, length: int | None = None) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._expr, "str", "slice", offset, length
)

def to_datetime(self, format: str | None = None) -> ArrowExpr: # noqa: A002
return reuse_series_namespace_implementation(
self._expr,
"str",
"to_datetime",
format,
)

def to_uppercase(self) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._expr,
Expand Down
40 changes: 37 additions & 3 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,50 @@ def get_categories(self) -> ArrowSeries:


class ArrowSeriesStringNamespace:
def __init__(self, series: ArrowSeries) -> None:
def __init__(self: Self, series: ArrowSeries) -> None:
self._arrow_series = series

def to_uppercase(self) -> ArrowSeries:
def starts_with(self: Self, prefix: str) -> ArrowSeries:
pc = get_pyarrow_compute()
return self._arrow_series._from_native_series(
pc.equal(self.slice(0, len(prefix))._native_series, prefix)
)

def ends_with(self: Self, suffix: str) -> ArrowSeries:
pc = get_pyarrow_compute()
return self._arrow_series._from_native_series(
pc.equal(self.slice(-len(suffix))._native_series, suffix)
)

def contains(self: Self, pattern: str, *, literal: bool = False) -> ArrowSeries:
pc = get_pyarrow_compute()
check_func = pc.match_substring if literal else pc.match_substring_regex
return self._arrow_series._from_native_series(
check_func(self._arrow_series._native_series, pattern)
)

def slice(self: Self, offset: int, length: int | None = None) -> ArrowSeries:
pc = get_pyarrow_compute()
stop = offset + length if length else None
return self._arrow_series._from_native_series(
pc.utf8_slice_codeunits(
self._arrow_series._native_series, start=offset, stop=stop
),
)

def to_datetime(self: Self, format: str | None = None) -> ArrowSeries: # noqa: A002
pc = get_pyarrow_compute()
return self._arrow_series._from_native_series(
pc.strptime(self._arrow_series._native_series, format=format, unit="us")
)

def to_uppercase(self: Self) -> ArrowSeries:
pc = get_pyarrow_compute()
return self._arrow_series._from_native_series(
pc.utf8_upper(self._arrow_series._native_series),
)

def to_lowercase(self) -> ArrowSeries:
def to_lowercase(self: Self) -> ArrowSeries:
pc = get_pyarrow_compute()
return self._arrow_series._from_native_series(
pc.utf8_lower(self._arrow_series._native_series),
Expand Down
6 changes: 2 additions & 4 deletions tests/expr_and_series/str/contains_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pandas as pd
import polars as pl
import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts
Expand All @@ -13,9 +12,8 @@
df_polars = pl.DataFrame(data)


@pytest.mark.parametrize("df_any", [df_pandas, df_polars])
def test_contains(df_any: Any) -> None:
df = nw.from_native(df_any, eager_only=True)
def test_contains(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.with_columns(
case_insensitive_match=nw.col("pets").str.contains("(?i)parrot|Dove")
)
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/str/head_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

data = {"a": ["foo", "bars"]}


def test_str_head(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_str_head(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.select(nw.col("a").str.head(3))
expected = {
Expand Down
5 changes: 1 addition & 4 deletions tests/expr_and_series/str/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
[(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})],
)
def test_str_slice(
request: Any, constructor: Any, offset: int, length: int | None, expected: Any
constructor: Any, offset: int, length: int | None, expected: Any
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data), eager_only=True)
result_frame = df.select(nw.col("a").str.slice(offset, length))
compare_dicts(result_frame, expected)
Expand Down
12 changes: 2 additions & 10 deletions tests/expr_and_series/str/starts_with_ends_with_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Any

import pytest

import narwhals.stable.v1 as nw

# Don't move this into typechecking block, for coverage
Expand All @@ -13,10 +11,7 @@
data = {"a": ["fdas", "edfas"]}


def test_ends_with(request: Any, constructor_with_lazy: Any) -> None:
if "pyarrow_table" in str(constructor_with_lazy):
request.applymarker(pytest.mark.xfail)

def test_ends_with(constructor_with_lazy: Any) -> None:
df = nw.from_native(constructor_with_lazy(data)).lazy()
result = df.select(nw.col("a").str.ends_with("das"))
expected = {
Expand All @@ -31,10 +26,7 @@ def test_ends_with(request: Any, constructor_with_lazy: Any) -> None:
compare_dicts(result, expected)


def test_starts_with(request: Any, constructor_with_lazy: Any) -> None:
if "pyarrow_table" in str(constructor_with_lazy):
request.applymarker(pytest.mark.xfail)

def test_starts_with(constructor_with_lazy: Any) -> None:
df = nw.from_native(constructor_with_lazy(data)).lazy()
result = df.select(nw.col("a").str.starts_with("fda"))
expected = {
Expand Down
7 changes: 1 addition & 6 deletions tests/expr_and_series/str/tail_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

data = {"a": ["foo", "bars"]}


def test_str_tail(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_str_tail(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
expected = {"a": ["foo", "ars"]}

Expand Down
18 changes: 7 additions & 11 deletions tests/expr_and_series/str/to_datetime_test.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from datetime import datetime
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals.stable.v1 as nw

df_pandas = pd.DataFrame({"a": ["2020-01-01T12:34:56"]})
df_polars = pl.DataFrame({"a": ["2020-01-01T12:34:56"]})
data = {"a": ["2020-01-01T12:34:56"]}


@pytest.mark.parametrize("df_any", [df_pandas, df_polars])
def test_to_datetime(df_any: Any) -> None:
result = nw.from_native(df_any, eager_only=True).select(
b=nw.col("a").str.to_datetime(format="%Y-%m-%dT%H:%M:%S")
)["b"][0]
def test_to_datetime(constructor: Any) -> None:
result = (
nw.from_native(constructor(data), eager_only=True)
.select(b=nw.col("a").str.to_datetime(format="%Y-%m-%dT%H:%M:%S"))
.item(row=0, column="b")
)
assert result == datetime(2020, 1, 1, 12, 34, 56)

0 comments on commit 4bf5bd2

Please sign in to comment.