Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Select parameter to Model/Index.query/scan #969

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pynamodb/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ def scan(
consistent_read: Optional[bool] = None,
index_name: Optional[str] = None,
settings: OperationSettings = OperationSettings.default,
select: Optional[str] = None,
) -> Dict:
"""
Performs the scan operation
Expand All @@ -1232,6 +1233,11 @@ def scan(
if filter_condition is not None:
filter_expression = filter_condition.serialize(name_placeholders, expression_attribute_values)
operation_kwargs[FILTER_EXPRESSION] = filter_expression
if select is not None:
select = select.upper()
if select not in SELECT_VALUES:
raise ValueError("{} must be one of {}".format(SELECT, SELECT_VALUES))
operation_kwargs[SELECT] = select
if attributes_to_get is not None:
projection_expression = create_projection_expression(attributes_to_get, name_placeholders)
operation_kwargs[PROJECTION_EXPRESSION] = projection_expression
Expand Down Expand Up @@ -1319,9 +1325,10 @@ def query(
if return_consumed_capacity:
operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity))
if select:
if select.upper() not in SELECT_VALUES:
select = select.upper()
if select not in SELECT_VALUES:
raise ValueError("{} must be one of {}".format(SELECT, SELECT_VALUES))
operation_kwargs[SELECT] = str(select).upper()
operation_kwargs[SELECT] = select
if scan_index_forward is not None:
operation_kwargs[SCAN_INDEX_FORWARD] = scan_index_forward
if name_placeholders:
Expand Down
2 changes: 2 additions & 0 deletions pynamodb/connection/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,15 @@ def scan(
consistent_read: Optional[bool] = None,
index_name: Optional[str] = None,
settings: OperationSettings = OperationSettings.default,
select: Optional[str] = None,
) -> Dict:
"""
Performs the scan operation
"""
return self.connection.scan(
self.table_name,
filter_condition=filter_condition,
select=select,
attributes_to_get=attributes_to_get,
limit=limit,
return_consumed_capacity=return_consumed_capacity,
Expand Down
9 changes: 5 additions & 4 deletions pynamodb/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Pynamodb constants
"""
from typing import Final

# Operations
TRANSACT_WRITE_ITEMS = 'TransactWriteItems'
Expand Down Expand Up @@ -152,10 +153,10 @@

# These are the valid select values for the Scan operation
# See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Scan.html#DDB-Scan-request-Select
ALL_ATTRIBUTES = 'ALL_ATTRIBUTES'
ALL_PROJECTED_ATTRIBUTES = 'ALL_PROJECTED_ATTRIBUTES'
SPECIFIC_ATTRIBUTES = 'SPECIFIC_ATTRIBUTES'
COUNT = 'COUNT'
ALL_ATTRIBUTES: Final = 'ALL_ATTRIBUTES'
ALL_PROJECTED_ATTRIBUTES: Final = 'ALL_PROJECTED_ATTRIBUTES'
SPECIFIC_ATTRIBUTES: Final = 'SPECIFIC_ATTRIBUTES'
COUNT: Final = 'COUNT'
SELECT_VALUES = [ALL_ATTRIBUTES, ALL_PROJECTED_ATTRIBUTES, SPECIFIC_ATTRIBUTES, COUNT]

# These are the valid comparison operators for the Scan operation
Expand Down
6 changes: 5 additions & 1 deletion pynamodb/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
PynamoDB Indexes
"""
from inspect import getmembers
from typing import Any, Dict, Generic, List, Optional, TypeVar
from typing import Any, Dict, Generic, List, Optional, TypeVar, Literal
from typing import TYPE_CHECKING

from pynamodb._compat import GenericMeta
Expand Down Expand Up @@ -90,6 +90,7 @@ def query(
attributes_to_get: Optional[List[str]] = None,
page_size: Optional[int] = None,
rate_limit: Optional[float] = None,
select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None,
) -> ResultIterator[_M]:
"""
Queries an index
Expand All @@ -106,6 +107,7 @@ def query(
attributes_to_get=attributes_to_get,
page_size=page_size,
rate_limit=rate_limit,
select=select,
)

