From 89bf5b3c06102884fce2ecf82fd44dfc3a329151 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 11 Aug 2021 11:28:16 -0400 Subject: [PATCH] Allow passing Select to Model/Index.query/scan --- pynamodb/connection/base.py | 11 +++++++++-- pynamodb/connection/table.py | 2 ++ pynamodb/constants.py | 9 +++++---- pynamodb/indexes.py | 6 +++++- pynamodb/models.py | 10 ++++++++-- tests/test_model.py | 24 ++++++++++++++++++++++-- 6 files changed, 51 insertions(+), 11 deletions(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index cca050f7d..e88e18912 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -1219,6 +1219,7 @@ def scan( consistent_read: Optional[bool] = None, index_name: Optional[str] = None, settings: OperationSettings = OperationSettings.default, + select: Optional[str] = None, ) -> Dict: """ Performs the scan operation @@ -1232,6 +1233,11 @@ def scan( if filter_condition is not None: filter_expression = filter_condition.serialize(name_placeholders, expression_attribute_values) operation_kwargs[FILTER_EXPRESSION] = filter_expression + if select is not None: + select = select.upper() + if select not in SELECT_VALUES: + raise ValueError("{} must be one of {}".format(SELECT, SELECT_VALUES)) + operation_kwargs[SELECT] = select if attributes_to_get is not None: projection_expression = create_projection_expression(attributes_to_get, name_placeholders) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression @@ -1319,9 +1325,10 @@ def query( if return_consumed_capacity: operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if select: - if select.upper() not in SELECT_VALUES: + select = select.upper() + if select not in SELECT_VALUES: raise ValueError("{} must be one of {}".format(SELECT, SELECT_VALUES)) - operation_kwargs[SELECT] = str(select).upper() + operation_kwargs[SELECT] = select if scan_index_forward is not None: operation_kwargs[SCAN_INDEX_FORWARD] = scan_index_forward if name_placeholders: diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 183467a9f..7762b254a 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -231,6 +231,7 @@ def scan( consistent_read: Optional[bool] = None, index_name: Optional[str] = None, settings: OperationSettings = OperationSettings.default, + select: Optional[str] = None, ) -> Dict: """ Performs the scan operation @@ -238,6 +239,7 @@ def scan( return self.connection.scan( self.table_name, filter_condition=filter_condition, + select=select, attributes_to_get=attributes_to_get, limit=limit, return_consumed_capacity=return_consumed_capacity, diff --git a/pynamodb/constants.py b/pynamodb/constants.py index 8c9076f85..1da746c56 100644 --- a/pynamodb/constants.py +++ b/pynamodb/constants.py @@ -1,6 +1,7 @@ """ Pynamodb constants """ +from typing import Final # Operations TRANSACT_WRITE_ITEMS = 'TransactWriteItems' @@ -152,10 +153,10 @@ # These are the valid select values for the Scan operation # See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Scan.html#DDB-Scan-request-Select -ALL_ATTRIBUTES = 'ALL_ATTRIBUTES' -ALL_PROJECTED_ATTRIBUTES = 'ALL_PROJECTED_ATTRIBUTES' -SPECIFIC_ATTRIBUTES = 'SPECIFIC_ATTRIBUTES' -COUNT = 'COUNT' +ALL_ATTRIBUTES: Final = 'ALL_ATTRIBUTES' +ALL_PROJECTED_ATTRIBUTES: Final = 'ALL_PROJECTED_ATTRIBUTES' +SPECIFIC_ATTRIBUTES: Final = 'SPECIFIC_ATTRIBUTES' +COUNT: Final = 'COUNT' SELECT_VALUES = [ALL_ATTRIBUTES, ALL_PROJECTED_ATTRIBUTES, SPECIFIC_ATTRIBUTES, COUNT] # These are the valid comparison operators for the Scan operation diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index 6ee5508b4..e028a670c 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -2,7 +2,7 @@ PynamoDB Indexes """ from inspect import getmembers -from typing import Any, Dict, Generic, List, Optional, TypeVar +from typing import Any, Dict, Generic, List, Optional, TypeVar, Literal from typing import TYPE_CHECKING from pynamodb._compat import GenericMeta @@ -90,6 +90,7 @@ def query( attributes_to_get: Optional[List[str]] = None, page_size: Optional[int] = None, rate_limit: Optional[float] = None, + select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None, ) -> ResultIterator[_M]: """ Queries an index @@ -106,6 +107,7 @@ def query( attributes_to_get=attributes_to_get, page_size=page_size, rate_limit=rate_limit, + select=select, ) @classmethod @@ -120,6 +122,7 @@ def scan( consistent_read: Optional[bool] = None, rate_limit: Optional[float] = None, attributes_to_get: Optional[List[str]] = None, + select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None, ) -> ResultIterator[_M]: """ Scans an index @@ -135,6 +138,7 @@ def scan( index_name=cls.Meta.index_name, rate_limit=rate_limit, attributes_to_get=attributes_to_get, + select=select, ) @classmethod diff --git a/pynamodb/models.py b/pynamodb/models.py index ca768afca..956225926 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -7,7 +7,7 @@ import warnings import sys from inspect import getmembers -from typing import Any +from typing import Any, Literal from typing import Dict from typing import Generic from typing import Iterable @@ -632,6 +632,7 @@ def query( page_size: Optional[int] = None, rate_limit: Optional[float] = None, settings: OperationSettings = OperationSettings.default, + select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None, ) -> ResultIterator[_T]: """ Provides a high level query API @@ -646,6 +647,8 @@ def query( Controls descending or ascending results :param last_evaluated_key: If set, provides the starting point for query. :param attributes_to_get: If set, only returns these elements + :param select: If set, specifies which attributes to return; + if SPECIFIC_ATTRIBUTES is set, the attributes_to_get parameter must be passed :param page_size: Page size of the query to DynamoDB :param rate_limit: If set then consumed capacity will be limited to this amount per second """ @@ -673,6 +676,7 @@ def query( scan_index_forward=scan_index_forward, limit=page_size, attributes_to_get=attributes_to_get, + select=select, ) return ResultIterator( @@ -699,6 +703,7 @@ def scan( rate_limit: Optional[float] = None, attributes_to_get: Optional[Sequence[str]] = None, settings: OperationSettings = OperationSettings.default, + select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None, ) -> ResultIterator[_T]: """ Iterates through all items in the table @@ -731,7 +736,8 @@ def scan( total_segments=total_segments, consistent_read=consistent_read, index_name=index_name, - attributes_to_get=attributes_to_get + attributes_to_get=attributes_to_get, + select=select, ) return ResultIterator( diff --git a/tests/test_model.py b/tests/test_model.py index 2c9d9a9c1..99c78fa71 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1392,6 +1392,26 @@ def test_query_with_exclusive_start_key(self): self.assertEqual(results_iter.total_count, 10) self.assertEqual(results_iter.page_iter.total_scanned_count, 10) + def test_query_with_select(self): + with patch(PATCH_METHOD) as req: + req.return_value = MODEL_TABLE_DATA + UserModel('foo', 'bar') + + with patch(PATCH_METHOD) as req: + items = [] + + req.side_effect = [ + {'Count': 0, 'ScannedCount': 0, 'Items': items}, + ] + results_iter = UserModel.query('foo', limit=10, page_size=10, select='ALL_ATTRIBUTES') + + results = list(results_iter) + self.assertEqual(len(results), 0) + self.assertEqual(len(req.mock_calls), 1) + self.assertEqual(req.mock_calls[0].args[1]['Select'], 'ALL_ATTRIBUTES') + self.assertEqual(results_iter.total_count, 0) + self.assertEqual(results_iter.page_iter.total_scanned_count, 0) + def test_query(self): """ Model.query @@ -1749,8 +1769,7 @@ def fake_scan(*args): item['user_id'] = {STRING: 'id-{0}'.format(idx)} items.append(item) req.return_value = {'Count': len(items), 'ScannedCount': len(items), 'Items': items} - for item in UserModel.scan( - attributes_to_get=['email']): + for item in UserModel.scan(attributes_to_get=['email'], select='SPECIFIC_ATTRIBUTES'): self.assertIsNotNone(item) params = { 'ReturnConsumedCapacity': 'TOTAL', @@ -1758,6 +1777,7 @@ def fake_scan(*args): 'ExpressionAttributeNames': { '#0': 'email' }, + 'Select': 'SPECIFIC_ATTRIBUTES', 'TableName': 'UserModel' } self.assertEqual(params, req.call_args[0][1])