diff --git a/planet/clients/data.py b/planet/clients/data.py index 7b8133a8..3f7e6e17 100644 --- a/planet/clients/data.py +++ b/planet/clients/data.py @@ -20,11 +20,13 @@ from typing import Any, AsyncIterator, Callable, Dict, List, Optional import uuid +from tqdm.asyncio import tqdm + from ..data_filter import empty_filter from .. import exceptions from ..constants import PLANET_BASE_URL from ..http import Session -from ..models import Paged, StreamingBody +from ..models import Paged from ..specs import validate_data_item_type BASE_URL = f'{PLANET_BASE_URL}/data/v1/' @@ -595,13 +597,30 @@ async def download_asset(self, raise exceptions.ClientError( 'asset missing ["location"] entry. Is asset active?') - async with self._session.stream(method='GET', url=location) as resp: - body = StreamingBody(resp) - dl_path = Path(directory, filename or body.name) - dl_path.parent.mkdir(exist_ok=True, parents=True) - await body.write(dl_path, - overwrite=overwrite, - progress_bar=progress_bar) + response = await self._session.request(method='GET', url=location) + filename = filename or response.filename + if not filename: + raise exceptions.ClientError( + f'Could not determine filename at {location}') + + dl_path = Path(directory, filename) + dl_path.parent.mkdir(exist_ok=True, parents=True) + LOGGER.info(f'Downloading {dl_path}') + + try: + mode = 'wb' if overwrite else 'xb' + with open(dl_path, mode) as fp: + with tqdm(total=response.length, + unit_scale=True, + unit_divisor=1024 * 1024, + unit='B', + desc=str(filename), + disable=not progress_bar) as progress: + update = progress.update if progress_bar else LOGGER.debug + await self._session.write(location, fp, update) + except FileExistsError: + LOGGER.info(f'File {dl_path} exists, not overwriting') + return dl_path @staticmethod diff --git a/planet/clients/orders.py b/planet/clients/orders.py index 92d74645..3ff0c059 100644 --- a/planet/clients/orders.py +++ b/planet/clients/orders.py @@ -15,17 +15,19 @@ """Functionality for interacting with the orders api""" import asyncio import logging +from pathlib import Path import time from typing import AsyncIterator, Callable, List, Optional import uuid import json import hashlib -from pathlib import Path +from tqdm.asyncio import tqdm + from .. import exceptions from ..constants import PLANET_BASE_URL from ..http import Session -from ..models import Paged, StreamingBody +from ..models import Paged BASE_URL = f'{PLANET_BASE_URL}/compute/ops' STATS_PATH = '/stats/orders/v2' @@ -251,14 +253,34 @@ async def download_asset(self, Raises: planet.exceptions.APIError: On API error. + planet.exceptions.ClientError: If location is not valid or retry + limit is exceeded. + """ - async with self._session.stream(method='GET', url=location) as resp: - body = StreamingBody(resp) - dl_path = Path(directory, filename or body.name) - dl_path.parent.mkdir(exist_ok=True, parents=True) - await body.write(dl_path, - overwrite=overwrite, - progress_bar=progress_bar) + response = await self._session.request(method='GET', url=location) + filename = filename or response.filename + length = response.length + if not filename: + raise exceptions.ClientError( + f'Could not determine filename at {location}') + + dl_path = Path(directory, filename) + dl_path.parent.mkdir(exist_ok=True, parents=True) + LOGGER.info(f'Downloading {dl_path}') + + try: + mode = 'wb' if overwrite else 'xb' + with open(dl_path, mode) as fp: + with tqdm(total=length, + unit_scale=True, + unit_divisor=1024 * 1024, + unit='B', + desc=str(filename), + disable=not progress_bar) as progress: + await self._session.write(location, fp, progress.update) + except FileExistsError: + LOGGER.info(f'File {dl_path} exists, not overwriting') + return dl_path async def download_order(self, diff --git a/planet/http.py b/planet/http.py index 6ec8fb75..36b5ad71 100644 --- a/planet/http.py +++ b/planet/http.py @@ -16,12 +16,11 @@ from __future__ import annotations # https://stackoverflow.com/a/33533514 import asyncio from collections import Counter -from contextlib import asynccontextmanager from http import HTTPStatus import logging import random import time -from typing import AsyncGenerator, Optional +from typing import Callable, Optional import httpx from typing_extensions import Literal @@ -42,7 +41,7 @@ httpx.ReadTimeout, httpx.RemoteProtocolError, exceptions.BadGateway, - exceptions.TooManyRequests + exceptions.TooManyRequests, ] MAX_RETRIES = 5 MAX_RETRY_BACKOFF = 64 # seconds @@ -327,6 +326,7 @@ async def _retry(self, func, *a, **kw): LOGGER.info(f'Retrying: sleeping {wait_time}s') await asyncio.sleep(wait_time) else: + LOGGER.info('Retrying: failed') raise e self.outcomes.update(['Successful']) @@ -394,26 +394,32 @@ async def _send(self, request, stream=False) -> httpx.Response: return http_resp - @asynccontextmanager - async def stream( - self, method: str, - url: str) -> AsyncGenerator[models.StreamingResponse, None]: - """Submit a request and get the response as a stream context manager. + async def write(self, url: str, fp, callback: Optional[Callable] = None): + """Write data to local file with limiting and retries. Parameters: - method: HTTP request method. - url: Location of the API endpoint. + url: Remote location url. + fp: Open write file pointer. + callback: Function that handles write progress updates. + + Raises: + planet.exceptions.APIException: On API error. - Returns: - Context manager providing the streaming response. """ - request = self._client.build_request(method=method, url=url) - http_response = await self._retry(self._send, request, stream=True) - response = models.StreamingResponse(http_response) - try: - yield response - finally: - await response.aclose() + + async def _limited_write(): + async with self._limiter: + async with self._client.stream('GET', url) as response: + previous = response.num_bytes_downloaded + + async for chunk in response.aiter_bytes(): + fp.write(chunk) + current = response.num_bytes_downloaded + if callback is not None: + callback(current - previous) + previous = current + + await self._retry(_limited_write) def client(self, name: Literal['data', 'orders', 'subscriptions'], diff --git a/planet/models.py b/planet/models.py index e2e7761c..cced0b64 100644 --- a/planet/models.py +++ b/planet/models.py @@ -14,16 +14,11 @@ # limitations under the License. """Manage data for requests and responses.""" import logging -import mimetypes -from pathlib import Path -import random import re -import string from typing import AsyncGenerator, Callable, List, Optional from urllib.parse import urlparse import httpx -from tqdm.asyncio import tqdm from .exceptions import PagingError @@ -49,134 +44,59 @@ def status_code(self) -> int: """HTTP status code""" return self._http_response.status_code - def json(self) -> dict: - """Response json""" - return self._http_response.json() - - -class StreamingResponse(Response): - - @property - def headers(self) -> httpx.Headers: - return self._http_response.headers - @property - def url(self) -> str: - return str(self._http_response.url) + def filename(self) -> Optional[str]: + """Name of the download file. - @property - def num_bytes_downloaded(self) -> int: - return self._http_response.num_bytes_downloaded + The filename is None if the response does not represent a download. + """ + filename = None - async def aiter_bytes(self): - async for c in self._http_response.aiter_bytes(): - yield c + if self.length is not None: # is a download file + filename = _get_filename_from_response(self._http_response) - async def aclose(self): - await self._http_response.aclose() + return filename + @property + def length(self) -> Optional[int]: + """Length of the download file. -class StreamingBody: - """A representation of a streaming resource from the API.""" + The length is None if the response does not represent a download. + """ + LOGGER.warning('here') + try: + length = int(self._http_response.headers["Content-Length"]) + except KeyError: + length = None + LOGGER.warning(length) + return length - def __init__(self, response: StreamingResponse): - """Initialize the object. + def json(self) -> dict: + """Response json""" + return self._http_response.json() - Parameters: - response: Response that was received from the server. - """ - self._response = response - @property - def name(self) -> str: - """The name of this resource. +def _get_filename_from_response(response) -> Optional[str]: + """The name of the response resource. The default is to use the content-disposition header value from the response. If not found, falls back to resolving the name from the url or generating a random name with the type from the response. """ - name = (_get_filename_from_headers(self._response.headers) - or _get_filename_from_url(self._response.url) - or _get_random_filename( - self._response.headers.get('content-type'))) - return name - - @property - def size(self) -> int: - """The size of the body.""" - return int(self._response.headers['Content-Length']) - - async def write(self, - filename: Path, - overwrite: bool = True, - progress_bar: bool = True): - """Write the body to a file. - Parameters: - filename: Name to assign to downloaded file. - overwrite: Overwrite any existing files. - progress_bar: Show progress bar during download. - """ - - class _LOG: - - def __init__(self, total, unit, filename, disable): - self.total = total - self.unit = unit - self.disable = disable - self.previous = 0 - self.filename = str(filename) - - if not self.disable: - LOGGER.debug(f'writing to {self.filename}') - - def update(self, new): - if new - self.previous > self.unit and not self.disable: - # LOGGER.debug(f'{new-self.previous}') - perc = int(100 * new / self.total) - LOGGER.debug(f'{self.filename}: ' - f'wrote {perc}% of {self.total}') - self.previous = new + name = (_get_filename_from_headers(response.headers) + or _get_filename_from_url(str(response.url))) + return name - unit = 1024 * 1024 - mode = 'wb' if overwrite else 'xb' - try: - with open(filename, mode) as fp: - _log = _LOG(self.size, - 16 * unit, - filename, - disable=progress_bar) - with tqdm(total=self.size, - unit_scale=True, - unit_divisor=unit, - unit='B', - desc=str(filename), - disable=not progress_bar) as progress: - previous = self._response.num_bytes_downloaded - async for chunk in self._response.aiter_bytes(): - fp.write(chunk) - new = self._response.num_bytes_downloaded - _log.update(new) - progress.update(new - previous) - previous = new - except FileExistsError: - LOGGER.info(f'File {filename} exists, not overwriting') - - -def _get_filename_from_headers(headers): - """Get a filename from the Content-Disposition header, if available. - - :param headers dict: a ``dict`` of response headers - :returns: a filename (i.e. ``basename``) - :rtype: str or None - """ +def _get_filename_from_headers(headers: httpx.Headers) -> Optional[str]: + """Get a filename from the Content-Disposition header, if available.""" cd = headers.get('content-disposition', '') match = re.search('filename="?([^"]+)"?', cd) return match.group(1) if match else None def _get_filename_from_url(url: str) -> Optional[str]: - """Get a filename from a url. + """Get a filename from the url. Getting a name for Landsat imagery uses this function. """ @@ -185,19 +105,6 @@ def _get_filename_from_url(url: str) -> Optional[str]: return name or None -def _get_random_filename(content_type=None): - """Get a pseudo-random, Planet-looking filename. - - :returns: a filename (i.e. ``basename``) - :rtype: str - """ - extension = mimetypes.guess_extension(content_type or '') or '' - characters = string.ascii_letters + '0123456789' - letters = ''.join(random.sample(characters, 8)) - name = 'planet-{}{}'.format(letters, extension) - return name - - class Paged: """Asynchronous iterator over results in a paged resource. diff --git a/tests/integration/test_data_api.py b/tests/integration/test_data_api.py index 7a57f539..2c57c567 100644 --- a/tests/integration/test_data_api.py +++ b/tests/integration/test_data_api.py @@ -837,11 +837,15 @@ async def _stream_img(): # populate request parameter to avoid respx cloning, which throws # an error caused by respx and not this code # https://github.com/lundberg/respx/issues/130 - mock_resp = httpx.Response(HTTPStatus.OK, - stream=_stream_img(), - headers=img_headers, - request='donotcloneme') - respx.get(dl_url).return_value = mock_resp + respx.get(dl_url).side_effect = [ + httpx.Response(HTTPStatus.OK, + headers=img_headers, + request='donotcloneme'), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=img_headers, + request='donotcloneme') + ] basic_udm2_asset = { "_links": { @@ -863,7 +867,8 @@ async def _stream_img(): path = await cl.download_asset(basic_udm2_asset, directory=tmpdir, - overwrite=overwrite) + overwrite=overwrite, + progress_bar=False) assert path.name == 'img.tif' assert path.is_file() diff --git a/tests/integration/test_data_cli.py b/tests/integration/test_data_cli.py index ddb04b28..90834b25 100644 --- a/tests/integration/test_data_cli.py +++ b/tests/integration/test_data_cli.py @@ -966,12 +966,18 @@ async def _stream_img(): for i in range(math.ceil(len(v) / (chunksize))): yield v[i * chunksize:min((i + 1) * chunksize, len(v))] - # Mock the response for download_asset - mock_resp_download = httpx.Response(HTTPStatus.OK, - stream=_stream_img(), - headers=img_headers, - request='donotcloneme') - respx.get(dl_url).return_value = mock_resp_download + # populate request parameter to avoid respx cloning, which throws + # an error caused by respx and not this code + # https://github.com/lundberg/respx/issues/130 + respx.get(dl_url).side_effect = [ + httpx.Response(HTTPStatus.OK, + headers=img_headers, + request='donotcloneme'), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=img_headers, + request='donotcloneme') + ] runner = CliRunner() with runner.isolated_filesystem() as folder: diff --git a/tests/integration/test_orders_api.py b/tests/integration/test_orders_api.py index 3470436f..0c164c53 100644 --- a/tests/integration/test_orders_api.py +++ b/tests/integration/test_orders_api.py @@ -19,7 +19,6 @@ from http import HTTPStatus import logging import math -import os from pathlib import Path from unittest.mock import call, create_autospec @@ -587,17 +586,22 @@ async def _stream_img(): # populate request parameter to avoid respx cloning, which throws # an error caused by respx and not this code # https://github.com/lundberg/respx/issues/130 - mock_resp = httpx.Response(HTTPStatus.OK, - stream=_stream_img(), - headers=img_headers, - request='donotcloneme') - respx.get(dl_url).return_value = mock_resp + respx.get(dl_url).side_effect = [ + httpx.Response(HTTPStatus.OK, + headers=img_headers, + request='donotcloneme'), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=img_headers, + request='donotcloneme') + ] cl = OrdersClient(session, base_url=TEST_URL) filename = await cl.download_asset(dl_url, directory=str(tmpdir)) - assert Path(filename).name == 'img.tif' - assert os.path.isfile(filename) + assert filename.name == 'img.tif' + assert filename.is_file() + assert len(filename.read_bytes()) == 527 @respx.mock diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index bc876a1c..b464d3aa 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -17,7 +17,8 @@ import logging from http import HTTPStatus import math -from unittest.mock import patch +from pathlib import Path +from unittest.mock import MagicMock, patch import httpx import respx @@ -26,7 +27,7 @@ from planet import exceptions, http -TEST_URL = 'mock://fantastic.com' +TEST_URL = 'http://www.MockNotRealURL.com' LOGGER = logging.getLogger(__name__) @@ -185,7 +186,7 @@ async def control(): @pytest.mark.anyio -async def test_session_contextmanager(): +async def test_Session_contextmanager(): async with http.Session(): pass @@ -193,7 +194,7 @@ async def test_session_contextmanager(): @respx.mock @pytest.mark.anyio @pytest.mark.parametrize('data', (None, {'boo': 'baa'})) -async def test_session_request_success(data): +async def test_Session_request_success(data): async with http.Session() as ps: resp_json = {'foo': 'bar'} @@ -217,19 +218,7 @@ async def test_session_request_success(data): @respx.mock @pytest.mark.anyio -async def test_session_stream(): - async with http.Session() as ps: - mock_resp = httpx.Response(HTTPStatus.OK, text='bubba') - respx.get(TEST_URL).return_value = mock_resp - - async with ps.stream(method='GET', url=TEST_URL) as resp: - chunks = [c async for c in resp.aiter_bytes()] - assert chunks[0] == b'bubba' - - -@respx.mock -@pytest.mark.anyio -async def test_session_request_retry(): +async def test_Session_request_retry(): """Test the retry in the Session.request method""" async with http.Session() as ps: route = respx.get(TEST_URL) @@ -248,7 +237,7 @@ async def test_session_request_retry(): @respx.mock @pytest.mark.anyio -async def test_session__retry(): +async def test_Session__retry(): """A unit test for the _retry function""" async def test_func(): @@ -268,7 +257,7 @@ async def test_func(): assert args == [(1, 64), (2, 64), (3, 64), (4, 64), (5, 64)] -def test__calculate_wait(): +def test_Session__calculate_wait(): max_retry_backoff = 20 wait_times = [ http.Session._calculate_wait(i + 1, max_retry_backoff) @@ -284,6 +273,52 @@ def test__calculate_wait(): assert math.floor(wait) == expected +# just need this for python 3.7 +class AsyncMock(MagicMock): + + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + +@respx.mock +@pytest.mark.anyio +async def test_Session_write(open_test_img, tmpdir): + """Ensure that write retries and that it writes to the file pointer""" + + async def _stream_img(): + data = open_test_img.read() + v = memoryview(data) + + chunksize = 100 + for i in range(math.ceil(len(v) / (chunksize))): + yield v[i * chunksize:min((i + 1) * chunksize, len(v))] + + img_headers = { + 'Content-Type': 'image/tiff', + 'Content-Length': '527', + 'Content-Disposition': 'attachment; filename="img.tif"' + } + + route = respx.get(TEST_URL) + route.side_effect = [ + httpx.Response(HTTPStatus.TOO_MANY_REQUESTS, json={}), + httpx.Response(HTTPStatus.OK, + stream=_stream_img(), + headers=httpx.Headers(img_headers)) + ] + + dl_path = Path(tmpdir, 'test.tif') + with open(dl_path, 'wb') as fp: + async with http.Session() as ps: + # let's not actually introduce a wait into the tests + ps.max_retry_backoff = 0 + + await ps.write(TEST_URL, fp=fp) + + assert dl_path.is_file() + assert dl_path.stat().st_size == 527 + + @respx.mock @pytest.mark.anyio async def test_authsession_request(): diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 99cb11f2..bff92805 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math from unittest.mock import MagicMock -import os -from pathlib import Path -import re +import httpx import pytest from planet import models @@ -26,68 +23,42 @@ LOGGER = logging.getLogger(__name__) - -def test_StreamingBody_name_filename(): - r = MagicMock(name='response') - r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - 'content-disposition': 'attachment; filename="open_california.tif"' - } - body = models.StreamingBody(r) - assert body.name == 'open_california.tif' - - -def test_StreamingBody_name_url(): +NO_NAME_HEADERS = { + 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': 'bytes', + 'Content-Type': 'image/tiff', + 'Content-Length': '57350256' +} +OPEN_CALIFORNIA_HEADERS = httpx.Headers({ + 'date': + 'Thu, 14 Feb 2019 16:13:26 GMT', + 'last-modified': + 'Wed, 22 Nov 2017 17:22:31 GMT', + 'accept-ranges': + 'bytes', + 'Content-Type': + 'image/tiff', + 'Content-Length': + '57350256', + 'Content-Disposition': + 'attachment; filename="open_california.tif"' +}) + + +def test_Response_filename(): r = MagicMock(name='response') r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - } - body = models.StreamingBody(r) + r.headers = OPEN_CALIFORNIA_HEADERS - assert body.name == 'example.tif' + assert models.Response(r).filename == 'open_california.tif' -def test_StreamingBody_name_content(): +def test__get_filename_from_response(): r = MagicMock(name='response') - r.url = 'https://planet.com/path/to/noname/' - r.headers = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - } - body = models.StreamingBody(r) - - assert body.name.startswith('planet-') - assert (body.name.endswith('.tiff') or body.name.endswith('.tif')) - - -NO_NAME_HEADERS = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256' -} -OPEN_CALIFORNIA_HEADERS = { - 'date': 'Thu, 14 Feb 2019 16:13:26 GMT', - 'last-modified': 'Wed, 22 Nov 2017 17:22:31 GMT', - 'accept-ranges': 'bytes', - 'content-type': 'image/tiff', - 'content-length': '57350256', - 'content-disposition': 'attachment; filename="open_california.tif"' -} + r.url = 'https://planet.com/path/to/example.tif?foo=f6f1' + r.headers = OPEN_CALIFORNIA_HEADERS + assert models._get_filename_from_response(r) == 'open_california.tif' @pytest.mark.parametrize('headers,expected', @@ -113,41 +84,6 @@ def test__get_filename_from_url(url, expected): assert models._get_filename_from_url(url) == expected -@pytest.mark.parametrize( - 'content_type,check', - [ - (None, - lambda x: re.match(r'^planet-[a-z0-9]{8}$', x, re.I) is not None), - ('image/tiff', lambda x: x.endswith(('.tif', '.tiff'))), - ]) -def test__get_random_filename(content_type, check): - assert check(models._get_random_filename(content_type)) - - -@pytest.mark.anyio -async def test_StreamingBody_write_img(tmpdir, open_test_img): - - async def _aiter_bytes(): - data = open_test_img.read() - v = memoryview(data) - - chunksize = 100 - for i in range(math.ceil(len(v) / (chunksize))): - yield v[i * chunksize:min((i + 1) * chunksize, len(v))] - - r = MagicMock(name='response') - r.aiter_bytes = _aiter_bytes - r.num_bytes_downloaded = 0 - r.headers['Content-Length'] = 527 - body = models.StreamingBody(r) - - filename = Path(tmpdir) / 'test.tif' - await body.write(filename, progress_bar=False) - - assert os.path.isfile(filename) - assert os.stat(filename).st_size == 527 - - @pytest.mark.anyio async def test_Paged_iterator(): resp = MagicMock(name='response')