diff --git a/marshmallow_oneofschema/one_of_schema.py b/marshmallow_oneofschema/one_of_schema.py index 7957f46..60a3da3 100644 --- a/marshmallow_oneofschema/one_of_schema.py +++ b/marshmallow_oneofschema/one_of_schema.py @@ -1,4 +1,28 @@ -from marshmallow import Schema, ValidationError +from collections.abc import Mapping +import inspect + +from marshmallow import Schema, ValidationError, RAISE + + +# these helpers copied from marshmallow.utils # + + +def is_generator(obj) -> bool: + """Return True if ``obj`` is a generator""" + return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj) + + +def is_iterable_but_not_string(obj) -> bool: + """Return True if ``obj`` is an iterable object that isn't a string.""" + return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj) + + +def is_collection(obj) -> bool: + """Return True if ``obj`` is a collection type, e.g list, tuple, queryset.""" + return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping) + + +# end of helpers copied from marshmallow.utils # class OneOfSchema(Schema): @@ -63,32 +87,16 @@ def get_obj_type(self, obj): """Returns name of object schema""" return obj.__class__.__name__ - def dump(self, obj, *, many=None, **kwargs): - errors = {} - result_data = [] - result_errors = {} - many = self.many if many is None else bool(many) - if not many: - result = result_data = self._dump(obj, **kwargs) - else: - for idx, o in enumerate(obj): - try: - result = self._dump(o, **kwargs) - result_data.append(result) - except ValidationError as error: - result_errors[idx] = error.normalized_messages() - result_data.append(error.valid_data) - - result = result_data - errors = result_errors - - if not errors: - return result - else: - exc = ValidationError(errors, data=obj, valid_data=result) - raise exc - - def _dump(self, obj, *, update_fields=True, **kwargs): + # override the `_serialize` method of Schema, rather than `dump` + # this requires that we interact with a private API of marshmallow, but + # `_serialize` is the step that happens between pre_dump and post_dump + # hooks, so by using this rather than `load()`, we get schema hooks to work + def _serialize(self, obj, *, many=False): + if many and obj is not None: + return [self._serialize(subdoc, many=False) for subdoc in obj] + return self._dump_type_schema(obj) + + def _dump_type_schema(self, obj): obj_type = self.get_obj_type(obj) if not obj_type: return ( @@ -104,46 +112,58 @@ def _dump(self, obj, *, update_fields=True, **kwargs): schema.context.update(getattr(self, "context", {})) - result = schema.dump(obj, many=False, **kwargs) + result = schema.dump(obj, many=False) if result is not None: result[self.type_field] = obj_type return result - def load(self, data, *, many=None, partial=None, unknown=None, **kwargs): - errors = {} - result_data = [] - result_errors = {} - many = self.many if many is None else bool(many) - if partial is None: - partial = self.partial - if not many: - try: - result = result_data = self._load( - data, partial=partial, unknown=unknown, **kwargs - ) - # result_data.append(result) - except ValidationError as error: - result_errors = error.normalized_messages() - result_data.append(error.valid_data) - else: - for idx, item in enumerate(data): - try: - result = self._load(item, partial=partial, **kwargs) - result_data.append(result) - except ValidationError as error: - result_errors[idx] = error.normalized_messages() - result_data.append(error.valid_data) - - result = result_data - errors = result_errors - - if not errors: - return result - else: - exc = ValidationError(errors, data=data, valid_data=result) - raise exc - - def _load(self, data, *, partial=None, unknown=None, **kwargs): + # override the `_deserialize` method of Schema, rather than `load` + # this requires that we interact with a private API of marshmallow, but + # `_deserialize` is the step that happens between pre_load and validation + # hooks, so by using this rather than `load()`, we get schema hooks to work + def _deserialize( + self, + data, + *, + error_store, + many=False, + partial=False, + unknown=RAISE, + index=None, + ): + index = index if self.opts.index_errors else None + # if many, check for non-collection data (error) or iterate and + # re-invoke `_deserialize` on each one with many=False + # this is paraphrased from marshmallow.Schema._deserialize + if many: + if not is_collection(data): + error_store.store_error([self.error_messages["type"]], index=index) + return [] + else: + return [ + self._deserialize( + subdoc, + error_store=error_store, + many=False, + partial=partial, + unknown=unknown, + index=idx, + ) + for idx, subdoc in enumerate(data) + ] + if not isinstance(data, Mapping): + error_store.store_error([self.error_messages["type"]], index=index) + return self.dict_class() + + try: + result = self._load_type_schema(data, partial=partial, unknown=unknown) + except ValidationError as err: + error_store.store_error(err.messages, index=index) + result = err.valid_data + + return result + + def _load_type_schema(self, data, *, partial=None, unknown=None): if not isinstance(data, dict): raise ValidationError({"_schema": "Invalid data type: %s" % data}) @@ -173,11 +193,4 @@ def _load(self, data, *, partial=None, unknown=None, **kwargs): schema.context.update(getattr(self, "context", {})) - return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs) - - def validate(self, data, *, many=None, partial=None): - try: - self.load(data, many=many, partial=partial) - except ValidationError as ve: - return ve.messages - return {} + return schema.load(data, many=False, partial=partial, unknown=unknown) diff --git a/tests/test_one_of_schema.py b/tests/test_one_of_schema.py index adbebe8..b77af36 100644 --- a/tests/test_one_of_schema.py +++ b/tests/test_one_of_schema.py @@ -185,6 +185,29 @@ class TestSchema(OneOfSchema): TestSchema(unknown="exclude").load({"type": "Bar", "bar": 123}) assert Nonlocal.data["type"] == "Bar" + def test_post_dump_remove_type_field(self): + # test using a @post_dump hook to remove the type field which + # OneOfSchema will add to the data by default + + # define a schema without post_dump + class MySchemaVariant1(OneOfSchema): + type_schemas = {"Foo": FooSchema, "Bar": BarSchema} + + # and a variant with post_dump + class MySchemaVariant2(MySchemaVariant1): + @m.post_dump + def remove_type_field(self, data, **kwargs): + del data["type"] + return data + + # sanity check: `type` should be present in a dump from Variant1 + assert MySchemaVariant1().dump(Foo("someval")) == { + "type": "Foo", + "value": "someval", + } + # now check that the post_dump hook fired + assert MySchemaVariant2().dump(Foo("someval")) == {"value": "someval"} + def test_load_non_dict(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load(123)