Skip to content

Commit

Permalink
Add missing force flags
Browse files Browse the repository at this point in the history
  • Loading branch information
dkraczkowski committed Dec 6, 2023
1 parent 4af8bee commit 9c1454e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 27 deletions.
30 changes: 17 additions & 13 deletions chili/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,11 @@ class ClassDecoder(TypeDecoder):
_fields: Dict[str, TypeDecoder]
_schema: TypeSchema

def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None, force: bool = False):
self.class_name = class_name
self._schema = create_schema(class_name) # type: ignore
self._extra_decoders = extra_decoders
self.force = force

def decode(self, value: StateObject) -> Any:
if not isinstance(value, dict):
Expand All @@ -387,22 +388,24 @@ def _build(self) -> Dict[str, TypeDecoder]:
return {name: self._build_type_decoder(field.type) for name, field in self._schema.items()}

def _build_type_decoder(self, a_type: Type) -> TypeDecoder:
return build_type_decoder(a_type, self._extra_decoders, self.class_name.__module__) # type: ignore
return build_type_decoder(a_type, self._extra_decoders, self.class_name.__module__, self.force) # type: ignore


class GenericClassDecoder(ClassDecoder):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None, force: bool = False):
self._generic_type = class_name
self._generic_parameters = get_parameters_map(class_name)
self._extra_decoders = extra_decoders
type_: Type = get_origin_type(class_name) # type: ignore
self.force = force
super().__init__(type_)

def _build_type_decoder(self, a_type: Type) -> TypeDecoder:
return build_type_decoder(
map_generic_type(a_type, self._generic_parameters),
self._extra_decoders,
self._generic_type.__module__,
self.force
)


Expand All @@ -415,11 +418,12 @@ def decode(self, value: Any) -> E:


class NamedTupleDecoder(TypeDecoder):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None, force: bool = False):
self.class_name = class_name
self._is_typed = hasattr(class_name, "__annotations__")
self._arg_decoders: List[TypeDecoder] = []
self._extra_decoders = extra_decoders
self.force = force
if self._is_typed:
self._build()

Expand All @@ -440,16 +444,16 @@ def decode(self, value: list) -> tuple:
def _build(self) -> None:
field_types = self.class_name.__annotations__
for item_type in field_types.values():
self._arg_decoders.append(build_type_decoder(item_type, self._extra_decoders, self.class_name.__module__))
self._arg_decoders.append(build_type_decoder(item_type, self._extra_decoders, self.class_name.__module__, self.force))


class TypedDictDecoder(TypeDecoder):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None):
def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None, force: bool = False):
self.class_name = class_name
self._key_decoders = {}
self._extra_decoders = extra_decoders
for key_name, key_type in class_name.__annotations__.items():
self._key_decoders[key_name] = build_type_decoder(key_type, self._extra_decoders, class_name.__module__)
self._key_decoders[key_name] = build_type_decoder(key_type, self._extra_decoders, class_name.__module__, force)

