Skip to content

Commit

Permalink
perf: Fix issues with Chart|LayerChart.encode, 1.32x speedup to `in…
Browse files Browse the repository at this point in the history
…fer_encoding_types` (#3444)

* fix, doc, perf: Fix issues with `Chart|LayerChart.encode`, 1.33x speedup to `infer_encoding_types`

Fixes:
- [Sphinx warning](https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html#altair.Chart) on `Chart.encode`. Also incorrectly under `Attributes` section
- Preserve static typing previously found in `_encode_signature` but lost after `_EncodingMixin.encode`
  - Re-running `mypy` output 'Found 63 errors in 47 files (checked 360 source files)', tests/examples

Perf:
- This was a response to the `TODO` left at the top of `infer_encoding_types`
- Will be adding the benchmark to the PR description

* fix(typing): Resolve assignment type errors revealed

Incompatible types in assignment (expression has type "Chart", variable has type "DataFrame")

* fix(typing): Resolve direct arg-type errors revealed

`Color` -> `Fill` when passed to `fill` channel

* fix(typing): Resolve `alt.condition` overload-related arg-type errors revealed

'error: Argument "color" to "encode" of "_EncodingMixin" has incompatible type "dict[Any, Any] | SchemaBase"; expected "str | Color | dict[Any, Any] | ColorDatum | ColorValue | UndefinedType"  [arg-type]'

* test: update `infer_encoding_types` tests

- New implementation does not use `**kwargs`, which eliminates an entire class of tests based on `.encode(invalidChannel=...)` as these now trigger a runtime error

* test: Rename `invalidChannel` to `invalidArgument`

Fixes https://github.com/vega/altair/pull/3444/files/e4ab7052e9a1e62ff1fd80379864489c69a1e020#r1657008627

* chore: remove PR note comment

Fixes #3444 (comment)

* docs: fix typo

* Exclude LookupData export from core.py to fix issue with mypy where it assumes that altair.LookupData comes from core.py instead of api.py

* Remove 'pd' and 'jsonschema' from __init__.py __all__. Unclear why they show up only now...

* Format code

---------

Co-authored-by: Stefan Binder <[email protected]>
Co-authored-by: Mattijn van Hoek <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent 79783ef commit c82f8c2
Show file tree
Hide file tree
Showing 13 changed files with 568 additions and 457 deletions.
203 changes: 140 additions & 63 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@
Callable,
TypeVar,
Any,
Sequence,
Iterator,
cast,
Literal,
Protocol,
TYPE_CHECKING,
runtime_checkable,
)
from itertools import groupby
from operator import itemgetter

import jsonschema
import pandas as pd
import numpy as np
from pandas.api.types import infer_dtype

from altair.utils.schemapi import SchemaBase
from altair.utils.schemapi import SchemaBase, Undefined
from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -773,9 +775,133 @@ def display_traceback(in_ipython: bool = True):
traceback.print_exception(*exc_info)


_ChannelType = Literal["field", "datum", "value"]
_CHANNEL_CACHE: _ChannelCache
"""Singleton `_ChannelCache` instance.
Initialized on first use.
"""


class _ChannelCache:
channel_to_name: dict[type[SchemaBase], str]
name_to_channel: dict[str, dict[_ChannelType, type[SchemaBase]]]

@classmethod
def from_channels(cls, channels: ModuleType, /) -> _ChannelCache:
# - This branch is only kept for tests that depend on mocking `channels`.
# - No longer needs to pass around `channels` reference and rebuild every call.
c_to_n = {
c: c._encoding_name
for c in channels.__dict__.values()
if isinstance(c, type)
and issubclass(c, SchemaBase)
and hasattr(c, "_encoding_name")
}
self = cls.__new__(cls)
self.channel_to_name = c_to_n
self.name_to_channel = _invert_group_channels(c_to_n)
return self

@classmethod
def from_cache(cls) -> _ChannelCache:
global _CHANNEL_CACHE
try:
cached = _CHANNEL_CACHE
except NameError:
cached = cls.__new__(cls)
cached.channel_to_name = _init_channel_to_name()
cached.name_to_channel = _invert_group_channels(cached.channel_to_name)
_CHANNEL_CACHE = cached
return _CHANNEL_CACHE

def get_encoding(self, tp: type[Any], /) -> str:
if encoding := self.channel_to_name.get(tp):
return encoding
msg = f"positional of type {type(tp).__name__!r}"
raise NotImplementedError(msg)

def _wrap_in_channel(self, obj: Any, encoding: str, /):
if isinstance(obj, SchemaBase):
return obj
elif isinstance(obj, str):
obj = {"shorthand": obj}
elif isinstance(obj, (list, tuple)):
return [self._wrap_in_channel(el, encoding) for el in obj]
if channel := self.name_to_channel.get(encoding):
tp = channel["value" if "value" in obj else "field"]
try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return tp.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj
else:
warnings.warn(f"Unrecognized encoding channel {encoding!r}", stacklevel=1)
return obj

def infer_encoding_types(self, kwargs: dict[str, Any], /):
return {
encoding: self._wrap_in_channel(obj, encoding)
for encoding, obj in kwargs.items()
if obj is not Undefined
}


def _init_channel_to_name():
"""
Construct a dictionary of channel type to encoding name.
Note
----
The return type is not expressible using annotations, but is used
internally by `mypy`/`pyright` and avoids the need for type ignores.
Returns
-------
mapping: dict[type[`<subclass of FieldChannelMixin and SchemaBase>`] | type[`<subclass of ValueChannelMixin and SchemaBase>`] | type[`<subclass of DatumChannelMixin and SchemaBase>`], str]
"""
from altair.vegalite.v5.schema import channels as ch

mixins = ch.FieldChannelMixin, ch.ValueChannelMixin, ch.DatumChannelMixin

return {
c: c._encoding_name
for c in ch.__dict__.values()
if isinstance(c, type) and issubclass(c, mixins) and issubclass(c, SchemaBase)
}


def _invert_group_channels(
m: dict[type[SchemaBase], str], /
) -> dict[str, dict[_ChannelType, type[SchemaBase]]]:
"""Grouped inverted index for `_ChannelCache.channel_to_name`."""

def _reduce(it: Iterator[tuple[type[Any], str]]) -> Any:
"""Returns a 1-2 item dict, per channel.
Never includes `datum`, as it is never utilized in `wrap_in_channel`.
"""
item: dict[Any, type[SchemaBase]] = {}
for tp, _ in it:
name = tp.__name__
if name.endswith("Datum"):
continue
elif name.endswith("Value"):
sub_key = "value"
else:
sub_key = "field"
item[sub_key] = tp
return item

grouper = groupby(m.items(), itemgetter(1))
return {k: _reduce(chans) for k, chans in grouper}


def infer_encoding_types(
args: Sequence[Any], kwargs: t.MutableMapping[str, Any], channels: ModuleType
) -> dict[str, SchemaBase | list | dict[str, str] | Any]:
args: tuple[Any, ...], kwargs: dict[str, Any], channels: ModuleType | None = None
):
"""Infer typed keyword arguments for args and kwargs
Parameters
Expand All @@ -793,68 +919,19 @@ def infer_encoding_types(
All args and kwargs in a single dict, with keys and types
based on the channels mapping.
"""
# Construct a dictionary of channel type to encoding name
# TODO: cache this somehow?
channel_objs = (getattr(channels, name) for name in dir(channels))
channel_objs = (
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
cache = (
_ChannelCache.from_channels(channels)
if channels
else _ChannelCache.from_cache()
)
channel_to_name: dict[type[SchemaBase], str] = {
c: c._encoding_name for c in channel_objs
}
name_to_channel: dict[str, dict[str, type[SchemaBase]]] = {}
for chan, name in channel_to_name.items():
chans = name_to_channel.setdefault(name, {})
if chan.__name__.endswith("Datum"):
key = "datum"
elif chan.__name__.endswith("Value"):
key = "value"
else:
key = "field"
chans[key] = chan

# First use the mapping to convert args to kwargs based on their types.
for arg in args:
if isinstance(arg, (list, tuple)) and len(arg) > 0:
type_ = type(arg[0])
el = next(iter(arg), None) if isinstance(arg, (list, tuple)) else arg
encoding = cache.get_encoding(type(el))
if encoding not in kwargs:
kwargs[encoding] = arg
else:
type_ = type(arg)

encoding = channel_to_name.get(type_)
if encoding is None:
msg = f"positional of type {type_}" ""
raise NotImplementedError(msg)
if encoding in kwargs:
msg = f"encoding {encoding} specified twice."
msg = f"encoding {encoding!r} specified twice."
raise ValueError(msg)
kwargs[encoding] = arg

def _wrap_in_channel_class(obj, encoding):
if isinstance(obj, SchemaBase):
return obj

if isinstance(obj, str):
obj = {"shorthand": obj}

if isinstance(obj, (list, tuple)):
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj]

