Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hint schemapi.py #3142

Merged
merged 3 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 113 additions & 33 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import inspect
import json
import sys
import textwrap
from typing import (
Any,
Expand All @@ -15,6 +16,11 @@
Tuple,
Iterable,
Type,
Generator,
Union,
overload,
Literal,
TypeVar,
)
from itertools import zip_longest

Expand All @@ -26,6 +32,13 @@

from altair import vegalite

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

_TSchemaBase = TypeVar("_TSchemaBase", bound="SchemaBase")

ValidationErrorList = List[jsonschema.exceptions.ValidationError]
GroupedValidationErrors = Dict[str, ValidationErrorList]

Expand All @@ -35,21 +48,21 @@
# larger specs, but leads to much more useful tracebacks for the user.
# Individual schema classes can override this by setting the
# class-level _class_is_valid_at_instantiation attribute to False
DEBUG_MODE = True
DEBUG_MODE: bool = True


def enable_debug_mode():
def enable_debug_mode() -> None:
global DEBUG_MODE
DEBUG_MODE = True


def disable_debug_mode():
def disable_debug_mode() -> None:
global DEBUG_MODE
DEBUG_MODE = False


@contextlib.contextmanager
def debug_mode(arg):
def debug_mode(arg: bool) -> Generator[None, None, None]:
global DEBUG_MODE
original = DEBUG_MODE
DEBUG_MODE = arg
Expand All @@ -59,12 +72,35 @@ def debug_mode(arg):
DEBUG_MODE = original


@overload
def validate_jsonschema(
spec: Dict[str, Any],
schema: Dict[str, Any],
rootschema: Optional[Dict[str, Any]] = None,
raise_error: bool = True,
rootschema: Optional[Dict[str, Any]] = ...,
*,
raise_error: Literal[True] = ...,
) -> None:
...


@overload
def validate_jsonschema(
spec: Dict[str, Any],
schema: Dict[str, Any],
rootschema: Optional[Dict[str, Any]] = ...,
*,
raise_error: Literal[False],
) -> Optional[jsonschema.exceptions.ValidationError]:
...


def validate_jsonschema(
spec,
schema,
rootschema=None,
*,
raise_error=True,
):
"""Validates the passed in spec against the schema in the context of the
rootschema. If any errors are found, they are deduplicated and prioritized
and only the most relevant errors are kept. Errors are then either raised
Expand All @@ -85,7 +121,7 @@ def validate_jsonschema(
# error message. Setting a new attribute like this is not ideal as
# it then no longer matches the type ValidationError. It would be better
# to refactor this function to never raise but only return errors.
main_error._all_errors = grouped_errors # type: ignore[attr-defined]
main_error._all_errors = grouped_errors
if raise_error:
raise main_error
else:
Expand Down Expand Up @@ -319,7 +355,7 @@ def _deduplicate_by_message(errors: ValidationErrorList) -> ValidationErrorList:
return list({e.message: e for e in errors}.values())


def _subclasses(cls):
def _subclasses(cls: type) -> Generator[type, None, None]:
"""Breadth-first sequence of all classes which inherit from cls."""
seen = set()
current_set = {cls}
Expand All @@ -330,7 +366,7 @@ def _subclasses(cls):
yield cls


def _todict(obj, context):
def _todict(obj: Any, context: Optional[Dict[str, Any]]) -> Any:
"""Convert an object to a dict representation."""
if isinstance(obj, SchemaBase):
return obj.to_dict(validate=False, context=context)
Expand All @@ -348,7 +384,7 @@ def _todict(obj, context):
return obj


def _resolve_references(schema, root=None):
def _resolve_references(schema: dict, root: Optional[dict] = None) -> dict:
"""Resolve schema references."""
resolver = jsonschema.RefResolver.from_schema(root or schema)
while "$ref" in schema:
Expand Down Expand Up @@ -597,9 +633,9 @@ class SchemaBase:

_schema: Optional[Dict[str, Any]] = None
_rootschema: Optional[Dict[str, Any]] = None
_class_is_valid_at_instantiation = True
_class_is_valid_at_instantiation: bool = True

def __init__(self, *args, **kwds):
def __init__(self, *args: Any, **kwds: Any) -> None:
# Two valid options for initialization, which should be handled by
# derived classes:
# - a single arg with no kwds, for, e.g. {'type': 'string'}
Expand All @@ -623,7 +659,9 @@ def __init__(self, *args, **kwds):
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy(self, deep=True, ignore=()):
def copy(
self, deep: Union[bool, Iterable] = True, ignore: Optional[list] = None
) -> Self:
"""Return a copy of the object

Parameters
Expand All @@ -648,7 +686,9 @@ def _shallow_copy(obj):
else:
return obj

