From 046f675aaa39fccca51d47188995f5f57eae58d0 Mon Sep 17 00:00:00 2001 From: Rahul Mahrsee <86819420+mahrsee1997@users.noreply.github.com> Date: Thu, 13 Apr 2023 09:05:54 +0000 Subject: [PATCH] Extended CDS client to separate out fetch & download steps. (#314) * Extended CDS client to separate out fetch & download steps. * fix fetcher_test.py test cases. * Refactored the code to use SplitRequestMixin pattern. * fix pytype error. * relocated the transfer code to enhance its reusability. * Bumped weather-dl version to v0.1.17. --- weather_dl/download_pipeline/clients.py | 91 +++++++++++++------- weather_dl/download_pipeline/fetcher_test.py | 40 +++++---- weather_dl/download_pipeline/util.py | 14 +++ weather_dl/setup.py | 2 +- 4 files changed, 98 insertions(+), 49 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index a8553d1f..cb551880 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -21,18 +21,18 @@ import json import logging import os -import subprocess +import time import typing as t import warnings from urllib.parse import urljoin -import cdsapi +from cdsapi import api as cds_api import urllib3 from ecmwfapi import api from .config import Config, optimize_selection_partition from .manifest import Manifest, Stage -from .util import retry_with_exponential_backoff +from .util import download_with_aria2, retry_with_exponential_backoff warnings.simplefilter( "ignore", category=urllib3.connectionpool.InsecureRequestWarning) @@ -73,6 +73,34 @@ def license_url(self): pass +class SplitCDSRequest(cds_api.Client): + """Extended CDS class that separates fetch and download stage.""" + @retry_with_exponential_backoff + def _download(self, url, path: str, size: int) -> None: + self.info("Downloading %s to %s (%s)", url, path, cds_api.bytes_to_string(size)) + start = time.time() + + download_with_aria2(url, path) + + elapsed = time.time() - start + if elapsed: + self.info("Download rate %s/s", cds_api.bytes_to_string(size / elapsed)) + + def fetch(self, request: t.Dict, dataset: str) -> t.Dict: + result = self.retrieve(dataset, request) + return {'href': result.location, 'size': result.content_length} + + def download(self, result: cds_api.Result, target: t.Optional[str] = None) -> None: + if target: + if os.path.exists(target): + # Empty the target file, if it already exists, otherwise the + # transfer below might be fooled into thinking we're resuming + # an interrupted download. + open(target, "w").close() + + self._download(result["href"], target, result["size"]) + + class CdsClient(Client): """A client to access weather data from the Cloud Data Store (CDS). @@ -95,27 +123,33 @@ class CdsClient(Client): """Name patterns of datasets that are hosted internally on CDS servers.""" cds_hosted_datasets = {'reanalysis-era'} - def __init__(self, config: Config, level: int = logging.INFO) -> None: - super().__init__(config, level) - self.c = cdsapi.Client( - url=config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')), - key=config.kwargs.get('api_key', os.environ.get('CDSAPI_KEY')), + def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: + c = CDSClientExtended( + url=self.config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')), + key=self.config.kwargs.get('api_key', os.environ.get('CDSAPI_KEY')), debug_callback=self.logger.debug, info_callback=self.logger.info, warning_callback=self.logger.warning, error_callback=self.logger.error, ) - - def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: selection_ = optimize_selection_partition(selection) - manifest.set_stage(Stage.RETRIEVE) - precise_retrieve_start_time = ( - datetime.datetime.utcnow() - .replace(tzinfo=datetime.timezone.utc) - .isoformat(timespec='seconds') - ) - manifest.prev_stage_precise_start_time = precise_retrieve_start_time - self.c.retrieve(dataset, selection_, output) + with StdoutLogger(self.logger, level=logging.DEBUG): + manifest.set_stage(Stage.FETCH) + precise_fetch_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec='seconds') + ) + manifest.prev_stage_precise_start_time = precise_fetch_start_time + result = c.fetch(selection_, dataset) + manifest.set_stage(Stage.DOWNLOAD) + precise_download_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec='seconds') + ) + manifest.prev_stage_precise_start_time = precise_download_start_time + c.download(result, target=output) @property def license_url(self): @@ -177,16 +211,9 @@ def _download(self, url, path: str, size: int) -> None: ) self.log("From %s" % (url,)) - dir_path, file_name = os.path.split(path) - try: - subprocess.run( - ['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'], - check=True, - capture_output=True) - except subprocess.CalledProcessError as e: - self.log(f'Failed download from ECMWF server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}') + download_with_aria2(url, path) - def fetch(self, request: t.Dict) -> t.Dict: + def fetch(self, request: t.Dict, dataset: str) -> t.Dict: status = None self.connection.submit("%s/%s/requests" % (self.url, self.service), request) @@ -224,13 +251,19 @@ def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: class SplitRequestMixin: c = None - def fetch(self, req: t.Dict) -> t.Dict: - return self.c.fetch(req) + def fetch(self, req: t.Dict, dataset: t.Optional[str] = None) -> t.Dict: + return self.c.fetch(req, dataset) def download(self, res: t.Dict, target: str) -> None: self.c.download(res, target) +class CDSClientExtended(SplitRequestMixin): + """Extended CDS Client class that separates fetch and download stage.""" + def __init__(self, *args, **kwargs): + self.c = SplitCDSRequest(*args, **kwargs) + + class MARSECMWFServiceExtended(api.ECMWFService, SplitRequestMixin): """Extended MARS ECMFService class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): diff --git a/weather_dl/download_pipeline/fetcher_test.py b/weather_dl/download_pipeline/fetcher_test.py index a9372d43..db4fc717 100644 --- a/weather_dl/download_pipeline/fetcher_test.py +++ b/weather_dl/download_pipeline/fetcher_test.py @@ -17,7 +17,7 @@ import os import tempfile import unittest -from unittest.mock import patch, ANY +from unittest.mock import patch from .config import Config from .fetcher import Fetcher @@ -30,8 +30,9 @@ class FetchDataTest(unittest.TestCase): def setUp(self) -> None: self.dummy_manifest = MockManifest(Location('dummy-manifest')) - @patch('cdsapi.Client.retrieve') - def test_fetch_data(self, mock_retrieve): + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.download') + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') + def test_fetch_data(self, mock_fetch, mock_download): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { @@ -53,13 +54,14 @@ def test_fetch_data(self, mock_retrieve): self.assertTrue(os.path.exists(os.path.join(tmpdir, 'download-01-12.nc'))) - mock_retrieve.assert_called_with( - 'reanalysis-era5-pressure-levels', + mock_fetch.assert_called_with( config.selection, - ANY) + 'reanalysis-era5-pressure-levels', + ) - @patch('cdsapi.Client.retrieve') - def test_fetch_data__manifest__returns_success(self, mock_retrieve): + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.download') + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') + def test_fetch_data__manifest__returns_success(self, mock_fetch, mock_download): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { @@ -88,8 +90,8 @@ def test_fetch_data__manifest__returns_success(self, mock_retrieve): username='unknown', ), list(self.dummy_manifest.records.values())[0]) - @patch('cdsapi.Client.retrieve') - def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') + def test_fetch_data__manifest__records_retrieve_failure(self, mock_fetch): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { @@ -107,7 +109,7 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): }) error = IOError("We don't have enough permissions to download this.") - mock_retrieve.side_effect = error + mock_fetch.side_effect = error with self.assertRaises(IOError) as e: fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) @@ -118,7 +120,7 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): self.assertDictContainsSubset(dict( selection=json.dumps(config.selection), location=os.path.join(tmpdir, 'download-01-12.nc'), - stage='retrieve', + stage='fetch', status='failure', username='unknown', ), actual) @@ -126,8 +128,8 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): self.assertIn(error.args[0], actual['error']) self.assertIn(error.args[0], e.exception.args[0]) - @patch('cdsapi.Client.retrieve') - def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve): + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') + def test_fetch_data__manifest__records_gcs_failure(self, mock_fetch): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { @@ -145,7 +147,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve): }) error = IOError("Can't open gcs file.") - mock_retrieve.side_effect = error + mock_fetch.side_effect = error with self.assertRaises(IOError) as e: fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) @@ -156,7 +158,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve): self.assertDictContainsSubset(dict( selection=json.dumps(config.selection), location=os.path.join(tmpdir, 'download-01-12.nc'), - stage='retrieve', + stage='fetch', status='failure', username='unknown', ), actual) @@ -165,8 +167,8 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve): self.assertIn(error.args[0], e.exception.args[0]) @patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO()) - @patch('cdsapi.Client.retrieve') - def test_fetch_data__skips_existing_download(self, mock_retrieve, mock_gcs_file): + @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') + def test_fetch_data__skips_existing_download(self, mock_fetch, mock_gcs_file): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { @@ -191,4 +193,4 @@ def test_fetch_data__skips_existing_download(self, mock_retrieve, mock_gcs_file) fetcher.fetch_data(config) self.assertFalse(mock_gcs_file.called) - self.assertFalse(mock_retrieve.called) + self.assertFalse(mock_fetch.called) diff --git a/weather_dl/download_pipeline/util.py b/weather_dl/download_pipeline/util.py index d2ae3053..faf12d0b 100644 --- a/weather_dl/download_pipeline/util.py +++ b/weather_dl/download_pipeline/util.py @@ -184,3 +184,17 @@ def get_wait_interval(num_retries: int = 0) -> float: def generate_md5_hash(input: str) -> str: """Generates md5 hash for the input string.""" return hashlib.md5(input.encode('utf-8')).hexdigest() + + +def download_with_aria2(url: str, path: str) -> None: + """Downloads a file from the given URL using the `aria2c` command-line utility, + with options set to improve download speed and reliability.""" + dir_path, file_name = os.path.split(path) + try: + subprocess.run( + ['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'], + check=True, + capture_output=True) + except subprocess.CalledProcessError as e: + logger.error(f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}') + raise diff --git a/weather_dl/setup.py b/weather_dl/setup.py index e440f091..f232c93c 100644 --- a/weather_dl/setup.py +++ b/weather_dl/setup.py @@ -48,7 +48,7 @@ setup( name='download_pipeline', packages=find_packages(), - version='0.1.16', + version='0.1.17', author='Anthromets', author_email='anthromets-ecmwf@google.com', url='https://weather-tools.readthedocs.io/en/latest/weather_dl/',