From 74c6ba216b1e7796a07c66b50a6689895da3c6a8 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 22 Nov 2023 17:59:46 +0100 Subject: [PATCH] depr(python): Rename `map_dict` to `replace` and change default behavior (#12599) --- .../reference/expressions/modify_select.rst | 1 + .../source/reference/series/computation.rst | 1 + py-polars/polars/expr/expr.py | 308 +++++++------- py-polars/polars/series/series.py | 131 +++--- py-polars/polars/utils/udfs.py | 4 +- py-polars/polars/utils/various.py | 5 +- .../map/test_inefficient_map_warning.py | 6 +- .../tests/unit/operations/test_replace.py | 389 ++++++++++++++++++ py-polars/tests/unit/series/test_series.py | 58 --- py-polars/tests/unit/test_exprs.py | 312 -------------- 10 files changed, 630 insertions(+), 585 deletions(-) create mode 100644 py-polars/tests/unit/operations/test_replace.py diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index f18e992c2768..94b461d4c673 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -41,6 +41,7 @@ Manipulation/selection Expr.rechunk Expr.reinterpret Expr.repeat_by + Expr.replace Expr.reshape Expr.reverse Expr.rle diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index efaf1973e6b8..02007e99d13a 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -46,6 +46,7 @@ Computation Series.peak_max Series.peak_min Series.rank + Series.replace Series.rolling_apply Series.rolling_map Series.rolling_max diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index df4045ea1aec..995243061ab5 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -58,7 +58,7 @@ warn_closed_future_change, ) from polars.utils.meta import threadpool_size -from polars.utils.various import sphinx_accessor +from polars.utils.various import no_default, sphinx_accessor with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import arg_where as py_arg_where @@ -3985,8 +3985,8 @@ def map_batches( See Also -------- - map_dict map_elements + replace Examples -------- @@ -9012,173 +9012,124 @@ def cache(self) -> Self: """ return self - def map_dict( + def replace( self, - remapping: dict[Any, Any], + mapping: dict[Any, Any], *, - default: Any = None, + default: Any = no_default, return_dtype: PolarsDataType | None = None, ) -> Self: """ - Replace values in column according to remapping dictionary. + Replace values according to the given mapping. Needs a global string cache for lazily evaluated queries on columns of - type `pl.Categorical`. + type `Categorical`. Parameters ---------- - remapping - Dictionary containing the before/after values to map. + mapping + Mapping of values to their replacement. default - Value to use when the remapping dict does not contain the lookup value. - Accepts expression input. Non-expression inputs are parsed as literals. - Use `pl.first()`, to keep the original value. + Value to use when the mapping does not contain the lookup value. + Defaults to keeping the original value. Accepts expression input. + Non-expression inputs are parsed as literals. return_dtype Set return dtype to override automatic return dtype determination. See Also -------- - map + str.replace Examples -------- - >>> country_code_dict = { + Replace a single value by another value. Values not in the mapping remain + unchanged. + + >>> df = pl.DataFrame({"a": [1, 2, 2, 3]}) + >>> df.with_columns(pl.col("a").replace({2: 100}).alias("replaced")) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════════╡ + │ 1 ┆ 1 │ + │ 2 ┆ 100 │ + │ 2 ┆ 100 │ + │ 3 ┆ 3 │ + └─────┴──────────┘ + + Replace multiple values. Specify a default to set values not in the given map + to the default value. + + >>> df = pl.DataFrame({"country_code": ["FR", "ES", "DE", None]}) + >>> country_code_map = { ... "CA": "Canada", ... "DE": "Germany", ... "FR": "France", - ... None: "Not specified", + ... None: "unspecified", ... } - >>> df = pl.DataFrame( - ... { - ... "country_code": ["FR", None, "ES", "DE"], - ... } - ... ).with_row_count() - >>> df - shape: (4, 2) - ┌────────┬──────────────┐ - │ row_nr ┆ country_code │ - │ --- ┆ --- │ - │ u32 ┆ str │ - ╞════════╪══════════════╡ - │ 0 ┆ FR │ - │ 1 ┆ null │ - │ 2 ┆ ES │ - │ 3 ┆ DE │ - └────────┴──────────────┘ - - >>> df.with_columns( - ... pl.col("country_code").map_dict(country_code_dict).alias("remapped") - ... ) - shape: (4, 3) - ┌────────┬──────────────┬───────────────┐ - │ row_nr ┆ country_code ┆ remapped │ - │ --- ┆ --- ┆ --- │ - │ u32 ┆ str ┆ str │ - ╞════════╪══════════════╪═══════════════╡ - │ 0 ┆ FR ┆ France │ - │ 1 ┆ null ┆ Not specified │ - │ 2 ┆ ES ┆ null │ - │ 3 ┆ DE ┆ Germany │ - └────────┴──────────────┴───────────────┘ - - Set a default value for values that cannot be mapped... - >>> df.with_columns( ... pl.col("country_code") - ... .map_dict(country_code_dict, default="unknown") - ... .alias("remapped") + ... .replace(country_code_map, default=None) + ... .alias("replaced") ... ) - shape: (4, 3) - ┌────────┬──────────────┬───────────────┐ - │ row_nr ┆ country_code ┆ remapped │ - │ --- ┆ --- ┆ --- │ - │ u32 ┆ str ┆ str │ - ╞════════╪══════════════╪═══════════════╡ - │ 0 ┆ FR ┆ France │ - │ 1 ┆ null ┆ Not specified │ - │ 2 ┆ ES ┆ unknown │ - │ 3 ┆ DE ┆ Germany │ - └────────┴──────────────┴───────────────┘ - - ...or keep the original value, by making use of `pl.first()`: - - >>> df.with_columns( - ... pl.col("country_code") - ... .map_dict(country_code_dict, default=pl.first()) - ... .alias("remapped") - ... ) - shape: (4, 3) - ┌────────┬──────────────┬───────────────┐ - │ row_nr ┆ country_code ┆ remapped │ - │ --- ┆ --- ┆ --- │ - │ u32 ┆ str ┆ str │ - ╞════════╪══════════════╪═══════════════╡ - │ 0 ┆ FR ┆ France │ - │ 1 ┆ null ┆ Not specified │ - │ 2 ┆ ES ┆ ES │ - │ 3 ┆ DE ┆ Germany │ - └────────┴──────────────┴───────────────┘ - - ...or keep the original value, by explicitly referring to the column: - - >>> df.with_columns( - ... pl.col("country_code") - ... .map_dict(country_code_dict, default=pl.col("country_code")) - ... .alias("remapped") + shape: (4, 2) + ┌──────────────┬─────────────┐ + │ country_code ┆ replaced │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞══════════════╪═════════════╡ + │ FR ┆ France │ + │ ES ┆ null │ + │ DE ┆ Germany │ + │ null ┆ unspecified │ + └──────────────┴─────────────┘ + + The return type can be overridden with the `return_dtype` argument. + + >>> df = df.with_row_count() + >>> df.select( + ... "row_nr", + ... pl.col("row_nr") + ... .replace({1: 10, 2: 20}, default=0, return_dtype=pl.UInt8) + ... .alias("replaced"), ... ) - shape: (4, 3) - ┌────────┬──────────────┬───────────────┐ - │ row_nr ┆ country_code ┆ remapped │ - │ --- ┆ --- ┆ --- │ - │ u32 ┆ str ┆ str │ - ╞════════╪══════════════╪═══════════════╡ - │ 0 ┆ FR ┆ France │ - │ 1 ┆ null ┆ Not specified │ - │ 2 ┆ ES ┆ ES │ - │ 3 ┆ DE ┆ Germany │ - └────────┴──────────────┴───────────────┘ - - If you need to access different columns to set a default value, a struct needs - to be constructed; in the first field is the column that you want to remap and - the rest of the fields are the other columns used in the default expression. + shape: (4, 2) + ┌────────┬──────────┐ + │ row_nr ┆ replaced │ + │ --- ┆ --- │ + │ u32 ┆ u8 │ + ╞════════╪══════════╡ + │ 0 ┆ 0 │ + │ 1 ┆ 10 │ + │ 2 ┆ 20 │ + │ 3 ┆ 0 │ + └────────┴──────────┘ + + To reference other columns as a `default` value, a struct column must be + constructed first. The first field must be the column in which values are + replaced. The other columns can be used in the default expression. >>> df.with_columns( - ... pl.struct(pl.col(["country_code", "row_nr"])).map_dict( - ... remapping=country_code_dict, + ... pl.struct("country_code", "row_nr") + ... .replace( + ... mapping=country_code_map, ... default=pl.col("row_nr").cast(pl.Utf8), ... ) - ... ) - shape: (4, 2) - ┌────────┬───────────────┐ - │ row_nr ┆ country_code │ - │ --- ┆ --- │ - │ u32 ┆ str │ - ╞════════╪═══════════════╡ - │ 0 ┆ France │ - │ 1 ┆ Not specified │ - │ 2 ┆ 2 │ - │ 3 ┆ Germany │ - └────────┴───────────────┘ - - Override return dtype: - - >>> df.with_columns( - ... pl.col("row_nr") - ... .map_dict({1: 7, 3: 4}, default=3, return_dtype=pl.UInt8) - ... .alias("remapped") + ... .alias("replaced") ... ) shape: (4, 3) - ┌────────┬──────────────┬──────────┐ - │ row_nr ┆ country_code ┆ remapped │ - │ --- ┆ --- ┆ --- │ - │ u32 ┆ str ┆ u8 │ - ╞════════╪══════════════╪══════════╡ - │ 0 ┆ FR ┆ 3 │ - │ 1 ┆ null ┆ 7 │ - │ 2 ┆ ES ┆ 3 │ - │ 3 ┆ DE ┆ 4 │ - └────────┴──────────────┴──────────┘ - + ┌────────┬──────────────┬─────────────┐ + │ row_nr ┆ country_code ┆ replaced │ + │ --- ┆ --- ┆ --- │ + │ u32 ┆ str ┆ str │ + ╞════════╪══════════════╪═════════════╡ + │ 0 ┆ FR ┆ France │ + │ 1 ┆ ES ┆ 1 │ + │ 2 ┆ DE ┆ Germany │ + │ 3 ┆ null ┆ unspecified │ + └────────┴──────────────┴─────────────┘ """ def _remap_key_or_value_series( @@ -9191,9 +9142,9 @@ def _remap_key_or_value_series( is_keys: bool, ) -> Series: """ - Convert remapping keys or remapping values to `Series` with `dtype`. + Convert mapping keys or mapping values to `Series` with `dtype`. - Try to convert the remapping keys or remapping values to `Series` with + Try to convert the mapping keys or mapping values to `Series` with the specified dtype and check that none of the values are accidentally lost (replaced by nulls) during the conversion. @@ -9202,7 +9153,7 @@ def _remap_key_or_value_series( name Name of the keys or values Series. values - Values for the Series: `remapping.keys()` or `remapping.values()`. + Values for the Series: `mapping.keys()` or `mapping.values()`. dtype User specified dtype. If None, dtype_if_empty @@ -9210,16 +9161,16 @@ def _remap_key_or_value_series( or a list with only None values, set the Polars dtype of the Series data. dtype_keys - If user set dtype is None, try to see if Series for remapping.values() - can be converted to same dtype as the remapping.keys() Series dtype. + If user set dtype is None, try to see if Series for mapping.values() + can be converted to same dtype as the mapping.keys() Series dtype. is_keys - If values contains keys or values from remapping dict. + If values contains keys or values from mapping dict. """ try: if dtype is None: # If no dtype was set, which should only happen when: - # values = remapping.values() + # values = mapping.values() # create a Series from those values and infer the dtype. s = pl.Series( name, @@ -9251,13 +9202,13 @@ def _remap_key_or_value_series( ) if dtype != s.dtype: raise ValueError( - f"remapping values for `map_dict` could not be converted to {dtype!r}: found {s.dtype!r}" + f"mapping values for `replace` could not be converted to {dtype!r}: found {s.dtype!r}" ) else: # dtype was set, which should always be the case when: - # values = remapping.keys() + # values = mapping.keys() # and in cases where the user set the output dtype when: - # values = remapping.values() + # values = mapping.values() s = pl.Series( name, values, @@ -9267,38 +9218,38 @@ def _remap_key_or_value_series( ) if dtype != s.dtype: raise ValueError( - f"remapping {'keys' if is_keys else 'values'} for `map_dict` could not be converted to {dtype!r}: found {s.dtype!r}" + f"mapping {'keys' if is_keys else 'values'} for `replace` could not be converted to {dtype!r}: found {s.dtype!r}" ) except OverflowError as exc: if is_keys: raise ValueError( - f"remapping keys for `map_dict` could not be converted to {dtype!r}: {exc!s}" + f"mapping keys for `replace` could not be converted to {dtype!r}: {exc!s}" ) from exc else: raise ValueError( - f"choose a more suitable output dtype for `map_dict` as remapping value could not be converted to {dtype!r}: {exc!s}" + f"choose a more suitable output dtype for `replace` as mapping value could not be converted to {dtype!r}: {exc!s}" ) from exc if is_keys: - # values = remapping.keys() + # values = mapping.keys() if s.null_count() == 0: # noqa: SIM114 pass - elif s.null_count() == 1 and None in remapping: + elif s.null_count() == 1 and None in mapping: pass else: raise ValueError( - f"remapping keys for `map_dict` could not be converted to {dtype!r} without losing values in the conversion" + f"mapping keys for `replace` could not be converted to {dtype!r} without losing values in the conversion" ) else: - # values = remapping.values() + # values = mapping.values() if s.null_count() == 0: # noqa: SIM114 pass elif s.len() - s.null_count() == len(list(filter(None, values))): pass else: raise ValueError( - f"remapping values for `map_dict` could not be converted to {dtype!r} without losing values in the conversion" + f"remapping values for `replace` could not be converted to {dtype!r} without losing values in the conversion" ) return s @@ -9326,7 +9277,7 @@ def inner_func(s: Series, default_value: Any = None) -> Series: ) remap_key_s = _remap_key_or_value_series( name=remap_key_column, - values=remapping.keys(), + values=mapping.keys(), dtype=input_dtype, dtype_if_empty=input_dtype, dtype_keys=input_dtype, @@ -9336,7 +9287,7 @@ def inner_func(s: Series, default_value: Any = None) -> Series: # Create remap value Series with specified output dtype. remap_value_s = pl.Series( remap_value_column, - remapping.values(), + mapping.values(), dtype=return_dtype_, dtype_if_empty=input_dtype, ) @@ -9346,7 +9297,7 @@ def inner_func(s: Series, default_value: Any = None) -> Series: # Series is pl.Utf8 and remap key Series is pl.Categorical). remap_value_s = _remap_key_or_value_series( name=remap_value_column, - values=remapping.values(), + values=mapping.values(), dtype=None, dtype_if_empty=input_dtype, dtype_keys=input_dtype, @@ -9374,8 +9325,11 @@ def inner_func(s: Series, default_value: Any = None) -> Series: return mapped.collect(no_optimization=True).to_series(index=result_index) - remapping_func = partial(inner_func, default_value=default) - return self.map_batches(function=remapping_func, return_dtype=return_dtype) + if default is no_default: + default = F.first() + + mapping_func = partial(inner_func, default_value=default) + return self.map_batches(function=mapping_func, return_dtype=return_dtype) @deprecate_renamed_function("map_batches", version="0.19.0") def map( @@ -9795,6 +9749,42 @@ def cumcount(self, *, reverse: bool = False) -> Self: """ return self.cum_count(reverse=reverse) + @deprecate_function( + "It has been renamed to `replace`." + " The default behavior has changed to keep any values not present in the mapping unchanged." + " Pass `default=None` to keep existing behavior.", + version="0.19.16", + ) + @deprecate_renamed_parameter("remapping", "mapping", version="0.19.16") + def map_dict( + self, + mapping: dict[Any, Any], + *, + default: Any = None, + return_dtype: PolarsDataType | None = None, + ) -> Self: + """ + Replace values in column according to remapping dictionary. + + .. deprecated:: 0.19.16 + This method has been renamed to :meth:`replace`. The default behavior + has changed to keep any values not present in the mapping unchanged. + Pass `default=None` to keep existing behavior. + + Parameters + ---------- + mapping + Dictionary containing the before/after values to map. + default + Value to use when the remapping dict does not contain the lookup value. + Accepts expression input. Non-expression inputs are parsed as literals. + Use `pl.first()`, to keep the original value. + return_dtype + Set return dtype to override automatic return dtype determination. + + """ + return self.replace(mapping, default=default, return_dtype=return_dtype) + @property def bin(self) -> ExprBinaryNameSpace: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a943dbae17a8..72680c04b40f 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -94,6 +94,7 @@ from polars.utils.meta import get_index_type from polars.utils.various import ( _is_generator, + no_default, parse_percentiles, parse_version, range_to_series, @@ -6228,83 +6229,81 @@ def upper_bound(self) -> Self: """ - def map_dict( + def replace( self, - remapping: dict[Any, Any], + mapping: dict[Any, Any], *, - default: Any = None, + default: Any = no_default, return_dtype: PolarsDataType | None = None, ) -> Self: """ - Replace values in the Series using a remapping dictionary. + Replace values according to the given mapping. + + Needs a global string cache for lazily evaluated queries on columns of + type `Categorical`. Parameters ---------- - remapping - Dictionary containing the before/after values to map. + mapping + Mapping of values to their replacement. default - Value to use when the remapping dict does not contain the lookup value. - Use `pl.first()`, to keep the original value. + Value to use when the mapping does not contain the lookup value. + Defaults to keeping the original value. return_dtype Set return dtype to override automatic return dtype determination. - Examples + See Also -------- - >>> s = pl.Series("iso3166", ["TUR", "???", "JPN", "NLD"]) - >>> country_lookup = { - ... "JPN": "Japan", - ... "TUR": "Türkiye", - ... "NLD": "Netherlands", - ... } + str.replace - Remap, setting a default for unrecognised values... + Examples + -------- + Replace a single value by another value. Values not in the mapping remain + unchanged. - >>> s.map_dict(country_lookup, default="Unspecified").alias("country_name") + >>> s = pl.Series("a", [1, 2, 2, 3]) + >>> s.replace({2: 100}) shape: (4,) - Series: 'country_name' [str] + Series: 'a' [i64] [ - "Türkiye" - "Unspecified" - "Japan" - "Netherlands" + 1 + 100 + 100 + 3 ] - ...or keep the original value, by making use of `pl.first()`: + Replace multiple values. Specify a default to set values not in the given map + to the default value. - >>> s.map_dict(country_lookup, default=pl.first()).alias("country_name") + >>> s = pl.Series("country_code", ["FR", "ES", "DE", None]) + >>> country_code_map = { + ... "CA": "Canada", + ... "DE": "Germany", + ... "FR": "France", + ... None: "unspecified", + ... } + >>> s.replace(country_code_map, default=None) shape: (4,) - Series: 'country_name' [str] + Series: 'country_code' [str] [ - "Türkiye" - "???" - "Japan" - "Netherlands" + "France" + null + "Germany" + "unspecified" ] - ...or keep the original value, by assigning the input series: + The return type can be overridden with the `return_dtype` argument. - >>> s.map_dict(country_lookup, default=s).alias("country_name") + >>> s = pl.Series("a", [0, 1, 2, 3]) + >>> s.replace({1: 10, 2: 20}, default=0, return_dtype=pl.UInt8) shape: (4,) - Series: 'country_name' [str] + Series: 'a' [u8] [ - "Türkiye" - "???" - "Japan" - "Netherlands" - ] - - Override return dtype: - - >>> s = pl.Series("int8", [5, 2, 3], dtype=pl.Int8) - >>> s.map_dict({2: 7}, default=pl.first(), return_dtype=pl.Int16) - shape: (3,) - Series: 'int8' [i16] - [ - 5 - 7 - 3 + 0 + 10 + 20 + 0 ] - """ def reshape(self, dimensions: tuple[int, ...]) -> Series: @@ -7136,6 +7135,40 @@ def view(self, *, ignore_nulls: bool = False) -> SeriesView: """ return self._view(ignore_nulls=ignore_nulls) + @deprecate_function( + "It has been renamed to `replace`." + " The default behavior has changed to keep any values not present in the mapping unchanged." + " Pass `default=None` to keep existing behavior.", + version="0.19.16", + ) + @deprecate_renamed_parameter("remapping", "mapping", version="0.19.16") + def map_dict( + self, + mapping: dict[Any, Any], + *, + default: Any = None, + return_dtype: PolarsDataType | None = None, + ) -> Self: + """ + Replace values in the Series using a remapping dictionary. + + .. deprecated:: 0.19.16 + This method has been renamed to :meth:`replace`. The default behavior + has changed to keep any values not present in the mapping unchanged. + Pass `default=None` to keep existing behavior. + + Parameters + ---------- + mapping + Dictionary containing the before/after values to map. + default + Value to use when the remapping dict does not contain the lookup value. + Use `pl.first()`, to keep the original value. + return_dtype + Set return dtype to override automatic return dtype determination. + """ + return self.replace(mapping, default=default, return_dtype=return_dtype) + # Keep the `list` and `str` properties below at the end of the definition of Series, # as to not confuse mypy with the type annotation `str` and `list` diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index c2a12122cc1e..773dd78ef073 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -483,7 +483,7 @@ def op(inst: Instruction) -> str: elif inst.opname in OpNames.UNARY: return OpNames.UNARY[inst.opname] elif inst.opname == "BINARY_SUBSCR": - return "map_dict" + return "replace" else: raise AssertionError( "unrecognized opname" @@ -520,7 +520,7 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str if " " in e1 else f"{not_}{e1}.is_in({e2})" ) - elif op == "map_dict": + elif op == "replace": if not self._caller_variables: self._caller_variables.update(_get_all_caller_variables()) if not isinstance(self._caller_variables.get(e1, None), dict): diff --git a/py-polars/polars/utils/various.py b/py-polars/polars/utils/various.py index 7ef46e24e371..fdcf1d44ffc9 100644 --- a/py-polars/polars/utils/various.py +++ b/py-polars/polars/utils/various.py @@ -341,8 +341,9 @@ def str_duration_(td: str | None) -> int | None: .cast(tp) ) elif tp == Boolean: - cast_cols[c] = F.col(c).map_dict( - remapping={"true": True, "false": False}, + cast_cols[c] = F.col(c).replace( + mapping={"true": True, "false": False}, + default=None, return_dtype=Boolean, ) elif tp in INTEGER_DTYPES: diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index f4c8ceaacc71..592cd99f8ad2 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -116,13 +116,13 @@ # --------------------------------------------- ("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_decode()'), # --------------------------------------------- - # map_dict + # replace # --------------------------------------------- - ("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'), + ("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace(MY_DICT)'), ( "a", "lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]", - '(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)', + '(pl.col("a") - 1).replace(MY_DICT) + (1 + pl.col("a")).replace(MY_DICT)', ), # --------------------------------------------- # standard library datetime parsing diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py new file mode 100644 index 000000000000..84a1bc3926b0 --- /dev/null +++ b/py-polars/tests/unit/operations/test_replace.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_replace_expr() -> None: + country_code_dict = { + "CA": "Canada", + "DE": "Germany", + "FR": "France", + None: "Not specified", + } + df = pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + ] + ) + result = df.with_columns( + pl.col("country_code").replace(country_code_dict).alias("replaced") + ) + expected = pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series( + "replaced", + ["France", "Not specified", "ES", "Germany"], + dtype=pl.Utf8, + ), + ] + ) + assert_frame_equal(result, expected) + + assert_frame_equal( + df.with_columns( + pl.col("country_code") + .replace(country_code_dict, default=pl.col("country_code")) + .alias("remapped") + ), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series( + "remapped", + ["France", "Not specified", "ES", "Germany"], + dtype=pl.Utf8, + ), + ] + ), + ) + + result = df.with_columns( + pl.col("country_code") + .replace(country_code_dict, default=None) + .alias("remapped") + ) + expected = pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series( + "remapped", + ["France", "Not specified", None, "Germany"], + dtype=pl.Utf8, + ), + ] + ) + assert_frame_equal(result, expected) + + assert_frame_equal( + df.with_row_count().with_columns( + pl.struct(pl.col(["country_code", "row_nr"])) + .replace( + country_code_dict, + default=pl.col("row_nr").cast(pl.Utf8), + ) + .alias("remapped") + ), + pl.DataFrame( + [ + pl.Series("row_nr", [0, 1, 2, 3], dtype=pl.UInt32), + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series( + "remapped", + ["France", "Not specified", "2", "Germany"], + dtype=pl.Utf8, + ), + ] + ), + ) + + with pl.StringCache(): + assert_frame_equal( + df.with_columns( + pl.col("country_code") + .cast(pl.Categorical) + .replace(country_code_dict, default=pl.col("country_code")) + .alias("remapped") + ), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series( + "remapped", + ["France", "Not specified", "ES", "Germany"], + dtype=pl.Categorical, + ), + ] + ), + ) + + df_categorical_lazy = df.lazy().with_columns( + pl.col("country_code").cast(pl.Categorical) + ) + + with pl.StringCache(): + assert_frame_equal( + df_categorical_lazy.with_columns( + pl.col("country_code") + .replace(country_code_dict, default=pl.col("country_code")) + .alias("remapped") + ).collect(), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series( + "country_code", ["FR", None, "ES", "DE"], dtype=pl.Categorical + ), + pl.Series( + "remapped", + ["France", "Not specified", "ES", "Germany"], + dtype=pl.Categorical, + ), + ] + ), + ) + + int_to_int_dict = {1: 5, 3: 7} + + assert_frame_equal( + df.with_columns(pl.col("int").replace(int_to_int_dict).alias("remapped")), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", [None, 5, None, 7], dtype=pl.Int16), + ] + ), + ) + + int_dict = {1: "b", 3: "d"} + + assert_frame_equal( + df.with_columns(pl.col("int").replace(int_dict).alias("remapped")), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", [None, "b", None, "d"], dtype=pl.Utf8), + ] + ), + ) + + int_with_none_dict = {1: "b", 3: "d", None: "e"} + + assert_frame_equal( + df.with_columns(pl.col("int").replace(int_with_none_dict).alias("remapped")), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", ["e", "b", "e", "d"], dtype=pl.Utf8), + ] + ), + ) + + int_with_only_none_values_dict = {3: None} + + assert_frame_equal( + df.with_columns( + pl.col("int") + .replace(int_with_only_none_values_dict, default=6) + .alias("remapped") + ), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", [6, 6, 6, None], dtype=pl.Int16), + ] + ), + ) + + assert_frame_equal( + df.with_columns( + pl.col("int") + .replace(int_with_only_none_values_dict, default=6, return_dtype=pl.Int32) + .alias("remapped") + ), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", [6, 6, 6, None], dtype=pl.Int32), + ] + ), + ) + + result = df.with_columns( + pl.col("int") + .replace(int_with_only_none_values_dict, default=None) + .alias("remapped") + ) + expected = pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", [None, None, None, None], dtype=pl.Int16), + ] + ) + assert_frame_equal(result, expected) + + empty_dict: dict[Any, Any] = {} + + assert_frame_equal( + df.with_columns(pl.col("int").replace(empty_dict).alias("remapped")), + pl.DataFrame( + [ + pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), + pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), + pl.Series("remapped", [None, 1, None, 3], dtype=pl.Int16), + ] + ), + ) + + float_dict = {1.0: "b", 3.0: "d"} + + with pytest.raises( + pl.ComputeError, + match=".*'float' object cannot be interpreted as an integer", + ): + df.with_columns(pl.col("int").replace(float_dict)) + + df_int_as_str = df.with_columns(pl.col("int").cast(pl.Utf8)) + + with pytest.raises( + pl.ComputeError, + match="mapping keys for `replace` could not be converted to Utf8 without losing values in the conversion", + ): + df_int_as_str.with_columns(pl.col("int").replace(int_dict)) + + with pytest.raises( + pl.ComputeError, + match="mapping keys for `replace` could not be converted to Utf8 without losing values in the conversion", + ): + df_int_as_str.with_columns(pl.col("int").replace(int_with_none_dict)) + + # 7132 + df = pl.DataFrame({"text": ["abc"]}) + mapper = {"abc": "123"} + assert_frame_equal( + df.select(pl.col("text").replace(mapper).str.replace_all("1", "-")), + pl.DataFrame( + [ + pl.Series("text", ["-23"], dtype=pl.Utf8), + ] + ), + ) + + result = pl.DataFrame( + [ + pl.Series("float_to_boolean", [1.0, None]), + pl.Series("boolean_to_int", [True, False]), + pl.Series("boolean_to_str", [True, False]), + ] + ).with_columns( + pl.col("float_to_boolean").replace({1.0: True}, default=None), + pl.col("boolean_to_int").replace({True: 1, False: 0}, default=None), + pl.col("boolean_to_str").replace({True: "1", False: "0"}, default=None), + ) + expected = pl.DataFrame( + [ + pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), + pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), + pl.Series("boolean_to_str", ["1", "0"], dtype=pl.Utf8), + ] + ) + assert_frame_equal(result, expected) + + lf = pl.LazyFrame({"a": [1, 2, 3]}) + assert_frame_equal( + lf.select( + pl.col("a").cast(pl.UInt8).replace({1: 11, 2: 22}, default=99) + ).collect(), + pl.DataFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}), + ) + + df = ( + pl.LazyFrame({"a": ["one", "two"]}) + .with_columns( + pl.col("a").replace({"one": 1}, default=None, return_dtype=pl.UInt32) + ) + .fill_null(999) + .collect() + ) + assert_frame_equal( + df, pl.DataFrame({"a": [1, 999]}, schema_overrides={"a": pl.UInt32}) + ) + + +def test_replace_series() -> None: + s = pl.Series("s", [-1, 2, None, 4, -5]) + remap = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + assert_series_equal( + s.abs().replace(remap, default="?"), + pl.Series("s", ["one", "two", "?", "four", "five"]), + ) + assert_series_equal( + s.replace(remap, default=s.cast(pl.Utf8)), + pl.Series("s", ["-1", "two", None, "four", "-5"]), + ) + + remap_int = {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} + + assert_series_equal( + s.replace(remap_int), + pl.Series("s", [-1, 22, None, 44, -5]), + ) + + assert_series_equal( + s.cast(pl.Int16).replace(remap_int, default=None), + pl.Series("s", [None, 22, None, 44, None], dtype=pl.Int16), + ) + + assert_series_equal( + s.cast(pl.Int16).replace(remap_int), + pl.Series("s", [-1, 22, None, 44, -5], dtype=pl.Int16), + ) + + assert_series_equal( + s.cast(pl.Int16).replace(remap_int, return_dtype=pl.Float32), + pl.Series("s", [-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32), + ) + + assert_series_equal( + s.cast(pl.Int16).replace(remap_int, default=9), + pl.Series("s", [9, 22, 9, 44, 9], dtype=pl.Int16), + ) + + assert_series_equal( + s.cast(pl.Int16).replace(remap_int, default=9, return_dtype=pl.Float32), + pl.Series("s", [9.0, 22.0, 9.0, 44.0, 9.0], dtype=pl.Float32), + ) + + assert_series_equal( + pl.Series("boolean_to_int", [True, False]).replace( + {True: 1, False: 0}, default=None + ), + pl.Series("boolean_to_int", [1, 0]), + ) + + assert_series_equal( + pl.Series("boolean_to_str", [True, False]).replace( + {True: "1", False: "0"}, default=None + ), + pl.Series("boolean_to_str", ["1", "0"]), + ) + + +def test_map_dict_deprecated() -> None: + s = pl.Series("a", [1, 2, 3]) + with pytest.deprecated_call(): + result = s.map_dict({2: 100}) + expected = pl.Series("a", [None, 100, None]) + assert_series_equal(result, expected) + + with pytest.deprecated_call(): + result = s.to_frame().select(pl.col("a").map_dict({2: 100})).to_series() + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 3e2a36c90509..a785905dd8ae 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -2672,64 +2672,6 @@ def test_is_between() -> None: ] -def test_map_dict() -> None: - s = pl.Series("s", [-1, 2, None, 4, -5]) - remap = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - assert_series_equal( - s.abs().map_dict(remap, default="?"), - pl.Series("s", ["one", "two", "?", "four", "five"]), - ) - assert_series_equal( - s.map_dict(remap, default=s.cast(pl.Utf8)), - pl.Series("s", ["-1", "two", None, "four", "-5"]), - ) - - remap_int = {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} - - assert_series_equal( - s.map_dict(remap_int, default=pl.first()), - pl.Series("s", [-1, 22, None, 44, -5]), - ) - - assert_series_equal( - s.cast(pl.Int16).map_dict(remap_int), - pl.Series("s", [None, 22, None, 44, None], dtype=pl.Int16), - ) - - assert_series_equal( - s.cast(pl.Int16).map_dict(remap_int, default=pl.first()), - pl.Series("s", [-1, 22, None, 44, -5], dtype=pl.Int16), - ) - - assert_series_equal( - s.cast(pl.Int16).map_dict( - remap_int, default=pl.first(), return_dtype=pl.Float32 - ), - pl.Series("s", [-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32), - ) - - assert_series_equal( - s.cast(pl.Int16).map_dict(remap_int, default=9), - pl.Series("s", [9, 22, 9, 44, 9], dtype=pl.Int16), - ) - - assert_series_equal( - s.cast(pl.Int16).map_dict(remap_int, default=9, return_dtype=pl.Float32), - pl.Series("s", [9.0, 22.0, 9.0, 44.0, 9.0], dtype=pl.Float32), - ) - - assert_series_equal( - pl.Series("boolean_to_int", [True, False]).map_dict({True: 1, False: 0}), - pl.Series("boolean_to_int", [1, 0]), - ) - - assert_series_equal( - pl.Series("boolean_to_str", [True, False]).map_dict({True: "1", False: "0"}), - pl.Series("boolean_to_str", ["1", "0"]), - ) - - @pytest.mark.parametrize( ("dtype", "lower", "upper"), [ diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 1054e24e8e26..ff5ebe1d84af 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -485,318 +485,6 @@ def test_ewm_with_multiple_chunks() -> None: assert ewm_std.null_count().sum_horizontal()[0] == 4 -def test_map_dict() -> None: - country_code_dict = { - "CA": "Canada", - "DE": "Germany", - "FR": "France", - None: "Not specified", - } - df = pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - ] - ) - - assert_frame_equal( - df.with_columns( - pl.col("country_code") - .map_dict(country_code_dict, default=pl.first()) - .alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series( - "remapped", - ["France", "Not specified", "ES", "Germany"], - dtype=pl.Utf8, - ), - ] - ), - ) - - assert_frame_equal( - df.with_columns( - pl.col("country_code") - .map_dict(country_code_dict, default=pl.col("country_code")) - .alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series( - "remapped", - ["France", "Not specified", "ES", "Germany"], - dtype=pl.Utf8, - ), - ] - ), - ) - - assert_frame_equal( - df.with_columns( - pl.col("country_code").map_dict(country_code_dict).alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series( - "remapped", - ["France", "Not specified", None, "Germany"], - dtype=pl.Utf8, - ), - ] - ), - ) - - assert_frame_equal( - df.with_row_count().with_columns( - pl.struct(pl.col(["country_code", "row_nr"])) - .map_dict( - country_code_dict, - default=pl.col("row_nr").cast(pl.Utf8), - ) - .alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("row_nr", [0, 1, 2, 3], dtype=pl.UInt32), - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series( - "remapped", - ["France", "Not specified", "2", "Germany"], - dtype=pl.Utf8, - ), - ] - ), - ) - - with pl.StringCache(): - assert_frame_equal( - df.with_columns( - pl.col("country_code") - .cast(pl.Categorical) - .map_dict(country_code_dict, default=pl.col("country_code")) - .alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series( - "remapped", - ["France", "Not specified", "ES", "Germany"], - dtype=pl.Categorical, - ), - ] - ), - ) - - df_categorical_lazy = df.lazy().with_columns( - pl.col("country_code").cast(pl.Categorical) - ) - - with pl.StringCache(): - assert_frame_equal( - df_categorical_lazy.with_columns( - pl.col("country_code") - .map_dict(country_code_dict, default=pl.col("country_code")) - .alias("remapped") - ).collect(), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series( - "country_code", ["FR", None, "ES", "DE"], dtype=pl.Categorical - ), - pl.Series( - "remapped", - ["France", "Not specified", "ES", "Germany"], - dtype=pl.Categorical, - ), - ] - ), - ) - - int_to_int_dict = {1: 5, 3: 7} - - assert_frame_equal( - df.with_columns(pl.col("int").map_dict(int_to_int_dict).alias("remapped")), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", [None, 5, None, 7], dtype=pl.Int16), - ] - ), - ) - - int_dict = {1: "b", 3: "d"} - - assert_frame_equal( - df.with_columns(pl.col("int").map_dict(int_dict).alias("remapped")), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", [None, "b", None, "d"], dtype=pl.Utf8), - ] - ), - ) - - int_with_none_dict = {1: "b", 3: "d", None: "e"} - - assert_frame_equal( - df.with_columns(pl.col("int").map_dict(int_with_none_dict).alias("remapped")), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", ["e", "b", "e", "d"], dtype=pl.Utf8), - ] - ), - ) - - int_with_only_none_values_dict = {3: None} - - assert_frame_equal( - df.with_columns( - pl.col("int") - .map_dict(int_with_only_none_values_dict, default=6) - .alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", [6, 6, 6, None], dtype=pl.Int16), - ] - ), - ) - - assert_frame_equal( - df.with_columns( - pl.col("int") - .map_dict(int_with_only_none_values_dict, default=6, return_dtype=pl.Int32) - .alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", [6, 6, 6, None], dtype=pl.Int32), - ] - ), - ) - - assert_frame_equal( - df.with_columns( - pl.col("int").map_dict(int_with_only_none_values_dict).alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", [None, None, None, None], dtype=pl.Int16), - ] - ), - ) - - empty_dict: dict[Any, Any] = {} - - assert_frame_equal( - df.with_columns( - pl.col("int").map_dict(empty_dict, default=pl.first()).alias("remapped") - ), - pl.DataFrame( - [ - pl.Series("int", [None, 1, None, 3], dtype=pl.Int16), - pl.Series("country_code", ["FR", None, "ES", "DE"], dtype=pl.Utf8), - pl.Series("remapped", [None, 1, None, 3], dtype=pl.Int16), - ] - ), - ) - - float_dict = {1.0: "b", 3.0: "d"} - - with pytest.raises( - pl.ComputeError, - match=".*'float' object cannot be interpreted as an integer", - ): - df.with_columns(pl.col("int").map_dict(float_dict)) - - df_int_as_str = df.with_columns(pl.col("int").cast(pl.Utf8)) - - with pytest.raises( - pl.ComputeError, - match="remapping keys for `map_dict` could not be converted to Utf8 without losing values in the conversion", - ): - df_int_as_str.with_columns(pl.col("int").map_dict(int_dict)) - - with pytest.raises( - pl.ComputeError, - match="remapping keys for `map_dict` could not be converted to Utf8 without losing values in the conversion", - ): - df_int_as_str.with_columns(pl.col("int").map_dict(int_with_none_dict)) - - # 7132 - df = pl.DataFrame({"text": ["abc"]}) - mapper = {"abc": "123"} - assert_frame_equal( - df.select(pl.col("text").map_dict(mapper).str.replace_all("1", "-")), - pl.DataFrame( - [ - pl.Series("text", ["-23"], dtype=pl.Utf8), - ] - ), - ) - - assert_frame_equal( - pl.DataFrame( - [ - pl.Series("float_to_boolean", [1.0, None]), - pl.Series("boolean_to_int", [True, False]), - pl.Series("boolean_to_str", [True, False]), - ] - ).with_columns( - pl.col("float_to_boolean").map_dict({1.0: True}), - pl.col("boolean_to_int").map_dict({True: 1, False: 0}), - pl.col("boolean_to_str").map_dict({True: "1", False: "0"}), - ), - pl.DataFrame( - [ - pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), - pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), - pl.Series("boolean_to_str", ["1", "0"], dtype=pl.Utf8), - ] - ), - ) - - lf = pl.LazyFrame({"a": [1, 2, 3]}) - assert_frame_equal( - lf.select( - pl.col("a").cast(pl.UInt8).map_dict({1: 11, 2: 22}, default=99) - ).collect(), - pl.DataFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}), - ) - - df = ( - pl.LazyFrame({"a": ["one", "two"]}) - .with_columns(pl.col("a").map_dict({"one": 1}, return_dtype=pl.UInt32)) - .fill_null(999) - .collect() - ) - assert_frame_equal( - df, pl.DataFrame({"a": [1, 999]}, schema_overrides={"a": pl.UInt32}) - ) - - def test_lit_dtypes() -> None: def lit_series(value: Any, dtype: pl.PolarsDataType | None) -> pl.Series: return pl.select(pl.lit(value, dtype=dtype)).to_series()