From bf7fc2988200972ce3e93edb7e27404260e36607 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:11:49 +0100 Subject: [PATCH] fix: Resolve `Then` copy `TypeError` (#3553) --- altair/utils/schemapi.py | 3 +++ altair/vegalite/v5/api.py | 3 +++ tests/vegalite/v5/test_api.py | 21 +++++++++++++++++++++ tools/schemapi/schemapi.py | 3 +++ 4 files changed, 30 insertions(+) diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index b91b90fbe..bc0b40581 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -854,6 +854,9 @@ def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: copy = partial(_deep_copy, by_ref=by_ref) if isinstance(obj, SchemaBase): + if copier := getattr(obj, "__deepcopy__", None): + with debug_mode(False): + return copier(obj) args = (copy(arg) for arg in obj._args) kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} with debug_mode(False): diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index d352b060b..4e8fde039 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1061,6 +1061,9 @@ def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: # type: ignore[ m = super().to_dict(*args, **kwds) return _Conditional(condition=m["condition"]) + def __deepcopy__(self, memo: Any) -> Self: + return type(self)(_Conditional(condition=_deepcopy(self.condition))) + class ChainedWhen(_BaseWhen): """ diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 29d68d1ea..241d47378 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -698,6 +698,27 @@ def test_when_condition_parity( assert chart_condition == chart_when +def test_when_then_interactive() -> None: + """Copy-related regression found in https://github.com/vega/altair/pull/3394#issuecomment-2302995453.""" + source = "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/movies.json" + predicate = (alt.datum.IMDB_Rating == None) | ( # noqa: E711 + alt.datum.Rotten_Tomatoes_Rating == None # noqa: E711 + ) + + chart = ( + alt.Chart(source) + .mark_point(invalid=None) + .encode( + x="IMDB_Rating:Q", + y="Rotten_Tomatoes_Rating:Q", + color=alt.when(predicate).then(alt.value("grey")), # type: ignore[arg-type] + ) + ) + assert chart.interactive() + assert chart.copy() + assert chart.to_dict() + + def test_selection_to_dict(): brush = alt.selection_interval() diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index b6907ec8f..9d21ab793 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -852,6 +852,9 @@ def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: copy = partial(_deep_copy, by_ref=by_ref) if isinstance(obj, SchemaBase): + if copier := getattr(obj, "__deepcopy__", None): + with debug_mode(False): + return copier(obj) args = (copy(arg) for arg in obj._args) kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} with debug_mode(False):