if encoding not in name_to_channel:
warnings.warn(f"Unrecognized encoding channel '{encoding}'", stacklevel=1)
return obj

classes = name_to_channel[encoding]
cls = classes["value"] if "value" in obj else classes["field"]

try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return cls.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj

return {
encoding: _wrap_in_channel_class(obj, encoding)
for encoding, obj in kwargs.items()
}
return cache.infer_encoding_types(kwargs)
76 changes: 43 additions & 33 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
from typing import Union, cast, Any, Iterable, Literal, IO, TYPE_CHECKING
from typing_extensions import TypeAlias
import typing

from .schema import core, channels, mixins, Undefined, SCHEMA_URL

Expand Down Expand Up @@ -74,8 +75,6 @@
Step,
RepeatRef,
NonNormalizedSpec,
LayerSpec,
UnitSpec,
UrlData,
SequenceGenerator,
GraticuleGenerator,
Expand Down Expand Up @@ -381,6 +380,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
return False


_TestPredicateType = Union[str, _expr_core.Expression, core.PredicateComposition]
_PredicateType = Union[
Parameter,
core.Expr,
typing.Dict[str, Any],
_TestPredicateType,
_expr_core.OperatorMixin,
]
_ConditionType = typing.Dict[str, Union[_TestPredicateType, Any]]
_DictOrStr = Union[typing.Dict[str, Any], str]
_DictOrSchema = Union[core.SchemaBase, typing.Dict[str, Any]]
_StatementType = Union[core.SchemaBase, _DictOrStr]

