Skip to content

Commit

Permalink
Extended CDS client to separate out fetch & download steps. (#314)
Browse files Browse the repository at this point in the history
* Extended CDS client to separate out fetch & download steps.

* fix fetcher_test.py test cases.

* Refactored the code to use SplitRequestMixin pattern.

* fix pytype error.

* relocated the transfer code to enhance its reusability.

* Bumped weather-dl version to v0.1.17.
  • Loading branch information
mahrsee1997 authored Apr 13, 2023
1 parent a546782 commit 046f675
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 49 deletions.
91 changes: 62 additions & 29 deletions weather_dl/download_pipeline/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
import json
import logging
import os
import subprocess
import time
import typing as t
import warnings
from urllib.parse import urljoin

import cdsapi
from cdsapi import api as cds_api
import urllib3
from ecmwfapi import api

from .config import Config, optimize_selection_partition
from .manifest import Manifest, Stage
from .util import retry_with_exponential_backoff
from .util import download_with_aria2, retry_with_exponential_backoff

warnings.simplefilter(
"ignore", category=urllib3.connectionpool.InsecureRequestWarning)
Expand Down Expand Up @@ -73,6 +73,34 @@ def license_url(self):
pass


class SplitCDSRequest(cds_api.Client):
"""Extended CDS class that separates fetch and download stage."""
@retry_with_exponential_backoff
def _download(self, url, path: str, size: int) -> None:
self.info("Downloading %s to %s (%s)", url, path, cds_api.bytes_to_string(size))
start = time.time()

download_with_aria2(url, path)

elapsed = time.time() - start
if elapsed:
self.info("Download rate %s/s", cds_api.bytes_to_string(size / elapsed))

def fetch(self, request: t.Dict, dataset: str) -> t.Dict:
result = self.retrieve(dataset, request)
return {'href': result.location, 'size': result.content_length}

def download(self, result: cds_api.Result, target: t.Optional[str] = None) -> None:
if target:
if os.path.exists(target):
# Empty the target file, if it already exists, otherwise the
# transfer below might be fooled into thinking we're resuming
# an interrupted download.
open(target, "w").close()

self._download(result["href"], target, result["size"])


class CdsClient(Client):
"""A client to access weather data from the Cloud Data Store (CDS).
Expand All @@ -95,27 +123,33 @@ class CdsClient(Client):
"""Name patterns of datasets that are hosted internally on CDS servers."""
cds_hosted_datasets = {'reanalysis-era'}

def __init__(self, config: Config, level: int = logging.INFO) -> None:
super().__init__(config, level)
self.c = cdsapi.Client(
url=config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')),
key=config.kwargs.get('api_key', os.environ.get('CDSAPI_KEY')),
def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None:
c = CDSClientExtended(
url=self.config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')),
key=self.config.kwargs.get('api_key', os.environ.get('CDSAPI_KEY')),
debug_callback=self.logger.debug,
info_callback=self.logger.info,
warning_callback=self.logger.warning,
error_callback=self.logger.error,
)

def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None:
selection_ = optimize_selection_partition(selection)
manifest.set_stage(Stage.RETRIEVE)
precise_retrieve_start_time = (
datetime.datetime.utcnow()
.replace(tzinfo=datetime.timezone.utc)
.isoformat(timespec='seconds')
)
manifest.prev_stage_precise_start_time = precise_retrieve_start_time
self.c.retrieve(dataset, selection_, output)
with StdoutLogger(self.logger, level=logging.DEBUG):
manifest.set_stage(Stage.FETCH)
precise_fetch_start_time = (
datetime.datetime.utcnow()
.replace(tzinfo=datetime.timezone.utc)
.isoformat(timespec='seconds')
)
manifest.prev_stage_precise_start_time = precise_fetch_start_time
result = c.fetch(selection_, dataset)
manifest.set_stage(Stage.DOWNLOAD)
precise_download_start_time = (
datetime.datetime.utcnow()
.replace(tzinfo=datetime.timezone.utc)
.isoformat(timespec='seconds')
)
manifest.prev_stage_precise_start_time = precise_download_start_time
c.download(result, target=output)

@property
def license_url(self):
Expand Down Expand Up @@ -177,16 +211,9 @@ def _download(self, url, path: str, size: int) -> None:
)
self.log("From %s" % (url,))

dir_path, file_name = os.path.split(path)
try:
subprocess.run(
['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'],
check=True,
capture_output=True)
except subprocess.CalledProcessError as e:
self.log(f'Failed download from ECMWF server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}')
download_with_aria2(url, path)

def fetch(self, request: t.Dict) -> t.Dict:
def fetch(self, request: t.Dict, dataset: str) -> t.Dict:
status = None

self.connection.submit("%s/%s/requests" % (self.url, self.service), request)
Expand Down Expand Up @@ -224,13 +251,19 @@ def download(self, result: t.Dict, target: t.Optional[str] = None) -> None:
class SplitRequestMixin:
c = None

def fetch(self, req: t.Dict) -> t.Dict:
return self.c.fetch(req)
def fetch(self, req: t.Dict, dataset: t.Optional[str] = None) -> t.Dict:
return self.c.fetch(req, dataset)

def download(self, res: t.Dict, target: str) -> None:
self.c.download(res, target)


class CDSClientExtended(SplitRequestMixin):
"""Extended CDS Client class that separates fetch and download stage."""
def __init__(self, *args, **kwargs):
self.c = SplitCDSRequest(*args, **kwargs)


class MARSECMWFServiceExtended(api.ECMWFService, SplitRequestMixin):
"""Extended MARS ECMFService class that separates fetch and download stage."""
def __init__(self, *args, **kwargs):
Expand Down
40 changes: 21 additions & 19 deletions weather_dl/download_pipeline/fetcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import tempfile
import unittest
from unittest.mock import patch, ANY
from unittest.mock import patch

from .config import Config
from .fetcher import Fetcher
Expand All @@ -30,8 +30,9 @@ class FetchDataTest(unittest.TestCase):
def setUp(self) -> None:
self.dummy_manifest = MockManifest(Location('dummy-manifest'))

@patch('cdsapi.Client.retrieve')
def test_fetch_data(self, mock_retrieve):
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.download')
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch')
def test_fetch_data(self, mock_fetch, mock_download):
with tempfile.TemporaryDirectory() as tmpdir:
config = Config.from_dict({
'parameters': {
Expand All @@ -53,13 +54,14 @@ def test_fetch_data(self, mock_retrieve):

self.assertTrue(os.path.exists(os.path.join(tmpdir, 'download-01-12.nc')))

mock_retrieve.assert_called_with(
'reanalysis-era5-pressure-levels',
mock_fetch.assert_called_with(
config.selection,
ANY)
'reanalysis-era5-pressure-levels',
)

@patch('cdsapi.Client.retrieve')
def test_fetch_data__manifest__returns_success(self, mock_retrieve):
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.download')
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch')
def test_fetch_data__manifest__returns_success(self, mock_fetch, mock_download):
with tempfile.TemporaryDirectory() as tmpdir:
config = Config.from_dict({
'parameters': {
Expand Down Expand Up @@ -88,8 +90,8 @@ def test_fetch_data__manifest__returns_success(self, mock_retrieve):
username='unknown',
), list(self.dummy_manifest.records.values())[0])

@patch('cdsapi.Client.retrieve')
def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve):
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch')
def test_fetch_data__manifest__records_retrieve_failure(self, mock_fetch):
with tempfile.TemporaryDirectory() as tmpdir:
config = Config.from_dict({
'parameters': {
Expand All @@ -107,7 +109,7 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve):
})

error = IOError("We don't have enough permissions to download this.")
mock_retrieve.side_effect = error
mock_fetch.side_effect = error

with self.assertRaises(IOError) as e:
fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore())
Expand All @@ -118,16 +120,16 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve):
self.assertDictContainsSubset(dict(
selection=json.dumps(config.selection),
location=os.path.join(tmpdir, 'download-01-12.nc'),
stage='retrieve',
stage='fetch',
status='failure',
username='unknown',
), actual)

self.assertIn(error.args[0], actual['error'])
self.assertIn(error.args[0], e.exception.args[0])

@patch('cdsapi.Client.retrieve')
def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve):
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch')
def test_fetch_data__manifest__records_gcs_failure(self, mock_fetch):
with tempfile.TemporaryDirectory() as tmpdir:
config = Config.from_dict({
'parameters': {
Expand All @@ -145,7 +147,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve):
})

error = IOError("Can't open gcs file.")
mock_retrieve.side_effect = error
mock_fetch.side_effect = error

with self.assertRaises(IOError) as e:
fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore())
Expand All @@ -156,7 +158,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve):
self.assertDictContainsSubset(dict(
selection=json.dumps(config.selection),
location=os.path.join(tmpdir, 'download-01-12.nc'),
stage='retrieve',
stage='fetch',
status='failure',
username='unknown',
), actual)
Expand All @@ -165,8 +167,8 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve):
self.assertIn(error.args[0], e.exception.args[0])

@patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO())
@patch('cdsapi.Client.retrieve')
def test_fetch_data__skips_existing_download(self, mock_retrieve, mock_gcs_file):
@patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch')
def test_fetch_data__skips_existing_download(self, mock_fetch, mock_gcs_file):
with tempfile.TemporaryDirectory() as tmpdir:
config = Config.from_dict({
'parameters': {
Expand All @@ -191,4 +193,4 @@ def test_fetch_data__skips_existing_download(self, mock_retrieve, mock_gcs_file)
fetcher.fetch_data(config)

self.assertFalse(mock_gcs_file.called)
self.assertFalse(mock_retrieve.called)
self.assertFalse(mock_fetch.called)
14 changes: 14 additions & 0 deletions weather_dl/download_pipeline/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,17 @@ def get_wait_interval(num_retries: int = 0) -> float:
def generate_md5_hash(input: str) -> str:
"""Generates md5 hash for the input string."""
return hashlib.md5(input.encode('utf-8')).hexdigest()


def download_with_aria2(url: str, path: str) -> None:
"""Downloads a file from the given URL using the `aria2c` command-line utility,
with options set to improve download speed and reliability."""
dir_path, file_name = os.path.split(path)
try:
subprocess.run(
['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'],
check=True,
capture_output=True)
except subprocess.CalledProcessError as e:
logger.error(f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}')
raise
2 changes: 1 addition & 1 deletion weather_dl/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
setup(
name='download_pipeline',
packages=find_packages(),
version='0.1.16',
version='0.1.17',
author='Anthromets',
author_email='[email protected]',
url='https://weather-tools.readthedocs.io/en/latest/weather_dl/',
Expand Down

0 comments on commit 046f675

Please sign in to comment.