Skip to content

Commit

Permalink
fix(python): fix issue with invalid Mapping objects used as schema …
Browse files Browse the repository at this point in the history
…being silently ignored (pola-rs#12027)
  • Loading branch information
alexander-beedie authored Oct 26, 2023
1 parent b4b68d9 commit 51e3d9a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
5 changes: 3 additions & 2 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def _scan_ndjson(
else:
sources = [normalize_filepath(source) for source in source]
source = None # type: ignore[assignment]

self = cls.__new__(cls)
self._ldf = PyLazyFrame.new_from_ndjson(
source,
Expand All @@ -566,13 +567,13 @@ def _scan_ndjson(
@classmethod
def _scan_python_function(
cls,
schema: pa.schema | dict[str, PolarsDataType],
schema: pa.schema | Mapping[str, PolarsDataType],
scan_fn: Any,
*,
pyarrow: bool = False,
) -> Self:
self = cls.__new__(cls)
if isinstance(schema, dict):
if isinstance(schema, Mapping):
self._ldf = PyLazyFrame.scan_from_python_function_pl_schema(
list(schema.items()), scan_fn, pyarrow
)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def _unpack_schema(
)

# determine column names from schema
if isinstance(schema, dict):
if isinstance(schema, Mapping):
column_names: list[str] = list(schema)
# coerce schema to list[str | tuple[str, PolarsDataType | PythonDataType | None]
schema = list(schema.items())
Expand Down Expand Up @@ -849,7 +849,7 @@ def dict_to_pydf(
nan_to_null: bool = False,
) -> PyDataFrame:
"""Construct a PyDataFrame from a dictionary of sequences."""
if isinstance(schema, dict) and data:
if isinstance(schema, Mapping) and data:
if not all((col in schema) for col in data):
raise ValueError(
"the given column-schema names do not match the data dictionary"
Expand Down
29 changes: 28 additions & 1 deletion py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,41 @@
from __future__ import annotations

from collections import OrderedDict
from datetime import date, timedelta
from typing import Any
from typing import Any, Iterator, Mapping

import pytest

import polars as pl
from polars.testing import assert_frame_equal


class CustomSchema(Mapping[str, Any]):
"""Dummy schema object for testing compatibility with Mapping."""

_entries: dict[str, Any]

def __init__(self, **named_entries: Any) -> None:
self._items = OrderedDict(named_entries.items())

def __getitem__(self, key: str) -> Any:
return self._items[key]

def __len__(self) -> int:
return len(self._items)

def __iter__(self) -> Iterator[str]:
yield from self._items


def test_custom_schema() -> None:
df = pl.DataFrame(schema=CustomSchema(bool=pl.Boolean, misc=pl.UInt8))
assert df.schema == OrderedDict([("bool", pl.Boolean), ("misc", pl.UInt8)])

with pytest.raises(ValueError):
pl.DataFrame(schema=CustomSchema(bool="boolean", misc="unsigned int"))


def test_schema_on_agg() -> None:
df = pl.DataFrame({"a": ["x", "x", "y", "n"], "b": [1, 2, 3, 4]})

Expand Down

0 comments on commit 51e3d9a

Please sign in to comment.