Skip to content

Commit

Permalink
feat(typing): Generate Literal(s) using "const" (#3538)
Browse files Browse the repository at this point in the history
* feat(typing): Generate `Literal`(s) using `"const"`

Previously only `"enum"` was considered, meaning `"const"` was represented as `str` - rather than a **single** valid `str` value.

* build: run `generate-schema-wrapper`

* fix(typing): Update `Chart` annotations to match now narrower schema

* fix: Ensure rewrapped `Literal`(s) order is deterministic

https://github.com/dangotbanned/altair/actions/runs/10404360936/job/28812668479

* chore: Dummy to retrigger action
  • Loading branch information
dangotbanned authored Aug 15, 2024
1 parent f8eb594 commit fa3c4c5
Show file tree
Hide file tree
Showing 7 changed files with 919 additions and 685 deletions.
8 changes: 5 additions & 3 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@
AggregateOp_T,
AutosizeType_T,
ColorName_T,
CompositeMark_T,
ImputeMethod_T,
LayoutAlign_T,
Mark_T,
MultiTimeUnit_T,
OneOrSeq,
ProjectionType_T,
Expand Down Expand Up @@ -3662,9 +3664,9 @@ def __init__(
self,
data: Optional[ChartDataType] = Undefined,
encoding: Optional[FacetedEncoding] = Undefined,
mark: Optional[str | AnyMark] = Undefined,
width: Optional[int | str | dict | Step] = Undefined,
height: Optional[int | str | dict | Step] = Undefined,
mark: Optional[AnyMark | Mark_T | CompositeMark_T] = Undefined,
width: Optional[int | dict | Step | Literal["container"]] = Undefined,
height: Optional[int | dict | Step | Literal["container"]] = Undefined,
**kwargs: Any,
) -> None:
# Data type hints won't match with what TopLevelUnitSpec expects
Expand Down
14 changes: 12 additions & 2 deletions altair/vegalite/v5/schema/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
"AllSortString_T",
"AutosizeType_T",
"AxisOrient_T",
"Baseline_T",
"BinnedTimeUnit_T",
"Blend_T",
"BoxPlot_T",
"ColorName_T",
"ColorScheme_T",
"CompositeMark_T",
"Cursor_T",
"ErrorBand_T",
"ErrorBarExtent_T",
"ErrorBar_T",
"FontWeight_T",
"ImputeMethod_T",
"Interpolate_T",
Expand Down Expand Up @@ -52,6 +55,7 @@
"StandardType_T",
"StrokeCap_T",
"StrokeJoin_T",
"TextBaseline_T",
"TextDirection_T",
"TimeInterval_T",
"TitleAnchor_T",
Expand Down Expand Up @@ -157,7 +161,6 @@ def func(
]
AutosizeType_T: TypeAlias = Literal["pad", "none", "fit", "fit-x", "fit-y"]
AxisOrient_T: TypeAlias = Literal["top", "bottom", "left", "right"]
Baseline_T: TypeAlias = Literal["top", "middle", "bottom"]
BinnedTimeUnit_T: TypeAlias = Literal[
"binnedyear",
"binnedyearquarter",
Expand Down Expand Up @@ -206,6 +209,7 @@ def func(
"color",
"luminosity",
]
BoxPlot_T: TypeAlias = Literal["boxplot"]
ColorName_T: TypeAlias = Literal[
"black",
"silver",
Expand Down Expand Up @@ -690,6 +694,7 @@ def func(
"rainbow",
"sinebow",
]
CompositeMark_T: TypeAlias = Literal["boxplot", "errorbar", "errorband"]
Cursor_T: TypeAlias = Literal[
"auto",
"default",
Expand Down Expand Up @@ -728,7 +733,9 @@ def func(
"grab",
"grabbing",
]
ErrorBand_T: TypeAlias = Literal["errorband"]
ErrorBarExtent_T: TypeAlias = Literal["ci", "iqr", "stderr", "stdev"]
ErrorBar_T: TypeAlias = Literal["errorbar"]
FontWeight_T: TypeAlias = Literal[
"normal", "bold", "lighter", "bolder", 100, 200, 300, 400, 500, 600, 700, 800, 900
]
Expand Down Expand Up @@ -1064,6 +1071,9 @@ def func(
StandardType_T: TypeAlias = Literal["quantitative", "ordinal", "temporal", "nominal"]
StrokeCap_T: TypeAlias = Literal["butt", "round", "square"]
StrokeJoin_T: TypeAlias = Literal["miter", "round", "bevel"]
TextBaseline_T: TypeAlias = Literal[
"alphabetic", "top", "middle", "bottom", "line-top", "line-bottom"
]
TextDirection_T: TypeAlias = Literal["ltr", "rtl"]
TimeInterval_T: TypeAlias = Literal[
"millisecond", "second", "minute", "hour", "day", "week", "month", "year"
Expand Down
446 changes: 238 additions & 208 deletions altair/vegalite/v5/schema/channels.py

Large diffs are not rendered by default.

746 changes: 403 additions & 343 deletions altair/vegalite/v5/schema/core.py

Large diffs are not rendered by default.

312 changes: 198 additions & 114 deletions altair/vegalite/v5/schema/mixins.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tools/schemapi/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ def get_args(self, si: SchemaInfo) -> list[str]:
altair_class_name = item_si.title
item_type = f"core.{altair_class_name}"
py_type = f"List[{item_type}]"
elif si.is_enum():
elif si.is_literal():
# If it's an enum, we can type hint it as a Literal which tells
# a type checker that only the values in enum are acceptable
py_type = TypeAliasTracer.add_literal(
si, spell_literal(si.enum), replace=True
si, spell_literal(si.literal), replace=True
)
contents.append(f"_: {py_type}")

Expand Down
74 changes: 61 additions & 13 deletions tools/schemapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def add_literal(
tp = alias
elif (alias := self._literals_invert.get(tp)) and replace:
tp = alias
elif replace and info.is_union_enum():
elif replace and info.is_union_literal():
# Handles one very specific edge case `WindowFieldDef`
# - Has an anonymous enum union
# - One of the members is declared afterwards
# - SchemaBase needs to be first, as the union wont be internally sorted
it = (
self.add_literal(el, spell_literal(el.enum), replace=True)
self.add_literal(el, spell_literal(el.literal), replace=True)
for el in info.anyOf
)
tp = f"Union[SchemaBase, {', '.join(it)}]"
Expand Down Expand Up @@ -420,13 +420,13 @@ def get_python_type_representation( # noqa: C901

if self.is_empty():
type_representations.append("Any")
elif self.is_enum():
tp_str = spell_literal(self.enum)
elif self.is_literal():
tp_str = spell_literal(self.literal)
if for_type_hints:
tp_str = TypeAliasTracer.add_literal(self, tp_str, replace=True)
type_representations.append(tp_str)
elif for_type_hints and self.is_union_enum():
it = chain.from_iterable(el.enum for el in self.anyOf)
elif for_type_hints and self.is_union_literal():
it = chain.from_iterable(el.literal for el in self.anyOf)
tp_str = TypeAliasTracer.add_literal(self, spell_literal(it), replace=True)
type_representations.append(tp_str)
elif self.is_anyOf():
Expand All @@ -436,7 +436,7 @@ def get_python_type_representation( # noqa: C901
)
for s in self.anyOf
)
type_representations.extend(it)
type_representations.extend(maybe_rewrap_literal(chain.from_iterable(it)))
elif isinstance(self.type, list):
options = []
subschema = SchemaInfo(dict(**self.schema))
Expand Down Expand Up @@ -557,9 +557,17 @@ def items(self) -> dict:
return self.schema.get("items", {})

@property
def enum(self) -> list:
def enum(self) -> list[str]:
return self.schema.get("enum", [])

@property
def const(self) -> str:
return self.schema.get("const", "")

@property
def literal(self) -> list[str]:
return self.schema.get("enum", [self.const])

@property
def refname(self) -> str:
return self.raw_schema.get("$ref", "#/").split("/")[-1]
Expand Down Expand Up @@ -598,6 +606,12 @@ def is_reference(self) -> bool:
def is_enum(self) -> bool:
return "enum" in self.schema

def is_const(self) -> bool:
return "const" in self.schema

def is_literal(self) -> bool:
return not ({"enum", "const"}.isdisjoint(self.schema))

def is_empty(self) -> bool:
return not (set(self.schema.keys()) - set(EXCLUDE_KEYS))

Expand Down Expand Up @@ -646,13 +660,13 @@ def is_union(self) -> bool:
"""
return self.is_anyOf() and self.type is None

def is_union_enum(self) -> bool:
def is_union_literal(self) -> bool:
"""
Candidate for reducing to a single ``Literal`` alias.
E.g. `BinnedTimeUnit`
"""
return self.is_union() and all(el.is_enum() for el in self.anyOf)
return self.is_union() and all(el.is_literal() for el in self.anyOf)


class RSTRenderer(_RSTRenderer):
Expand Down Expand Up @@ -792,9 +806,43 @@ def flatten(container: Iterable) -> Iterable:
yield i


def spell_literal(it: Iterable[str], /) -> str:
s = ", ".join(f"{s!r}" for s in it)
return f"Literal[{s}]"
def spell_literal(it: Iterable[str], /, *, quote: bool = True) -> str:
"""
Combine individual ``str`` type reprs into a single ``Literal``.
Parameters
----------
it
Type representations.
quote
Call ``repr()`` on each element in ``it``.
.. note::
Set to ``False`` if performing a second pass.
"""
it_el: Iterable[str] = (f"{s!r}" for s in it) if quote else it
return f"Literal[{', '.join(it_el)}]"


def maybe_rewrap_literal(it: Iterable[str], /) -> Iterator[str]:
"""
Where `it` may contain one or more `"enum"`, `"const"`, flatten to a single `Literal[...]`.
All other type representations are yielded unchanged.
"""
seen: set[str] = set()
for s in it:
if s.startswith("Literal["):
seen.add(unwrap_literal(s))
else:
yield s
if seen:
yield spell_literal(sorted(seen), quote=False)


def unwrap_literal(tp: str, /) -> str:
"""`"Literal['value']"` -> `"value"`."""
return re.sub(r"Literal\[(.+)\]", r"\g<1>", tp)


def ruff_format_str(code: str | list[str]) -> str:
Expand Down

0 comments on commit fa3c4c5

Please sign in to comment.