@classmethod
Expand All @@ -120,6 +122,7 @@ def scan(
consistent_read: Optional[bool] = None,
rate_limit: Optional[float] = None,
attributes_to_get: Optional[List[str]] = None,
select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None,
) -> ResultIterator[_M]:
"""
Scans an index
Expand All @@ -135,6 +138,7 @@ def scan(
index_name=cls.Meta.index_name,
rate_limit=rate_limit,
attributes_to_get=attributes_to_get,
select=select,
)

@classmethod
Expand Down
10 changes: 8 additions & 2 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
import sys
from inspect import getmembers
from typing import Any
from typing import Any, Literal
from typing import Dict
from typing import Generic
from typing import Iterable
Expand Down Expand Up @@ -632,6 +632,7 @@ def query(
page_size: Optional[int] = None,
rate_limit: Optional[float] = None,
settings: OperationSettings = OperationSettings.default,
select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None,
) -> ResultIterator[_T]:
"""
Provides a high level query API
Expand All @@ -646,6 +647,8 @@ def query(
Controls descending or ascending results
:param last_evaluated_key: If set, provides the starting point for query.
:param attributes_to_get: If set, only returns these elements
:param select: If set, specifies which attributes to return;
if SPECIFIC_ATTRIBUTES is set, the attributes_to_get parameter must be passed
:param page_size: Page size of the query to DynamoDB
:param rate_limit: If set then consumed capacity will be limited to this amount per second
"""
Expand Down Expand Up @@ -673,6 +676,7 @@ def query(
scan_index_forward=scan_index_forward,
limit=page_size,
attributes_to_get=attributes_to_get,
select=select,
)

return ResultIterator(
Expand All @@ -699,6 +703,7 @@ def scan(
rate_limit: Optional[float] = None,
attributes_to_get: Optional[Sequence[str]] = None,
settings: OperationSettings = OperationSettings.default,
select: Optional[Literal['ALL_ATTRIBUTES', 'ALL_PROJECTED_ATTRIBUTES', 'COUNT', 'SPECIFIC_ATTRIBUTES']] = None,
) -> ResultIterator[_T]:
"""
Iterates through all items in the table
Expand Down Expand Up @@ -731,7 +736,8 @@ def scan(
total_segments=total_segments,
consistent_read=consistent_read,
index_name=index_name,
attributes_to_get=attributes_to_get
attributes_to_get=attributes_to_get,
select=select,
)

return ResultIterator(
Expand Down
24 changes: 22 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,26 @@ def test_query_with_exclusive_start_key(self):
self.assertEqual(results_iter.total_count, 10)
self.assertEqual(results_iter.page_iter.total_scanned_count, 10)

def test_query_with_select(self):
with patch(PATCH_METHOD) as req:
req.return_value = MODEL_TABLE_DATA
UserModel('foo', 'bar')

with patch(PATCH_METHOD) as req:
items = []

req.side_effect = [
{'Count': 0, 'ScannedCount': 0, 'Items': items},
]
results_iter = UserModel.query('foo', limit=10, page_size=10, select='ALL_ATTRIBUTES')

results = list(results_iter)
self.assertEqual(len(results), 0)
self.assertEqual(len(req.mock_calls), 1)
self.assertEqual(req.mock_calls[0].args[1]['Select'], 'ALL_ATTRIBUTES')
self.assertEqual(results_iter.total_count, 0)
self.assertEqual(results_iter.page_iter.total_scanned_count, 0)

def test_query(self):
"""
Model.query
Expand Down Expand Up @@ -1749,15 +1769,15 @@ def fake_scan(*args):
item['user_id'] = {STRING: 'id-{0}'.format(idx)}
items.append(item)
req.return_value = {'Count': len(items), 'ScannedCount': len(items), 'Items': items}
for item in UserModel.scan(
attributes_to_get=['email']):
for item in UserModel.scan(attributes_to_get=['email'], select='SPECIFIC_ATTRIBUTES'):
self.assertIsNotNone(item)
params = {
'ReturnConsumedCapacity': 'TOTAL',
'ProjectionExpression': '#0',
'ExpressionAttributeNames': {
'#0': 'email'
},
'Select': 'SPECIFIC_ATTRIBUTES',
'TableName': 'UserModel'
}
self.assertEqual(params, req.call_args[0][1])
Expand Down