def decode(self, value: dict) -> dict:
return {key: self._key_decoders[key].decode(item) for key, item in value.items()}
Expand Down Expand Up @@ -497,7 +501,7 @@ def build_type_decoder(
if origin_type and is_dataclass(origin_type):
if issubclass(origin_type, Generic): # type: ignore
return GenericClassDecoder(a_type)
return ClassDecoder(a_type, extra_decoders)
return ClassDecoder(a_type, extra_decoders, force)

if origin_type is None:
origin_type = a_type
Expand All @@ -509,7 +513,7 @@ def build_type_decoder(
return NamedTupleDecoder(origin_type, extra_decoders)

if is_class(origin_type) and is_typed_dict(origin_type):
return TypedDictDecoder(origin_type, extra_decoders)
return TypedDictDecoder(origin_type, extra_decoders, force)

if is_class(origin_type) and is_user_string(origin_type):
return SimpleDecoder[origin_type](origin_type) # type: ignore
Expand All @@ -518,7 +522,7 @@ def build_type_decoder(
type_args = get_type_args(a_type)
if len(type_args) == 2 and type_args[-1] is type(None): # type: ignore
return OptionalTypeDecoder(
build_type_decoder(a_type=type_args[0], extra_decoders=extra_decoders) # type: ignore
build_type_decoder(a_type=type_args[0], extra_decoders=extra_decoders, force=force) # type: ignore
)
return UnionDecoder(type_args, extra_decoders=extra_decoders, force=force)

Expand All @@ -530,10 +534,10 @@ def build_type_decoder(
if isinstance(a_type, TypeVar):
if a_type.__bound__ is None:
raise DecoderError.invalid_type(a_type)
return build_type_decoder(a_type.__bound__, extra_decoders, module)
return build_type_decoder(a_type.__bound__, extra_decoders, module, force)

if is_newtype(a_type):
return build_type_decoder(a_type.__supertype__, extra_decoders, module)
return build_type_decoder(a_type.__supertype__, extra_decoders, module, force)

if get_origin(origin_type) is not None:
raise DecoderError.invalid_type(a_type)
Expand All @@ -542,7 +546,7 @@ def build_type_decoder(
return Decoder[origin_type](decoders=extra_decoders) # type: ignore

if is_optional(a_type):
return OptionalTypeDecoder(build_type_decoder(unpack_optional(a_type))) # type: ignore
return OptionalTypeDecoder(build_type_decoder(unpack_optional(a_type), extra_decoders, module, force)) # type: ignore

if origin_type not in _supported_generics:
if force and is_class(origin_type):
Expand Down
33 changes: 19 additions & 14 deletions chili/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ class ClassEncoder(TypeEncoder):
_fields: Dict[str, TypeEncoder]
_schema: TypeSchema

def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None, force: bool = False):
self.class_name = class_name
self._extra_encoders = extra_encoders
self._schema = create_schema(class_name) # type: ignore
self.force = force

def encode(self, value: Any) -> StateObject:
if not isinstance(value, self.class_name):
Expand All @@ -263,14 +264,15 @@ def _build(self) -> Dict[str, TypeEncoder]:
return {name: self._build_type_encoder(field.type) for name, field in self._schema.items()}

def _build_type_encoder(self, a_type: Type) -> TypeEncoder:
return build_type_encoder(a_type, self._extra_encoders, self.class_name.__module__) # type: ignore
return build_type_encoder(a_type, self._extra_encoders, self.class_name.__module__, self.force) # type: ignore


class GenericClassEncoder(ClassEncoder):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None, force: bool = False):
self._generic_type = class_name
self._extra_encoders = extra_encoders
self._generic_parameters = get_parameters_map(class_name)
self.force = force
type_: Type = get_origin_type(class_name) # type: ignore
super().__init__(type_)

Expand All @@ -279,6 +281,7 @@ def _build_type_encoder(self, a_type: Type) -> TypeEncoder:
map_generic_type(a_type, self._generic_parameters),
self._extra_encoders,
self._generic_type.__module__,
self.force
)


Expand All @@ -294,11 +297,12 @@ def encode(self, value: E) -> Any:


class NamedTupleEncoder(TypeEncoder):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None, force: bool = False):
self.type = class_name
self._is_typed = hasattr(class_name, "__annotations__")
self._arg_encoders: List[TypeEncoder] = []
self._extra_encoders = extra_encoders
self.force = force
if self._is_typed:
self._build()

Expand All @@ -314,15 +318,16 @@ def encode(self, value: tuple) -> list:
def _build(self) -> None:
field_types = self.type.__annotations__
for item_type in field_types.values():
self._arg_encoders.append(build_type_encoder(item_type, self._extra_encoders, self.type.__module__))
self._arg_encoders.append(build_type_encoder(item_type, self._extra_encoders, self.type.__module__, self.force))


class TypedDictEncoder(TypeEncoder):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None, force: bool = False):
self.type = class_name
self._key_encoders = {}
self.force = force
for key_name, key_type in class_name.__annotations__.items():
self._key_encoders[key_name] = build_type_encoder(key_type, extra_encoders, class_name.__module__)
self._key_encoders[key_name] = build_type_encoder(key_type, extra_encoders, class_name.__module__, self.force)

def encode(self, value: dict) -> dict:
return {key: self._key_encoders[key].encode(item) for key, item in value.items()}
Expand Down Expand Up @@ -384,7 +389,7 @@ def build_type_encoder(
if origin_type and is_dataclass(origin_type):
if issubclass(origin_type, Generic): # type: ignore
return GenericClassEncoder(a_type)
return ClassEncoder(a_type, extra_encoders)
return ClassEncoder(a_type, extra_encoders, force)

if origin_type is None:
origin_type = a_type
Expand All @@ -393,18 +398,18 @@ def build_type_encoder(
return EnumEncoder(origin_type)

if is_class(origin_type) and is_named_tuple(origin_type):
return NamedTupleEncoder(origin_type, extra_encoders)
return NamedTupleEncoder(origin_type, extra_encoders, force)

if is_class(origin_type) and is_typed_dict(origin_type):
return TypedDictEncoder(origin_type, extra_encoders)
return TypedDictEncoder(origin_type, extra_encoders, force)

if is_class(origin_type) and is_user_string(origin_type):
return SimpleEncoder[str](str)

if origin_type is Union or (UnionType and isinstance(origin_type, UnionType)):
type_args = get_type_args(a_type)
if len(type_args) == 2 and type_args[-1] is type(None):
return OptionalTypeEncoder(build_type_encoder(type_args[0], extra_encoders)) # type: ignore
return OptionalTypeEncoder(build_type_encoder(type_args[0], extra_encoders, module, force)) # type: ignore
return UnionEncoder(type_args, extra_encoders, force=force)

if isinstance(a_type, typing.ForwardRef) and module is not None:
Expand All @@ -423,18 +428,18 @@ def build_type_encoder(
if isinstance(a_type, TypeVar):
if a_type.__bound__ is None:
raise EncoderError.invalid_type(a_type)
return build_type_encoder(a_type.__bound__, extra_encoders, module)
return build_type_encoder(a_type.__bound__, extra_encoders, module, force)

if is_newtype(a_type):
return build_type_encoder(a_type.__supertype__, extra_encoders, module)
return build_type_encoder(a_type.__supertype__, extra_encoders, module, force)

if origin_type not in _supported_generics:
if is_class(origin_type) and force:
return Encoder[origin_type](encoders=extra_encoders) # type: ignore[valid-type]
raise EncoderError.invalid_type(type=a_type)

type_attributes: List[TypeEncoder] = [
build_type_encoder(subtype, extra_encoders=extra_encoders, module=module) # type: ignore
build_type_encoder(subtype, extra_encoders=extra_encoders, module=module, force=force) # type: ignore
if subtype is not ...
else ... # noqa: E501
for subtype in get_type_args(a_type)
Expand Down

0 comments on commit 9c1454e

Please sign in to comment.