From ac30c566940ac302bef600dee95abcbda14fc0a1 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Fri, 11 Aug 2023 10:00:46 -0500 Subject: [PATCH] Add flag to prevent operations on manytomany field w/unsaved instances. By default, attempting to read or write a many-to-many field on an unsaved model instance will now raise an exception. To disable this behavior, specify `prevent_unsaved=False` when initializing your ManyToManyField. Refs #2765 --- peewee.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/peewee.py b/peewee.py index 0a68ac2db..d261b310a 100644 --- a/peewee.py +++ b/peewee.py @@ -5574,6 +5574,9 @@ def __get__(self, instance, instance_type=None, force_query=False): return [getattr(obj, self.dest_fk.name) for obj in backref] src_id = getattr(instance, self.src_fk.rel_field.name) + if src_id is None and self.field._prevent_unsaved: + raise ValueError('Cannot get many-to-many "%s" for unsaved ' + 'instance "%s".' % (self.field, instance)) return (ManyToManyQuery(instance, self, self.rel_model) .join(self.through_model) .join(self.model) @@ -5582,6 +5585,10 @@ def __get__(self, instance, instance_type=None, force_query=False): return self.field def __set__(self, instance, value): + src_id = getattr(instance, self.src_fk.rel_field.name) + if src_id is None and self.field._prevent_unsaved: + raise ValueError('Cannot set many-to-many "%s" for unsaved ' + 'instance "%s".' % (self.field, instance)) query = self.__get__(instance, force_query=True) query.add(value, clear_existing=True) @@ -5590,7 +5597,7 @@ class ManyToManyField(MetaField): accessor_class = ManyToManyFieldAccessor def __init__(self, model, backref=None, through_model=None, on_delete=None, - on_update=None, _is_backref=False): + on_update=None, prevent_unsaved=True, _is_backref=False): if through_model is not None: if not (isinstance(through_model, DeferredThroughModel) or is_model(through_model)): @@ -5604,6 +5611,7 @@ def __init__(self, model, backref=None, through_model=None, on_delete=None, self._through_model = through_model self._on_delete = on_delete self._on_update = on_update + self._prevent_unsaved = prevent_unsaved self._is_backref = _is_backref def _get_descriptor(self):