Skip to content

Commit

Permalink
Type hints: Finish type hints and mark package as typed (#3272)
Browse files Browse the repository at this point in the history
* Type Undefined as UndefinedType. Some minor mypy fixes

* Add entry to changelog

* Make types public so users can use them in their own code if needed

* Add py.typed

* Remove type annotation on inputs to vconcat as too complex to typed. Already removed for layer and hconcat in a previous PR

* Move some changes into code generation files

* Remove unused ignore statement

* Make SupportsGeoInterface and DataFrameLike public
  • Loading branch information
binste authored Nov 23, 2023
1 parent 1e17571 commit dd5b61a
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 72 deletions.
3 changes: 3 additions & 0 deletions altair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"CalculateTransform",
"Categorical",
"Chart",
"ChartDataType",
"Color",
"ColorDatum",
"ColorDef",
Expand Down Expand Up @@ -125,7 +126,9 @@
"Cyclical",
"Data",
"DataFormat",
"DataFrameLike",
"DataSource",
"DataType",
"Datasets",
"DateTime",
"DatumChannelMixin",
Expand Down
Empty file added altair/py.typed
Empty file.
6 changes: 3 additions & 3 deletions altair/utils/_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
data_transformers,
)
from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion
from altair.utils.core import _DataFrameLike
from altair.utils.core import DataFrameLike
from altair.utils.schemapi import Undefined

Scope = Tuple[int, ...]
Expand Down Expand Up @@ -56,7 +56,7 @@ def transformed_data(
chart: Union[Chart, FacetChart],
row_limit: Optional[int] = None,
exclude: Optional[Iterable[str]] = None,
) -> Optional[_DataFrameLike]:
) -> Optional[DataFrameLike]:
...


Expand All @@ -65,7 +65,7 @@ def transformed_data(
chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart],
row_limit: Optional[int] = None,
exclude: Optional[Iterable[str]] = None,
) -> List[_DataFrameLike]:
) -> List[DataFrameLike]:
...


Expand Down
12 changes: 6 additions & 6 deletions altair/utils/_vegafusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from typing import TypedDict, Final

from altair.utils._importers import import_vegafusion
from altair.utils.core import _DataFrameLike
from altair.utils.data import _DataType, _ToValuesReturnType, MaxRowsError
from altair.utils.core import DataFrameLike
from altair.utils.data import DataType, ToValuesReturnType, MaxRowsError
from altair.vegalite.data import default_data_transformer

# Temporary storage for dataframes that have been extracted
# from charts by the vegafusion data transformer. Use a WeakValueDictionary
# rather than a dict so that the Python interpreter is free to garbage
# collect the stored DataFrames.
extracted_inline_tables: MutableMapping[str, _DataFrameLike] = WeakValueDictionary()
extracted_inline_tables: MutableMapping[str, DataFrameLike] = WeakValueDictionary()

# Special URL prefix that VegaFusion uses to denote that a
# dataset in a Vega spec corresponds to an entry in the `inline_datasets`
Expand All @@ -29,8 +29,8 @@ class _ToVegaFusionReturnUrlDict(TypedDict):

@curried.curry
def vegafusion_data_transformer(
data: _DataType, max_rows: int = 100000
) -> Union[_ToVegaFusionReturnUrlDict, _ToValuesReturnType]:
data: DataType, max_rows: int = 100000
) -> Union[_ToVegaFusionReturnUrlDict, ToValuesReturnType]:
"""VegaFusion Data Transformer"""
if hasattr(data, "__geo_interface__"):
# Use default transformer for geo interface objects
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_inline_table_names(vega_spec: dict) -> Set[str]:
return table_names