# ------------------------------------------------------------------------
# Top-Level Functions

Expand Down Expand Up @@ -826,18 +838,33 @@ def binding_range(**kwargs):
return core.BindRange(input="range", **kwargs)


_TSchemaBase = typing.TypeVar("_TSchemaBase", bound=core.SchemaBase)


@typing.overload
def condition(
predicate: _PredicateType, if_true: _StatementType, if_false: _TSchemaBase, **kwargs
) -> _TSchemaBase: ...
@typing.overload
def condition(
predicate: _PredicateType, if_true: str, if_false: str, **kwargs
) -> typing.NoReturn: ...
@typing.overload
def condition(
predicate: _PredicateType, if_true: _DictOrSchema, if_false: _DictOrStr, **kwargs
) -> dict[str, _ConditionType | Any]: ...
@typing.overload
def condition(
predicate: _PredicateType,
if_true: _DictOrStr,
if_false: dict[str, Any],
**kwargs,
) -> dict[str, _ConditionType | Any]: ...
# TODO: update the docstring
def condition(
predicate: Parameter
| str
| Expression
| Expr
| PredicateComposition
| dict[str, Any],
# Types of these depends on where the condition is used so we probably
# can't be more specific here.
if_true: Any,
if_false: Any,
predicate: _PredicateType,
if_true: _StatementType,
if_false: _StatementType,
**kwargs,
) -> dict[str, Any] | SchemaBase:
"""A conditional attribute or encoding
Expand Down Expand Up @@ -2729,24 +2756,7 @@ def resolve_scale(self, *args, **kwargs) -> Self:
return self._set_resolve(scale=core.ScaleResolveMap(*args, **kwargs))


class _EncodingMixin:
@utils.use_signature(channels._encode_signature)
def encode(self, *args, **kwargs) -> Self:
# Convert args to kwargs based on their types.
kwargs = utils.infer_encoding_types(args, kwargs, channels)

# get a copy of the dict representation of the previous encoding
# ignore type as copy method comes from SchemaBase
copy = self.copy(deep=["encoding"]) # type: ignore[attr-defined]
encoding = copy._get("encoding", {})
if isinstance(encoding, core.VegaLiteSchema):
encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined}

# update with the new encodings, and apply them to the copy
encoding.update(kwargs)
copy.encoding = core.FacetedEncoding(**encoding)
return copy

class _EncodingMixin(channels._EncodingMixin):
def facet(
self,
facet: Optional[str | Facet] = Undefined,
Expand Down Expand Up @@ -3614,20 +3624,20 @@ def transformed_data(

return transformed_data(self, row_limit=row_limit, exclude=exclude)

def __iadd__(self, other: LayerSpec | UnitSpec) -> Self:
def __iadd__(self, other: LayerChart | Chart) -> Self:
_check_if_valid_subspec(other, "LayerChart")
_check_if_can_be_layered(other)
self.layer.append(other)
self.data, self.layer = _combine_subchart_data(self.data, self.layer)
self.params, self.layer = _combine_subchart_params(self.params, self.layer)
return self

def __add__(self, other: LayerSpec | UnitSpec) -> Self:
def __add__(self, other: LayerChart | Chart) -> Self:
copy = self.copy(deep=["layer"])
copy += other
return copy

def add_layers(self, *layers: LayerSpec | UnitSpec) -> Self:
def add_layers(self, *layers: LayerChart | Chart) -> Self:
copy = self.copy(deep=["layer"])
for layer in layers:
copy += layer
Expand Down
Loading

0 comments on commit c82f8c2

Please sign in to comment.