def _deep_copy(obj, ignore=()):
def _deep_copy(obj, ignore: Optional[list] = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
Expand All @@ -668,7 +708,7 @@ def _deep_copy(obj, ignore=()):
return obj

try:
deep = list(deep)
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
Expand All @@ -680,6 +720,8 @@ def _deep_copy(obj, ignore=()):
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand Down Expand Up @@ -873,12 +915,19 @@ def to_json(
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs)

@classmethod
def _default_wrapper_classes(cls):
def _default_wrapper_classes(cls) -> Generator[Type["SchemaBase"], None, None]:
"""Return the set of classes used within cls.from_dict()"""
return _subclasses(SchemaBase)

@classmethod
def from_dict(cls, dct, validate=True, _wrapper_classes=None):
def from_dict(
cls,
dct: dict,
validate: bool = True,
_wrapper_classes: Optional[Iterable[Type["SchemaBase"]]] = None,
# Type hints for this method would get rather complicated
# if we want to provide a more specific return type
) -> "SchemaBase":
"""Construct class from a dictionary representation

Parameters
Expand All @@ -887,7 +936,7 @@ def from_dict(cls, dct, validate=True, _wrapper_classes=None):
The dict from which to construct the class
validate : boolean
If True (default), then validate the input against the schema.
_wrapper_classes : list (optional)
_wrapper_classes : iterable (optional)
The set of SchemaBase classes to use when constructing wrappers
of the dict inputs. If not specified, the result of
cls._default_wrapper_classes will be used.
Expand All @@ -910,7 +959,14 @@ def from_dict(cls, dct, validate=True, _wrapper_classes=None):
return converter.from_dict(dct, cls)

@classmethod
def from_json(cls, json_string, validate=True, **kwargs):
def from_json(
cls,
json_string: str,
validate: bool = True,
**kwargs: Any
# Type hints for this method would get rather complicated
# if we want to provide a more specific return type
) -> Any:
"""Instantiate the object from a valid JSON string

Parameters
Expand All @@ -931,27 +987,36 @@ def from_json(cls, json_string, validate=True, **kwargs):
return cls.from_dict(dct, validate=validate)

@classmethod
def validate(cls, instance, schema=None):
def validate(
cls, instance: Dict[str, Any], schema: Optional[Dict[str, Any]] = None
) -> None:
"""
Validate the instance against the class schema in the context of the
rootschema.
"""
if schema is None:
schema = cls._schema
# For the benefit of mypy
assert schema is not None
return validate_jsonschema(
instance, schema, rootschema=cls._rootschema or cls._schema
)

@classmethod
def resolve_references(cls, schema=None):
def resolve_references(cls, schema: Optional[dict] = None) -> dict:
"""Resolve references in the context of this object's schema or root schema."""
schema_to_pass = schema or cls._schema
# For the benefit of mypy
assert schema_to_pass is not None
return _resolve_references(
schema=(schema or cls._schema),
schema=schema_to_pass,
root=(cls._rootschema or cls._schema or schema),
)

@classmethod
def validate_property(cls, name, value, schema=None):
def validate_property(
cls, name: str, value: Any, schema: Optional[dict] = None
) -> None:
"""
Validate a property against property schema in the context of the
rootschema
Expand All @@ -962,8 +1027,8 @@ def validate_property(cls, name, value, schema=None):
value, props.get(name, {}), rootschema=cls._rootschema or cls._schema
)

def __dir__(self):
return sorted(super().__dir__() + list(self._kwds.keys()))
def __dir__(self) -> list:
return sorted(list(super().__dir__()) + list(self._kwds.keys()))


def _passthrough(*args, **kwds):
Expand All @@ -980,7 +1045,7 @@ class _FromDict:

_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id")

def __init__(self, class_list):
def __init__(self, class_list: Iterable[Type[SchemaBase]]) -> None:
# Create a mapping of a schema hash to a list of matching classes
# This lets us quickly determine the correct class to construct
self.class_dict = collections.defaultdict(list)
Expand All @@ -989,7 +1054,7 @@ def __init__(self, class_list):
self.class_dict[self.hash_schema(cls._schema)].append(cls)

@classmethod
def hash_schema(cls, schema, use_json=True):
def hash_schema(cls, schema: dict, use_json: bool = True) -> int:
"""
Compute a python hash for a nested dictionary which
properly handles dicts, lists, sets, and tuples.
Expand Down Expand Up @@ -1025,14 +1090,29 @@ def _freeze(val):
return hash(_freeze(schema))

def from_dict(
self, dct, cls=None, schema=None, rootschema=None, default_class=_passthrough
):
self,
dct: dict,
cls: Optional[Type[SchemaBase]] = None,
schema: Optional[dict] = None,
rootschema: Optional[dict] = None,
default_class=_passthrough,
# Type hints for this method would get rather complicated
# if we want to provide a more specific return type
) -> Any:
"""Construct an object from a dict representation"""
if (schema is None) == (cls is None):
raise ValueError("Must provide either cls or schema, but not both.")
if schema is None:
schema = schema or cls._schema
rootschema = rootschema or cls._rootschema
# Can ignore type errors as cls is not None in case schema is
schema = cls._schema # type: ignore[union-attr]
# For the benefit of mypy
assert schema is not None
if rootschema:
rootschema = rootschema
elif cls is not None and cls._rootschema is not None:
rootschema = cls._rootschema
else:
rootschema = None
rootschema = rootschema or schema

if isinstance(dct, SchemaBase):
Expand Down Expand Up @@ -1086,7 +1166,7 @@ def from_dict(


class _PropertySetter:
def __init__(self, prop, schema):
def __init__(self, prop: str, schema: dict) -> None:
self.prop = prop
self.schema = schema

Expand Down Expand Up @@ -1133,7 +1213,7 @@ def __call__(self, *args, **kwargs):
return obj


def with_property_setters(cls):
def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase:
"""
Decorator to add property setters to a Schema class.
"""
Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,7 +2592,7 @@ def _get_name(cls):
return f"view_{cls._counter}"

@classmethod
def from_dict(cls, dct, validate=True) -> "Chart": # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future
def from_dict(cls, dct, validate=True) -> core.SchemaBase: # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future
"""Construct class from a dictionary representation

Parameters
Expand Down
Loading