def get_inline_tables(vega_spec: dict) -> Dict[str, _DataFrameLike]:
def get_inline_tables(vega_spec: dict) -> Dict[str, DataFrameLike]:
"""Get the inline tables referenced by a Vega specification
Note: This function should only be called on a Vega spec that corresponds
Expand Down
18 changes: 9 additions & 9 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
if TYPE_CHECKING:
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn

_V = TypeVar("_V")
_P = ParamSpec("_P")
V = TypeVar("V")
P = ParamSpec("P")


class _DataFrameLike(Protocol):
class DataFrameLike(Protocol):
def __dataframe__(self, *args, **kwargs) -> DfiDataFrame:
...

Expand Down Expand Up @@ -188,12 +188,12 @@ def __dataframe__(self, *args, **kwargs) -> DfiDataFrame:
]


_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]
InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]


def infer_vegalite_type(
data: object,
) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]:
) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]:
"""
From an array-like input, infer the correct vega typecode
('ordinal', 'nominal', 'quantitative', or 'temporal')
Expand Down Expand Up @@ -442,7 +442,7 @@ def sanitize_arrow_table(pa_table):

def parse_shorthand(
shorthand: Union[Dict[str, Any], str],
data: Optional[Union[pd.DataFrame, _DataFrameLike]] = None,
data: Optional[Union[pd.DataFrame, DataFrameLike]] = None,
parse_aggregates: bool = True,
parse_window_ops: bool = False,
parse_timeunits: bool = True,
Expand Down Expand Up @@ -637,7 +637,7 @@ def parse_shorthand(

def infer_vegalite_type_for_dfi_column(
column: Union[Column, "PandasColumn"],
) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]:
) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]:
from pyarrow.interchange.from_dataframe import column_to_array

try:
Expand Down Expand Up @@ -672,10 +672,10 @@ def infer_vegalite_type_for_dfi_column(
raise ValueError(f"Unexpected DtypeKind: {kind}")


def use_signature(Obj: Callable[_P, Any]):
def use_signature(Obj: Callable[P, Any]):
"""Apply call signature and documentation of Obj to the decorated method"""

def decorate(f: Callable[..., _V]) -> Callable[_P, _V]:
def decorate(f: Callable[..., V]) -> Callable[P, V]:
# call-signature of f is exposed via __wrapped__.
# we want it to mimic Obj.__init__
f.__wrapped__ = Obj.__init__ # type: ignore
Expand Down
30 changes: 15 additions & 15 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TypeVar

from ._importers import import_pyarrow_interchange
from .core import sanitize_dataframe, sanitize_arrow_table, _DataFrameLike
from .core import sanitize_dataframe, sanitize_arrow_table, DataFrameLike
from .core import sanitize_geo_interface
from .deprecation import AltairDeprecationWarning
from .plugin_registry import PluginRegistry
Expand All @@ -23,15 +23,15 @@
import pyarrow.lib


class _SupportsGeoInterface(Protocol):
class SupportsGeoInterface(Protocol):
__geo_interface__: MutableMapping


_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _DataFrameLike]
_TDataType = TypeVar("_TDataType", bound=_DataType)
DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, DataFrameLike]
TDataType = TypeVar("TDataType", bound=DataType)

_VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]]
_ToValuesReturnType = Dict[str, Union[dict, List[dict]]]
VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]]
ToValuesReturnType = Dict[str, Union[dict, List[dict]]]


# ==============================================================================
Expand All @@ -46,7 +46,7 @@ class _SupportsGeoInterface(Protocol):
# form.
# ==============================================================================
class DataTransformerType(Protocol):
def __call__(self, data: _DataType, **kwargs) -> _VegaLiteDataDict:
def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict:
pass


Expand All @@ -70,7 +70,7 @@ class MaxRowsError(Exception):


