Skip to content

Commit

Permalink
feat: Generate expr method signatures, docs (#3600)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Oct 12, 2024
1 parent 02ad17d commit 6cb633f
Show file tree
Hide file tree
Showing 9 changed files with 2,349 additions and 737 deletions.
1,789 changes: 1,141 additions & 648 deletions altair/expr/__init__.py

Large diffs are not rendered by default.

60 changes: 39 additions & 21 deletions tests/expr/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

import operator
import sys
from inspect import classify_class_attrs, getmembers
from typing import Any, Iterator
from inspect import classify_class_attrs, getmembers, signature
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast

import pytest
from jsonschema.exceptions import ValidationError

from altair import datum, expr, ExprRef
from altair.expr import _ConstExpressionType
from altair.expr import _ExprMeta
from altair.expr.core import Expression, GetAttrExpression

if TYPE_CHECKING:
from inspect import _IntrospectableCallable

T = TypeVar("T")

# This maps vega expression function names to the Python name
VEGA_REMAP = {"if_": "if"}
Expand All @@ -19,20 +25,29 @@ def _is_property(obj: Any, /) -> bool:
return isinstance(obj, property)


def _get_classmethod_names(tp: type[Any], /) -> Iterator[str]:
for m in classify_class_attrs(tp):
if m.kind == "class method" and m.defining_class is tp:
yield m.name
def _get_property_names(tp: type[Any], /) -> Iterator[str]:
for nm, _ in getmembers(tp, _is_property):
yield nm


def _remap_classmethod_names(tp: type[Any], /) -> Iterator[tuple[str, str]]:
for name in _get_classmethod_names(tp):
yield VEGA_REMAP.get(name, name), name
def signature_n_params(
obj: _IntrospectableCallable,
/,
*,
exclude: Iterable[str] = frozenset(("cls", "self")),
) -> int:
sig = signature(obj)
return len(set(sig.parameters).difference(exclude))


def _get_property_names(tp: type[Any], /) -> Iterator[str]:
for nm, _ in getmembers(tp, _is_property):
yield nm
def _iter_classmethod_specs(
tp: type[T], /
) -> Iterator[tuple[str, Callable[..., Expression], int]]:
for m in classify_class_attrs(tp):
if m.kind == "class method" and m.defining_class is tp:
name = m.name
fn = cast("classmethod[T, ..., Expression]", m.object).__func__
yield (VEGA_REMAP.get(name, name), fn.__get__(tp), signature_n_params(fn))


def test_unary_operations():
Expand Down Expand Up @@ -86,23 +101,26 @@ def test_abs():
assert repr(z) == "abs(datum.xxx)"


@pytest.mark.parametrize(("veganame", "methodname"), _remap_classmethod_names(expr))
def test_expr_funcs(veganame: str, methodname: str):
"""Test all functions defined in expr.funcs."""
func = getattr(expr, methodname)
z = func(datum.xxx)
assert repr(z) == f"{veganame}(datum.xxx)"
@pytest.mark.parametrize(("veganame", "fn", "n_params"), _iter_classmethod_specs(expr))
def test_expr_methods(
veganame: str, fn: Callable[..., Expression], n_params: int
) -> None:
datum_names = [f"col_{n}" for n in range(n_params)]
datum_args = ",".join(f"datum.{nm}" for nm in datum_names)

fn_call = fn(*(GetAttrExpression("datum", nm) for nm in datum_names))
assert repr(fn_call) == f"{veganame}({datum_args})"


@pytest.mark.parametrize("constname", _get_property_names(_ConstExpressionType))
@pytest.mark.parametrize("constname", _get_property_names(_ExprMeta))
def test_expr_consts(constname: str):
"""Test all constants defined in expr.consts."""
const = getattr(expr, constname)
z = const * datum.xxx
assert repr(z) == f"({constname} * datum.xxx)"


@pytest.mark.parametrize("constname", _get_property_names(_ConstExpressionType))
@pytest.mark.parametrize("constname", _get_property_names(_ExprMeta))
def test_expr_consts_immutable(constname: str):
"""Ensure e.g `alt.expr.PI = 2` is prevented."""
if sys.version_info >= (3, 11):
Expand Down
14 changes: 10 additions & 4 deletions tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,13 @@ def test_when_labels_position_based_on_condition() -> None:
# `mypy` will flag structural errors here
cond = when["condition"][0]
otherwise = when["value"]
param_color_py_when = alt.param(
expr=alt.expr.if_(cond["test"], cond["value"], otherwise)
)

# TODO: Open an issue on making `OperatorMixin` generic
# Something like this would be used as the return type for all `__dunder__` methods:
# R = TypeVar("R", Expression, SelectionPredicateComposition)
test = cond["test"]
assert not isinstance(test, alt.PredicateComposition)
param_color_py_when = alt.param(expr=alt.expr.if_(test, cond["value"], otherwise))
lhs_param = param_color_py_expr.param
rhs_param = param_color_py_when.param
assert isinstance(lhs_param, alt.VariableParameter)
Expand Down Expand Up @@ -600,7 +604,9 @@ def test_when_expressions_inside_parameters() -> None:
cond = when_then_otherwise["condition"][0]
otherwise = when_then_otherwise["value"]
expected = alt.expr.if_(alt.datum.b >= 0, 10, -20)
actual = alt.expr.if_(cond["test"], cond["value"], otherwise)
test = cond["test"]
assert not isinstance(test, alt.PredicateComposition)
actual = alt.expr.if_(test, cond["value"], otherwise)
assert expected == actual

text_conditioned = bar.mark_text(
Expand Down
9 changes: 8 additions & 1 deletion tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from tools import generate_api_docs, generate_schema_wrapper, schemapi, update_init_file
from tools import (
generate_api_docs,
generate_schema_wrapper,
markup,
schemapi,
update_init_file,
)

__all__ = [
"generate_api_docs",
"generate_schema_wrapper",
"markup",
"schemapi",
"update_init_file",
]
24 changes: 15 additions & 9 deletions tools/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,8 @@
sys.path.insert(0, str(Path.cwd()))


from tools.schemapi import ( # noqa: F401
CodeSnippet,
SchemaInfo,
arg_invalid_kwds,
arg_kwds,
arg_required_kwds,
codegen,
)
from tools.markup import rst_syntax_for_class
from tools.schemapi import CodeSnippet, SchemaInfo, arg_kwds, arg_required_kwds, codegen
from tools.schemapi.utils import (
SchemaProperties,
TypeAliasTracer,
Expand All @@ -37,16 +31,17 @@
import_typing_extensions,
indent_docstring,
resolve_references,
rst_syntax_for_class,
ruff_format_py,
ruff_write_lint_format_str,
spell_literal,
)
from tools.vega_expr import write_expr_module

if TYPE_CHECKING:
from tools.schemapi.codegen import ArgInfo, AttrGetter
from vl_convert import VegaThemes


SCHEMA_VERSION: Final = "v5.20.1"


Expand All @@ -60,8 +55,14 @@
"""

SCHEMA_URL_TEMPLATE: Final = "https://vega.github.io/schema/{library}/{version}.json"
VL_PACKAGE_TEMPLATE = (
"https://raw.githubusercontent.com/vega/vega-lite/refs/tags/{version}/package.json"
)
SCHEMA_FILE = "vega-lite-schema.json"
THEMES_FILE = "vega-themes.json"
EXPR_FILE: Path = (
Path(__file__).parent / ".." / "altair" / "expr" / "__init__.py"
).resolve()

CHANNEL_MYPY_IGNORE_STATEMENTS: Final = """\
# These errors need to be ignored as they come from the overload methods
Expand Down Expand Up @@ -1207,6 +1208,11 @@ def main() -> None:
args = parser.parse_args()
copy_schemapi_util()
vegalite_main(args.skip_download)
write_expr_module(
vlc.get_vega_version(),
output=EXPR_FILE,
header=HEADER_COMMENT,
)

# The modules below are imported after the generation of the new schema files
# as these modules import Altair. This allows them to use the new changes
Expand Down
150 changes: 150 additions & 0 deletions tools/markup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Tools for working with formats like ``.md``, ``.rst``."""

from __future__ import annotations

import re
from html import unescape
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Literal
from urllib import request

import mistune.util
from mistune import InlineParser as _InlineParser
from mistune import Markdown as _Markdown
from mistune.renderers.rst import RSTRenderer as _RSTRenderer

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 11):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from re import Pattern

from mistune import BaseRenderer, BlockParser, BlockState, InlineState

Url: TypeAlias = str

Token: TypeAlias = "dict[str, Any]"

_RE_LINK: Pattern[str] = re.compile(r"(?<=\[)([^\]]+)(?=\]\([^\)]+\))", re.MULTILINE)
_RE_SPECIAL: Pattern[str] = re.compile(r"[*_]{2,3}|`", re.MULTILINE)
_RE_LIQUID_INCLUDE: Pattern[str] = re.compile(r"( \{% include.+%\})")


class RSTRenderer(_RSTRenderer):
def __init__(self) -> None:
super().__init__()

def inline_html(self, token: Token, state: BlockState) -> str:
html = token["raw"]
return rf"\ :raw-html:`{html}`\ "


class RSTParse(_Markdown):
"""
Minor extension to support partial `ast`_ conversion.
Only need to convert the docstring tokens to `.rst`.
.. _ast:
https://mistune.lepture.com/en/latest/guide.html#abstract-syntax-tree
"""

def __init__(
self,
renderer: BaseRenderer | Literal["ast"] | None,
block: BlockParser | None = None,
inline: _InlineParser | None = None,
plugins=None,
) -> None:
if renderer == "ast":
renderer = None
super().__init__(renderer, block, inline, plugins)

def __call__(self, s: str) -> str:
s = super().__call__(s) # pyright: ignore[reportAssignmentType]
return unescape(s).replace(r"\ ,", ",").replace(r"\ ", " ")

def render_tokens(self, tokens: Iterable[Token], /) -> str:
"""
Render ast tokens originating from another parser.
Parameters
----------
tokens
All tokens will be rendered into a single `.rst` string
"""
if self.renderer is None:
msg = "Unable to render tokens without a renderer."
raise TypeError(msg)
state = self.block.state_cls()
s = self.renderer(self._iter_render(tokens, state), state)
return mistune.util.unescape(s)


class RSTParseVegaLite(RSTParse):
def __init__(
self,
renderer: RSTRenderer | None = None,
block: BlockParser | None = None,
inline: _InlineParser | None = None,
plugins=None,
) -> None:
super().__init__(renderer or RSTRenderer(), block, inline, plugins)

def __call__(self, s: str) -> str:
# remove formatting from links
description = "".join(
_RE_SPECIAL.sub("", d) if i % 2 else d
for i, d in enumerate(_RE_LINK.split(s))
)

description = super().__call__(description)
# Some entries in the Vega-Lite schema miss the second occurence of '__'
description = description.replace("__Default value: ", "__Default value:__ ")
# Links to the vega-lite documentation cannot be relative but instead need to
# contain the full URL.
description = description.replace(
"types#datetime", "https://vega.github.io/vega-lite/docs/datetime.html"
)
# Fixing ambiguous unicode, RUF001 produces RUF002 in docs
description = description.replace("’", "'") # noqa: RUF001 [RIGHT SINGLE QUOTATION MARK]
description = description.replace("–", "-") # noqa: RUF001 [EN DASH]
description = description.replace(" ", " ") # noqa: RUF001 [NO-BREAK SPACE]
return description.strip()


class InlineParser(_InlineParser):
def __init__(self, hard_wrap: bool = False) -> None:
super().__init__(hard_wrap)

def process_text(self, text: str, state: InlineState) -> None:
"""
Removes `liquid`_ templating markup.
.. _liquid:
https://shopify.github.io/liquid/
"""
state.append_token({"type": "text", "raw": _RE_LIQUID_INCLUDE.sub(r"", text)})


def read_ast_tokens(source: Url | Path, /) -> list[Token]:
"""
Read from ``source``, drop ``BlockState``.
Factored out to provide accurate typing.
"""
markdown = _Markdown(renderer=None, inline=InlineParser())
if isinstance(source, Path):
tokens = markdown.read(source)
else:
with request.urlopen(source) as response:
s = response.read().decode("utf-8")
tokens = markdown.parse(s, markdown.block.state_cls())
return tokens[0]


def rst_syntax_for_class(class_name: str) -> str:
return f":class:`{class_name}`"
2 changes: 2 additions & 0 deletions tools/schemapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from tools.schemapi.schemapi import SchemaBase, Undefined
from tools.schemapi.utils import OneOrSeq, SchemaInfo
from tools.vega_expr import write_expr_module

__all__ = [
"CodeSnippet",
Expand All @@ -21,4 +22,5 @@
"arg_required_kwds",
"codegen",
"utils",
"write_expr_module",
]
Loading

0 comments on commit 6cb633f

Please sign in to comment.