diff --git a/pynamodb/models.py b/pynamodb/models.py index 7d5e99161..dbbfb12af 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -46,7 +46,7 @@ ATTR_DEFINITIONS, ATTR_NAME, ATTR_TYPE, KEY_SCHEMA, KEY_TYPE, ITEM, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS, RANGE_KEY, ATTRIBUTES, PUT, DELETE, RESPONSES, - INDEX_NAME, PROVISIONED_THROUGHPUT, PROJECTION, ALL_NEW, + INDEX_NAME, PROVISIONED_THROUGHPUT, PROJECTION, ALL_NEW, NONE, GLOBAL_SECONDARY_INDEXES, LOCAL_SECONDARY_INDEXES, KEYS, PROJECTION_TYPE, NON_KEY_ATTRIBUTES, TABLE_STATUS, ACTIVE, RETURN_VALUES, BATCH_GET_PAGE_LIMIT, @@ -417,31 +417,43 @@ def delete(self, condition: Optional[Condition] = None, settings: OperationSetti return self._get_connection().delete_item(hk_value, range_key=rk_value, condition=condition, settings=settings) - def update(self, actions: List[Action], condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Any: + def update( + self, + actions: List[Action], + condition: Optional[Condition] = None, + settings: OperationSettings = OperationSettings.default, + read_back: str = ALL_NEW, + ) -> Any: """ Updates an item using the UpdateItem operation. :param actions: a list of Action updates to apply :param condition: an optional Condition on which to update :param settings: per-operation settings + :param read_back: what to read back from the DB after the update + ALL_NEW: read back entire object after the update (default) + NONE: read back nothing, local version is not updated :raises ModelInstance.DoesNotExist: if the object to be updated does not exist :raises pynamodb.exceptions.UpdateError: if the `condition` is not met """ if not isinstance(actions, list) or len(actions) == 0: raise TypeError("the value of `actions` is expected to be a non-empty list") + if read_back not in (ALL_NEW, NONE): + raise ValueError("expected `read_back` to be `ALL_NEW` or `NONE`, but was: {}".format(read_back)) hk_value, rk_value = self._get_hash_range_key_serialized_values() version_condition = self._handle_version_attribute(actions=actions) if version_condition is not None: condition &= version_condition - data = self._get_connection().update_item(hk_value, range_key=rk_value, return_values=ALL_NEW, condition=condition, actions=actions, settings=settings) - item_data = data[ATTRIBUTES] - stored_cls = self._get_discriminator_class(item_data) - if stored_cls and stored_cls != type(self): - raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__)) - self.deserialize(item_data) - return data + data = self._get_connection().update_item(hk_value, range_key=rk_value, return_values=read_back, condition=condition, actions=actions, settings=settings) + if read_back == ALL_NEW: + item_data = data[ATTRIBUTES] + stored_cls = self._get_discriminator_class(item_data) + if stored_cls and stored_cls != type(self): + raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__)) + self.deserialize(item_data) + return data def save(self, condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Dict[str, Any]: """ diff --git a/tests/test_model.py b/tests/test_model.py index 94bc1799a..a96cb6f18 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -17,7 +17,7 @@ from pynamodb.constants import ( ITEM, STRING, ALL, KEYS_ONLY, INCLUDE, REQUEST_ITEMS, UNPROCESSED_KEYS, CAMEL_COUNT, RESPONSES, KEYS, ITEMS, LAST_EVALUATED_KEY, EXCLUSIVE_START_KEY, ATTRIBUTES, BINARY, - UNPROCESSED_ITEMS, DEFAULT_ENCODING, MAP, LIST, NUMBER, SCANNED_COUNT, + UNPROCESSED_ITEMS, DEFAULT_ENCODING, MAP, LIST, NUMBER, SCANNED_COUNT, ALL_NEW, NONE ) from pynamodb.models import Model from pynamodb.indexes import ( @@ -921,6 +921,28 @@ def test_update(self, mock_time): assert item.views is None self.assertEqual({'bob'}, item.custom_aliases) + def test_update_readback(self): + self.init_table_meta(SimpleUserModel, SIMPLE_MODEL_TABLE_DATA) + item = SimpleUserModel(user_name='foo', is_active=True, email='original@example.com', signature='foo', views=100) + + with patch(PATCH_METHOD) as req: + req.return_value = {} + item.update( + actions=[SimpleUserModel.email.set('changed@example.com')], + read_back=NONE) + params = { + 'TableName': 'SimpleModel', + 'Key': {'user_name': {'S': 'foo'}}, + 'ReturnValues': 'NONE', + 'UpdateExpression': 'SET #0 = :0', + 'ExpressionAttributeNames': {'#0': 'email'}, + 'ExpressionAttributeValues': {':0': {'S': 'changed@example.com'}}, + 'ReturnConsumedCapacity': 'TOTAL' + } + args = req.call_args[0][1] + deep_eq(args, params, _assert=True) + assert item.email == 'original@example.com' + def test_update_doesnt_do_validation_on_null_attributes(self): self.init_table_meta(CarModel, CAR_MODEL_TABLE_DATA) item = CarModel(12345)