@curried.curry
def limit_rows(data: _TDataType, max_rows: Optional[int] = 5000) -> _TDataType:
def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType:
"""Raise MaxRowsError if the data model has more than max_rows.
If max_rows is None, then do not perform any check.
Expand Down Expand Up @@ -122,7 +122,7 @@ def raise_max_rows_error():

@curried.curry
def sample(
data: _DataType, n: Optional[int] = None, frac: Optional[float] = None
data: DataType, n: Optional[int] = None, frac: Optional[float] = None
) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]:
"""Reduce the size of the data model by sampling without replacement."""
check_data_type(data)
Expand Down Expand Up @@ -180,7 +180,7 @@ class _ToCsvReturnUrlDict(TypedDict):

@curried.curry
def to_json(
data: _DataType,
data: DataType,
prefix: str = "altair-data",
extension: str = "json",
filename: str = "{prefix}-{hash}.{extension}",
Expand All @@ -199,7 +199,7 @@ def to_json(

@curried.curry
def to_csv(
data: Union[dict, pd.DataFrame, _DataFrameLike],
data: Union[dict, pd.DataFrame, DataFrameLike],
prefix: str = "altair-data",
extension: str = "csv",
filename: str = "{prefix}-{hash}.{extension}",
Expand All @@ -215,7 +215,7 @@ def to_csv(


@curried.curry
def to_values(data: _DataType) -> _ToValuesReturnType:
def to_values(data: DataType) -> ToValuesReturnType:
"""Replace a DataFrame by a data model with values."""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
Expand All @@ -242,7 +242,7 @@ def to_values(data: _DataType) -> _ToValuesReturnType:
raise ValueError("Unrecognized data type: {}".format(type(data)))


def check_data_type(data: _DataType) -> None:
def check_data_type(data: DataType) -> None:
if not isinstance(data, (dict, pd.DataFrame)) and not any(
hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"]
):
Expand All @@ -260,7 +260,7 @@ def _compute_data_hash(data_str: str) -> str:
return hashlib.md5(data_str.encode()).hexdigest()


def _data_to_json_string(data: _DataType) -> str:
def _data_to_json_string(data: DataType) -> str:
"""Return a JSON string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
Expand Down Expand Up @@ -288,7 +288,7 @@ def _data_to_json_string(data: _DataType) -> str:
)


def _data_to_csv_string(data: Union[dict, pd.DataFrame, _DataFrameLike]) -> str:
def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str:
"""return a CSV string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
Expand Down
10 changes: 3 additions & 7 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
else:
from typing_extensions import Self

_TSchemaBase = TypeVar("_TSchemaBase", bound=Type["SchemaBase"])
TSchemaBase = TypeVar("TSchemaBase", bound=Type["SchemaBase"])

ValidationErrorList = List[jsonschema.exceptions.ValidationError]
GroupedValidationErrors = Dict[str, ValidationErrorList]
Expand Down Expand Up @@ -733,11 +733,7 @@ def __repr__(self):
return "Undefined"


# In the future Altair may implement a more complete set of type hints.
# But for now, we'll add an annotation to indicate that the type checker
# should permit any value passed to a function argument whose default
# value is Undefined.
Undefined: Any = UndefinedType()
Undefined = UndefinedType()


class SchemaBase:
Expand Down Expand Up @@ -1329,7 +1325,7 @@ def __call__(self, *args, **kwargs):
return obj


def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase:
def with_property_setters(cls: TSchemaBase) -> TSchemaBase:
"""
Decorator to add property setters to a Schema class.
"""
Expand Down
6 changes: 3 additions & 3 deletions altair/vegalite/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
check_data_type,
)
from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry
from ..utils.data import _DataType, _ToValuesReturnType
from ..utils.data import DataType, ToValuesReturnType
from ..utils.plugin_registry import PluginEnabler


@curried.curry
def default_data_transformer(
data: _DataType, max_rows: int = 5000
) -> _ToValuesReturnType:
data: DataType, max_rows: int = 5000
) -> ToValuesReturnType:
return curried.pipe(data, limit_rows(max_rows=max_rows), to_values)


Expand Down
Loading

0 comments on commit dd5b61a

Please sign in to comment.