From 54354d8803b7c8a2f0e12b9f55c8e03f5fe0dd3e Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Mon, 17 Jul 2023 18:11:45 -0600 Subject: [PATCH 1/7] Bump version to next dev version --- planet/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planet/__version__.py b/planet/__version__.py index a33997dd..06754e28 100644 --- a/planet/__version__.py +++ b/planet/__version__.py @@ -1 +1 @@ -__version__ = '2.1.0' +__version__ = '2.1.1dev' From e94ea985f98e2c67db79cfcbf0d834a5910910c9 Mon Sep 17 00:00:00 2001 From: Jennifer Reiber Kyle Date: Thu, 20 Jul 2023 10:13:34 -0700 Subject: [PATCH 2/7] add ServerError to list of exceptions to retry --- planet/http.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/planet/http.py b/planet/http.py index 6ec8fb75..52824e49 100644 --- a/planet/http.py +++ b/planet/http.py @@ -42,7 +42,8 @@ httpx.ReadTimeout, httpx.RemoteProtocolError, exceptions.BadGateway, - exceptions.TooManyRequests + exceptions.TooManyRequests, + exceptions.ServerError ] MAX_RETRIES = 5 MAX_RETRY_BACKOFF = 64 # seconds From ce8adc3e3b6de32832cd67023048a7d32bf31ccc Mon Sep 17 00:00:00 2001 From: Jennifer Reiber Kyle Date: Fri, 21 Jul 2023 11:47:45 -0700 Subject: [PATCH 3/7] move write functionality to Session, include limiting and retries --- planet/clients/data.py | 19 +++--- planet/clients/orders.py | 18 ++--- planet/http.py | 130 +++++++++++++++++++++++++++++++---- planet/models.py | 144 --------------------------------------- 4 files changed, 132 insertions(+), 179 deletions(-) diff --git a/planet/clients/data.py b/planet/clients/data.py index 7b8133a8..22495e38 100644 --- a/planet/clients/data.py +++ b/planet/clients/data.py @@ -24,7 +24,7 @@ from .. import exceptions from ..constants import PLANET_BASE_URL from ..http import Session -from ..models import Paged, StreamingBody +from ..models import Paged from ..specs import validate_data_item_type BASE_URL = f'{PLANET_BASE_URL}/data/v1/' @@ -586,8 +586,8 @@ async def download_asset(self, Raises: planet.exceptions.APIError: On API error. - planet.exceptions.ClientError: If asset is not active or asset - description is not valid. + planet.exceptions.ClientError: If asset is not active, asset + description is not valid, or retry limit is exceeded. """ try: location = asset['location'] @@ -595,14 +595,11 @@ async def download_asset(self, raise exceptions.ClientError( 'asset missing ["location"] entry. Is asset active?') - async with self._session.stream(method='GET', url=location) as resp: - body = StreamingBody(resp) - dl_path = Path(directory, filename or body.name) - dl_path.parent.mkdir(exist_ok=True, parents=True) - await body.write(dl_path, - overwrite=overwrite, - progress_bar=progress_bar) - return dl_path + return await self._session.write(location, + filename=filename, + directory=directory, + overwrite=overwrite, + progress_bar=progress_bar) @staticmethod def validate_checksum(asset: dict, filename: Path): diff --git a/planet/clients/orders.py b/planet/clients/orders.py index 92d74645..a668c010 100644 --- a/planet/clients/orders.py +++ b/planet/clients/orders.py @@ -25,7 +25,7 @@ from .. import exceptions from ..constants import PLANET_BASE_URL from ..http import Session -from ..models import Paged, StreamingBody +from ..models import Paged BASE_URL = f'{PLANET_BASE_URL}/compute/ops' STATS_PATH = '/stats/orders/v2' @@ -251,15 +251,15 @@ async def download_asset(self, Raises: planet.exceptions.APIError: On API error. + planet.exceptions.ClientError: If location is not valid or retry + limit is exceeded. + """ - async with self._session.stream(method='GET', url=location) as resp: - body = StreamingBody(resp) - dl_path = Path(directory, filename or body.name) - dl_path.parent.mkdir(exist_ok=True, parents=True) - await body.write(dl_path, - overwrite=overwrite, - progress_bar=progress_bar) - return dl_path + return await self._session.write(location, + filename=filename, + directory=directory, + overwrite=overwrite, + progress_bar=progress_bar) async def download_order(self, order_id: str, diff --git a/planet/http.py b/planet/http.py index 52824e49..ac164a76 100644 --- a/planet/http.py +++ b/planet/http.py @@ -19,11 +19,13 @@ from contextlib import asynccontextmanager from http import HTTPStatus import logging +from pathlib import Path import random import time from typing import AsyncGenerator, Optional import httpx +from tqdm.asyncio import tqdm from typing_extensions import Literal from .auth import Auth, AuthType @@ -395,26 +397,79 @@ async def _send(self, request, stream=False) -> httpx.Response: return http_resp - @asynccontextmanager - async def stream( - self, method: str, - url: str) -> AsyncGenerator[models.StreamingResponse, None]: - """Submit a request and get the response as a stream context manager. + async def write(self, + url: str, + directory: Path, + overwrite: bool, + progress_bar: bool, + filename: Optional[str] = None) -> str: + """Write data to local file with limiting and retries. Parameters: - method: HTTP request method. - url: Location of the API endpoint. + url: Remote location url + filename: Custom name to assign to downloaded file. + directory: Base directory for file download. This directory will be + created if it does not already exist. + overwrite: Overwrite any existing files. + progress_bar: Show progress bar during download. Returns: - Context manager providing the streaming response. + Path to downloaded file. + + Raises: + planet.exceptions.APIException: On API error. + planet.exceptions.ClientError: When retry limit is exceeded. + """ - request = self._client.build_request(method=method, url=url) - http_response = await self._retry(self._send, request, stream=True) - response = models.StreamingResponse(http_response) - try: - yield response - finally: - await response.aclose() + async def _limited_write(): + async with self._limiter: + dl_path = await self._write(url=url, + directory=directory, + overwrite=overwrite, + progress_bar=progress_bar, + filename=filename) + return dl_path + + return await self._retry(_limited_write) + + async def _write(self, + url: str, + directory: Path, + overwrite: bool, + progress_bar: bool, + filename: Optional[str] = None) -> Path: + """Write data to local file. To be used in write()""" + + async with self._client.stream('GET', url) as response: + + dl_path = Path(directory, + filename or _get_filename_from_response(response)) + dl_path.parent.mkdir(exist_ok=True, parents=True) + + total = int(response.headers["Content-Length"]) + + try: + mode = 'wb' if overwrite else 'xb' + with open(dl_path, mode) as fp: + LOGGER.info(f'Writing {dl_path}, size {total}B') + + with tqdm(total=total, + unit_scale=True, + unit_divisor=1024 * 1024, + unit='B', + desc=str(filename), + disable=not progress_bar) as progress: + + previous = response.num_bytes_downloaded + + async for chunk in response.aiter_bytes(): + fp.write(chunk) + new = response.num_bytes_downloaded - previous + progress.update(new - previous) + previous = new + progress.update() + except FileExistsError: + LOGGER.info(f'File {dl_path} exists, not overwriting') def client(self, name: Literal['data', 'orders', 'subscriptions'], @@ -440,6 +495,51 @@ def client(self, raise exceptions.ClientError("No such client.") +def _get_filename_from_response(response) -> str: + """The name of the response resource. + + The default is to use the content-disposition header value from the + response. If not found, falls back to resolving the name from the url + or generating a random name with the type from the response. + """ + name = (_get_filename_from_headers(response.headers) + or _get_filename_from_url(response.url) + or _get_random_filename(response.headers.get('content-type'))) + return name + + +def _get_filename_from_headers(headers): + """Get a filename from the Content-Disposition header, if available. + + :param headers dict: a ``dict`` of response headers + :returns: a filename (i.e. ``basename``) + :rtype: str or None + """ + cd = headers.get('content-disposition', '') + match = re.search('filename="?([^"]+)"?', cd) + return match.group(1) if match else None + + +def _get_filename_from_url(url: str) -> Optional[str]: + """Get a filename from a url. + + Getting a name for Landsat imagery uses this function. + """ + path = urlparse(url).path + name = path[path.rfind('/') + 1:] + return name or None + + +def _get_random_filename(content_type=None) -> str: + """Get a pseudo-random, Planet-looking filename. + """ + extension = mimetypes.guess_extension(content_type or '') or '' + characters = string.ascii_letters + '0123456789' + letters = ''.join(random.sample(characters, 8)) + name = 'planet-{}{}'.format(letters, extension) + return name + + class AuthSession(BaseSession): """Synchronous connection to the Planet Auth service.""" diff --git a/planet/models.py b/planet/models.py index e2e7761c..8524d47e 100644 --- a/planet/models.py +++ b/planet/models.py @@ -54,150 +54,6 @@ def json(self) -> dict: return self._http_response.json() -class StreamingResponse(Response): - - @property - def headers(self) -> httpx.Headers: - return self._http_response.headers - - @property - def url(self) -> str: - return str(self._http_response.url) - - @property - def num_bytes_downloaded(self) -> int: - return self._http_response.num_bytes_downloaded - - async def aiter_bytes(self): - async for c in self._http_response.aiter_bytes(): - yield c - - async def aclose(self): - await self._http_response.aclose() - - -class StreamingBody: - """A representation of a streaming resource from the API.""" - - def __init__(self, response: StreamingResponse): - """Initialize the object. - - Parameters: - response: Response that was received from the server. - """ - self._response = response - - @property - def name(self) -> str: - """The name of this resource. - - The default is to use the content-disposition header value from the - response. If not found, falls back to resolving the name from the url - or generating a random name with the type from the response. - """ - name = (_get_filename_from_headers(self._response.headers) - or _get_filename_from_url(self._response.url) - or _get_random_filename( - self._response.headers.get('content-type'))) - return name - - @property - def size(self) -> int: - """The size of the body.""" - return int(self._response.headers['Content-Length']) - - async def write(self, - filename: Path, - overwrite: bool = True, - progress_bar: bool = True): - """Write the body to a file. - Parameters: - filename: Name to assign to downloaded file. - overwrite: Overwrite any existing files. - progress_bar: Show progress bar during download. - """ - - class _LOG: - - def __init__(self, total, unit, filename, disable): - self.total = total - self.unit = unit - self.disable = disable - self.previous = 0 - self.filename = str(filename) - - if not self.disable: - LOGGER.debug(f'writing to {self.filename}') - - def update(self, new): - if new - self.previous > self.unit and not self.disable: - # LOGGER.debug(f'{new-self.previous}') - perc = int(100 * new / self.total) - LOGGER.debug(f'{self.filename}: ' - f'wrote {perc}% of {self.total}') - self.previous = new - - unit = 1024 * 1024 - - mode = 'wb' if overwrite else 'xb' - try: - with open(filename, mode) as fp: - _log = _LOG(self.size, - 16 * unit, - filename, - disable=progress_bar) - with tqdm(total=self.size, - unit_scale=True, - unit_divisor=unit, - unit='B', - desc=str(filename), - disable=not progress_bar) as progress: - previous = self._response.num_bytes_downloaded - async for chunk in self._response.aiter_bytes(): - fp.write(chunk) - new = self._response.num_bytes_downloaded - _log.update(new) - progress.update(new - previous) - previous = new - except FileExistsError: - LOGGER.info(f'File {filename} exists, not overwriting') - - -def _get_filename_from_headers(headers): - """Get a filename from the Content-Disposition header, if available. - - :param headers dict: a ``dict`` of response headers - :returns: a filename (i.e. ``basename``) - :rtype: str or None - """ - cd = headers.get('content-disposition', '') - match = re.search('filename="?([^"]+)"?', cd) - return match.group(1) if match else None - - -def _get_filename_from_url(url: str) -> Optional[str]: - """Get a filename from a url. - - Getting a name for Landsat imagery uses this function. - """ - path = urlparse(url).path - name = path[path.rfind('/') + 1:] - return name or None - - -def _get_random_filename(content_type=None): - """Get a pseudo-random, Planet-looking filename. - - :returns: a filename (i.e. ``basename``) - :rtype: str - """ - extension = mimetypes.guess_extension(content_type or '') or '' - characters = string.ascii_letters + '0123456789' - letters = ''.join(random.sample(characters, 8)) - name = 'planet-{}{}'.format(letters, extension) - return name - - class Paged: """Asynchronous iterator over results in a paged resource. From 6fbdefe8a76bda6ec70ad2dd8d7ab438c28ede66 Mon Sep 17 00:00:00 2001 From: Jennifer Reiber Kyle Date: Fri, 21 Jul 2023 14:49:27 -0700 Subject: [PATCH 4/7] add tests, fixes so tests pass, remove ServerError retry because it slows tests --- planet/clients/data.py | 8 +-- planet/clients/orders.py | 8 +-- planet/http.py | 103 ++++++++++++++------------- planet/models.py | 9 +-- tests/unit/test_http.py | 144 +++++++++++++++++++++++++++++++++----- tests/unit/test_models.py | 125 --------------------------------- 6 files changed, 188 insertions(+), 209 deletions(-) diff --git a/planet/clients/data.py b/planet/clients/data.py index 22495e38..8608c36d 100644 --- a/planet/clients/data.py +++ b/planet/clients/data.py @@ -596,10 +596,10 @@ async def download_asset(self, 'asset missing ["location"] entry. Is asset active?') return await self._session.write(location, - filename=filename, - directory=directory, - overwrite=overwrite, - progress_bar=progress_bar) + filename=filename, + directory=directory, + overwrite=overwrite, + progress_bar=progress_bar) @staticmethod def validate_checksum(asset: dict, filename: Path): diff --git a/planet/clients/orders.py b/planet/clients/orders.py index a668c010..82bec167 100644 --- a/planet/clients/orders.py +++ b/planet/clients/orders.py @@ -256,10 +256,10 @@ async def download_asset(self, """ return await self._session.write(location, - filename=filename, - directory=directory, - overwrite=overwrite, - progress_bar=progress_bar) + filename=filename, + directory=directory, + overwrite=overwrite, + progress_bar=progress_bar) async def download_order(self, order_id: str, diff --git a/planet/http.py b/planet/http.py index ac164a76..426e6995 100644 --- a/planet/http.py +++ b/planet/http.py @@ -16,13 +16,16 @@ from __future__ import annotations # https://stackoverflow.com/a/33533514 import asyncio from collections import Counter -from contextlib import asynccontextmanager from http import HTTPStatus import logging +import mimetypes from pathlib import Path import random +import re +import string import time -from typing import AsyncGenerator, Optional +from typing import Optional +from urllib.parse import urlparse import httpx from tqdm.asyncio import tqdm @@ -44,8 +47,7 @@ httpx.ReadTimeout, httpx.RemoteProtocolError, exceptions.BadGateway, - exceptions.TooManyRequests, - exceptions.ServerError + exceptions.TooManyRequests, # exceptions.ServerError ] MAX_RETRIES = 5 MAX_RETRY_BACKOFF = 64 # seconds @@ -399,10 +401,10 @@ async def _send(self, request, stream=False) -> httpx.Response: async def write(self, url: str, - directory: Path, - overwrite: bool, - progress_bar: bool, - filename: Optional[str] = None) -> str: + filename: Optional[str] = None, + directory: Path = Path('.'), + overwrite: bool = False, + progress_bar: bool = False) -> Path: """Write data to local file with limiting and retries. Parameters: @@ -421,55 +423,56 @@ async def write(self, planet.exceptions.ClientError: When retry limit is exceeded. """ + + async def _write(): + async with self._client.stream('GET', url) as response: + + dl_path = Path( + directory, + filename or _get_filename_from_response(response)) + dl_path.parent.mkdir(exist_ok=True, parents=True) + + await self._write_response(response, + dl_path, + overwrite=overwrite, + progress_bar=progress_bar) + + return dl_path + async def _limited_write(): async with self._limiter: - dl_path = await self._write(url=url, - directory=directory, - overwrite=overwrite, - progress_bar=progress_bar, - filename=filename) + dl_path = await _write() return dl_path return await self._retry(_limited_write) - async def _write(self, - url: str, - directory: Path, - overwrite: bool, - progress_bar: bool, - filename: Optional[str] = None) -> Path: - """Write data to local file. To be used in write()""" + async def _write_response(self, + response, + filename, + overwrite, + progress_bar): + total = int(response.headers["Content-Length"]) - async with self._client.stream('GET', url) as response: - - dl_path = Path(directory, - filename or _get_filename_from_response(response)) - dl_path.parent.mkdir(exist_ok=True, parents=True) - - total = int(response.headers["Content-Length"]) - - try: - mode = 'wb' if overwrite else 'xb' - with open(dl_path, mode) as fp: - LOGGER.info(f'Writing {dl_path}, size {total}B') - - with tqdm(total=total, - unit_scale=True, - unit_divisor=1024 * 1024, - unit='B', - desc=str(filename), - disable=not progress_bar) as progress: - - previous = response.num_bytes_downloaded - - async for chunk in response.aiter_bytes(): - fp.write(chunk) - new = response.num_bytes_downloaded - previous - progress.update(new - previous) - previous = new - progress.update() - except FileExistsError: - LOGGER.info(f'File {dl_path} exists, not overwriting') + try: + mode = 'wb' if overwrite else 'xb' + with open(filename, mode) as fp: + + with tqdm(total=total, + unit_scale=True, + unit_divisor=1024 * 1024, + unit='B', + desc=str(filename), + disable=not progress_bar) as progress: + previous = response.num_bytes_downloaded + + async for chunk in response.aiter_bytes(): + fp.write(chunk) + new = response.num_bytes_downloaded - previous + progress.update(new - previous) + previous = new + progress.update() + except FileExistsError: + LOGGER.info(f'File {filename} exists, not overwriting') def client(self, name: Literal['data', 'orders', 'subscriptions'], diff --git a/planet/models.py b/planet/models.py index 8524d47e..28ec645c 100644 --- a/planet/models.py +++ b/planet/models.py @@ -14,16 +14,9 @@ # limitations under the License. """Manage data for requests and responses.""" import logging -import mimetypes -from pathlib import Path -import random -import re -import string -from typing import AsyncGenerator, Callable, List, Optional -from urllib.parse import urlparse +from typing import AsyncGenerator, Callable, List import httpx -from tqdm.asyncio import tqdm from .exceptions import PagingError diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index bc876a1c..e1c9a80c 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -17,7 +17,9 @@ import logging from http import HTTPStatus import math -from unittest.mock import patch +from pathlib import Path +import re +from unittest.mock import MagicMock, patch import httpx import respx @@ -185,7 +187,7 @@ async def control(): @pytest.mark.anyio -async def test_session_contextmanager(): +async def test_Session_contextmanager(): async with http.Session(): pass @@ -193,7 +195,7 @@ async def test_session_contextmanager(): @respx.mock @pytest.mark.anyio @pytest.mark.parametrize('data', (None, {'boo': 'baa'})) -async def test_session_request_success(data): +async def test_Session_request_success(data): async with http.Session() as ps: resp_json = {'foo': 'bar'} @@ -217,19 +219,7 @@ async def test_session_request_success(data): @respx.mock @pytest.mark.anyio -async def test_session_stream(): - async with http.Session() as ps: - mock_resp = httpx.Response(HTTPStatus.OK, text='bubba') - respx.get(TEST_URL).return_value = mock_resp - - async with ps.stream(method='GET', url=TEST_URL) as resp: - chunks = [c async for c in resp.aiter_bytes()] - assert chunks[0] == b'bubba' - - -@respx.mock -@pytest.mark.anyio -async def test_session_request_retry(): +async def test_Session_request_retry(): """Test the retry in the Session.request method""" async with http.Session() as ps: route = respx.get(TEST_URL) @@ -248,7 +238,7 @@ async def test_session_request_retry(): @respx.mock @pytest.mark.anyio -async def test_session__retry(): +async def test_Session__retry(): """A unit test for the _retry function""" async def test_func(): @@ -268,7 +258,7 @@ async def test_func(): assert args == [(1, 64), (2, 64), (3, 64), (4, 64), (5, 64)] -def test__calculate_wait(): +def test_Session__calculate_wait(): max_retry_backoff = 20 wait_times = [ http.Session._calculate_wait(i + 1, max_retry_backoff) @@ -284,6 +274,59 @@ def test__calculate_wait(): assert math.floor(wait) == expected +@respx.mock +@pytest.mark.anyio +async def test_Session_write(): + """Ensure that write retries and that it passes the correct info to + _write_response.""" + resp_success = httpx.Response(HTTPStatus.OK, json={}) + + route = respx.get(TEST_URL) + route.side_effect = [ + httpx.Response(HTTPStatus.TOO_MANY_REQUESTS, json={}), resp_success + ] + + with patch('planet.http.Session._write_response') as mock_write_response: + + async with http.Session() as ps: + # let's not actually introduce a wait into the tests + ps.max_retry_backoff = 0 + + await ps.write(TEST_URL, filename='test', directory='testdir') + + req_call_args = mock_write_response.call_args[0] + assert req_call_args[0].status_code == resp_success.status_code + assert req_call_args[1] == Path('testdir') / 'test' # filename + + +@pytest.mark.anyio +async def test_Session__write_response(tmpdir, open_test_img): + """Ensure content is downloaded and written to the correct file""" + + async def _aiter_bytes(): + data = open_test_img.read() + v = memoryview(data) + + chunksize = 100 + for i in range(math.ceil(len(v) / (chunksize))): + yield v[i * chunksize:min((i + 1) * chunksize, len(v))] + + r = MagicMock(name='response') + r.aiter_bytes = _aiter_bytes + r.num_bytes_downloaded = 0 + r.headers['Content-Length'] = 527 + + dl_path = Path(tmpdir) / 'test.tif' + async with http.Session() as ps: + await ps._write_response(r, + dl_path, + overwrite=False, + progress_bar=False) + + assert dl_path.is_file() + assert dl_path.stat().st_size == 527 + + @respx.mock @pytest.mark.anyio async def test_authsession_request(): @@ -304,3 +347,68 @@ def test_authsession__raise_for_status(mock_response): with pytest.raises(exceptions.APIError): http.AuthSession._raise_for_status( mock_response(HTTPStatus.UNAUTHORIZED, json={})) + + +def test__get_filename_from_response(): + r = MagicMock(name='response') + r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' + r.headers = { + 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': 'bytes', + 'content-type': 'image/tiff', + 'content-length': '57350256', + 'content-disposition': 'attachment; filename="open_california.tif"' + } + assert http._get_filename_from_response(r) == 'open_california.tif' + + +NO_NAME_HEADERS = { + 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': 'bytes', + 'content-type': 'image/tiff', + 'content-length': '57350256' +} +OPEN_CALIFORNIA_HEADERS = { + 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': 'bytes', + 'content-type': 'image/tiff', + 'content-length': '57350256', + 'content-disposition': 'attachment; filename="open_california.tif"' +} + + +@pytest.mark.parametrize('headers,expected', + [(OPEN_CALIFORNIA_HEADERS, 'open_california.tif'), + (NO_NAME_HEADERS, None), + ({}, None)]) # yapf: disable +def test__get_filename_from_headers(headers, expected): + assert http._get_filename_from_headers(headers) == expected + + +@pytest.mark.parametrize( + 'url,expected', + [ + ('https://planet.com/', None), + ('https://planet.com/path/to/', None), + ('https://planet.com/path/to/example.tif', 'example.tif'), + ('https://planet.com/path/to/example.tif?foo=f6f1&bar=baz', + 'example.tif'), + ('https://planet.com/path/to/example.tif?foo=f6f1#quux', + 'example.tif'), + ]) +def test__get_filename_from_url(url, expected): + assert http._get_filename_from_url(url) == expected + + +@pytest.mark.parametrize( + 'content_type,check', + [ + (None, + lambda x: re.match(r'^planet-[a-z0-9]{8}$', x, re.I) is not None), + ('image/tiff', lambda x: x.endswith(('.tif', '.tiff'))), + ]) +def test__get_random_filename(content_type, check): + assert check(http._get_random_filename(content_type)) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 99cb11f2..a542865a 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -13,11 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math from unittest.mock import MagicMock -import os -from pathlib import Path -import re import pytest @@ -27,127 +23,6 @@ LOGGER = logging.getLogger(__name__) -def test_StreamingBody_name_filename(): - r = MagicMock(name='response') - r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - 'content-disposition': 'attachment; filename="open_california.tif"' - } - body = models.StreamingBody(r) - assert body.name == 'open_california.tif' - - -def test_StreamingBody_name_url(): - r = MagicMock(name='response') - r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - } - body = models.StreamingBody(r) - - assert body.name == 'example.tif' - - -def test_StreamingBody_name_content(): - r = MagicMock(name='response') - r.url = 'https://planet.com/path/to/noname/' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - } - body = models.StreamingBody(r) - - assert body.name.startswith('planet-') - assert (body.name.endswith('.tiff') or body.name.endswith('.tif')) - - -NO_NAME_HEADERS = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256' -} -OPEN_CALIFORNIA_HEADERS = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - 'content-disposition': 'attachment; filename="open_california.tif"' -} - - -@pytest.mark.parametrize('headers,expected', - [(OPEN_CALIFORNIA_HEADERS, 'open_california.tif'), - (NO_NAME_HEADERS, None), - ({}, None)]) # yapf: disable -def test__get_filename_from_headers(headers, expected): - assert models._get_filename_from_headers(headers) == expected - - -@pytest.mark.parametrize( - 'url,expected', - [ - ('https://planet.com/', None), - ('https://planet.com/path/to/', None), - ('https://planet.com/path/to/example.tif', 'example.tif'), - ('https://planet.com/path/to/example.tif?foo=f6f1&bar=baz', - 'example.tif'), - ('https://planet.com/path/to/example.tif?foo=f6f1#quux', - 'example.tif'), - ]) -def test__get_filename_from_url(url, expected): - assert models._get_filename_from_url(url) == expected - - -@pytest.mark.parametrize( - 'content_type,check', - [ - (None, - lambda x: re.match(r'^planet-[a-z0-9]{8}$', x, re.I) is not None), - ('image/tiff', lambda x: x.endswith(('.tif', '.tiff'))), - ]) -def test__get_random_filename(content_type, check): - assert check(models._get_random_filename(content_type)) - - -@pytest.mark.anyio -async def test_StreamingBody_write_img(tmpdir, open_test_img): - - async def _aiter_bytes(): - data = open_test_img.read() - v = memoryview(data) - - chunksize = 100 - for i in range(math.ceil(len(v) / (chunksize))): - yield v[i * chunksize:min((i + 1) * chunksize, len(v))] - - r = MagicMock(name='response') - r.aiter_bytes = _aiter_bytes - r.num_bytes_downloaded = 0 - r.headers['Content-Length'] = 527 - body = models.StreamingBody(r) - - filename = Path(tmpdir) / 'test.tif' - await body.write(filename, progress_bar=False) - - assert os.path.isfile(filename) - assert os.stat(filename).st_size == 527 - - @pytest.mark.anyio async def test_Paged_iterator(): resp = MagicMock(name='response') From d7611916f161b33d0289587e7dcfcdfa29403192 Mon Sep 17 00:00:00 2001 From: Jennifer Reiber Kyle Date: Fri, 21 Jul 2023 15:07:05 -0700 Subject: [PATCH 5/7] patch with asyncmock in python 3.7 --- tests/unit/test_http.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index e1c9a80c..64b7a435 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -274,6 +274,13 @@ def test_Session__calculate_wait(): assert math.floor(wait) == expected +# just need this for python 3.7 +class AsyncMock(MagicMock): + + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + @respx.mock @pytest.mark.anyio async def test_Session_write(): @@ -286,7 +293,8 @@ async def test_Session_write(): httpx.Response(HTTPStatus.TOO_MANY_REQUESTS, json={}), resp_success ] - with patch('planet.http.Session._write_response') as mock_write_response: + with patch('planet.http.Session._write_response', + new=AsyncMock()) as mock_write_response: async with http.Session() as ps: # let's not actually introduce a wait into the tests From 11136e2c4204266dd1378910d6d102d002c3c9d6 Mon Sep 17 00:00:00 2001 From: Jennifer Reiber Kyle Date: Fri, 21 Jul 2023 15:16:08 -0700 Subject: [PATCH 6/7] commented code clean up --- planet/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planet/http.py b/planet/http.py index 426e6995..2a75b758 100644 --- a/planet/http.py +++ b/planet/http.py @@ -47,7 +47,7 @@ httpx.ReadTimeout, httpx.RemoteProtocolError, exceptions.BadGateway, - exceptions.TooManyRequests, # exceptions.ServerError + exceptions.TooManyRequests, ] MAX_RETRIES = 5 MAX_RETRY_BACKOFF = 64 # seconds From 71d3085df52d12279125f2f6f8f79a0dc935bba2 Mon Sep 17 00:00:00 2001 From: Jennifer Reiber Kyle Date: Fri, 28 Jul 2023 18:13:14 -0700 Subject: [PATCH 7/7] move filename determination to models.Response and file and progress bar management to clients --- planet/clients/data.py | 36 ++++++-- planet/clients/orders.py | 34 +++++-- planet/http.py | 124 +++---------------------- planet/models.py | 60 ++++++++++++- tests/integration/test_data_api.py | 17 ++-- tests/integration/test_data_cli.py | 18 ++-- tests/integration/test_orders_api.py | 20 +++-- tests/unit/test_http.py | 129 +++++---------------------- tests/unit/test_models.py | 61 +++++++++++++ 9 files changed, 249 insertions(+), 250 deletions(-) diff --git a/planet/clients/data.py b/planet/clients/data.py index 8608c36d..3f7e6e17 100644 --- a/planet/clients/data.py +++ b/planet/clients/data.py @@ -20,6 +20,8 @@ from typing import Any, AsyncIterator, Callable, Dict, List, Optional import uuid +from tqdm.asyncio import tqdm + from ..data_filter import empty_filter from .. import exceptions from ..constants import PLANET_BASE_URL @@ -586,8 +588,8 @@ async def download_asset(self, Raises: planet.exceptions.APIError: On API error. - planet.exceptions.ClientError: If asset is not active, asset - description is not valid, or retry limit is exceeded. + planet.exceptions.ClientError: If asset is not active or asset + description is not valid. """ try: location = asset['location'] @@ -595,11 +597,31 @@ async def download_asset(self, raise exceptions.ClientError( 'asset missing ["location"] entry. Is asset active?') - return await self._session.write(location, - filename=filename, - directory=directory, - overwrite=overwrite, - progress_bar=progress_bar) + response = await self._session.request(method='GET', url=location) + filename = filename or response.filename + if not filename: + raise exceptions.ClientError( + f'Could not determine filename at {location}') + + dl_path = Path(directory, filename) + dl_path.parent.mkdir(exist_ok=True, parents=True) + LOGGER.info(f'Downloading {dl_path}') + + try: + mode = 'wb' if overwrite else 'xb' + with open(dl_path, mode) as fp: + with tqdm(total=response.length, + unit_scale=True, + unit_divisor=1024 * 1024, + unit='B', + desc=str(filename), + disable=not progress_bar) as progress: + update = progress.update if progress_bar else LOGGER.debug + await self._session.write(location, fp, update) + except FileExistsError: + LOGGER.info(f'File {dl_path} exists, not overwriting') + + return dl_path @staticmethod def validate_checksum(asset: dict, filename: Path): diff --git a/planet/clients/orders.py b/planet/clients/orders.py index 82bec167..3ff0c059 100644 --- a/planet/clients/orders.py +++ b/planet/clients/orders.py @@ -15,13 +15,15 @@ """Functionality for interacting with the orders api""" import asyncio import logging +from pathlib import Path import time from typing import AsyncIterator, Callable, List, Optional import uuid import json import hashlib -from pathlib import Path +from tqdm.asyncio import tqdm + from .. import exceptions from ..constants import PLANET_BASE_URL from ..http import Session @@ -255,11 +257,31 @@ async def download_asset(self, limit is exceeded. """ - return await self._session.write(location, - filename=filename, - directory=directory, - overwrite=overwrite, - progress_bar=progress_bar) + response = await self._session.request(method='GET', url=location) + filename = filename or response.filename + length = response.length + if not filename: + raise exceptions.ClientError( + f'Could not determine filename at {location}') + + dl_path = Path(directory, filename) + dl_path.parent.mkdir(exist_ok=True, parents=True) + LOGGER.info(f'Downloading {dl_path}') + + try: + mode = 'wb' if overwrite else 'xb' + with open(dl_path, mode) as fp: + with tqdm(total=length, + unit_scale=True, + unit_divisor=1024 * 1024, + unit='B', + desc=str(filename), + disable=not progress_bar) as progress: + await self._session.write(location, fp, progress.update) + except FileExistsError: + LOGGER.info(f'File {dl_path} exists, not overwriting') + + return dl_path async def download_order(self, order_id: str, diff --git a/planet/http.py b/planet/http.py index 2a75b758..36b5ad71 100644 --- a/planet/http.py +++ b/planet/http.py @@ -18,17 +18,11 @@ from collections import Counter from http import HTTPStatus import logging -import mimetypes -from pathlib import Path import random -import re -import string import time -from typing import Optional -from urllib.parse import urlparse +from typing import Callable, Optional import httpx -from tqdm.asyncio import tqdm from typing_extensions import Literal from .auth import Auth, AuthType @@ -332,6 +326,7 @@ async def _retry(self, func, *a, **kw): LOGGER.info(f'Retrying: sleeping {wait_time}s') await asyncio.sleep(wait_time) else: + LOGGER.info('Retrying: failed') raise e self.outcomes.update(['Successful']) @@ -399,80 +394,32 @@ async def _send(self, request, stream=False) -> httpx.Response: return http_resp - async def write(self, - url: str, - filename: Optional[str] = None, - directory: Path = Path('.'), - overwrite: bool = False, - progress_bar: bool = False) -> Path: + async def write(self, url: str, fp, callback: Optional[Callable] = None): """Write data to local file with limiting and retries. Parameters: - url: Remote location url - filename: Custom name to assign to downloaded file. - directory: Base directory for file download. This directory will be - created if it does not already exist. - overwrite: Overwrite any existing files. - progress_bar: Show progress bar during download. - - Returns: - Path to downloaded file. + url: Remote location url. + fp: Open write file pointer. + callback: Function that handles write progress updates. Raises: planet.exceptions.APIException: On API error. - planet.exceptions.ClientError: When retry limit is exceeded. """ - async def _write(): - async with self._client.stream('GET', url) as response: - - dl_path = Path( - directory, - filename or _get_filename_from_response(response)) - dl_path.parent.mkdir(exist_ok=True, parents=True) - - await self._write_response(response, - dl_path, - overwrite=overwrite, - progress_bar=progress_bar) - - return dl_path - async def _limited_write(): async with self._limiter: - dl_path = await _write() - return dl_path - - return await self._retry(_limited_write) - - async def _write_response(self, - response, - filename, - overwrite, - progress_bar): - total = int(response.headers["Content-Length"]) - - try: - mode = 'wb' if overwrite else 'xb' - with open(filename, mode) as fp: - - with tqdm(total=total, - unit_scale=True, - unit_divisor=1024 * 1024, - unit='B', - desc=str(filename), - disable=not progress_bar) as progress: + async with self._client.stream('GET', url) as response: previous = response.num_bytes_downloaded async for chunk in response.aiter_bytes(): fp.write(chunk) - new = response.num_bytes_downloaded - previous - progress.update(new - previous) - previous = new - progress.update() - except FileExistsError: - LOGGER.info(f'File {filename} exists, not overwriting') + current = response.num_bytes_downloaded + if callback is not None: + callback(current - previous) + previous = current + + await self._retry(_limited_write) def client(self, name: Literal['data', 'orders', 'subscriptions'], @@ -498,51 +445,6 @@ def client(self, raise exceptions.ClientError("No such client.") -def _get_filename_from_response(response) -> str: - """The name of the response resource. - - The default is to use the content-disposition header value from the - response. If not found, falls back to resolving the name from the url - or generating a random name with the type from the response. - """ - name = (_get_filename_from_headers(response.headers) - or _get_filename_from_url(response.url) - or _get_random_filename(response.headers.get('content-type'))) - return name - - -def _get_filename_from_headers(headers): - """Get a filename from the Content-Disposition header, if available. - - :param headers dict: a ``dict`` of response headers - :returns: a filename (i.e. ``basename``) - :rtype: str or None - """ - cd = headers.get('content-disposition', '') - match = re.search('filename="?([^"]+)"?', cd) - return match.group(1) if match else None - - -def _get_filename_from_url(url: str) -> Optional[str]: - """Get a filename from a url. - - Getting a name for Landsat imagery uses this function. - """ - path = urlparse(url).path - name = path[path.rfind('/') + 1:] - return name or None - - -def _get_random_filename(content_type=None) -> str: - """Get a pseudo-random, Planet-looking filename. - """ - extension = mimetypes.guess_extension(content_type or '') or '' - characters = string.ascii_letters + '0123456789' - letters = ''.join(random.sample(characters, 8)) - name = 'planet-{}{}'.format(letters, extension) - return name - - class AuthSession(BaseSession): """Synchronous connection to the Planet Auth service.""" diff --git a/planet/models.py b/planet/models.py index 28ec645c..cced0b64 100644 --- a/planet/models.py +++ b/planet/models.py @@ -14,7 +14,9 @@ # limitations under the License. """Manage data for requests and responses.""" import logging -from typing import AsyncGenerator, Callable, List +import re +from typing import AsyncGenerator, Callable, List, Optional +from urllib.parse import urlparse import httpx @@ -42,11 +44,67 @@ def status_code(self) -> int: """HTTP status code""" return self._http_response.status_code + @property + def filename(self) -> Optional[str]: + """Name of the download file. + + The filename is None if the response does not represent a download. + """ + filename = None + + if self.length is not None: # is a download file + filename = _get_filename_from_response(self._http_response) + + return filename + + @property + def length(self) -> Optional[int]: + """Length of the download file. + + The length is None if the response does not represent a download. + """ + LOGGER.warning('here') + try: + length = int(self._http_response.headers["Content-Length"]) + except KeyError: + length = None + LOGGER.warning(length) + return length + def json(self) -> dict: """Response json""" return self._http_response.json() +def _get_filename_from_response(response) -> Optional[str]: + """The name of the response resource. + + The default is to use the content-disposition header value from the + response. If not found, falls back to resolving the name from the url + or generating a random name with the type from the response. + """ + name = (_get_filename_from_headers(response.headers) + or _get_filename_from_url(str(response.url))) + return name + + +def _get_filename_from_headers(headers: httpx.Headers) -> Optional[str]: + """Get a filename from the Content-Disposition header, if available.""" + cd = headers.get('content-disposition', '') + match = re.search('filename="?([^"]+)"?', cd) + return match.group(1) if match else None + + +def _get_filename_from_url(url: str) -> Optional[str]: + """Get a filename from the url. + + Getting a name for Landsat imagery uses this function. + """ + path = urlparse(url).path + name = path[path.rfind('/') + 1:] + return name or None + + class Paged: """Asynchronous iterator over results in a paged resource. diff --git a/tests/integration/test_data_api.py b/tests/integration/test_data_api.py index 7a57f539..2c57c567 100644 --- a/tests/integration/test_data_api.py +++ b/tests/integration/test_data_api.py @@ -837,11 +837,15 @@ async def _stream_img(): # populate request parameter to avoid respx cloning, which throws # an error caused by respx and not this code # https://github.com/lundberg/respx/issues/130 - mock_resp = httpx.Response(HTTPStatus.OK, - stream=_stream_img(), - headers=img_headers, - request='donotcloneme') - respx.get(dl_url).return_value = mock_resp + respx.get(dl_url).side_effect = [ + httpx.Response(HTTPStatus.OK, + headers=img_headers, + request='donotcloneme'), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=img_headers, + request='donotcloneme') + ] basic_udm2_asset = { "_links": { @@ -863,7 +867,8 @@ async def _stream_img(): path = await cl.download_asset(basic_udm2_asset, directory=tmpdir, - overwrite=overwrite) + overwrite=overwrite, + progress_bar=False) assert path.name == 'img.tif' assert path.is_file() diff --git a/tests/integration/test_data_cli.py b/tests/integration/test_data_cli.py index ddb04b28..90834b25 100644 --- a/tests/integration/test_data_cli.py +++ b/tests/integration/test_data_cli.py @@ -966,12 +966,18 @@ async def _stream_img(): for i in range(math.ceil(len(v) / (chunksize))): yield v[i * chunksize:min((i + 1) * chunksize, len(v))] - # Mock the response for download_asset - mock_resp_download = httpx.Response(HTTPStatus.OK, - stream=_stream_img(), - headers=img_headers, - request='donotcloneme') - respx.get(dl_url).return_value = mock_resp_download + # populate request parameter to avoid respx cloning, which throws + # an error caused by respx and not this code + # https://github.com/lundberg/respx/issues/130 + respx.get(dl_url).side_effect = [ + httpx.Response(HTTPStatus.OK, + headers=img_headers, + request='donotcloneme'), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=img_headers, + request='donotcloneme') + ] runner = CliRunner() with runner.isolated_filesystem() as folder: diff --git a/tests/integration/test_orders_api.py b/tests/integration/test_orders_api.py index 3470436f..0c164c53 100644 --- a/tests/integration/test_orders_api.py +++ b/tests/integration/test_orders_api.py @@ -19,7 +19,6 @@ from http import HTTPStatus import logging import math -import os from pathlib import Path from unittest.mock import call, create_autospec @@ -587,17 +586,22 @@ async def _stream_img(): # populate request parameter to avoid respx cloning, which throws # an error caused by respx and not this code # https://github.com/lundberg/respx/issues/130 - mock_resp = httpx.Response(HTTPStatus.OK, - stream=_stream_img(), - headers=img_headers, - request='donotcloneme') - respx.get(dl_url).return_value = mock_resp + respx.get(dl_url).side_effect = [ + httpx.Response(HTTPStatus.OK, + headers=img_headers, + request='donotcloneme'), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=img_headers, + request='donotcloneme') + ] cl = OrdersClient(session, base_url=TEST_URL) filename = await cl.download_asset(dl_url, directory=str(tmpdir)) - assert Path(filename).name == 'img.tif' - assert os.path.isfile(filename) + assert filename.name == 'img.tif' + assert filename.is_file() + assert len(filename.read_bytes()) == 527 @respx.mock diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index 64b7a435..b464d3aa 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -18,7 +18,6 @@ from http import HTTPStatus import math from pathlib import Path -import re from unittest.mock import MagicMock, patch import httpx @@ -28,7 +27,7 @@ from planet import exceptions, http -TEST_URL = 'mock://fantastic.com' +TEST_URL = 'http://www.MockNotRealURL.com' LOGGER = logging.getLogger(__name__) @@ -283,35 +282,10 @@ async def __call__(self, *args, **kwargs): @respx.mock @pytest.mark.anyio -async def test_Session_write(): - """Ensure that write retries and that it passes the correct info to - _write_response.""" - resp_success = httpx.Response(HTTPStatus.OK, json={}) +async def test_Session_write(open_test_img, tmpdir): + """Ensure that write retries and that it writes to the file pointer""" - route = respx.get(TEST_URL) - route.side_effect = [ - httpx.Response(HTTPStatus.TOO_MANY_REQUESTS, json={}), resp_success - ] - - with patch('planet.http.Session._write_response', - new=AsyncMock()) as mock_write_response: - - async with http.Session() as ps: - # let's not actually introduce a wait into the tests - ps.max_retry_backoff = 0 - - await ps.write(TEST_URL, filename='test', directory='testdir') - - req_call_args = mock_write_response.call_args[0] - assert req_call_args[0].status_code == resp_success.status_code - assert req_call_args[1] == Path('testdir') / 'test' # filename - - -@pytest.mark.anyio -async def test_Session__write_response(tmpdir, open_test_img): - """Ensure content is downloaded and written to the correct file""" - - async def _aiter_bytes(): + async def _stream_img(): data = open_test_img.read() v = memoryview(data) @@ -319,17 +293,27 @@ async def _aiter_bytes(): for i in range(math.ceil(len(v) / (chunksize))): yield v[i * chunksize:min((i + 1) * chunksize, len(v))] - r = MagicMock(name='response') - r.aiter_bytes = _aiter_bytes - r.num_bytes_downloaded = 0 - r.headers['Content-Length'] = 527 + img_headers = { + 'Content-Type': 'image/tiff', + 'Content-Length': '527', + 'Content-Disposition': 'attachment; filename="img.tif"' + } - dl_path = Path(tmpdir) / 'test.tif' - async with http.Session() as ps: - await ps._write_response(r, - dl_path, - overwrite=False, - progress_bar=False) + route = respx.get(TEST_URL) + route.side_effect = [ + httpx.Response(HTTPStatus.TOO_MANY_REQUESTS, json={}), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=httpx.Headers(img_headers)) + ] + + dl_path = Path(tmpdir, 'test.tif') + with open(dl_path, 'wb') as fp: + async with http.Session() as ps: + # let's not actually introduce a wait into the tests + ps.max_retry_backoff = 0 + + await ps.write(TEST_URL, fp=fp) assert dl_path.is_file() assert dl_path.stat().st_size == 527 @@ -355,68 +339,3 @@ def test_authsession__raise_for_status(mock_response): with pytest.raises(exceptions.APIError): http.AuthSession._raise_for_status( mock_response(HTTPStatus.UNAUTHORIZED, json={})) - - -def test__get_filename_from_response(): - r = MagicMock(name='response') - r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - 'content-disposition': 'attachment; filename="open_california.tif"' - } - assert http._get_filename_from_response(r) == 'open_california.tif' - - -NO_NAME_HEADERS = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256' -} -OPEN_CALIFORNIA_HEADERS = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - 'content-disposition': 'attachment; filename="open_california.tif"' -} - - -@pytest.mark.parametrize('headers,expected', - [(OPEN_CALIFORNIA_HEADERS, 'open_california.tif'), - (NO_NAME_HEADERS, None), - ({}, None)]) # yapf: disable -def test__get_filename_from_headers(headers, expected): - assert http._get_filename_from_headers(headers) == expected - - -@pytest.mark.parametrize( - 'url,expected', - [ - ('https://planet.com/', None), - ('https://planet.com/path/to/', None), - ('https://planet.com/path/to/example.tif', 'example.tif'), - ('https://planet.com/path/to/example.tif?foo=f6f1&bar=baz', - 'example.tif'), - ('https://planet.com/path/to/example.tif?foo=f6f1#quux', - 'example.tif'), - ]) -def test__get_filename_from_url(url, expected): - assert http._get_filename_from_url(url) == expected - - -@pytest.mark.parametrize( - 'content_type,check', - [ - (None, - lambda x: re.match(r'^planet-[a-z0-9]{8}$', x, re.I) is not None), - ('image/tiff', lambda x: x.endswith(('.tif', '.tiff'))), - ]) -def test__get_random_filename(content_type, check): - assert check(http._get_random_filename(content_type)) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index a542865a..bff92805 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -15,6 +15,7 @@ import logging from unittest.mock import MagicMock +import httpx import pytest from planet import models @@ -22,6 +23,66 @@ LOGGER = logging.getLogger(__name__) +NO_NAME_HEADERS = { + 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': 'bytes', + 'Content-Type': 'image/tiff', + 'Content-Length': '57350256' +} +OPEN_CALIFORNIA_HEADERS = httpx.Headers({ + 'date': + 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': + 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': + 'bytes', + 'Content-Type': + 'image/tiff', + 'Content-Length': + '57350256', + 'Content-Disposition': + 'attachment; filename="open_california.tif"' +}) + + +def test_Response_filename(): + r = MagicMock(name='response') + r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' + r.headers = OPEN_CALIFORNIA_HEADERS + + assert models.Response(r).filename == 'open_california.tif' + + +def test__get_filename_from_response(): + r = MagicMock(name='response') + r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' + r.headers = OPEN_CALIFORNIA_HEADERS + assert models._get_filename_from_response(r) == 'open_california.tif' + + +@pytest.mark.parametrize('headers,expected', + [(OPEN_CALIFORNIA_HEADERS, 'open_california.tif'), + (NO_NAME_HEADERS, None), + ({}, None)]) # yapf: disable +def test__get_filename_from_headers(headers, expected): + assert models._get_filename_from_headers(headers) == expected + + +@pytest.mark.parametrize( + 'url,expected', + [ + ('https://planet.com/', None), + ('https://planet.com/path/to/', None), + ('https://planet.com/path/to/example.tif', 'example.tif'), + ('https://planet.com/path/to/example.tif?foo=f6f1&bar=baz', + 'example.tif'), + ('https://planet.com/path/to/example.tif?foo=f6f1#quux', + 'example.tif'), + ]) +def test__get_filename_from_url(url, expected): + assert models._get_filename_from_url(url) == expected + @pytest.mark.anyio async def test_Paged_iterator():