-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: improving array casting #1865
Changes from 5 commits
1b8fd7d
569c204
ab25f4d
65287c9
7c84531
31dd90f
044db09
35c2a2e
c79507d
981f87c
690fc99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,10 +111,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: | |
) | ||
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): | ||
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) | ||
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): | ||
if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Array type in duckdb can have multiple dimensions. The resulting type is: With this new regex we can parse multiple instances of the dimensions |
||
duckdb_inner_type = match_.group(1) | ||
duckdb_shape = match_.group(2) | ||
shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape)) | ||
return dtypes.Array( | ||
native_to_narwhals_dtype(match_.group(1), version), | ||
int(match_.group(2)), | ||
inner=native_to_narwhals_dtype(duckdb_inner_type, version), | ||
shape=shape, | ||
) | ||
if duckdb_dtype.startswith("DECIMAL("): | ||
return dtypes.Decimal() | ||
|
@@ -171,8 +174,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st | |
) | ||
return f"STRUCT({inner})" | ||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover | ||
msg = "todo" | ||
raise NotImplementedError(msg) | ||
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr] | ||
while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr] | ||
dtype = dtype.inner # type: ignore[union-attr] | ||
inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr] | ||
return f"{inner}{duckdb_shape_fmt}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First creates the shape |
||
msg = f"Unknown dtype: {dtype}" # pragma: no cover | ||
raise AssertionError(msg) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,7 +160,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any: | |
if isinstance(other, PandasLikeSeries): | ||
len_other = other.len() | ||
|
||
if len_other == 1: | ||
if len_other == 1 and len(index) != 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly as for pyarrow, otherwise for list/array types we end up getting the first element |
||
# broadcast | ||
s = other._native_series | ||
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) | ||
|
@@ -387,9 +387,7 @@ def rename( | |
|
||
|
||
@lru_cache(maxsize=16) | ||
def non_object_native_to_narwhals_dtype( | ||
dtype: str, version: Version, _implementation: Implementation | ||
) -> DType: | ||
def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType: | ||
dtypes = import_dtypes_module(version) | ||
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: | ||
return dtypes.Int64() | ||
|
@@ -465,7 +463,7 @@ def native_to_narwhals_dtype( | |
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version) | ||
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version) | ||
if dtype != "object": | ||
return non_object_native_to_narwhals_dtype(dtype, version, implementation) | ||
return non_object_native_to_narwhals_dtype(dtype, version) | ||
if implementation is Implementation.DASK: | ||
# Dask columns are lazy, so we can't inspect values. | ||
# The most useful assumption is probably String | ||
|
@@ -649,68 +647,25 @@ def narwhals_to_native_dtype( # noqa: PLR0915 | |
if isinstance_or_issubclass(dtype, dtypes.Enum): | ||
msg = "Converting to Enum is not (yet) supported" | ||
raise NotImplementedError(msg) | ||
if isinstance_or_issubclass(dtype, dtypes.List): | ||
from narwhals._arrow.utils import ( | ||
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, | ||
) | ||
|
||
if implementation is Implementation.PANDAS and backend_version >= (2, 2): | ||
try: | ||
import pandas as pd | ||
import pyarrow as pa # ignore-banned-import | ||
except ImportError as exc: # pragma: no cover | ||
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" | ||
raise ImportError(msg) from exc | ||
|
||
return pd.ArrowDtype( | ||
pa.list_( | ||
value_type=arrow_narwhals_to_native_dtype( | ||
dtype.inner, # type: ignore[union-attr] | ||
version=version, | ||
) | ||
) | ||
) | ||
else: # pragma: no cover | ||
msg = ( | ||
"Converting to List dtype is not supported for implementation " | ||
f"{implementation} and version {version}." | ||
) | ||
return NotImplementedError(msg) | ||
if isinstance_or_issubclass(dtype, dtypes.Struct): | ||
if isinstance_or_issubclass(dtype, (dtypes.Struct, dtypes.Array, dtypes.List)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This diff is quite nice π |
||
if implementation is Implementation.PANDAS and backend_version >= (2, 2): | ||
try: | ||
import pandas as pd | ||
import pyarrow as pa # ignore-banned-import | ||
import pyarrow as pa # ignore-banned-import # noqa: F401 | ||
except ImportError as exc: # pragma: no cover | ||
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" | ||
raise ImportError(msg) from exc | ||
from narwhals._arrow.utils import ( | ||
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, | ||
) | ||
|
||
return pd.ArrowDtype( | ||
pa.struct( | ||
[ | ||
( | ||
field.name, | ||
arrow_narwhals_to_native_dtype( | ||
field.dtype, | ||
version=version, | ||
), | ||
) | ||
for field in dtype.fields # type: ignore[union-attr] | ||
] | ||
) | ||
) | ||
return pd.ArrowDtype(arrow_narwhals_to_native_dtype(dtype, version=version)) | ||
else: # pragma: no cover | ||
msg = ( | ||
"Converting to Struct dtype is not supported for implementation " | ||
f"Converting to {dtype} dtype is not supported for implementation " | ||
f"{implementation} and version {version}." | ||
) | ||
return NotImplementedError(msg) | ||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover | ||
msg = "Converting to Array dtype is not supported yet" | ||
return NotImplementedError(msg) | ||
raise NotImplementedError(msg) | ||
msg = f"Unknown dtype: {dtype}" # pragma: no cover | ||
raise AssertionError(msg) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise for list/array types we end up getting the first element