From f542e9e8f1e4bf0fa4fe64e89dc5e72b94881cd2 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:18:51 +0100 Subject: [PATCH] feat: Adds `@register_theme` decorator (#3526) * feat: Adds `@register_theme` decorator Resolves one item in https://github.com/vega/altair/issues/3519 * build: run `update-init-file` Adds `@register_theme` to top-level * test: Adds `test_register_theme_decorator` * refactor(typing): Specify `dict[str, Any]` instead of `dict[Any, Any]` The latter may give false-positives for json-incompatible dicts --------- Co-authored-by: Stefan Binder --- altair/__init__.py | 1 + altair/utils/plugin_registry.py | 2 +- altair/utils/theme.py | 10 ++-- altair/vegalite/v5/__init__.py | 2 +- altair/vegalite/v5/theme.py | 94 +++++++++++++++++++++++++++++++-- tests/vegalite/v5/test_theme.py | 17 +++++- 6 files changed, 114 insertions(+), 12 deletions(-) diff --git a/altair/__init__.py b/altair/__init__.py index 8cf283e49..d4e20f02f 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -617,6 +617,7 @@ "mixins", "param", "parse_shorthand", + "register_theme", "renderers", "repeat", "sample", diff --git a/altair/utils/plugin_registry.py b/altair/utils/plugin_registry.py index 996c6623e..b2723396a 100644 --- a/altair/utils/plugin_registry.py +++ b/altair/utils/plugin_registry.py @@ -115,7 +115,7 @@ def __init__( self.entry_point_group: str = entry_point_group self.plugin_type: IsPlugin if plugin_type is not callable and isinstance(plugin_type, type): - msg = ( + msg: Any = ( f"Pass a callable `TypeIs` function to `plugin_type` instead.\n" f"{type(self).__name__!r}(plugin_type)\n\n" f"See also:\n" diff --git a/altair/utils/theme.py b/altair/utils/theme.py index 02372e690..47e5da6ad 100644 --- a/altair/utils/theme.py +++ b/altair/utils/theme.py @@ -3,9 +3,9 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Dict -from .plugin_registry import PluginRegistry +from .plugin_registry import Plugin, PluginRegistry if sys.version_info >= (3, 11): from typing import LiteralString @@ -16,10 +16,12 @@ from altair.utils.plugin_registry import PluginEnabler from altair.vegalite.v5.theme import AltairThemes, VegaThemes -ThemeType = Callable[..., dict] +ThemeType = Plugin[Dict[str, Any]] -class ThemeRegistry(PluginRegistry[ThemeType, dict]): +# HACK: See for `LiteralString` requirement in `name` +# https://github.com/vega/altair/pull/3526#discussion_r1743350127 +class ThemeRegistry(PluginRegistry[ThemeType, Dict[str, Any]]): def enable( self, name: LiteralString | AltairThemes | VegaThemes | None = None, **options ) -> PluginEnabler: diff --git a/altair/vegalite/v5/__init__.py b/altair/vegalite/v5/__init__.py index bc0703ec6..a18be6e11 100644 --- a/altair/vegalite/v5/__init__.py +++ b/altair/vegalite/v5/__init__.py @@ -21,4 +21,4 @@ renderers, ) from .schema import * -from .theme import themes +from .theme import register_theme, themes diff --git a/altair/vegalite/v5/theme.py b/altair/vegalite/v5/theme.py index c98826826..2e438679b 100644 --- a/altair/vegalite/v5/theme.py +++ b/altair/vegalite/v5/theme.py @@ -2,21 +2,33 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Final, Literal, get_args +import sys +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, Final, Literal, TypeVar, get_args from altair.utils.theme import ThemeRegistry from altair.vegalite.v5.schema._typing import VegaThemes -if TYPE_CHECKING: - import sys +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +if TYPE_CHECKING: + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias +P = ParamSpec("P") +R = TypeVar("R", bound=Dict[str, Any]) AltairThemes: TypeAlias = Literal["default", "opaque"] -VEGA_THEMES: list[str] = list(get_args(VegaThemes)) +VEGA_THEMES: list[LiteralString] = list(get_args(VegaThemes)) class VegaTheme: @@ -60,3 +72,77 @@ def __repr__(self) -> str: themes.register(theme, VegaTheme(theme)) themes.enable("default") + + +# HACK: See for `LiteralString` requirement in `name` +# https://github.com/vega/altair/pull/3526#discussion_r1743350127 +def register_theme( + name: LiteralString, *, enable: bool +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + Decorator for registering a theme function. + + Parameters + ---------- + name + Unique name assigned in ``alt.themes``. + enable + Auto-enable the wrapped theme. + + Examples + -------- + Register and enable a theme:: + + from __future__ import annotations + + from typing import Any + import altair as alt + + + @alt.register_theme("param_font_size", enable=True) + def custom_theme() -> dict[str, Any]: + sizes = 12, 14, 16, 18, 20 + return { + "autosize": {"contains": "content", "resize": True}, + "background": "#F3F2F1", + "config": { + "axisX": {"labelFontSize": sizes[1], "titleFontSize": sizes[1]}, + "axisY": {"labelFontSize": sizes[1], "titleFontSize": sizes[1]}, + "font": "'Lato', 'Segoe UI', Tahoma, Verdana, sans-serif", + "headerColumn": {"labelFontSize": sizes[1]}, + "headerFacet": {"labelFontSize": sizes[1]}, + "headerRow": {"labelFontSize": sizes[1]}, + "legend": {"labelFontSize": sizes[0], "titleFontSize": sizes[1]}, + "text": {"fontSize": sizes[0]}, + "title": {"fontSize": sizes[-1]}, + }, + "height": {"step": 28}, + "width": 350, + } + + Until another theme has been enabled, all charts will use defaults set in ``custom_theme``:: + + from vega_datasets import data + + source = data.stocks() + lines = ( + alt.Chart(source, title=alt.Title("Stocks")) + .mark_line() + .encode(x="date:T", y="price:Q", color="symbol:N") + ) + lines.interactive(bind_y=False) + + """ + + def decorate(func: Callable[P, R], /) -> Callable[P, R]: + themes.register(name, func) + if enable: + themes.enable(name) + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper + + return decorate diff --git a/tests/vegalite/v5/test_theme.py b/tests/vegalite/v5/test_theme.py index 0eab5546d..fa6be95ac 100644 --- a/tests/vegalite/v5/test_theme.py +++ b/tests/vegalite/v5/test_theme.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import pytest import altair.vegalite.v5 as alt -from altair.vegalite.v5.theme import VEGA_THEMES +from altair.vegalite.v5.theme import VEGA_THEMES, register_theme, themes @pytest.fixture @@ -9,7 +11,7 @@ def chart(): return alt.Chart("data.csv").mark_bar().encode(x="x:Q") -def test_vega_themes(chart): +def test_vega_themes(chart) -> None: for theme in VEGA_THEMES: with alt.themes.enable(theme): dct = chart.to_dict() @@ -17,3 +19,14 @@ def test_vega_themes(chart): assert dct["config"] == { "view": {"continuousWidth": 300, "continuousHeight": 300} } + + +def test_register_theme_decorator() -> None: + @register_theme("unique name", enable=True) + def custom_theme() -> dict[str, int]: + return {"height": 400, "width": 700} + + assert themes.active == "unique name" + registered = themes.get() + assert registered is not None + assert registered() == {"height": 400, "width": 700} == custom_theme()