diff --git a/.github/workflows/check_tpch_queries.yml b/.github/workflows/check_tpch_queries.yml index ce7da6f8ea..df2e31def0 100644 --- a/.github/workflows/check_tpch_queries.yml +++ b/.github/workflows/check_tpch_queries.yml @@ -25,7 +25,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: local-install - run: uv pip install -U --pre -e ".[dev, core, dask]" --system + run: uv pip install -U --pre -e ".[tests, core, dask]" --system - name: generate-data run: cd tpch && python generate_data.py - name: tpch-tests diff --git a/.github/workflows/downstream_tests.yml b/.github/workflows/downstream_tests.yml index 4c7e3be8b5..58a4f439cf 100644 --- a/.github/workflows/downstream_tests.yml +++ b/.github/workflows/downstream_tests.yml @@ -220,7 +220,7 @@ jobs: run: | cd tea-tasting pdm remove narwhals - pdm add ./..[dev] + pdm add ./..[tests] - name: show-deps run: | cd tea-tasting diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 923e716eb5..6749db244a 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -28,7 +28,7 @@ jobs: run: uv pip install pipdeptree tox virtualenv setuptools pandas==0.25.3 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" scipy==1.5.0 scikit-learn==1.1.0 duckdb==1.0 tzdata --system - name: install-reqs run: | - uv pip install -e ".[dev]" --system + uv pip install -e ".[tests]" --system - name: show-deps run: uv pip freeze - name: Assert dependencies @@ -64,7 +64,7 @@ jobs: - name: install-pretty-old-versions run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.5.0 scikit-learn==1.1.0 duckdb==1.0 tzdata --system - name: install-reqs - run: uv pip install -e ".[dev]" --system + run: uv pip install -e ".[tests]" --system - name: show-deps run: uv pip freeze - name: show-deptree @@ -103,7 +103,7 @@ jobs: - name: install-not-so-old-versions run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.8.0 scikit-learn==1.3.0 duckdb==1.0 dask[dataframe]==2024.10 tzdata --system - name: install-reqs - run: uv pip install -e ".[dev]" --system + run: uv pip install -e ".[tests]" --system - name: show-deps run: uv pip freeze - name: Assert not so old versions dependencies @@ -140,7 +140,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: install-reqs - run: uv pip install -e ".[dev]" --system + run: uv pip install -e ".[tests]" --system - name: install-kaggle run: uv pip install kaggle --system - name: Download Kaggle notebook artifact diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1222fe1d92..aa0c769048 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -25,7 +25,7 @@ jobs: cache-dependency-glob: "pyproject.toml" - name: install-reqs # Python3.8 is technically at end-of-life, so we don't test everything - run: uv pip install -e ".[dev, core]" --system + run: uv pip install -e ".[tests, core]" --system - name: show-deps run: uv pip freeze - name: Run pytest @@ -49,7 +49,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: install-reqs - run: uv pip install -e ".[dev, core, extra, dask, modin]" --system + run: uv pip install -e ".[tests, core, extra, dask, modin]" --system - name: install pyspark run: uv pip install -e ".[pyspark]" --system # PySpark is not yet available on Python3.12+ @@ -83,7 +83,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: install-reqs - run: uv pip install -e ".[dev, core, extra, modin, dask]" --system + run: uv pip install -e ".[tests, core, extra, modin, dask]" --system - name: install pyspark run: uv pip install -e ".[pyspark]" --system # PySpark is not yet available on Python3.12+ diff --git a/.github/workflows/random_ci_pytest.yml b/.github/workflows/random_ci_pytest.yml index 4ec50da065..bf77ce4065 100644 --- a/.github/workflows/random_ci_pytest.yml +++ b/.github/workflows/random_ci_pytest.yml @@ -27,7 +27,7 @@ jobs: - name: install-random-verions run: uv pip install -r random-requirements.txt --system - name: install-narwhals - run: uv pip install -e ".[dev]" --system + run: uv pip install -e ".[tests]" --system - name: show versions run: uv pip freeze - name: Run pytest diff --git a/.github/workflows/typing.yml b/.github/workflows/typing.yml new file mode 100644 index 0000000000..c6ccb13562 --- /dev/null +++ b/.github/workflows/typing.yml @@ -0,0 +1,40 @@ +name: Type checking + +on: + pull_request: + push: + branches: [main] + +jobs: + mypy: + strategy: + matrix: + python-version: ["3.11"] + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "pyproject.toml" + - name: Create venv + run: uv venv .venv + - name: install-reqs + # TODO: add more dependencies/backends incrementally + run: | + source .venv/bin/activate + uv pip install -e ".[tests, typing, core]" + - name: show-deps + run: | + source .venv/bin/activate + uv pip freeze + - name: Run mypy + run: | + source .venv/bin/activate + make typing diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84d5a6df04..4f98b8d26f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ ci: autoupdate_schedule: monthly + skip: [mypy] repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. @@ -13,12 +14,6 @@ repos: - id: ruff alias: check-docstrings entry: python utils/check_docstrings.py -- repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.14.1' - hooks: - - id: mypy - additional_dependencies: ['polars==1.4.1', 'pytest==8.3.2'] - files: ^(narwhals|tests)/ - repo: https://github.com/codespell-project/codespell rev: 'v2.4.1' hooks: @@ -84,6 +79,13 @@ repos: entry: pull_request_target language: pygrep files: ^\.github/workflows/ + - id: mypy + name: mypy + entry: make typing + files: ^(narwhals|tests)/ + language: system + types: [python] + require_serial: true - repo: https://github.com/adamchainz/blacken-docs rev: "1.19.1" # replace with latest tag on GitHub hooks: diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..b1d40f406a --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +# Mostly based on polars Makefile +# https://github.com/pola-rs/polars/blob/main/py-polars/Makefile + +.DEFAULT_GOAL := help + +SHELL=bash +VENV=./.venv + +ifeq ($(OS),Windows_NT) + VENV_BIN=$(VENV)/Scripts +else + VENV_BIN=$(VENV)/bin +endif + + +.PHONY: help +help: ## Display this help screen + @echo -e "\033[1mAvailable commands:\033[0m" + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort + +.PHONY: typing +typing: ## Run typing checks + $(VENV_BIN)/mypy diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 3b98dacc09..469111a15f 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -39,6 +39,7 @@ from narwhals._arrow.namespace import ArrowNamespace from narwhals.dtypes import DType from narwhals.typing import _1DArray + from narwhals.typing import _2DArray from narwhals.utils import Version @@ -340,7 +341,7 @@ def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self: def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: import numpy as np # ignore-banned-import - mask = np.zeros(self.len(), dtype=bool) + mask: _1DArray = np.zeros(self.len(), dtype=bool) mask[indices] = True if isinstance(values, self.__class__): ser, values = broadcast_and_extract_native( @@ -729,7 +730,7 @@ def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFram name = self._name da = series.dictionary_encode(null_encoding="encode").combine_chunks() - columns = np.zeros((len(da.dictionary), len(da)), np.int8) + columns: _2DArray = np.zeros((len(da.dictionary), len(da)), np.int8) columns[da.indices, np.arange(len(da))] = 1 null_col_pa, null_col_pl = f"{name}{separator}None", f"{name}{separator}null" cols = [ diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 7c2a672786..764480ab4b 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -7,10 +7,10 @@ import duckdb from duckdb import ColumnExpression -from duckdb import ConstantExpression from duckdb import FunctionExpression from narwhals._duckdb.utils import ExprKind +from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals._duckdb.utils import parse_exprs from narwhals.dependencies import get_duckdb @@ -145,7 +145,7 @@ def aggregate(self: Self, *exprs: DuckDBExpr) -> Self: new_columns_map = parse_exprs(self, *exprs) return self._from_native_frame( self._native_frame.aggregate( - [val.alias(col) for col, val in new_columns_map.items()] + [val.alias(col) for col, val in new_columns_map.items()] # type: ignore[arg-type] ), validate_column_names=False, ) @@ -302,7 +302,7 @@ def join( raise NotImplementedError(msg) rel = self._native_frame.set_alias("lhs").cross( # pragma: no cover other._native_frame.set_alias("rhs") - ) + ) # type: ignore[operator] else: # help mypy assert left_on is not None # noqa: S101 @@ -467,9 +467,9 @@ def explode(self: Self, columns: list[str]) -> Self: rel = self._native_frame original_columns = self.columns - not_null_condition = ( - col_to_explode.isnotnull() & FunctionExpression("len", col_to_explode) > 0 - ) + not_null_condition = col_to_explode.isnotnull() & FunctionExpression( + "len", col_to_explode + ) > lit(0) non_null_rel = rel.filter(not_null_condition).select( *( FunctionExpression("unnest", col_to_explode).alias(col) @@ -480,10 +480,7 @@ def explode(self: Self, columns: list[str]) -> Self: ) null_rel = rel.filter(~not_null_condition).select( - *( - ConstantExpression(None).alias(col) if col in columns else col - for col in original_columns - ) + *(lit(None).alias(col) if col in columns else col for col in original_columns) ) return self._from_native_frame( diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index e946ca27a1..544381addf 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -9,14 +9,15 @@ from duckdb import CaseExpression from duckdb import CoalesceOperator from duckdb import ColumnExpression -from duckdb import ConstantExpression from duckdb import FunctionExpression +from duckdb.typing import DuckDBPyType from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace from narwhals._duckdb.expr_list import DuckDBExprListNamespace from narwhals._duckdb.expr_name import DuckDBExprNameNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace from narwhals._duckdb.utils import ExprKind +from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import maybe_evaluate from narwhals._duckdb.utils import n_ary_operation_expr_kind from narwhals._duckdb.utils import narwhals_to_native_dtype @@ -33,7 +34,7 @@ from narwhals.utils import Version -class DuckDBExpr(CompliantExpr["duckdb.Expression"]): +class DuckDBExpr(CompliantExpr["duckdb.Expression"]): # type: ignore[type-var] _implementation = Implementation.DUCKDB _depth = 0 # Unused, just for compatibility with CompliantExpr @@ -311,15 +312,13 @@ def mean(self: Self) -> Self: def skew(self: Self) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: count = FunctionExpression("count", _input) - return CaseExpression( - condition=count == 0, value=ConstantExpression(None) - ).otherwise( + return CaseExpression(condition=(count == lit(0)), value=lit(None)).otherwise( CaseExpression( - condition=count == 1, value=ConstantExpression(float("nan")) + condition=(count == lit(1)), value=lit(float("nan")) ).otherwise( - CaseExpression( - condition=count == 2, value=ConstantExpression(0.0) - ).otherwise(FunctionExpression("skewness", _input)) + CaseExpression(condition=(count == lit(2)), value=lit(0.0)).otherwise( + FunctionExpression("skewness", _input) + ) ) ) @@ -353,9 +352,7 @@ def quantile( ) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: if interpolation == "linear": - return FunctionExpression( - "quantile_cont", _input, ConstantExpression(quantile) - ) + return FunctionExpression("quantile_cont", _input, lit(quantile)) msg = "Only linear interpolation methods are supported for DuckDB quantile." raise NotImplementedError(msg) @@ -395,9 +392,9 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: "array_unique", FunctionExpression("array_agg", _input) ) + FunctionExpression( "max", - CaseExpression( - condition=_input.isnotnull(), value=ConstantExpression(0) - ).otherwise(ConstantExpression(1)), + CaseExpression(condition=_input.isnotnull(), value=lit(0)).otherwise( + lit(1) + ), ) return self._from_call( @@ -423,11 +420,11 @@ def len(self: Self) -> Self: def std(self: Self, ddof: int) -> Self: def _std(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: n_samples = FunctionExpression("count", _input) - + # NOTE: Not implemented Error: Unable to transform python value of type '' to DuckDB LogicalType return ( FunctionExpression("stddev_pop", _input) * FunctionExpression("sqrt", n_samples) - / (FunctionExpression("sqrt", (n_samples - ddof))) + / (FunctionExpression("sqrt", (n_samples - ddof))) # type: ignore[operator] ) return self._from_call( @@ -440,7 +437,8 @@ def _std(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: def var(self: Self, ddof: int) -> Self: def _var(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: n_samples = FunctionExpression("count", _input) - return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof) + # NOTE: Not implemented Error: Unable to transform python value of type '' to DuckDB LogicalType + return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof) # type: ignore[operator] return self._from_call( _var, @@ -493,16 +491,14 @@ def is_finite(self: Self) -> Self: def is_in(self: Self, other: Sequence[Any]) -> Self: return self._from_call( - lambda _input: _input.isin(*[ConstantExpression(x) for x in other]), + lambda _input: _input.isin(*[lit(x) for x in other]), "is_in", expr_kind=self._expr_kind, ) def round(self: Self, decimals: int) -> Self: return self._from_call( - lambda _input: FunctionExpression( - "round", _input, ConstantExpression(decimals) - ), + lambda _input: FunctionExpression("round", _input, lit(decimals)), "round", expr_kind=self._expr_kind, ) @@ -513,7 +509,7 @@ def fill_null(self: Self, value: Any, strategy: Any, limit: int | None) -> Self: raise NotImplementedError(msg) return self._from_call( - lambda _input: CoalesceOperator(_input, ConstantExpression(value)), + lambda _input: CoalesceOperator(_input, lit(value)), "fill_null", expr_kind=self._expr_kind, ) @@ -521,7 +517,7 @@ def fill_null(self: Self, value: Any, strategy: Any, limit: int | None) -> Self: def cast(self: Self, dtype: DType | type[DType]) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: native_dtype = narwhals_to_native_dtype(dtype, self._version) - return _input.cast(native_dtype) + return _input.cast(DuckDBPyType(native_dtype)) return self._from_call( func, diff --git a/narwhals/_duckdb/expr_dt.py b/narwhals/_duckdb/expr_dt.py index b7a750d2ac..4cb9c73fa9 100644 --- a/narwhals/_duckdb/expr_dt.py +++ b/narwhals/_duckdb/expr_dt.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING -from duckdb import ConstantExpression from duckdb import FunctionExpression +from narwhals._duckdb.utils import lit + if TYPE_CHECKING: from typing_extensions import Self @@ -60,7 +61,7 @@ def second(self: Self) -> DuckDBExpr: def millisecond(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("millisecond", _input) - - FunctionExpression("second", _input) * 1_000, + - FunctionExpression("second", _input) * lit(1_000), "millisecond", expr_kind=self._compliant_expr._expr_kind, ) @@ -68,7 +69,7 @@ def millisecond(self: Self) -> DuckDBExpr: def microsecond(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("microsecond", _input) - - FunctionExpression("second", _input) * 1_000_000, + - FunctionExpression("second", _input) * lit(1_000_000), "microsecond", expr_kind=self._compliant_expr._expr_kind, ) @@ -76,16 +77,14 @@ def microsecond(self: Self) -> DuckDBExpr: def nanosecond(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("nanosecond", _input) - - FunctionExpression("second", _input) * 1_000_000_000, + - FunctionExpression("second", _input) * lit(1_000_000_000), "nanosecond", expr_kind=self._compliant_expr._expr_kind, ) def to_string(self: Self, format: str) -> DuckDBExpr: # noqa: A002 return self._compliant_expr._from_call( - lambda _input: FunctionExpression( - "strftime", _input, ConstantExpression(format) - ), + lambda _input: FunctionExpression("strftime", _input, lit(format)), "to_string", expr_kind=self._compliant_expr._expr_kind, ) @@ -113,36 +112,33 @@ def date(self: Self) -> DuckDBExpr: def total_minutes(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression( - "datepart", ConstantExpression("minute"), _input - ), + lambda _input: FunctionExpression("datepart", lit("minute"), _input), "total_minutes", expr_kind=self._compliant_expr._expr_kind, ) def total_seconds(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: 60 - * FunctionExpression("datepart", ConstantExpression("minute"), _input) - + FunctionExpression("datepart", ConstantExpression("second"), _input), + lambda _input: lit(60) * FunctionExpression("datepart", lit("minute"), _input) + + FunctionExpression("datepart", lit("second"), _input), "total_seconds", expr_kind=self._compliant_expr._expr_kind, ) def total_milliseconds(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: 60_000 - * FunctionExpression("datepart", ConstantExpression("minute"), _input) - + FunctionExpression("datepart", ConstantExpression("millisecond"), _input), + lambda _input: lit(60_000) + * FunctionExpression("datepart", lit("minute"), _input) + + FunctionExpression("datepart", lit("millisecond"), _input), "total_milliseconds", expr_kind=self._compliant_expr._expr_kind, ) def total_microseconds(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: 60_000_000 - * FunctionExpression("datepart", ConstantExpression("minute"), _input) - + FunctionExpression("datepart", ConstantExpression("microsecond"), _input), + lambda _input: lit(60_000_000) + * FunctionExpression("datepart", lit("minute"), _input) + + FunctionExpression("datepart", lit("microsecond"), _input), "total_microseconds", expr_kind=self._compliant_expr._expr_kind, ) diff --git a/narwhals/_duckdb/expr_str.py b/narwhals/_duckdb/expr_str.py index 6b4540914a..a3d1b81dd3 100644 --- a/narwhals/_duckdb/expr_str.py +++ b/narwhals/_duckdb/expr_str.py @@ -1,13 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import NoReturn -from duckdb import ConstantExpression from duckdb import FunctionExpression +from narwhals._duckdb.utils import lit + if TYPE_CHECKING: import duckdb + from typing_extensions import Never from typing_extensions import Self from narwhals._duckdb.expr import DuckDBExpr @@ -19,18 +20,14 @@ def __init__(self: Self, expr: DuckDBExpr) -> None: def starts_with(self: Self, prefix: str) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression( - "starts_with", _input, ConstantExpression(prefix) - ), + lambda _input: FunctionExpression("starts_with", _input, lit(prefix)), "starts_with", expr_kind=self._compliant_expr._expr_kind, ) def ends_with(self: Self, suffix: str) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression( - "ends_with", _input, ConstantExpression(suffix) - ), + lambda _input: FunctionExpression("ends_with", _input, lit(suffix)), "ends_with", expr_kind=self._compliant_expr._expr_kind, ) @@ -38,10 +35,8 @@ def ends_with(self: Self, suffix: str) -> DuckDBExpr: def contains(self: Self, pattern: str, *, literal: bool) -> DuckDBExpr: def func(_input: duckdb.Expression) -> duckdb.Expression: if literal: - return FunctionExpression("contains", _input, ConstantExpression(pattern)) - return FunctionExpression( - "regexp_matches", _input, ConstantExpression(pattern) - ) + return FunctionExpression("contains", _input, lit(pattern)) + return FunctionExpression("regexp_matches", _input, lit(pattern)) return self._compliant_expr._from_call( func, "contains", expr_kind=self._compliant_expr._expr_kind @@ -49,15 +44,16 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: def slice(self: Self, offset: int, length: int) -> DuckDBExpr: def func(_input: duckdb.Expression) -> duckdb.Expression: + offset_lit = lit(offset) return FunctionExpression( "array_slice", _input, - ConstantExpression(offset + 1) + lit(offset + 1) if offset >= 0 - else FunctionExpression("length", _input) + offset + 1, + else FunctionExpression("length", _input) + offset_lit + lit(1), FunctionExpression("length", _input) if length is None - else ConstantExpression(length) + offset, + else lit(length) + offset_lit, ) return self._compliant_expr._from_call( @@ -92,9 +88,7 @@ def strip_chars(self: Self, characters: str | None) -> DuckDBExpr: lambda _input: FunctionExpression( "trim", _input, - ConstantExpression( - string.whitespace if characters is None else characters - ), + lit(string.whitespace if characters is None else characters), ), "strip_chars", expr_kind=self._compliant_expr._expr_kind, @@ -104,26 +98,20 @@ def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> DuckD if not literal: return self._compliant_expr._from_call( lambda _input: FunctionExpression( - "regexp_replace", - _input, - ConstantExpression(pattern), - ConstantExpression(value), - ConstantExpression("g"), + "regexp_replace", _input, lit(pattern), lit(value), lit("g") ), "replace_all", expr_kind=self._compliant_expr._expr_kind, ) return self._compliant_expr._from_call( lambda _input: FunctionExpression( - "replace", _input, ConstantExpression(pattern), ConstantExpression(value) + "replace", _input, lit(pattern), lit(value) ), "replace_all", expr_kind=self._compliant_expr._expr_kind, ) - def replace( - self: Self, pattern: str, value: str, *, literal: bool, n: int - ) -> NoReturn: + def replace(self: Self, pattern: str, value: str, *, literal: bool, n: int) -> Never: msg = "`replace` is currently not supported for DuckDB" raise NotImplementedError(msg) @@ -133,9 +121,7 @@ def to_datetime(self: Self, format: str | None) -> DuckDBExpr: # noqa: A002 raise NotImplementedError(msg) return self._compliant_expr._from_call( - lambda _input: FunctionExpression( - "strptime", _input, ConstantExpression(format) - ), + lambda _input: FunctionExpression("strptime", _input, lit(format)), "to_datetime", expr_kind=self._compliant_expr._expr_kind, ) diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index ed0cbb77de..3ba9a5c9bb 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from duckdb import Expression from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame @@ -23,7 +24,7 @@ def __init__( self._keys = keys def agg(self: Self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: - agg_columns = self._keys.copy() + agg_columns: list[str | Expression] = list(self._keys) df = self._compliant_frame for expr in exprs: output_names = expr._evaluate_output_names(df) @@ -49,5 +50,5 @@ def agg(self: Self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: ) return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.aggregate(agg_columns) + self._compliant_frame._native_frame.aggregate(agg_columns) # type: ignore[arg-type] ) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index b7089659c5..840c202b13 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -13,12 +13,14 @@ from duckdb import CaseExpression from duckdb import CoalesceOperator from duckdb import ColumnExpression -from duckdb import ConstantExpression from duckdb import FunctionExpression +from duckdb.typing import BIGINT +from duckdb.typing import VARCHAR from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.selectors import DuckDBSelectorNamespace from narwhals._duckdb.utils import ExprKind +from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import n_ary_operation_expr_kind from narwhals._duckdb.utils import narwhals_to_native_dtype from narwhals._expression_parsing import combine_alias_output_names @@ -34,7 +36,7 @@ from narwhals.utils import Version -class DuckDBNamespace(CompliantNamespace["duckdb.Expression"]): +class DuckDBNamespace(CompliantNamespace["duckdb.Expression"]): # type: ignore[type-var] def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version ) -> None: @@ -98,9 +100,9 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols_separated = [ y for x in [ - (col.cast("string"),) + (col.cast(VARCHAR),) if i == len(cols) - 1 - else (col.cast("string"), ConstantExpression(separator)) + else (col.cast(VARCHAR), lit(separator)) for i, col in enumerate(cols) ] for y in x @@ -111,15 +113,11 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: ) else: init_value, *values = [ - CaseExpression(~nm, col.cast("string")).otherwise( - ConstantExpression("") - ) + CaseExpression(~nm, col.cast(VARCHAR)).otherwise(lit("")) for col, nm in zip(cols, null_mask) ] separators = ( - CaseExpression(nm, ConstantExpression("")).otherwise( - ConstantExpression(separator) - ) + CaseExpression(nm, lit("")).otherwise(lit(separator)) for nm in null_mask[:-1] ) result = reduce( @@ -205,11 +203,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def sum_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - cols = ( - CoalesceOperator(col, ConstantExpression(0)) - for _expr in exprs - for col in _expr(df) - ) + cols = (CoalesceOperator(col, lit(0)) for _expr in exprs for col in _expr(df)) return [reduce(operator.add, cols)] return DuckDBExpr( @@ -227,11 +221,8 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] return [ ( - reduce( - operator.add, - (CoalesceOperator(col, ConstantExpression(0)) for col in cols), - ) - / reduce(operator.add, (col.isnotnull().cast("int") for col in cols)) + reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols)) + / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols)) ) ] @@ -272,11 +263,11 @@ def lit(self: Self, value: Any, dtype: DType | None) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: if dtype is not None: return [ - ConstantExpression(value).cast( - narwhals_to_native_dtype(dtype, version=self._version) + lit(value).cast( + narwhals_to_native_dtype(dtype, version=self._version) # type: ignore[arg-type] ) ] - return [ConstantExpression(value)] + return [lit(value)] return DuckDBExpr( func, @@ -329,7 +320,7 @@ def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: value = self._then_value(df)[0] else: # `self._otherwise_value` is a scalar - value = ConstantExpression(self._then_value) + value = lit(self._then_value) value = cast("duckdb.Expression", value) if self._otherwise_value is None: @@ -338,7 +329,7 @@ def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: # `self._otherwise_value` is a scalar return [ CaseExpression(condition=condition, value=value).otherwise( - ConstantExpression(self._otherwise_value) + lit(self._otherwise_value) ) ] otherwise = self._otherwise_value(df)[0] diff --git a/narwhals/_duckdb/series.py b/narwhals/_duckdb/series.py index 6d6bb45283..b63f8c700e 100644 --- a/narwhals/_duckdb/series.py +++ b/narwhals/_duckdb/series.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import NoReturn from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals.dependencies import get_duckdb @@ -10,6 +9,7 @@ from types import ModuleType import duckdb + from typing_extensions import Never from typing_extensions import Self from narwhals.dtypes import DType @@ -31,7 +31,7 @@ def __native_namespace__(self: Self) -> ModuleType: def dtype(self: Self) -> DType: return native_to_narwhals_dtype(str(self._native_series.types[0]), self._version) - def __getattr__(self: Self, attr: str) -> NoReturn: + def __getattr__(self: Self, attr: str) -> Never: msg = ( # pragma: no cover f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" "If you would like to see this kind of object better supported in " diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 79d1e2878d..009296c760 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -18,6 +18,9 @@ from narwhals.dtypes import DType from narwhals.utils import Version +lit = duckdb.ConstantExpression +"""Alias for `duckdb.ConstantExpression`.""" + class ExprKind(Enum): """Describe which kind of expression we are dealing with. diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 0595a52484..c72e43c084 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -57,7 +57,7 @@ from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame -CLASSICAL_NUMPY_DTYPES = frozenset( +CLASSICAL_NUMPY_DTYPES: frozenset[np.dtype] = frozenset( [ np.dtype("float64"), np.dtype("float32"), @@ -257,8 +257,8 @@ def __getitem__( elif isinstance(item, tuple) and len(item) == 2: if isinstance(item[1], str): - item = (item[0], self._native_frame.columns.get_loc(item[1])) # pyright: ignore[reportAssignmentType] - native_series = self._native_frame.iloc[item] + index = (item[0], self._native_frame.columns.get_loc(item[1])) + native_series = self._native_frame.iloc[index] elif isinstance(item[1], int): native_series = self._native_frame.iloc[item] else: # pragma: no cover diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 2d586fc989..2b047b3cfc 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -6,6 +6,7 @@ from typing import Iterator from typing import Literal from typing import Sequence +from typing import cast from typing import overload from narwhals._pandas_like.series_cat import PandasLikeSeriesCatNamespace @@ -30,6 +31,7 @@ if TYPE_CHECKING: from types import ModuleType + from typing import Hashable import pandas as pd import polars as pl @@ -39,6 +41,7 @@ from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals.dtypes import DType from narwhals.typing import _1DArray + from narwhals.typing import _AnyDArray from narwhals.utils import Version PANDAS_TO_NUMPY_DTYPE_NO_MISSING = { @@ -654,7 +657,7 @@ def sort(self: Self, *, descending: bool, nulls_last: bool) -> PandasLikeSeries: ) ).alias(self.name) - def alias(self: Self, name: str) -> Self: + def alias(self: Self, name: str | Hashable) -> Self: if name != self.name: return self._from_native_series( rename( @@ -1039,7 +1042,7 @@ def hist( from narwhals._pandas_like.dataframe import PandasLikeDataFrame ns = self.__native_namespace__() - data: dict[str, Sequence[int | float | str]] + data: dict[str, Sequence[int | float | str] | _AnyDArray] if bin_count == 0 or (bins is not None and len(bins) <= 1): data = {} @@ -1056,15 +1059,10 @@ def hist( ) elif self._native_series.count() < 1: if bins is not None: - data = { - "breakpoint": bins[1:], - "count": zeros(shape=len(bins) - 1), - } + data = {"breakpoint": bins[1:], "count": zeros(shape=len(bins) - 1)} else: - data = { - "breakpoint": linspace(0, 1, bin_count), - "count": zeros(shape=bin_count), - } + count = cast("int", bin_count) + data = {"breakpoint": linspace(0, 1, count), "count": zeros(shape=count)} if not include_breakpoint: del data["breakpoint"] diff --git a/pyproject.toml b/pyproject.toml index 554bfc6d4e..7537879165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,8 @@ polars = ["polars>=0.20.3"] dask = ["dask[dataframe]>=2024.8"] duckdb = ["duckdb>=1.0"] ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"] -dev = [ +tests = [ "covdefaults", - "pre-commit", "pytest", "pytest-cov", "pytest-randomly", @@ -47,12 +46,22 @@ dev = [ "hypothesis", "typing_extensions", ] +typing = [ + "typing_extensions", + "mypy~=1.15.0", +] +dev = [ + "pre-commit", + "narwhals[tests]", + "narwhals[typing]", +] core = [ "duckdb", "pandas", "polars", "pyarrow", - "pyarrow-stubs", + #TODO: reintroduce when fixing #1961 + # "pyarrow-stubs", ] extra = [ # heavier dependencies we don't necessarily need in every testing job "scikit-learn", @@ -209,13 +218,27 @@ exclude_also = [ ] [tool.mypy] -strict = true +files = ["narwhals", "tests"] +# TODO: reenable strict mode +# strict = true [[tool.mypy.overrides]] -# the pandas API is just too inconsistent for type hinting to be useful. module = [ - "pandas.*", + # TODO: enable step by step when it makes sense + # e.g. the pandas API is just too inconsistent for type hinting to be useful. "cudf.*", + "dask.*", + "dask_expr.*", + "duckdb.*", + "ibis.*", "modin.*", + "numpy.*", + "pandas.*", + "pyarrow.*", + "pyspark.*", + "sklearn.*", + "sqlframe.*", ] +# TODO: remove follow_imports +follow_imports = "skip" ignore_missing_imports = true diff --git a/tests/expr_and_series/arithmetic_test.py b/tests/expr_and_series/arithmetic_test.py index 3a5cde3a36..4260e40880 100644 --- a/tests/expr_and_series/arithmetic_test.py +++ b/tests/expr_and_series/arithmetic_test.py @@ -159,10 +159,7 @@ def test_truediv_same_dims( @pytest.mark.slow -@given( # type: ignore[misc] - left=st.integers(-100, 100), - right=st.integers(-100, 100), -) +@given(left=st.integers(-100, 100), right=st.integers(-100, 100)) @pytest.mark.skipif(PANDAS_VERSION < (2, 0), reason="convert_dtypes not available") def test_floordiv(left: int, right: int) -> None: # hypothesis complains if we add `constructor` as an argument, so this @@ -197,10 +194,7 @@ def test_floordiv(left: int, right: int) -> None: @pytest.mark.slow -@given( # type: ignore[misc] - left=st.integers(-100, 100), - right=st.integers(-100, 100), -) +@given(left=st.integers(-100, 100), right=st.integers(-100, 100)) @pytest.mark.skipif(PANDAS_VERSION < (2, 0), reason="convert_dtypes not available") def test_mod(left: int, right: int) -> None: # hypothesis complains if we add `constructor` as an argument, so this diff --git a/tests/expr_and_series/dt/datetime_duration_test.py b/tests/expr_and_series/dt/datetime_duration_test.py index 32bf5d9049..c4ed2d2a1c 100644 --- a/tests/expr_and_series/dt/datetime_duration_test.py +++ b/tests/expr_and_series/dt/datetime_duration_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import timedelta +from typing import Literal import numpy as np import pyarrow as pa @@ -107,7 +108,9 @@ def test_duration_attributes_series( ("total_nanoseconds", 70e9), ], ) -def test_pyarrow_units(unit: str, attribute: str, expected: int) -> None: +def test_pyarrow_units( + unit: Literal["s", "ms", "us", "ns"], attribute: str, expected: int +) -> None: data = [None, timedelta(minutes=1, seconds=10)] arr = pc.cast(pa.array(data), pa.duration(unit)) df = nw.from_native(pa.table({"a": arr}), eager_only=True) diff --git a/tests/expr_and_series/dt/ordinal_day_test.py b/tests/expr_and_series/dt/ordinal_day_test.py index 82e30d8a12..3fe87fc5f2 100644 --- a/tests/expr_and_series/dt/ordinal_day_test.py +++ b/tests/expr_and_series/dt/ordinal_day_test.py @@ -12,7 +12,7 @@ from tests.utils import PANDAS_VERSION -@given(dates=st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1))) # type: ignore[misc] +@given(dates=st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1))) @pytest.mark.skipif( PANDAS_VERSION < (2, 0, 0), reason="pyarrow dtype not available", diff --git a/tests/expr_and_series/dt/timestamp_test.py b/tests/expr_and_series/dt/timestamp_test.py index b7e20519fb..448091cad4 100644 --- a/tests/expr_and_series/dt/timestamp_test.py +++ b/tests/expr_and_series/dt/timestamp_test.py @@ -204,7 +204,7 @@ def test_timestamp_invalid_unit_series(constructor_eager: ConstructorEager) -> N nw.from_native(constructor_eager(data))["a"].dt.timestamp(time_unit_invalid) # type: ignore[arg-type] -@given( # type: ignore[misc] +@given( inputs=st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1)), time_unit=st.sampled_from(["ms", "us", "ns"]), # We keep 'ms' out for now due to an upstream bug: https://github.com/pola-rs/polars/issues/19309 diff --git a/tests/expr_and_series/dt/total_minutes_test.py b/tests/expr_and_series/dt/total_minutes_test.py index 094c51cbfa..41bfe0aaf3 100644 --- a/tests/expr_and_series/dt/total_minutes_test.py +++ b/tests/expr_and_series/dt/total_minutes_test.py @@ -17,11 +17,8 @@ min_value=-timedelta(days=5, minutes=70, seconds=10), max_value=timedelta(days=3, minutes=90, seconds=60), ) -) # type: ignore[misc] -@pytest.mark.skipif( - PANDAS_VERSION < (2, 2, 0), - reason="pyarrow dtype not available", ) +@pytest.mark.skipif(PANDAS_VERSION < (2, 2, 0), reason="pyarrow dtype not available") @pytest.mark.slow def test_total_minutes(timedeltas: timedelta) -> None: result_pd = nw.from_native( diff --git a/tests/expr_and_series/rolling_mean_test.py b/tests/expr_and_series/rolling_mean_test.py index 5d69639ce8..ddcd0c79bf 100644 --- a/tests/expr_and_series/rolling_mean_test.py +++ b/tests/expr_and_series/rolling_mean_test.py @@ -69,7 +69,7 @@ def test_rolling_mean_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -@given( # type: ignore[misc] +@given( center=st.booleans(), values=st.lists(st.floats(-10, 10), min_size=3, max_size=10), ) diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index 59872befcc..2b9e273adf 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -188,7 +188,7 @@ def test_rolling_sum_series_invalid_params( df["a"].rolling_sum(window_size=window_size, min_samples=min_samples) -@given( # type: ignore[misc] +@given( center=st.booleans(), values=st.lists(st.floats(-10, 10), min_size=3, max_size=10), ) diff --git a/tests/expr_and_series/rolling_var_test.py b/tests/expr_and_series/rolling_var_test.py index 5eeda49061..86b47df330 100644 --- a/tests/expr_and_series/rolling_var_test.py +++ b/tests/expr_and_series/rolling_var_test.py @@ -99,7 +99,7 @@ def test_rolling_var_series( assert_equal_data(result, {name: expected}) -@given( # type: ignore[misc] +@given( center=st.booleans(), values=st.lists(st.floats(-10, 10), min_size=5, max_size=10), ) diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index c357d4adf0..912079248c 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any import pandas as pd import polars as pl @@ -30,6 +31,7 @@ def test_collect_to_default_backend(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.lazy().collect().to_native() + expected_cls: Any if "polars" in str(constructor): expected_cls = pl.DataFrame elif any(x in str(constructor) for x in ("pandas", "dask")): diff --git a/tests/frame/lazy_test.py b/tests/frame/lazy_test.py index 12229afba6..fa11cae09b 100644 --- a/tests/frame/lazy_test.py +++ b/tests/frame/lazy_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any import pandas as pd import polars as pl @@ -28,6 +29,7 @@ def test_lazy_to_default(constructor_eager: ConstructorEager) -> None: result = df.lazy() assert isinstance(result, nw_v1.LazyFrame) + expected_cls: Any if "polars" in str(constructor_eager): expected_cls = pl.LazyFrame elif "pandas" in str(constructor_eager): diff --git a/tests/hypothesis/basic_arithmetic_test.py b/tests/hypothesis/basic_arithmetic_test.py index 00818271d1..5e9b18703f 100644 --- a/tests/hypothesis/basic_arithmetic_test.py +++ b/tests/hypothesis/basic_arithmetic_test.py @@ -21,24 +21,13 @@ min_size=3, max_size=3, ), -) # type: ignore[misc] +) @pytest.mark.slow def test_mean( - integer: st.SearchStrategy[list[int]], - floats: st.SearchStrategy[float], + integer: st.SearchStrategy[list[int]], floats: st.SearchStrategy[float] ) -> None: - df_pandas = pd.DataFrame( - { - "integer": integer, - "floats": floats, - } - ) - df_polars = pl.DataFrame( - { - "integer": integer, - "floats": floats, - }, - ) + df_pandas = pd.DataFrame({"integer": integer, "floats": floats}) + df_polars = pl.DataFrame({"integer": integer, "floats": floats}) df_nw1 = nw.from_native(df_pandas, eager_only=True) df_nw2 = nw.from_native(df_polars, eager_only=True) diff --git a/tests/hypothesis/getitem_test.py b/tests/hypothesis/getitem_test.py index 33c31f7611..f6cfd45897 100644 --- a/tests/hypothesis/getitem_test.py +++ b/tests/hypothesis/getitem_test.py @@ -46,11 +46,8 @@ def pandas_or_pyarrow_constructor( TEST_DATA_NUM_ROWS = len(TEST_DATA[TEST_DATA_COLUMNS[0]]) -@st.composite # type: ignore[misc] -def string_slice( - draw: st.DrawFn, - strs: Sequence[str], -) -> slice: +@st.composite +def string_slice(draw: st.DrawFn, strs: Sequence[str]) -> slice: """Return slices such as `"a":`, `"a":"c"`, `"a":"c":2`, etc.""" n_cols = len(strs) index_slice = draw( @@ -90,7 +87,7 @@ def string_slice( ) -@st.composite # type: ignore[misc] +@st.composite def tuple_selector(draw: st.DrawFn) -> tuple[Any, Any]: rows = st.one_of( st.lists( @@ -105,7 +102,7 @@ def tuple_selector(draw: st.DrawFn) -> tuple[Any, Any]: ), st.slices(TEST_DATA_NUM_ROWS), arrays( - dtype=st.sampled_from([np.int8, np.int16, np.int32, np.int64]), + dtype=st.sampled_from([np.int8, np.int16, np.int32, np.int64]), # type: ignore[arg-type] shape=st.integers(min_value=0, max_value=10), elements=st.integers( min_value=0, # pyarrow does not support negative indexing @@ -135,11 +132,8 @@ def tuple_selector(draw: st.DrawFn) -> tuple[Any, Any]: @given( - selector=st.one_of( - single_selector, - tuple_selector(), - ), -) # type: ignore[misc] + selector=st.one_of(single_selector, tuple_selector()), +) @pytest.mark.slow def test_getitem( pandas_or_pyarrow_constructor: Any, diff --git a/tests/hypothesis/join_test.py b/tests/hypothesis/join_test.py index da4a61679e..97830ab0ac 100644 --- a/tests/hypothesis/join_test.py +++ b/tests/hypothesis/join_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import cast + import pandas as pd import polars as pl import pyarrow as pa @@ -37,7 +39,7 @@ max_size=3, unique=True, ), -) # type: ignore[misc] +) @pytest.mark.skipif(POLARS_VERSION < (0, 20, 13), reason="0.0 == -0.0") @pytest.mark.skipif(PANDAS_VERSION < (2, 0, 0), reason="requires pyarrow") @pytest.mark.slow @@ -48,18 +50,20 @@ def test_join( # pragma: no cover cols: st.SearchStrategy[list[str]], ) -> None: data = {"a": integers, "b": other_integers, "c": floats} + join_cols = cast(list[str], cols) df_polars = pl.DataFrame(data) df_polars2 = pl.DataFrame(data) df_pl = nw.from_native(df_polars, eager_only=True) other_pl = nw.from_native(df_polars2, eager_only=True) - dframe_pl = df_pl.join(other_pl, left_on=cols, right_on=cols, how="inner") + + dframe_pl = df_pl.join(other_pl, left_on=join_cols, right_on=join_cols, how="inner") df_pandas = pd.DataFrame(data) df_pandas2 = pd.DataFrame(data) df_pd = nw.from_native(df_pandas, eager_only=True) other_pd = nw.from_native(df_pandas2, eager_only=True) - dframe_pd = df_pd.join(other_pd, left_on=cols, right_on=cols, how="inner") + dframe_pd = df_pd.join(other_pd, left_on=join_cols, right_on=join_cols, how="inner") dframe_pd1 = nw.to_native(dframe_pl).to_pandas() dframe_pd1 = dframe_pd1.sort_values( @@ -85,7 +89,7 @@ def test_join( # pragma: no cover min_size=3, max_size=3, ), -) # type: ignore[misc] +) @pytest.mark.skipif(PANDAS_VERSION < (2, 0, 0), reason="requires pyarrow") @pytest.mark.slow def test_cross_join( # pragma: no cover @@ -119,7 +123,7 @@ def test_cross_join( # pragma: no cover assert_frame_equal(dframe_pd1, dframe_pd2) -@given( # type: ignore[misc] +@given( a_left_data=st.lists(st.integers(min_value=0, max_value=5), min_size=3, max_size=3), b_left_data=st.lists(st.integers(min_value=0, max_value=5), min_size=3, max_size=3), c_left_data=st.lists(st.integers(min_value=0, max_value=5), min_size=3, max_size=3), diff --git a/tests/utils_test.py b/tests/utils_test.py index 66046186de..aec7c652a4 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -257,7 +257,7 @@ def test_get_trivial_version_with_uninstalled_module() -> None: assert result == (0, 0, 0) -@given(n_bytes=st.integers(1, 100)) # type: ignore[misc] +@given(n_bytes=st.integers(1, 100)) @pytest.mark.slow def test_generate_temporary_column_name(n_bytes: int) -> None: columns = ["abc", "XYZ"]