diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 0325c936..b691aadb 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -56,12 +56,15 @@ jobs:
channels: conda-forge
environment-file: ci${{ matrix.python-version}}.yml
activate-environment: weather-tools
+ miniforge-variant: Mambaforge
+ miniforge-version: latest
+ use-mamba: true
- name: Check MetView's installation
shell: bash -l {0}
run: python -m metview selfcheck
- name: Run unit tests
shell: bash -l {0}
- run: pytest --memray
+ run: pytest --memray --ignore=weather_dl_v2 # Ignoring dl-v2 as it only supports py3.10
lint:
runs-on: ubuntu-latest
strategy:
@@ -84,13 +87,15 @@ jobs:
echo "::set-output name=dir::$(pip cache dir)"
- name: Install linter
run: |
- pip install ruff
+ pip install ruff==0.1.2
- name: Lint project
run: ruff check .
type-check:
runs-on: ubuntu-latest
strategy:
fail-fast: false
+ matrix:
+ python-version: ["3.8"]
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
@@ -98,28 +103,27 @@ jobs:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/head/main'}}
- uses: actions/checkout@v2
- - name: Set up Python 3.8
- uses: actions/setup-python@v2
+ - name: conda cache
+ uses: actions/cache@v2
+ env:
+ # Increase this value to reset cache if etc/example-environment.yml has not changed
+ CACHE_NUMBER: 0
with:
- python-version: "3.8"
- - name: Setup conda
- uses: s-weigand/setup-conda@v1
+ path: ~/conda_pkgs_dir
+ key:
+ ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ matrix.python-version }}-${{ hashFiles('ci3.8.yml') }}
+ - name: Setup conda environment
+ uses: conda-incubator/setup-miniconda@v2
with:
- update-conda: true
- python-version: "3.8"
- conda-channels: anaconda, conda-forge
- - name: Install ecCodes
- run: |
- conda install -y eccodes>=2.21.0 -c conda-forge
- conda install -y pyproj -c conda-forge
- conda install -y gdal -c conda-forge
- - name: Get pip cache dir
- id: pip-cache
- run: |
- python -m pip install --upgrade pip wheel
- echo "::set-output name=dir::$(pip cache dir)"
- - name: Install weather-tools
+ python-version: ${{ matrix.python-version }}
+ channels: conda-forge
+ environment-file: ci${{ matrix.python-version}}.yml
+ activate-environment: weather-tools
+ miniforge-variant: Mambaforge
+ miniforge-version: latest
+ use-mamba: true
+ - name: Install weather-tools[test]
run: |
- pip install -e .[test] --use-deprecated=legacy-resolver
+ conda run -n weather-tools pip install -e .[test] --use-deprecated=legacy-resolver
- name: Run type checker
- run: pytype
+ run: conda run -n weather-tools pytype
diff --git a/Dockerfile b/Dockerfile
index 057442f5..30a72f57 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -29,6 +29,7 @@ ARG weather_tools_git_rev=main
RUN git clone https://github.com/google/weather-tools.git /weather
WORKDIR /weather
RUN git checkout "${weather_tools_git_rev}"
+RUN rm -r /weather/weather_*/test_data/
RUN conda env create -f environment.yml --debug
# Activate the conda env and update the PATH
diff --git a/ci3.8.yml b/ci3.8.yml
index d6a1e0bd..211d36be 100644
--- a/ci3.8.yml
+++ b/ci3.8.yml
@@ -16,7 +16,7 @@ dependencies:
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.13.4
- - xarray-beam=0.3.1
+ - xarray-beam=0.6.2
- ecmwf-api-client=1.6.3
- fsspec=2022.11.0
- gcsfs=2022.11.0
@@ -30,9 +30,11 @@ dependencies:
- pip=22.3
- pygrib=2.1.4
- xarray==2023.1.0
- - ruff==0.0.260
+ - ruff==0.1.2
- google-cloud-sdk=410.0.0
- aria2=1.36.0
+ - zarr=2.15.0
- pip:
+ - cython==0.29.34
- earthengine-api==0.1.329
- .[test]
diff --git a/ci3.9.yml b/ci3.9.yml
index a43cec16..86f0968d 100644
--- a/ci3.9.yml
+++ b/ci3.9.yml
@@ -16,7 +16,7 @@ dependencies:
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.13.4
- - xarray-beam=0.3.1
+ - xarray-beam=0.6.2
- ecmwf-api-client=1.6.3
- fsspec=2022.11.0
- gcsfs=2022.11.0
@@ -32,7 +32,9 @@ dependencies:
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- xarray==2023.1.0
- - ruff==0.0.260
+ - ruff==0.1.2
+ - zarr=2.15.0
- pip:
+ - cython==0.29.34
- earthengine-api==0.1.329
- .[test]
diff --git a/environment.yml b/environment.yml
index eae35f9c..0b043980 100644
--- a/environment.yml
+++ b/environment.yml
@@ -4,7 +4,7 @@ channels:
dependencies:
- python=3.8.13
- apache-beam=2.40.0
- - xarray-beam=0.3.1
+ - xarray-beam=0.6.2
- xarray=2023.1.0
- fsspec=2022.11.0
- gcsfs=2022.11.0
@@ -25,7 +25,9 @@ dependencies:
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- pip=22.3
+ - zarr=2.15.0
- pip:
+ - cython==0.29.34
- earthengine-api==0.1.329
- firebase-admin==6.0.1
- .
diff --git a/pyproject.toml b/pyproject.toml
index 8e782e8d..2eaadb0d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,4 +42,4 @@ target-version = "py310"
[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
-max-complexity = 10
\ No newline at end of file
+max-complexity = 10
diff --git a/setup.py b/setup.py
index af798889..233e5193 100644
--- a/setup.py
+++ b/setup.py
@@ -57,8 +57,9 @@
"earthengine-api>=0.1.263",
"pyproj", # requires separate binary installation!
"gdal", # requires separate binary installation!
- "xarray-beam==0.3.1",
+ "xarray-beam==0.6.2",
"gcsfs==2022.11.0",
+ "zarr==2.15.0",
]
weather_sp_requirements = [
@@ -70,7 +71,7 @@
test_requirements = [
"pytype==2021.11.29",
- "ruff",
+ "ruff==0.1.2",
"pytest",
"pytest-subtests",
"netcdf4",
@@ -82,6 +83,7 @@
"memray",
"pytest-memray",
"h5py",
+ "pooch",
]
all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \
@@ -115,7 +117,7 @@
],
python_requires='>=3.8, <3.10',
- install_requires=['apache-beam[gcp]==2.40.0'],
+ install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'],
use_scm_version=True,
setup_requires=['setuptools_scm'],
scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'],
diff --git a/weather_dl/download_pipeline/util.py b/weather_dl/download_pipeline/util.py
index 1ee9e24e..dd09ba29 100644
--- a/weather_dl/download_pipeline/util.py
+++ b/weather_dl/download_pipeline/util.py
@@ -102,10 +102,10 @@ def to_json_serializable_type(value: t.Any) -> t.Any:
return None
elif np.issubdtype(type(value), np.floating):
return float(value)
- elif type(value) == np.ndarray:
+ elif isinstance(value, np.ndarray):
# Will return a scaler if array is of size 1, else will return a list.
return value.tolist()
- elif type(value) == datetime.datetime or type(value) == str or type(value) == np.datetime64:
+ elif isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64):
# Assume strings are ISO format timestamps...
try:
value = datetime.datetime.fromisoformat(value)
@@ -126,7 +126,7 @@ def to_json_serializable_type(value: t.Any) -> t.Any:
# We assume here that naive timestamps are in UTC timezone.
return value.replace(tzinfo=datetime.timezone.utc).isoformat()
- elif type(value) == np.timedelta64:
+ elif isinstance(value, np.timedelta64):
# Return time delta in seconds.
return float(value / np.timedelta64(1, 's'))
# This check must happen after processing np.timedelta64 and np.datetime64.
diff --git a/weather_dl_v2/README.md b/weather_dl_v2/README.md
new file mode 100644
index 00000000..ea7b7bb5
--- /dev/null
+++ b/weather_dl_v2/README.md
@@ -0,0 +1,12 @@
+## weather-dl-v2
+
+
+
+> **_NOTE:_** weather-dl-v2 only supports python 3.10
+
+### Sequence of steps:
+1) Refer to downloader_kubernetes/README.md
+2) Refer to license_deployment/README.md
+3) Refer to fastapi-server/README.md
+4) Refer to cli/README.md
+
diff --git a/weather_dl_v2/__init__.py b/weather_dl_v2/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/cli/CLI-Documentation.md b/weather_dl_v2/cli/CLI-Documentation.md
new file mode 100644
index 00000000..ea16bb6b
--- /dev/null
+++ b/weather_dl_v2/cli/CLI-Documentation.md
@@ -0,0 +1,306 @@
+# CLI Documentation
+The following doc provides cli commands and their various arguments and options.
+
+Base Command:
+```
+weather-dl-v2
+```
+
+## Ping
+Ping the FastAPI server and check if it’s live and reachable.
+
+ weather-dl-v2 ping
+
+##### Usage
+```
+weather-dl-v2 ping
+```
+
+
+
+
+## Download
+Manage download configs.
+
+
+### Add Downloads
+ weather-dl-v2 download add
+ Adds a new download config to specific licenses.
+
+
+
+##### Arguments
+> `FILE_PATH` : Path to config file.
+
+##### Options
+> `-l/--license` (Required): License ID to which this download has to be added to.
+> `-f/--force-download` : Force redownload of partitions that were previously downloaded.
+
+##### Usage
+```
+weather-dl-v2 download add /path/to/example.cfg –l L1 -l L2 [--force-download]
+```
+
+### List Downloads
+ weather-dl-v2 download list
+ List all the active downloads.
+
+
+The list can also be filtered out by client_names.
+Available filters:
+```
+Filter Key: client_name
+Values: cds, mars, ecpublic
+
+Filter Key: status
+Values: completed, failed, in-progress
+```
+
+##### Options
+> `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value
+
+##### Usage
+```
+weather-dl-v2 download list
+weather-dl-v2 download list --filter client_name=cds
+weather-dl-v2 download list --filter status=success
+weather-dl-v2 download list --filter status=failed
+weather-dl-v2 download list --filter status=in-progress
+weather-dl-v2 download list --filter client_name=cds --filter status=success
+```
+
+### Download Get
+ weather-dl-v2 download get
+ Get a particular download by config name.
+
+
+##### Arguments
+> `CONFIG_NAME` : Name of the download config.
+
+##### Usage
+```
+weather-dl-v2 download get example.cfg
+```
+
+### Download Show
+ weather-dl-v2 download show
+ Get contents of a particular config by config name.
+
+
+##### Arguments
+> `CONFIG_NAME` : Name of the download config.
+
+##### Usage
+```
+weather-dl-v2 download show example.cfg
+```
+
+### Download Remove
+ weather-dl-v2 download remove
+ Remove a download by config name.
+
+
+##### Arguments
+> `CONFIG_NAME` : Name of the download config.
+
+##### Usage
+```
+weather-dl-v2 download remove example.cfg
+```
+
+### Download Refetch
+ weather-dl-v2 download refetch
+ Refetch all non-successful partitions of a config.
+
+
+##### Arguments
+> `CONFIG_NAME` : Name of the download config.
+
+##### Options
+> `-l/--license` (Required): License ID to which this download has to be added to.
+
+##### Usage
+```
+weather-dl-v2 download refetch example.cfg -l L1 -l L2
+```
+
+
+
+## License
+Manage licenses.
+
+### License Add
+ weather-dl-v2 license add
+ Add a new license. New licenses are added using a json file.
+
+
+The json file should be in this format:
+```
+{
+ "license_id: ,
+ "client_name": ,
+ "number_of_requests": ,
+ "secret_id":
+}
+```
+NOTE: `license_id` is case insensitive and has to be unique for each license.
+
+
+##### Arguments
+> `FILE_PATH` : Path to the license json.
+
+##### Usage
+```
+weather-dl-v2 license add /path/to/new-license.json
+```
+
+### License Get
+ weather-dl-v2 license get
+ Get a particular license by license ID.
+
+
+##### Arguments
+> `LICENSE` : License ID of the license to be fetched.
+
+##### Usage
+```
+weather-dl-v2 license get L1
+```
+
+### License Remove
+ weather-dl-v2 license remove
+ Remove a particular license by license ID.
+
+
+##### Arguments
+> `LICENSE` : License ID of the license to be removed.
+
+##### Usage
+```
+weather-dl-v2 license remove L1
+```
+
+### License List
+ weather-dl-v2 license list
+ List all the licenses available.
+
+
+ The list can also be filtered by client name.
+
+##### Options
+> `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value.
+
+##### Usage
+```
+weather-dl-v2 license list
+weather-dl-v2 license list --filter client_name=cds
+```
+
+### License Update
+ weather-dl-v2 license update
+ Update an existing license using License ID and a license json.
+
+
+ The json should be of the same format used to add a new license.
+
+##### Arguments
+> `LICENSE` : License ID of the license to be edited.
+> `FILE_PATH` : Path to the license json.
+
+##### Usage
+```
+weather-dl-v2 license update L1 /path/to/license.json
+```
+
+
+
+## Queue
+Manage all the license queue.
+
+### Queue List
+ weather-dl-v2 queue list
+ List all the queues.
+
+
+ The list can also be filtered by client name.
+
+##### Options
+> `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value.
+
+##### Usage
+```
+weather-dl-v2 queue list
+weather-dl-v2 queue list --filter client_name=cds
+```
+
+### Queue Get
+ weather-dl-v2 queue get
+ Get a queue by license ID.
+
+
+ The list can also be filtered by client name.
+
+##### Arguments
+> `LICENSE` : License ID of the queue to be fetched.
+
+##### Usage
+```
+weather-dl-v2 queue get L1
+```
+
+### Queue Edit
+ weather-dl-v2 queue edit
+ Edit the priority of configs inside queues using edit.
+
+
+Priority can be edited in two ways:
+1. The new priority queue is passed using a priority json file that should follow the following format:
+```
+{
+ “priority”: [“c1.cfg”, “c3.cfg”, “c2.cfg”]
+}
+```
+2. A config file name and its absolute priority can be passed and it updates the priority for that particular config file in the mentioned license queue.
+
+##### Arguments
+> `LICENSE` : License ID of queue to be edited.
+
+##### Options
+> `-f/--file` : Path of the new priority json file.
+> `-c/--config` : Config name for absolute priority.
+> `-p/--priority`: Absolute priority for the config in a license queue. Priority increases in ascending order with 0 having highest priority.
+
+##### Usage
+```
+weather-dl-v2 queue edit L1 --file /path/to/priority.json
+weather-dl-v2 queue edit L1 --config example.cfg --priority 0
+```
+
+
+
+## Config
+Configurations for cli.
+
+### Config Show IP
+ weather-dl-v2 config show-ip
+See the current server IP address.
+
+
+##### Usage
+```
+weather-dl-v2 config show-ip
+```
+
+### Config Set IP
+ weather-dl-v2 config set-ip
+See the current server IP address.
+
+
+##### Arguments
+> `NEW_IP` : New IP address. (Do not add port or protocol).
+
+##### Usage
+```
+weather-dl-v2 config set-ip 127.0.0.1
+```
+
diff --git a/weather_dl_v2/cli/Dockerfile b/weather_dl_v2/cli/Dockerfile
new file mode 100644
index 00000000..ec3536be
--- /dev/null
+++ b/weather_dl_v2/cli/Dockerfile
@@ -0,0 +1,43 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+FROM continuumio/miniconda3:latest
+
+COPY . .
+
+# Add the mamba solver for faster builds
+RUN conda install -n base conda-libmamba-solver
+RUN conda config --set solver libmamba
+
+# Create conda env using environment.yml
+RUN conda update conda -y
+RUN conda env create --name weather-dl-v2-cli --file=environment.yml
+
+# Activate the conda env and update the PATH
+ARG CONDA_ENV_NAME=weather-dl-v2-cli
+RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc
+ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH
+
+RUN apt-get update -y
+RUN apt-get install nano -y
+RUN apt-get install vim -y
+RUN apt-get install curl -y
+
+# Install gsutil
+RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-443.0.0-linux-arm.tar.gz
+RUN tar -xf google-cloud-cli-443.0.0-linux-arm.tar.gz
+RUN ./google-cloud-sdk/install.sh --quiet
+RUN echo "if [ -f '/google-cloud-sdk/path.bash.inc' ]; then . '/google-cloud-sdk/path.bash.inc'; fi" >> /root/.bashrc
+RUN echo "if [ -f '/google-cloud-sdk/completion.bash.inc' ]; then . '/google-cloud-sdk/completion.bash.inc'; fi" >> /root/.bashrc
diff --git a/weather_dl_v2/cli/README.md b/weather_dl_v2/cli/README.md
new file mode 100644
index 00000000..a4f4932f
--- /dev/null
+++ b/weather_dl_v2/cli/README.md
@@ -0,0 +1,58 @@
+# weather-dl-cli
+This is a command line interface for talking to the weather-dl-v2 FastAPI server.
+
+- Due to our org level policy we can't expose external-ip using LoadBalancer Service
+while deploying our FastAPI server. Hence we need to deploy the CLI on a VM to interact
+through our fastapi server.
+
+Replace the FastAPI server pod's IP in cli_config.json.
+```
+Please make approriate changes in cli_config.json, if required.
+```
+> Note: Command to get the Pod IP : `kubectl get pods -o wide`.
+>
+> Though note that in case of Pod restart IP might get change. So we need to look
+> for better solution for the same.
+
+## Create docker image for weather-dl-cli
+
+```
+export PROJECT_ID=
+export REPO= eg:weather-tools
+
+gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-cli" --timeout=79200 --machine-type=e2-highcpu-32
+```
+
+## Create a VM using above created docker-image
+```
+export ZONE= eg: us-west1-a
+export SERVICE_ACCOUNT= # Let's keep this as Compute Engine Default Service Account
+export IMAGE_PATH= # The above created image-path
+
+gcloud compute instances create-with-container weather-dl-v2-cli \
+ --project=$PROJECT_ID \
+ --zone=$ZONE \
+ --machine-type=e2-medium \
+ --network-interface=network-tier=PREMIUM,subnet=default \
+ --maintenance-policy=MIGRATE \
+ --provisioning-model=STANDARD \
+ --service-account=$SERVICE_ACCOUNT \
+ --scopes=https://www.googleapis.com/auth/cloud-platform \
+ --tags=http-server,https-server \
+ --image=projects/cos-cloud/global/images/cos-stable-105-17412-101-24 \
+ --boot-disk-size=10GB \
+ --boot-disk-type=pd-balanced \
+ --boot-disk-device-name=weather-dl-v2-cli \
+ --container-image=$IMAGE_PATH \
+ --container-restart-policy=on-failure \
+ --container-tty \
+ --no-shielded-secure-boot \
+ --shielded-vtpm \
+ --labels=goog-ec-src=vm_add-gcloud,container-vm=cos-stable-105-17412-101-24 \
+ --metadata-from-file=startup-script=vm-startup.sh
+```
+
+## Use the cli after doing ssh in the above created VM
+```
+weather-dl-v2 --help
+```
diff --git a/weather_dl_v2/cli/app/__init__.py b/weather_dl_v2/cli/app/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/cli/app/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/cli/app/cli_config.py b/weather_dl_v2/cli/app/cli_config.py
new file mode 100644
index 00000000..9bfeb1de
--- /dev/null
+++ b/weather_dl_v2/cli/app/cli_config.py
@@ -0,0 +1,62 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import dataclasses
+import typing as t
+import json
+import os
+
+Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet
+
+
+@dataclasses.dataclass
+class CliConfig:
+ pod_ip: str = ""
+ port: str = ""
+
+ @property
+ def BASE_URI(self) -> str:
+ return f"http://{self.pod_ip}:{self.port}"
+
+ kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: t.Dict):
+ config_instance = cls()
+
+ for key, value in config.items():
+ if hasattr(config_instance, key):
+ setattr(config_instance, key, value)
+ else:
+ config_instance.kwargs[key] = value
+
+ return config_instance
+
+
+cli_config = None
+
+
+def get_config():
+ global cli_config
+ # TODO: Update this so cli can work from any folder level.
+ # Right now it only works in folder where cli_config.json is present.
+ cli_config_json = os.path.join(os.getcwd(), "cli_config.json")
+
+ if cli_config is None:
+ with open(cli_config_json) as file:
+ firestore_dict = json.load(file)
+ cli_config = CliConfig.from_dict(firestore_dict)
+
+ return cli_config
diff --git a/weather_dl_v2/cli/app/main.py b/weather_dl_v2/cli/app/main.py
new file mode 100644
index 00000000..03a52577
--- /dev/null
+++ b/weather_dl_v2/cli/app/main.py
@@ -0,0 +1,48 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import typer
+import logging
+from app.cli_config import get_config
+import requests
+from app.subcommands import download, queue, license, config
+from app.utils import Loader
+
+logger = logging.getLogger(__name__)
+
+app = typer.Typer(
+ help="weather-dl-v2 is a cli tool for communicating with FastAPI server."
+)
+
+app.add_typer(download.app, name="download", help="Manage downloads.")
+app.add_typer(queue.app, name="queue", help="Manage queues.")
+app.add_typer(license.app, name="license", help="Manage licenses.")
+app.add_typer(config.app, name="config", help="Configurations for cli.")
+
+
+@app.command("ping", help="Check if FastAPI server is live and rechable.")
+def ping():
+ uri = f"{get_config().BASE_URI}/"
+ try:
+ with Loader("Sending request..."):
+ x = requests.get(uri)
+ except Exception as e:
+ print(f"error {e}")
+ return
+ print(x.text)
+
+
+if __name__ == "__main__":
+ app()
diff --git a/weather_dl_v2/cli/app/services/__init__.py b/weather_dl_v2/cli/app/services/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/cli/app/services/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/cli/app/services/download_service.py b/weather_dl_v2/cli/app/services/download_service.py
new file mode 100644
index 00000000..4d467271
--- /dev/null
+++ b/weather_dl_v2/cli/app/services/download_service.py
@@ -0,0 +1,128 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+import json
+import typing as t
+from app.services.network_service import network_service
+from app.cli_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+class DownloadService(abc.ABC):
+
+ @abc.abstractmethod
+ def _list_all_downloads(self):
+ pass
+
+ @abc.abstractmethod
+ def _list_all_downloads_by_filter(self, filter_dict: dict):
+ pass
+
+ @abc.abstractmethod
+ def _get_download_by_config(self, config_name: str):
+ pass
+
+ @abc.abstractmethod
+ def _show_config_content(self, config_name: str):
+ pass
+
+ @abc.abstractmethod
+ def _add_new_download(
+ self, file_path: str, licenses: t.List[str], force_download: bool
+ ):
+ pass
+
+ @abc.abstractmethod
+ def _remove_download(self, config_name: str):
+ pass
+
+ @abc.abstractmethod
+ def _refetch_config_partitions(self, config_name: str, licenses: t.List[str]):
+ pass
+
+
+class DownloadServiceNetwork(DownloadService):
+
+ def __init__(self):
+ self.endpoint = f"{get_config().BASE_URI}/download"
+
+ def _list_all_downloads(self):
+ return network_service.get(
+ uri=self.endpoint, header={"accept": "application/json"}
+ )
+
+ def _list_all_downloads_by_filter(self, filter_dict: dict):
+ return network_service.get(
+ uri=self.endpoint,
+ header={"accept": "application/json"},
+ query=filter_dict,
+ )
+
+ def _get_download_by_config(self, config_name: str):
+ return network_service.get(
+ uri=f"{self.endpoint}/{config_name}",
+ header={"accept": "application/json"},
+ )
+
+ def _show_config_content(self, config_name: str):
+ return network_service.get(
+ uri=f"{self.endpoint}/show/{config_name}",
+ header={"accept": "application/json"},
+ )
+
+ def _add_new_download(
+ self, file_path: str, licenses: t.List[str], force_download: bool
+ ):
+ try:
+ file = {"file": open(file_path, "rb")}
+ except FileNotFoundError:
+ return "File not found."
+
+ return network_service.post(
+ uri=self.endpoint,
+ header={"accept": "application/json"},
+ file=file,
+ payload={"licenses": licenses},
+ query={"force_download": force_download},
+ )
+
+ def _remove_download(self, config_name: str):
+ return network_service.delete(
+ uri=f"{self.endpoint}/{config_name}", header={"accept": "application/json"}
+ )
+
+ def _refetch_config_partitions(self, config_name: str, licenses: t.List[str]):
+ return network_service.post(
+ uri=f"{self.endpoint}/retry/{config_name}",
+ header={"accept": "application/json"},
+ payload=json.dumps({"licenses": licenses}),
+ )
+
+
+class DownloadServiceMock(DownloadService):
+ pass
+
+
+def get_download_service(test: bool = False):
+ if test:
+ return DownloadServiceMock()
+ else:
+ return DownloadServiceNetwork()
+
+
+download_service = get_download_service()
diff --git a/weather_dl_v2/cli/app/services/license_service.py b/weather_dl_v2/cli/app/services/license_service.py
new file mode 100644
index 00000000..09ff4f3c
--- /dev/null
+++ b/weather_dl_v2/cli/app/services/license_service.py
@@ -0,0 +1,107 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+import json
+from app.services.network_service import network_service
+from app.cli_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+class LicenseService(abc.ABC):
+
+ @abc.abstractmethod
+ def _get_all_license(self):
+ pass
+
+ @abc.abstractmethod
+ def _get_all_license_by_client_name(self, client_name: str):
+ pass
+
+ @abc.abstractmethod
+ def _get_license_by_license_id(self, license_id: str):
+ pass
+
+ @abc.abstractmethod
+ def _add_license(self, license_dict: dict):
+ pass
+
+ @abc.abstractmethod
+ def _remove_license(self, license_id: str):
+ pass
+
+ @abc.abstractmethod
+ def _update_license(self, license_id: str, license_dict: dict):
+ pass
+
+
+class LicenseServiceNetwork(LicenseService):
+
+ def __init__(self):
+ self.endpoint = f"{get_config().BASE_URI}/license"
+
+ def _get_all_license(self):
+ return network_service.get(
+ uri=self.endpoint, header={"accept": "application/json"}
+ )
+
+ def _get_all_license_by_client_name(self, client_name: str):
+ return network_service.get(
+ uri=self.endpoint,
+ header={"accept": "application/json"},
+ query={"client_name": client_name},
+ )
+
+ def _get_license_by_license_id(self, license_id: str):
+ return network_service.get(
+ uri=f"{self.endpoint}/{license_id}",
+ header={"accept": "application/json"},
+ )
+
+ def _add_license(self, license_dict: dict):
+ return network_service.post(
+ uri=self.endpoint,
+ header={"accept": "application/json"},
+ payload=json.dumps(license_dict),
+ )
+
+ def _remove_license(self, license_id: str):
+ return network_service.delete(
+ uri=f"{self.endpoint}/{license_id}",
+ header={"accept": "application/json"},
+ )
+
+ def _update_license(self, license_id: str, license_dict: dict):
+ return network_service.put(
+ uri=f"{self.endpoint}/{license_id}",
+ header={"accept": "application/json"},
+ payload=json.dumps(license_dict),
+ )
+
+
+class LicenseServiceMock(LicenseService):
+ pass
+
+
+def get_license_service(test: bool = False):
+ if test:
+ return LicenseServiceMock()
+ else:
+ return LicenseServiceNetwork()
+
+
+license_service = get_license_service()
diff --git a/weather_dl_v2/cli/app/services/network_service.py b/weather_dl_v2/cli/app/services/network_service.py
new file mode 100644
index 00000000..4406d91b
--- /dev/null
+++ b/weather_dl_v2/cli/app/services/network_service.py
@@ -0,0 +1,85 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import requests
+import json
+import logging
+from app.utils import Loader, timeit
+
+logger = logging.getLogger(__name__)
+
+
+class NetworkService:
+
+ def parse_response(self, response: requests.Response):
+ try:
+ parsed = json.loads(response.text)
+ except Exception as e:
+ logger.info(f"Parsing error: {e}.")
+ logger.info(f"Status code {response.status_code}")
+ logger.info(f"Response {response.text}")
+ return
+
+ if isinstance(parsed, list):
+ print(f"[Total {len(parsed)} items.]")
+
+ return json.dumps(parsed, indent=3)
+
+ @timeit
+ def get(self, uri, header, query=None, payload=None):
+ try:
+ with Loader("Sending request..."):
+ x = requests.get(uri, params=query, headers=header, data=payload)
+ return self.parse_response(x)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"request error: {e}")
+ raise SystemExit(e)
+
+ @timeit
+ def post(self, uri, header, query=None, payload=None, file=None):
+ try:
+ with Loader("Sending request..."):
+ x = requests.post(
+ uri, params=query, headers=header, data=payload, files=file
+ )
+ return self.parse_response(x)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"request error: {e}")
+ raise SystemExit(e)
+
+ @timeit
+ def put(self, uri, header, query=None, payload=None, file=None):
+ try:
+ with Loader("Sending request..."):
+ x = requests.put(
+ uri, params=query, headers=header, data=payload, files=file
+ )
+ return self.parse_response(x)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"request error: {e}")
+ raise SystemExit(e)
+
+ @timeit
+ def delete(self, uri, header, query=None):
+ try:
+ with Loader("Sending request..."):
+ x = requests.delete(uri, params=query, headers=header)
+ return self.parse_response(x)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"request error: {e}")
+ raise SystemExit(e)
+
+
+network_service = NetworkService()
diff --git a/weather_dl_v2/cli/app/services/queue_service.py b/weather_dl_v2/cli/app/services/queue_service.py
new file mode 100644
index 00000000..f6824934
--- /dev/null
+++ b/weather_dl_v2/cli/app/services/queue_service.py
@@ -0,0 +1,101 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+import json
+import typing as t
+from app.services.network_service import network_service
+from app.cli_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+class QueueService(abc.ABC):
+
+ @abc.abstractmethod
+ def _get_all_license_queues(self):
+ pass
+
+ @abc.abstractmethod
+ def _get_license_queue_by_client_name(self, client_name: str):
+ pass
+
+ @abc.abstractmethod
+ def _get_queue_by_license(self, license_id: str):
+ pass
+
+ @abc.abstractmethod
+ def _edit_license_queue(self, license_id: str, priority_list: t.List[str]):
+ pass
+
+ @abc.abstractmethod
+ def _edit_config_absolute_priority(
+ self, license_id: str, config_name: str, priority: int
+ ):
+ pass
+
+
+class QueueServiceNetwork(QueueService):
+
+ def __init__(self):
+ self.endpoint = f"{get_config().BASE_URI}/queues"
+
+ def _get_all_license_queues(self):
+ return network_service.get(
+ uri=self.endpoint, header={"accept": "application/json"}
+ )
+
+ def _get_license_queue_by_client_name(self, client_name: str):
+ return network_service.get(
+ uri=self.endpoint,
+ header={"accept": "application/json"},
+ query={"client_name": client_name},
+ )
+
+ def _get_queue_by_license(self, license_id: str):
+ return network_service.get(
+ uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json"}
+ )
+
+ def _edit_license_queue(self, license_id: str, priority_list: t.List[str]):
+ return network_service.post(
+ uri=f"{self.endpoint}/{license_id}",
+ header={"accept": "application/json", "Content-Type": "application/json"},
+ payload=json.dumps(priority_list),
+ )
+
+ def _edit_config_absolute_priority(
+ self, license_id: str, config_name: str, priority: int
+ ):
+ return network_service.put(
+ uri=f"{self.endpoint}/priority/{license_id}",
+ header={"accept": "application/json"},
+ query={"config_name": config_name, "priority": priority},
+ )
+
+
+class QueueServiceMock(QueueService):
+ pass
+
+
+def get_queue_service(test: bool = False):
+ if test:
+ return QueueServiceMock()
+ else:
+ return QueueServiceNetwork()
+
+
+queue_service = get_queue_service()
diff --git a/weather_dl_v2/cli/app/subcommands/__init__.py b/weather_dl_v2/cli/app/subcommands/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/cli/app/subcommands/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/cli/app/subcommands/config.py b/weather_dl_v2/cli/app/subcommands/config.py
new file mode 100644
index 00000000..b2a03aaf
--- /dev/null
+++ b/weather_dl_v2/cli/app/subcommands/config.py
@@ -0,0 +1,60 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import typer
+import json
+import os
+from typing_extensions import Annotated
+from app.cli_config import get_config
+from app.utils import Validator
+
+app = typer.Typer()
+
+
+class ConfigValidator(Validator):
+ pass
+
+
+@app.command("show-ip", help="See the current server IP address.")
+def show_server_ip():
+ print(f"Current pod IP: {get_config().pod_ip}")
+
+
+@app.command("set-ip", help="Update the server IP address.")
+def update_server_ip(
+ new_ip: Annotated[
+ str, typer.Argument(help="New IP address. (Do not add port or protocol).")
+ ],
+):
+ file_path = os.path.join(os.getcwd(), "cli_config.json")
+ cli_config = {}
+ with open(file_path, "r") as file:
+ cli_config = json.load(file)
+
+ old_ip = cli_config["pod_ip"]
+ cli_config["pod_ip"] = new_ip
+
+ with open(file_path, "w") as file:
+ json.dump(cli_config, file)
+
+ validator = ConfigValidator(valid_keys=["pod_ip", "port"])
+
+ try:
+ cli_config = validator.validate_json(file_path=file_path)
+ except Exception as e:
+ print(f"payload error: {e}")
+ return
+
+ print(f"Pod IP Updated {old_ip} -> {new_ip} .")
diff --git a/weather_dl_v2/cli/app/subcommands/download.py b/weather_dl_v2/cli/app/subcommands/download.py
new file mode 100644
index 00000000..b16a26e8
--- /dev/null
+++ b/weather_dl_v2/cli/app/subcommands/download.py
@@ -0,0 +1,102 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import typer
+from typing_extensions import Annotated
+from app.services.download_service import download_service
+from app.utils import Validator, as_table
+from typing import List
+
+app = typer.Typer(rich_markup_mode="markdown")
+
+
+class DowloadFilterValidator(Validator):
+ pass
+
+
+@app.command("list", help="List out all the configs.")
+def get_downloads(
+ filter: Annotated[
+ List[str],
+ typer.Option(
+ help="""Filter by some value. Format: filter_key=filter_value. Available filters """
+ """[key: client_name, values: cds, mars, ecpublic] """
+ """[key: status, values: completed, failed, in-progress]"""
+ ),
+ ] = []
+):
+ if len(filter) > 0:
+ validator = DowloadFilterValidator(valid_keys=["client_name", "status"])
+
+ try:
+ filter_dict = validator.validate(filters=filter, allow_missing=True)
+ except Exception as e:
+ print(f"filter error: {e}")
+ return
+
+ print(as_table(download_service._list_all_downloads_by_filter(filter_dict)))
+ return
+
+ print(as_table(download_service._list_all_downloads()))
+
+
+# TODO: Add support for submitting multiple configs using *.cfg notation.
+@app.command("add", help="Submit new config to download.")
+def submit_download(
+ file_path: Annotated[
+ str, typer.Argument(help="File path of config to be uploaded.")
+ ],
+ license: Annotated[List[str], typer.Option("--license", "-l", help="License ID.")],
+ force_download: Annotated[
+ bool,
+ typer.Option(
+ "-f",
+ "--force-download",
+ help="Force redownload of partitions that were previously downloaded.",
+ ),
+ ] = False,
+):
+ print(download_service._add_new_download(file_path, license, force_download))
+
+
+@app.command("get", help="Get a particular config.")
+def get_download_by_config(
+ config_name: Annotated[str, typer.Argument(help="Config file name.")]
+):
+ print(as_table(download_service._get_download_by_config(config_name)))
+
+
+@app.command("show", help="Show contents of a particular config.")
+def show_config(
+ config_name: Annotated[str, typer.Argument(help="Config file name.")]
+):
+ print(download_service._show_config_content(config_name))
+
+
+@app.command("remove", help="Remove existing config.")
+def remove_download(
+ config_name: Annotated[str, typer.Argument(help="Config file name.")]
+):
+ print(download_service._remove_download(config_name))
+
+
+@app.command(
+ "refetch", help="Reschedule all partitions of a config that are not successful."
+)
+def refetch_config(
+ config_name: Annotated[str, typer.Argument(help="Config file name.")],
+ license: Annotated[List[str], typer.Option("--license", "-l", help="License ID.")],
+):
+ print(download_service._refetch_config_partitions(config_name, license))
diff --git a/weather_dl_v2/cli/app/subcommands/license.py b/weather_dl_v2/cli/app/subcommands/license.py
new file mode 100644
index 00000000..68dccd1d
--- /dev/null
+++ b/weather_dl_v2/cli/app/subcommands/license.py
@@ -0,0 +1,104 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import typer
+from typing_extensions import Annotated
+from app.services.license_service import license_service
+from app.utils import Validator, as_table
+
+app = typer.Typer()
+
+
+class LicenseValidator(Validator):
+ pass
+
+
+@app.command("list", help="List all licenses.")
+def get_all_license(
+ filter: Annotated[
+ str, typer.Option(help="Filter by some value. Format: filter_key=filter_value")
+ ] = None
+):
+ if filter:
+ validator = LicenseValidator(valid_keys=["client_name"])
+
+ try:
+ data = validator.validate(filters=[filter])
+ client_name = data["client_name"]
+ except Exception as e:
+ print(f"filter error: {e}")
+ return
+
+ print(as_table(license_service._get_all_license_by_client_name(client_name)))
+ return
+
+ print(as_table(license_service._get_all_license()))
+
+
+@app.command("get", help="Get a particular license by ID.")
+def get_license(license: Annotated[str, typer.Argument(help="License ID.")]):
+ print(as_table(license_service._get_license_by_license_id(license)))
+
+
+@app.command("add", help="Add new license.")
+def add_license(
+ file_path: Annotated[
+ str,
+ typer.Argument(
+ help="""Input json file. Example json for new license-"""
+ """{"license_id" : , "client_name" : , "number_of_requests" : , "secret_id" : }"""
+ """\nNOTE: license_id is case insensitive and has to be unique for each license."""
+ ),
+ ],
+):
+ validator = LicenseValidator(
+ valid_keys=["license_id", "client_name", "number_of_requests", "secret_id"]
+ )
+
+ try:
+ license_dict = validator.validate_json(file_path=file_path)
+ except Exception as e:
+ print(f"payload error: {e}")
+ return
+
+ print(license_service._add_license(license_dict))
+
+
+@app.command("remove", help="Remove a license.")
+def remove_license(license: Annotated[str, typer.Argument(help="License ID.")]):
+ print(license_service._remove_license(license))
+
+
+@app.command("update", help="Update existing license.")
+def update_license(
+ license: Annotated[str, typer.Argument(help="License ID.")],
+ file_path: Annotated[
+ str,
+ typer.Argument(
+ help="""Input json file. Example json for updated license- """
+ """{"client_id": , "client_name" : , "number_of_requests" : , "secret_id" : }"""
+ ),
+ ], # noqa
+):
+ validator = LicenseValidator(
+ valid_keys=["client_id", "client_name", "number_of_requests", "secret_id"]
+ )
+ try:
+ license_dict = validator.validate_json(file_path=file_path)
+ except Exception as e:
+ print(f"payload error: {e}")
+ return
+
+ print(license_service._update_license(license, license_dict))
diff --git a/weather_dl_v2/cli/app/subcommands/queue.py b/weather_dl_v2/cli/app/subcommands/queue.py
new file mode 100644
index 00000000..816564ca
--- /dev/null
+++ b/weather_dl_v2/cli/app/subcommands/queue.py
@@ -0,0 +1,111 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import typer
+from typing_extensions import Annotated
+from app.services.queue_service import queue_service
+from app.utils import Validator, as_table
+
+app = typer.Typer()
+
+
+class QueueValidator(Validator):
+ pass
+
+
+@app.command("list", help="List all the license queues.")
+def get_all_license_queue(
+ filter: Annotated[
+ str, typer.Option(help="Filter by some value. Format: filter_key=filter_value")
+ ] = None
+):
+ if filter:
+ validator = QueueValidator(valid_keys=["client_name"])
+
+ try:
+ data = validator.validate(filters=[filter])
+ client_name = data["client_name"]
+ except Exception as e:
+ print(f"filter error: {e}")
+ return
+
+ print(as_table(queue_service._get_license_queue_by_client_name(client_name)))
+ return
+
+ print(as_table(queue_service._get_all_license_queues()))
+
+
+@app.command("get", help="Get queue of particular license.")
+def get_license_queue(license: Annotated[str, typer.Argument(help="License ID")]):
+ print(as_table(queue_service._get_queue_by_license(license)))
+
+
+@app.command(
+ "edit",
+ help="Edit existing license queue. Queue can edited via a priority"
+ "file or my moving a single config to a given priority.",
+) # noqa
+def modify_license_queue(
+ license: Annotated[str, typer.Argument(help="License ID.")],
+ file: Annotated[
+ str,
+ typer.Option(
+ "--file",
+ "-f",
+ help="""File path of priority json file. Example json: {"priority": ["c1.cfg", "c2.cfg",...]}""",
+ ),
+ ] = None, # noqa
+ config: Annotated[
+ str, typer.Option("--config", "-c", help="Config name for absolute priority.")
+ ] = None,
+ priority: Annotated[
+ int,
+ typer.Option(
+ "--priority",
+ "-p",
+ help="Absolute priority for the config in a license queue."
+ "Priority increases in ascending order with 0 having highest priority.",
+ ),
+ ] = None, # noqa
+):
+ if file is None and (config is None and priority is None):
+ print("Priority file or config name with absolute priority must be passed.")
+ return
+
+ if file is not None and (config is not None or priority is not None):
+ print("--config & --priority can't be used along with --file argument.")
+ return
+
+ if file is not None:
+ validator = QueueValidator(valid_keys=["priority"])
+
+ try:
+ data = validator.validate_json(file_path=file)
+ priority_list = data["priority"]
+ except Exception as e:
+ print(f"key error: {e}")
+ return
+ print(queue_service._edit_license_queue(license, priority_list))
+ return
+ elif config is not None and priority is not None:
+ if priority < 0:
+ print("Priority can not be negative.")
+ return
+
+ print(queue_service._edit_config_absolute_priority(license, config, priority))
+ return
+ else:
+ print("--config & --priority arguments should be used together.")
+ return
diff --git a/weather_dl_v2/cli/app/utils.py b/weather_dl_v2/cli/app/utils.py
new file mode 100644
index 00000000..1ced5c7b
--- /dev/null
+++ b/weather_dl_v2/cli/app/utils.py
@@ -0,0 +1,168 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+import dataclasses
+import typing as t
+import json
+from time import time
+from itertools import cycle
+from shutil import get_terminal_size
+from threading import Thread
+from time import sleep
+from tabulate import tabulate
+
+logger = logging.getLogger(__name__)
+
+
+def timeit(func):
+ def wrap_func(*args, **kwargs):
+ t1 = time()
+ result = func(*args, **kwargs)
+ t2 = time()
+ print(f"[executed in {(t2-t1):.4f}s.]")
+ return result
+
+ return wrap_func
+
+
+# TODO: Add a flag (may be -j/--json) to support raw response.
+def as_table(response: str):
+ data = json.loads(response)
+
+ if not isinstance(data, list):
+ # convert response to list if not a list.
+ data = [data]
+
+ if len(data) == 0:
+ return ""
+
+ header = data[0].keys()
+ # if any column has lists, convert that to a string.
+ rows = [
+ [
+ ",\n".join([f"{i} {ele}" for i, ele in enumerate(val)])
+ if isinstance(val, list)
+ else val
+ for val in x.values()
+ ]
+ for x in data
+ ]
+ rows.insert(0, list(header))
+ return tabulate(
+ rows, showindex=True, tablefmt="grid", maxcolwidths=[16] * len(header)
+ )
+
+
+class Loader:
+
+ def __init__(self, desc="Loading...", end="", timeout=0.1):
+ """
+ A loader-like context manager
+
+ Args:
+ desc (str, optional): The loader's description. Defaults to "Loading...".
+ end (str, optional): Final print. Defaults to "Done!".
+ timeout (float, optional): Sleep time between prints. Defaults to 0.1.
+ """
+ self.desc = desc
+ self.end = end
+ self.timeout = timeout
+
+ self._thread = Thread(target=self._animate, daemon=True)
+ self.steps = ["⢿", "⣻", "⣽", "⣾", "⣷", "⣯", "⣟", "⡿"]
+ self.done = False
+
+ def start(self):
+ self._thread.start()
+ return self
+
+ def _animate(self):
+ for c in cycle(self.steps):
+ if self.done:
+ break
+ print(f"\r{self.desc} {c}", flush=True, end="")
+ sleep(self.timeout)
+
+ def __enter__(self):
+ self.start()
+
+ def stop(self):
+ self.done = True
+ cols = get_terminal_size((80, 20)).columns
+ print("\r" + " " * cols, end="", flush=True)
+
+ def __exit__(self, exc_type, exc_value, tb):
+ # handle exceptions with those variables ^
+ self.stop()
+
+
+@dataclasses.dataclass
+class Validator(abc.ABC):
+ valid_keys: t.List[str]
+
+ def validate(
+ self, filters: t.List[str], show_valid_filters=True, allow_missing: bool = False
+ ):
+ filter_dict = {}
+
+ for filter in filters:
+ _filter = filter.split("=")
+
+ if len(_filter) != 2:
+ if show_valid_filters:
+ logger.info(f"valid filters are: {self.valid_keys}.")
+ raise ValueError("Incorrect Filter. Please Try again.")
+
+ key, value = _filter
+ filter_dict[key] = value
+
+ data_set = set(filter_dict.keys())
+ valid_set = set(self.valid_keys)
+
+ if self._validate_keys(data_set, valid_set, allow_missing):
+ return filter_dict
+
+ def validate_json(self, file_path, allow_missing: bool = False):
+ try:
+ with open(file_path) as f:
+ data: dict = json.load(f)
+ data_keys = data.keys()
+
+ data_set = set(data_keys)
+ valid_set = set(self.valid_keys)
+
+ if self._validate_keys(data_set, valid_set, allow_missing):
+ return data
+
+ except FileNotFoundError:
+ logger.info("file not found.")
+ raise FileNotFoundError
+
+ def _validate_keys(self, data_set: set, valid_set: set, allow_missing: bool):
+ missing_keys = valid_set.difference(data_set)
+ invalid_keys = data_set.difference(valid_set)
+
+ if not allow_missing and len(missing_keys) > 0:
+ raise ValueError(f"keys {missing_keys} are missing in file.")
+
+ if len(invalid_keys) > 0:
+ raise ValueError(f"keys {invalid_keys} are invalid keys.")
+
+ if allow_missing or data_set == valid_set:
+ return True
+
+ return False
diff --git a/weather_dl_v2/cli/cli_config.json b/weather_dl_v2/cli/cli_config.json
new file mode 100644
index 00000000..076ed641
--- /dev/null
+++ b/weather_dl_v2/cli/cli_config.json
@@ -0,0 +1,4 @@
+{
+ "pod_ip": "",
+ "port": 8080
+}
\ No newline at end of file
diff --git a/weather_dl_v2/cli/environment.yml b/weather_dl_v2/cli/environment.yml
new file mode 100644
index 00000000..f2ffec62
--- /dev/null
+++ b/weather_dl_v2/cli/environment.yml
@@ -0,0 +1,14 @@
+name: weather-dl-v2-cli
+channels:
+ - conda-forge
+dependencies:
+ - python=3.10
+ - pip=23.0.1
+ - typer=0.9.0
+ - tabulate=0.9.0
+ - pip:
+ - requests
+ - ruff
+ - pytype
+ - pytest
+ - .
diff --git a/weather_dl_v2/cli/setup.py b/weather_dl_v2/cli/setup.py
new file mode 100644
index 00000000..509f42fc
--- /dev/null
+++ b/weather_dl_v2/cli/setup.py
@@ -0,0 +1,30 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from setuptools import setup
+
+requirements = ["typer", "requests", "tabulate"]
+
+setup(
+ name="weather-dl-v2",
+ packages=["app", "app.subcommands", "app.services"],
+ install_requires=requirements,
+ version="0.0.1",
+ author="aniket",
+ description=(
+ "This cli tools helps in interacting with weather dl v2 fast API server."
+ ),
+ entry_points={"console_scripts": ["weather-dl-v2=app.main:app"]},
+)
diff --git a/weather_dl_v2/cli/vm-startup.sh b/weather_dl_v2/cli/vm-startup.sh
new file mode 100644
index 00000000..e36f6edc
--- /dev/null
+++ b/weather_dl_v2/cli/vm-startup.sh
@@ -0,0 +1,4 @@
+#! /bin/bash
+
+command="docker exec -it \\\$(docker ps -qf name=weather-dl-v2-cli) /bin/bash"
+sudo sh -c "echo \"$command\" >> /etc/profile"
\ No newline at end of file
diff --git a/weather_dl_v2/config.json b/weather_dl_v2/config.json
new file mode 100644
index 00000000..f5afae8b
--- /dev/null
+++ b/weather_dl_v2/config.json
@@ -0,0 +1,11 @@
+{
+ "download_collection": "download",
+ "queues_collection": "queues",
+ "license_collection": "license",
+ "manifest_collection": "manifest",
+ "storage_bucket": "XXXXXXX",
+ "gcs_project": "XXXXXXX",
+ "license_deployment_image": "XXXXXXX",
+ "downloader_k8_image": "XXXXXXX",
+ "welcome_message": "Greetings from weather-dl v2!"
+}
\ No newline at end of file
diff --git a/weather_dl_v2/downloader_kubernetes/Dockerfile b/weather_dl_v2/downloader_kubernetes/Dockerfile
new file mode 100644
index 00000000..74084030
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/Dockerfile
@@ -0,0 +1,46 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+FROM continuumio/miniconda3:latest
+
+# Update miniconda
+RUN conda update conda -y
+
+# Add the mamba solver for faster builds
+RUN conda install -n base conda-libmamba-solver
+RUN conda config --set solver libmamba
+
+# Create conda env using environment.yml
+COPY . .
+RUN conda env create -f environment.yml --debug
+
+# Activate the conda env and update the PATH
+ARG CONDA_ENV_NAME=weather-dl-v2-downloader
+RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc
+ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH
diff --git a/weather_dl_v2/downloader_kubernetes/README.md b/weather_dl_v2/downloader_kubernetes/README.md
new file mode 100644
index 00000000..b0d865f8
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/README.md
@@ -0,0 +1,23 @@
+# Deployment / Usage Instruction
+
+### User authorization required to set up the environment:
+* roles/container.admin
+
+### Authorization needed for the tool to operate:
+We are not configuring any service account here hence make sure that compute engine default service account have roles:
+* roles/storage.admin
+* roles/bigquery.dataEditor
+* roles/bigquery.jobUser
+
+### Make changes in weather_dl_v2/config.json, if required [for running locally]
+```
+export CONFIG_PATH=/path/to/weather_dl_v2/config.json
+```
+
+### Create docker image for downloader:
+```
+export PROJECT_ID=
+export REPO= eg:weather-tools
+
+gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-downloader" --timeout=79200 --machine-type=e2-highcpu-32
+```
diff --git a/weather_dl_v2/downloader_kubernetes/downloader.py b/weather_dl_v2/downloader_kubernetes/downloader.py
new file mode 100644
index 00000000..c8a5c7dc
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/downloader.py
@@ -0,0 +1,73 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This program downloads ECMWF data & upload it into GCS.
+"""
+import tempfile
+import os
+import sys
+from manifest import FirestoreManifest, Stage
+from util import copy, download_with_aria2
+import datetime
+
+
+def download(url: str, path: str) -> None:
+ """Download data from client, with retries."""
+ if path:
+ if os.path.exists(path):
+ # Empty the target file, if it already exists, otherwise the
+ # transfer below might be fooled into thinking we're resuming
+ # an interrupted download.
+ open(path, "w").close()
+ download_with_aria2(url, path)
+
+
+def main(
+ config_name, dataset, selection, user_id, url, target_path, license_id
+) -> None:
+ """Download data from a client to a temp file."""
+
+ manifest = FirestoreManifest(license_id=license_id)
+ temp_name = ""
+ with manifest.transact(config_name, dataset, selection, target_path, user_id):
+ with tempfile.NamedTemporaryFile(delete=False) as temp:
+ temp_name = temp.name
+ manifest.set_stage(Stage.DOWNLOAD)
+ precise_download_start_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ manifest.prev_stage_precise_start_time = precise_download_start_time
+ print(f"Downloading data for {target_path!r}.")
+ download(url, temp_name)
+ print(f"Download completed for {target_path!r}.")
+
+ manifest.set_stage(Stage.UPLOAD)
+ precise_upload_start_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ manifest.prev_stage_precise_start_time = precise_upload_start_time
+ print(f"Uploading to store for {target_path!r}.")
+ copy(temp_name, target_path)
+ print(f"Upload to store complete for {target_path!r}.")
+ os.unlink(temp_name)
+
+
+if __name__ == "__main__":
+ main(*sys.argv[1:])
diff --git a/weather_dl_v2/downloader_kubernetes/downloader_config.py b/weather_dl_v2/downloader_kubernetes/downloader_config.py
new file mode 100644
index 00000000..247ae664
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/downloader_config.py
@@ -0,0 +1,65 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import dataclasses
+import typing as t
+import json
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet
+
+
+@dataclasses.dataclass
+class DownloaderConfig:
+ manifest_collection: str = ""
+ kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: t.Dict):
+ config_instance = cls()
+
+ for key, value in config.items():
+ if hasattr(config_instance, key):
+ setattr(config_instance, key, value)
+ else:
+ config_instance.kwargs[key] = value
+
+ return config_instance
+
+
+downloader_config = None
+
+
+def get_config():
+ global downloader_config
+ if downloader_config:
+ return downloader_config
+
+ downloader_config_json = "config/config.json"
+ if not os.path.exists(downloader_config_json):
+ downloader_config_json = os.environ.get("CONFIG_PATH", None)
+
+ if downloader_config_json is None:
+ logger.error("Couldn't load config file for downloader.")
+ raise FileNotFoundError("Couldn't load config file for downloader.")
+
+ with open(downloader_config_json) as file:
+ config_dict = json.load(file)
+ downloader_config = DownloaderConfig.from_dict(config_dict)
+
+ return downloader_config
diff --git a/weather_dl_v2/downloader_kubernetes/environment.yml b/weather_dl_v2/downloader_kubernetes/environment.yml
new file mode 100644
index 00000000..79e75565
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/environment.yml
@@ -0,0 +1,17 @@
+name: weather-dl-v2-downloader
+channels:
+ - conda-forge
+dependencies:
+ - python=3.10
+ - google-cloud-sdk=410.0.0
+ - aria2=1.36.0
+ - geojson=2.5.0=py_0
+ - xarray=2022.11.0
+ - google-apitools
+ - pip=22.3
+ - pip:
+ - apache_beam[gcp]==2.40.0
+ - firebase-admin
+ - google-cloud-pubsub
+ - kubernetes
+ - psutil
diff --git a/weather_dl_v2/downloader_kubernetes/manifest.py b/weather_dl_v2/downloader_kubernetes/manifest.py
new file mode 100644
index 00000000..0bc82264
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/manifest.py
@@ -0,0 +1,503 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Client interface for connecting to a manifest."""
+
+import abc
+import dataclasses
+import datetime
+import enum
+import json
+import pandas as pd
+import time
+import traceback
+import typing as t
+
+from util import (
+ to_json_serializable_type,
+ fetch_geo_polygon,
+ get_file_size,
+ get_wait_interval,
+ generate_md5_hash,
+ GLOBAL_COVERAGE_AREA,
+)
+
+import firebase_admin
+from firebase_admin import credentials
+from firebase_admin import firestore
+from google.cloud.firestore_v1 import DocumentReference
+from google.cloud.firestore_v1.types import WriteResult
+from downloader_config import get_config
+
+"""An implementation-dependent Manifest URI."""
+Location = t.NewType("Location", str)
+
+
+class ManifestException(Exception):
+ """Errors that occur in Manifest Clients."""
+
+ pass
+
+
+class Stage(enum.Enum):
+ """A request can be either in one of the following stages at a time:
+
+ fetch : This represents request is currently in fetch stage i.e. request placed on the client's server
+ & waiting for some result before starting download (eg. MARS client).
+ download : This represents request is currently in download stage i.e. data is being downloading from client's
+ server to the worker's local file system.
+ upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local
+ file system to target location (GCS path).
+ retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client),
+ request will be in the retrieve stage i.e. fetch + download.
+ """
+
+ RETRIEVE = "retrieve"
+ FETCH = "fetch"
+ DOWNLOAD = "download"
+ UPLOAD = "upload"
+
+
+class Status(enum.Enum):
+ """Depicts the request's state status:
+
+ scheduled : A request partition is created & scheduled for processing.
+ Note: Its corresponding state can be None only.
+ in-progress : This represents the request state is currently in-progress (i.e. running).
+ The next status would be "success" or "failure".
+ success : This represents the request state execution completed successfully without any error.
+ failure : This represents the request state execution failed.
+ """
+
+ SCHEDULED = "scheduled"
+ IN_PROGRESS = "in-progress"
+ SUCCESS = "success"
+ FAILURE = "failure"
+
+
+@dataclasses.dataclass
+class DownloadStatus:
+ """Data recorded in `Manifest`s reflecting the status of a download."""
+
+ """The name of the config file associated with the request."""
+ config_name: str = ""
+
+ """Represents the dataset field of the configuration."""
+ dataset: t.Optional[str] = ""
+
+ """Copy of selection section of the configuration."""
+ selection: t.Dict = dataclasses.field(default_factory=dict)
+
+ """Location of the downloaded data."""
+ location: str = ""
+
+ """Represents area covered by the shard."""
+ area: str = ""
+
+ """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None."""
+ stage: t.Optional[Stage] = None
+
+ """Download status: 'scheduled', 'in-progress', 'success', or 'failure'."""
+ status: t.Optional[Status] = None
+
+ """Cause of error, if any."""
+ error: t.Optional[str] = ""
+
+ """Identifier for the user running the download."""
+ username: str = ""
+
+ """Shard size in GB."""
+ size: t.Optional[float] = 0
+
+ """A UTC datetime when download was scheduled."""
+ scheduled_time: t.Optional[str] = ""
+
+ """A UTC datetime when the retrieve stage starts."""
+ retrieve_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the retrieve state ends."""
+ retrieve_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the fetch state starts."""
+ fetch_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the fetch state ends."""
+ fetch_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the download state starts."""
+ download_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the download state ends."""
+ download_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the upload state starts."""
+ upload_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the upload state ends."""
+ upload_end_time: t.Optional[str] = ""
+
+ @classmethod
+ def from_dict(cls, download_status: t.Dict) -> "DownloadStatus":
+ """Instantiate DownloadStatus dataclass from dict."""
+ download_status_instance = cls()
+ for key, value in download_status.items():
+ if key == "status":
+ setattr(download_status_instance, key, Status(value))
+ elif key == "stage" and value is not None:
+ setattr(download_status_instance, key, Stage(value))
+ else:
+ setattr(download_status_instance, key, value)
+ return download_status_instance
+
+ @classmethod
+ def to_dict(cls, instance) -> t.Dict:
+ """Return the fields of a dataclass instance as a manifest ingestible
+ dictionary mapping of field names to field values."""
+ download_status_dict = {}
+ for field in dataclasses.fields(instance):
+ key = field.name
+ value = getattr(instance, field.name)
+ if isinstance(value, Status) or isinstance(value, Stage):
+ download_status_dict[key] = value.value
+ elif isinstance(value, pd.Timestamp):
+ download_status_dict[key] = value.isoformat()
+ elif key == "selection" and value is not None:
+ download_status_dict[key] = json.dumps(value)
+ else:
+ download_status_dict[key] = value
+ return download_status_dict
+
+
+@dataclasses.dataclass
+class Manifest(abc.ABC):
+ """Abstract manifest of download statuses.
+
+ Update download statuses to some storage medium.
+
+ This class lets one indicate that a download is `scheduled` or in a transaction process.
+ In the event of a transaction, a download will be updated with an `in-progress`, `success`
+ or `failure` status (with accompanying metadata).
+
+ Example:
+ ```
+ my_manifest = parse_manifest_location(Location('fs://some-firestore-collection'))
+
+ # Schedule data for download
+ my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username')
+
+ # ...
+
+ # Initiate a transaction – it will record that the download is `in-progess`
+ with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx:
+ # download logic here
+ pass
+
+ # ...
+
+ # on error, will record the download as a `failure` before propagating the error. By default, it will
+ # record download as a `success`.
+ ```
+
+ Attributes:
+ status: The current `DownloadStatus` of the Manifest.
+ """
+
+ # To reduce the impact of _read() and _update() calls
+ # on the start time of the stage.
+ license_id: str = ""
+ prev_stage_precise_start_time: t.Optional[str] = None
+ status: t.Optional[DownloadStatus] = None
+
+ # This is overridden in subclass.
+ def __post_init__(self):
+ """Initialize the manifest."""
+ pass
+
+ def schedule(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Indicate that a job has been scheduled for download.
+
+ 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'.
+ """
+ scheduled_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ self.status = DownloadStatus(
+ config_name=config_name,
+ dataset=dataset if dataset else None,
+ selection=selection,
+ location=location,
+ area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)),
+ username=user,
+ stage=None,
+ status=Status.SCHEDULED,
+ error=None,
+ size=None,
+ scheduled_time=scheduled_time,
+ retrieve_start_time=None,
+ retrieve_end_time=None,
+ fetch_start_time=None,
+ fetch_end_time=None,
+ download_start_time=None,
+ download_end_time=None,
+ upload_start_time=None,
+ upload_end_time=None,
+ )
+ self._update(self.status)
+
+ def skip(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Updates the manifest to mark the shards that were skipped in the current job
+ as 'upload' stage and 'success' status, indicating that they have already been downloaded.
+ """
+ old_status = self._read(location)
+ # The manifest needs to be updated for a skipped shard if its entry is not present, or
+ # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'.
+ if (
+ old_status.location != location
+ or old_status.stage != Stage.UPLOAD
+ or old_status.status != Status.SUCCESS
+ ):
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ size = get_file_size(location)
+
+ status = DownloadStatus(
+ config_name=config_name,
+ dataset=dataset if dataset else None,
+ selection=selection,
+ location=location,
+ area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)),
+ username=user,
+ stage=Stage.UPLOAD,
+ status=Status.SUCCESS,
+ error=None,
+ size=size,
+ scheduled_time=None,
+ retrieve_start_time=None,
+ retrieve_end_time=None,
+ fetch_start_time=None,
+ fetch_end_time=None,
+ download_start_time=None,
+ download_end_time=None,
+ upload_start_time=current_utc_time,
+ upload_end_time=current_utc_time,
+ )
+ self._update(status)
+ print(
+ f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}."
+ )
+
+ def _set_for_transaction(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Reset Manifest state in preparation for a new transaction."""
+ self.status = dataclasses.replace(self._read(location))
+ self.status.config_name = config_name
+ self.status.dataset = dataset if dataset else None
+ self.status.selection = selection
+ self.status.location = location
+ self.status.username = user
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, exc_type, exc_inst, exc_tb) -> None:
+ """Record end status of a transaction as either 'success' or 'failure'."""
+ if exc_type is None:
+ status = Status.SUCCESS
+ error = None
+ else:
+ status = Status.FAILURE
+ # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception
+ error = f"license_id: {self.license_id} "
+ error += "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb))
+
+ new_status = dataclasses.replace(self.status)
+ new_status.error = error
+ new_status.status = status
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ # This is necessary for setting the precise start time of the previous stage
+ # and end time of the final stage, as well as handling the case of Status.FAILURE.
+ if new_status.stage == Stage.FETCH:
+ new_status.fetch_start_time = self.prev_stage_precise_start_time
+ new_status.fetch_end_time = current_utc_time
+ elif new_status.stage == Stage.RETRIEVE:
+ new_status.retrieve_start_time = self.prev_stage_precise_start_time
+ new_status.retrieve_end_time = current_utc_time
+ elif new_status.stage == Stage.DOWNLOAD:
+ new_status.download_start_time = self.prev_stage_precise_start_time
+ new_status.download_end_time = current_utc_time
+ else:
+ new_status.upload_start_time = self.prev_stage_precise_start_time
+ new_status.upload_end_time = current_utc_time
+
+ new_status.size = get_file_size(new_status.location)
+
+ self.status = new_status
+
+ self._update(self.status)
+
+ def transact(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> "Manifest":
+ """Create a download transaction."""
+ self._set_for_transaction(config_name, dataset, selection, location, user)
+ return self
+
+ def set_stage(self, stage: Stage) -> None:
+ """Sets the current stage in manifest."""
+ new_status = dataclasses.replace(self.status)
+ new_status.stage = stage
+ new_status.status = Status.IN_PROGRESS
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ if stage == Stage.DOWNLOAD:
+ new_status.download_start_time = current_utc_time
+ else:
+ new_status.download_start_time = self.prev_stage_precise_start_time
+ new_status.download_end_time = current_utc_time
+ new_status.upload_start_time = current_utc_time
+
+ self.status = new_status
+ self._update(self.status)
+
+ @abc.abstractmethod
+ def _read(self, location: str) -> DownloadStatus:
+ pass
+
+ @abc.abstractmethod
+ def _update(self, download_status: DownloadStatus) -> None:
+ pass
+
+
+class FirestoreManifest(Manifest):
+ """A Firestore Manifest.
+ This Manifest implementation stores DownloadStatuses in a Firebase document store.
+ The document hierarchy for the manifest is as follows:
+ [manifest ]
+ ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... }
+ └── etc...
+ Where `[]` indicates a collection and ` {...}` indicates a document.
+ """
+
+ def _get_db(self) -> firestore.firestore.Client:
+ """Acquire a firestore client, initializing the firebase app if necessary.
+ Will attempt to get the db client five times. If it's still unsuccessful, a
+ `ManifestException` will be raised.
+ """
+ db = None
+ attempts = 0
+
+ while db is None:
+ try:
+ db = firestore.client()
+ except ValueError as e:
+ # The above call will fail with a value error when the firebase app is not initialized.
+ # Initialize the app here, and try again.
+ # Use the application default credentials.
+ cred = credentials.ApplicationDefault()
+
+ firebase_admin.initialize_app(cred)
+ print("Initialized Firebase App.")
+
+ if attempts > 4:
+ raise ManifestException(
+ "Exceeded number of retries to get firestore client."
+ ) from e
+
+ time.sleep(get_wait_interval(attempts))
+
+ attempts += 1
+
+ return db
+
+ def _read(self, location: str) -> DownloadStatus:
+ """Reads the JSON data from a manifest."""
+
+ doc_id = generate_md5_hash(location)
+
+ # Update document with download status
+ download_doc_ref = self.root_document_for_store(doc_id)
+
+ result = download_doc_ref.get()
+ row = {}
+ if result.exists:
+ records = result.to_dict()
+ row = {n: to_json_serializable_type(v) for n, v in records.items()}
+ return DownloadStatus.from_dict(row)
+
+ def _update(self, download_status: DownloadStatus) -> None:
+ """Update or create a download status record."""
+ print("Updating Firestore Manifest.")
+
+ status = DownloadStatus.to_dict(download_status)
+ doc_id = generate_md5_hash(status["location"])
+
+ # Update document with download status
+ download_doc_ref = self.root_document_for_store(doc_id)
+
+ result: WriteResult = download_doc_ref.set(status)
+
+ print(
+ f"Firestore manifest updated. "
+ f"update_time={result.update_time}, "
+ f"filename={download_status.location}."
+ )
+
+ def root_document_for_store(self, store_scheme: str) -> DocumentReference:
+ """Get the root manifest document given the user's config and current document's storage location."""
+ return (
+ self._get_db()
+ .collection(get_config().manifest_collection)
+ .document(store_scheme)
+ )
diff --git a/weather_dl_v2/downloader_kubernetes/util.py b/weather_dl_v2/downloader_kubernetes/util.py
new file mode 100644
index 00000000..5777234f
--- /dev/null
+++ b/weather_dl_v2/downloader_kubernetes/util.py
@@ -0,0 +1,226 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import datetime
+import geojson
+import hashlib
+import itertools
+import os
+import socket
+import subprocess
+import sys
+import typing as t
+
+import numpy as np
+import pandas as pd
+from apache_beam.io.gcp import gcsio
+from apache_beam.utils import retry
+from xarray.core.utils import ensure_us_time_resolution
+from urllib.parse import urlparse
+from google.api_core.exceptions import BadRequest
+
+
+LATITUDE_RANGE = (-90, 90)
+LONGITUDE_RANGE = (-180, 180)
+GLOBAL_COVERAGE_AREA = [90, -180, -90, 180]
+
+
+def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter(
+ exception,
+) -> bool:
+ if isinstance(exception, socket.timeout):
+ return True
+ if isinstance(exception, TimeoutError):
+ return True
+ # To handle the concurrency issue in BigQuery.
+ if isinstance(exception, BadRequest):
+ return True
+ return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception)
+
+
+class _FakeClock:
+
+ def sleep(self, value):
+ pass
+
+
+def retry_with_exponential_backoff(fun):
+ """A retry decorator that doesn't apply during test time."""
+ clock = retry.Clock()
+
+ # Use a fake clock only during test time...
+ if "unittest" in sys.modules.keys():
+ clock = _FakeClock()
+
+ return retry.with_exponential_backoff(
+ retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter,
+ clock=clock,
+ )(fun)
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]:
+ """Yield evenly-sized chunks from an iterable."""
+ input_ = iter(iterable)
+ try:
+ while True:
+ it = itertools.islice(input_, n)
+ # peek to check if 'it' has next item.
+ first = next(it)
+ yield itertools.chain([first], it)
+ except StopIteration:
+ pass
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def copy(src: str, dst: str) -> None:
+ """Copy data via `gsutil cp`."""
+ try:
+ subprocess.run(["gsutil", "cp", src, dst], check=True, capture_output=True)
+ except subprocess.CalledProcessError as e:
+ print(
+ f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}'
+ )
+ raise
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def to_json_serializable_type(value: t.Any) -> t.Any:
+ """Returns the value with a type serializable to JSON"""
+ # Note: The order of processing is significant.
+ print("Serializing to JSON")
+
+ if pd.isna(value) or value is None:
+ return None
+ elif np.issubdtype(type(value), np.floating):
+ return float(value)
+ elif isinstance(value, np.ndarray):
+ # Will return a scaler if array is of size 1, else will return a list.
+ return value.tolist()
+ elif (
+ isinstance(value, datetime.datetime)
+ or isinstance(value, str)
+ or isinstance(value, np.datetime64)
+ ):
+ # Assume strings are ISO format timestamps...
+ try:
+ value = datetime.datetime.fromisoformat(value)
+ except ValueError:
+ # ... if they are not, assume serialization is already correct.
+ return value
+ except TypeError:
+ # ... maybe value is a numpy datetime ...
+ try:
+ value = ensure_us_time_resolution(value).astype(datetime.datetime)
+ except AttributeError:
+ # ... value is a datetime object, continue.
+ pass
+
+ # We use a string timestamp representation.
+ if value.tzname():
+ return value.isoformat()
+
+ # We assume here that naive timestamps are in UTC timezone.
+ return value.replace(tzinfo=datetime.timezone.utc).isoformat()
+ elif isinstance(value, np.timedelta64):
+ # Return time delta in seconds.
+ return float(value / np.timedelta64(1, "s"))
+ # This check must happen after processing np.timedelta64 and np.datetime64.
+ elif np.issubdtype(type(value), np.integer):
+ return int(value)
+
+ return value
+
+
+def fetch_geo_polygon(area: t.Union[list, str]) -> str:
+ """Calculates a geography polygon from an input area."""
+ # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973
+ if isinstance(area, str):
+ # European area
+ if area == "E":
+ area = [73.5, -27, 33, 45]
+ # Global area
+ elif area == "G":
+ area = GLOBAL_COVERAGE_AREA
+ else:
+ raise RuntimeError(f"Not a valid value for area in config: {area}.")
+
+ n, w, s, e = [float(x) for x in area]
+ if s < LATITUDE_RANGE[0]:
+ raise ValueError(f"Invalid latitude value for south: '{s}'")
+ if n > LATITUDE_RANGE[1]:
+ raise ValueError(f"Invalid latitude value for north: '{n}'")
+ if w < LONGITUDE_RANGE[0]:
+ raise ValueError(f"Invalid longitude value for west: '{w}'")
+ if e > LONGITUDE_RANGE[1]:
+ raise ValueError(f"Invalid longitude value for east: '{e}'")
+
+ # Define the coordinates of the bounding box.
+ coords = [[w, n], [w, s], [e, s], [e, n], [w, n]]
+
+ # Create the GeoJSON polygon object.
+ polygon = geojson.dumps(geojson.Polygon([coords]))
+ return polygon
+
+
+def get_file_size(path: str) -> float:
+ parsed_gcs_path = urlparse(path)
+ if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "":
+ return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0
+ else:
+ return (
+ gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0
+ )
+
+
+def get_wait_interval(num_retries: int = 0) -> float:
+ """Returns next wait interval in seconds, using an exponential backoff algorithm."""
+ if 0 == num_retries:
+ return 0
+ return 2**num_retries
+
+
+def generate_md5_hash(input: str) -> str:
+ """Generates md5 hash for the input string."""
+ return hashlib.md5(input.encode("utf-8")).hexdigest()
+
+
+def download_with_aria2(url: str, path: str) -> None:
+ """Downloads a file from the given URL using the `aria2c` command-line utility,
+ with options set to improve download speed and reliability."""
+ dir_path, file_name = os.path.split(path)
+ try:
+ subprocess.run(
+ [
+ "aria2c",
+ "-x",
+ "16",
+ "-s",
+ "16",
+ url,
+ "-d",
+ dir_path,
+ "-o",
+ file_name,
+ "--allow-overwrite",
+ ],
+ check=True,
+ capture_output=True,
+ )
+ except subprocess.CalledProcessError as e:
+ print(
+ f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}'
+ )
+ raise
diff --git a/weather_dl_v2/fastapi-server/API-Interactions.md b/weather_dl_v2/fastapi-server/API-Interactions.md
new file mode 100644
index 00000000..3ea4eece
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/API-Interactions.md
@@ -0,0 +1,25 @@
+# API Interactions
+| Command | Type | Endpoint |
+|---|---|---|
+| `weather-dl-v2 ping` | `get` | `/`
+| Download | | |
+| `weather-dl-v2 download add –l [--force-download]` | `post` | `/download?force_download={value}` |
+| `weather-dl-v2 download list` | `get` | `/download/` |
+| `weather-dl-v2 download list --filter client_name=` | `get` | `/download?client_name={name}` |
+| `weather-dl-v2 download get ` | `get` | `/download/{config_name}` |
+| `weather-dl-v2 download show ` | `get` | `/download/show/{config_name}` |
+| `weather-dl-v2 download remove ` | `delete` | `/download/{config_name}` |
+| `weather-dl-v2 download refetch -l ` | `post` | `/download/refetch/{config_name}` |
+| License | | |
+| `weather-dl-v2 license add ` | `post` | `/license/` |
+| `weather-dl-v2 license get ` | `get` | `/license/{license_id}` |
+| `weather-dl-v2 license remove ` | `delete` | `/license/{license_id}` |
+| `weather-dl-v2 license list` | `get` | `/license/` |
+| `weather-dl-v2 license list --filter client_name=` | `get` | `/license?client_name={name}` |
+| `weather-dl-v2 license edit ` | `put` | `/license/{license_id}` |
+| Queue | | |
+| `weather-dl-v2 queue list` | `get` | `/queues/` |
+| `weather-dl-v2 queue list --filter client_name=` | `get` | `/queues?client_name={name}` |
+| `weather-dl-v2 queue get ` | `get` | `/queues/{license_id}` |
+| `queue edit --config --priority ` | `post` | `/queues/{license_id}` |
+| `queue edit --file ` | `put` | `/queues/priority/{license_id}` |
\ No newline at end of file
diff --git a/weather_dl_v2/fastapi-server/Dockerfile b/weather_dl_v2/fastapi-server/Dockerfile
new file mode 100644
index 00000000..b54e41c0
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/Dockerfile
@@ -0,0 +1,40 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+FROM continuumio/miniconda3:latest
+
+EXPOSE 8080
+
+# Update miniconda
+RUN conda update conda -y
+
+# Add the mamba solver for faster builds
+RUN conda install -n base conda-libmamba-solver
+RUN conda config --set solver libmamba
+
+COPY . .
+# Create conda env using environment.yml
+RUN conda env create -f environment.yml --debug
+
+# Activate the conda env and update the PATH
+ARG CONDA_ENV_NAME=weather-dl-v2-server
+RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc
+ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH
+
+# Use the ping endpoint as a healthcheck,
+# so Docker knows if the API is still running ok or needs to be restarted
+HEALTHCHECK --interval=21s --timeout=3s --start-period=10s CMD curl --fail http://localhost:8080/ping || exit 1
+
+CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
diff --git a/weather_dl_v2/fastapi-server/README.md b/weather_dl_v2/fastapi-server/README.md
new file mode 100644
index 00000000..2debb563
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/README.md
@@ -0,0 +1,91 @@
+# Deployment Instructions & General Notes
+
+### User authorization required to set up the environment:
+* roles/container.admin
+
+### Authorization needed for the tool to operate:
+We are not configuring any service account here hence make sure that compute engine default service account have roles:
+* roles/pubsub.subscriber
+* roles/storage.admin
+* roles/bigquery.dataEditor
+* roles/bigquery.jobUser
+
+### Install kubectl:
+```
+apt-get update
+
+apt-get install -y kubectl
+```
+
+### Create cluster:
+```
+export PROJECT_ID=
+export REGION= eg: us-west1
+export ZONE= eg: us-west1-a
+export CLUSTER_NAME= eg: weather-dl-v2-cluster
+export DOWNLOAD_NODE_POOL=downloader-pool
+
+gcloud beta container --project $PROJECT_ID clusters create $CLUSTER_NAME --zone $ZONE --no-enable-basic-auth --cluster-version "1.27.2-gke.1200" --release-channel "regular" --machine-type "e2-standard-8" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "1100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/cloud-platform" --max-pods-per-node "16" --num-nodes "4" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM --enable-ip-alias --network "projects/$PROJECT_ID/global/networks/default" --subnetwork "projects/$PROJECT_ID/regions/$REGION/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "16" --enable-autoscaling --min-nodes "4" --max-nodes "100" --location-policy "BALANCED" --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --enable-managed-prometheus --enable-shielded-nodes --node-locations $ZONE --node-labels preemptible=false && gcloud beta container --project $PROJECT_ID node-pools create $DOWNLOAD_NODE_POOL --cluster $CLUSTER_NAME --zone $ZONE --machine-type "e2-standard-8" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "1100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/cloud-platform" --max-pods-per-node "16" --num-nodes "1" --enable-autoscaling --min-nodes "1" --max-nodes "100" --location-policy "BALANCED" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations $ZONE --node-labels preemptible=false
+```
+
+### Connect to Cluster:
+```
+gcloud container clusters get-credentials $CLUSTER_NAME --zone $ZONE --project $PROJECT_ID
+```
+
+### How to create environment:
+```
+conda env create --name weather-dl-v2-server --file=environment.yml
+
+conda activate weather-dl-v2-server
+```
+
+### Make changes in weather_dl_v2/config.json, if required [for running locally]
+```
+export CONFIG_PATH=/path/to/weather_dl_v2/config.json
+```
+
+### To run fastapi server:
+```
+uvicorn main:app --reload
+```
+
+* Open your browser at http://127.0.0.1:8000.
+
+
+### Create docker image for server:
+```
+export PROJECT_ID=
+export REPO= eg:weather-tools
+
+gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-server" --timeout=79200 --machine-type=e2-highcpu-32
+```
+
+### Add path of created server image in server.yaml:
+```
+Please write down the fastAPI server's docker image path at Line 42 of server.yaml.
+```
+
+### Create ConfigMap of common configurations for services:
+Make necessary changes to weather_dl_v2/config.json and run following command.
+ConfigMap is used for:
+- Having a common configuration file for all services.
+- Decoupling docker image and config files.
+```
+kubectl create configmap dl-v2-config --from-file=/path/to/weather_dl_v2/config.json
+```
+
+### Deploy fastapi server on kubernetes:
+```
+kubectl apply -f server.yaml --force
+```
+
+## General Commands
+### For viewing the current pods:
+```
+kubectl get pods
+```
+
+### For deleting existing deployment:
+```
+kubectl delete -f server.yaml --force
\ No newline at end of file
diff --git a/weather_dl_v2/fastapi-server/__init__.py b/weather_dl_v2/fastapi-server/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/fastapi-server/config_processing/config.py b/weather_dl_v2/fastapi-server/config_processing/config.py
new file mode 100644
index 00000000..fe2199b8
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/config.py
@@ -0,0 +1,120 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import calendar
+import copy
+import dataclasses
+import typing as t
+
+Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet
+
+
+@dataclasses.dataclass
+class Config:
+ """Contains pipeline parameters.
+
+ Attributes:
+ config_name:
+ Name of the config file.
+ client:
+ Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable.
+ dataset (optional):
+ Name of the target dataset. Allowed options are dictated by the client.
+ partition_keys (optional):
+ Choose the keys from the selection section to partition the data request.
+ This will compute a cartesian cross product of the selected keys
+ and assign each as their own download.
+ target_path:
+ Download artifact filename template. Can make use of Python's standard string formatting.
+ It can contain format symbols to be replaced by partition keys;
+ if this is used, the total number of format symbols must match the number of partition keys.
+ subsection_name:
+ Name of the particular subsection. 'default' if there is no subsection.
+ force_download:
+ Force redownload of partitions that were previously downloaded.
+ user_id:
+ Username from the environment variables.
+ kwargs (optional):
+ For representing subsections or any other parameters.
+ selection:
+ Contains parameters used to select desired data.
+ """
+
+ config_name: str = ""
+ client: str = ""
+ dataset: t.Optional[str] = ""
+ target_path: str = ""
+ partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list)
+ subsection_name: str = "default"
+ force_download: bool = False
+ user_id: str = "unknown"
+ kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict)
+ selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: t.Dict) -> "Config":
+ config_instance = cls()
+ for section_key, section_value in config.items():
+ if section_key == "parameters":
+ for key, value in section_value.items():
+ if hasattr(config_instance, key):
+ setattr(config_instance, key, value)
+ else:
+ config_instance.kwargs[key] = value
+ if section_key == "selection":
+ config_instance.selection = section_value
+ return config_instance
+
+
+def optimize_selection_partition(selection: t.Dict) -> t.Dict:
+ """Compute right-hand-side values for the selection section of a single partition.
+
+ Used to support custom syntax and optimizations, such as 'all'.
+ """
+ selection_ = copy.deepcopy(selection)
+
+ if "day" in selection_.keys() and selection_["day"] == "all":
+ year, month = selection_["year"], selection_["month"]
+
+ multiples_error = (
+ "Cannot use keyword 'all' on selections with multiple '{type}'s."
+ )
+
+ if isinstance(year, list):
+ assert len(year) == 1, multiples_error.format(type="year")
+ year = year[0]
+
+ if isinstance(month, list):
+ assert len(month) == 1, multiples_error.format(type="month")
+ month = month[0]
+
+ if isinstance(year, str):
+ assert "/" not in year, multiples_error.format(type="year")
+
+ if isinstance(month, str):
+ assert "/" not in month, multiples_error.format(type="month")
+
+ year, month = int(year), int(month)
+
+ _, n_days_in_month = calendar.monthrange(year, month)
+
+ selection_[
+ "date"
+ ] = f"{year:04d}-{month:02d}-01/to/{year:04d}-{month:02d}-{n_days_in_month:02d}"
+ del selection_["day"]
+ del selection_["month"]
+ del selection_["year"]
+
+ return selection_
diff --git a/weather_dl_v2/fastapi-server/config_processing/manifest.py b/weather_dl_v2/fastapi-server/config_processing/manifest.py
new file mode 100644
index 00000000..35a8bf7b
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/manifest.py
@@ -0,0 +1,513 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Client interface for connecting to a manifest."""
+
+import abc
+import dataclasses
+import logging
+import datetime
+import enum
+import json
+import pandas as pd
+import time
+import traceback
+import typing as t
+
+from .util import (
+ to_json_serializable_type,
+ fetch_geo_polygon,
+ get_file_size,
+ get_wait_interval,
+ generate_md5_hash,
+ GLOBAL_COVERAGE_AREA,
+)
+
+import firebase_admin
+from firebase_admin import credentials
+from firebase_admin import firestore
+from google.cloud.firestore_v1 import DocumentReference
+from google.cloud.firestore_v1.types import WriteResult
+from server_config import get_config
+from database.session import Database
+
+"""An implementation-dependent Manifest URI."""
+Location = t.NewType("Location", str)
+
+logger = logging.getLogger(__name__)
+
+
+class ManifestException(Exception):
+ """Errors that occur in Manifest Clients."""
+
+ pass
+
+
+class Stage(enum.Enum):
+ """A request can be either in one of the following stages at a time:
+
+ fetch : This represents request is currently in fetch stage i.e. request placed on the client's server
+ & waiting for some result before starting download (eg. MARS client).
+ download : This represents request is currently in download stage i.e. data is being downloading from client's
+ server to the worker's local file system.
+ upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local
+ file system to target location (GCS path).
+ retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client),
+ request will be in the retrieve stage i.e. fetch + download.
+ """
+
+ RETRIEVE = "retrieve"
+ FETCH = "fetch"
+ DOWNLOAD = "download"
+ UPLOAD = "upload"
+
+
+class Status(enum.Enum):
+ """Depicts the request's state status:
+
+ scheduled : A request partition is created & scheduled for processing.
+ Note: Its corresponding state can be None only.
+ in-progress : This represents the request state is currently in-progress (i.e. running).
+ The next status would be "success" or "failure".
+ success : This represents the request state execution completed successfully without any error.
+ failure : This represents the request state execution failed.
+ """
+
+ SCHEDULED = "scheduled"
+ IN_PROGRESS = "in-progress"
+ SUCCESS = "success"
+ FAILURE = "failure"
+
+
+@dataclasses.dataclass
+class DownloadStatus:
+ """Data recorded in `Manifest`s reflecting the status of a download."""
+
+ """The name of the config file associated with the request."""
+ config_name: str = ""
+
+ """Represents the dataset field of the configuration."""
+ dataset: t.Optional[str] = ""
+
+ """Copy of selection section of the configuration."""
+ selection: t.Dict = dataclasses.field(default_factory=dict)
+
+ """Location of the downloaded data."""
+ location: str = ""
+
+ """Represents area covered by the shard."""
+ area: str = ""
+
+ """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None."""
+ stage: t.Optional[Stage] = None
+
+ """Download status: 'scheduled', 'in-progress', 'success', or 'failure'."""
+ status: t.Optional[Status] = None
+
+ """Cause of error, if any."""
+ error: t.Optional[str] = ""
+
+ """Identifier for the user running the download."""
+ username: str = ""
+
+ """Shard size in GB."""
+ size: t.Optional[float] = 0
+
+ """A UTC datetime when download was scheduled."""
+ scheduled_time: t.Optional[str] = ""
+
+ """A UTC datetime when the retrieve stage starts."""
+ retrieve_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the retrieve state ends."""
+ retrieve_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the fetch state starts."""
+ fetch_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the fetch state ends."""
+ fetch_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the download state starts."""
+ download_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the download state ends."""
+ download_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the upload state starts."""
+ upload_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the upload state ends."""
+ upload_end_time: t.Optional[str] = ""
+
+ @classmethod
+ def from_dict(cls, download_status: t.Dict) -> "DownloadStatus":
+ """Instantiate DownloadStatus dataclass from dict."""
+ download_status_instance = cls()
+ for key, value in download_status.items():
+ if key == "status":
+ setattr(download_status_instance, key, Status(value))
+ elif key == "stage" and value is not None:
+ setattr(download_status_instance, key, Stage(value))
+ else:
+ setattr(download_status_instance, key, value)
+ return download_status_instance
+
+ @classmethod
+ def to_dict(cls, instance) -> t.Dict:
+ """Return the fields of a dataclass instance as a manifest ingestible
+ dictionary mapping of field names to field values."""
+ download_status_dict = {}
+ for field in dataclasses.fields(instance):
+ key = field.name
+ value = getattr(instance, field.name)
+ if isinstance(value, Status) or isinstance(value, Stage):
+ download_status_dict[key] = value.value
+ elif isinstance(value, pd.Timestamp):
+ download_status_dict[key] = value.isoformat()
+ elif key == "selection" and value is not None:
+ download_status_dict[key] = json.dumps(value)
+ else:
+ download_status_dict[key] = value
+ return download_status_dict
+
+
+@dataclasses.dataclass
+class Manifest(abc.ABC):
+ """Abstract manifest of download statuses.
+
+ Update download statuses to some storage medium.
+
+ This class lets one indicate that a download is `scheduled` or in a transaction process.
+ In the event of a transaction, a download will be updated with an `in-progress`, `success`
+ or `failure` status (with accompanying metadata).
+
+ Example:
+ ```
+ my_manifest = parse_manifest_location(Location('fs://some-firestore-collection'))
+
+ # Schedule data for download
+ my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username')
+
+ # ...
+
+ # Initiate a transaction – it will record that the download is `in-progess`
+ with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx:
+ # download logic here
+ pass
+
+ # ...
+
+ # on error, will record the download as a `failure` before propagating the error. By default, it will
+ # record download as a `success`.
+ ```
+
+ Attributes:
+ status: The current `DownloadStatus` of the Manifest.
+ """
+
+ # To reduce the impact of _read() and _update() calls
+ # on the start time of the stage.
+ prev_stage_precise_start_time: t.Optional[str] = None
+ status: t.Optional[DownloadStatus] = None
+
+ # This is overridden in subclass.
+ def __post_init__(self):
+ """Initialize the manifest."""
+ pass
+
+ def schedule(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Indicate that a job has been scheduled for download.
+
+ 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'.
+ """
+ scheduled_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ self.status = DownloadStatus(
+ config_name=config_name,
+ dataset=dataset if dataset else None,
+ selection=selection,
+ location=location,
+ area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)),
+ username=user,
+ stage=None,
+ status=Status.SCHEDULED,
+ error=None,
+ size=None,
+ scheduled_time=scheduled_time,
+ retrieve_start_time=None,
+ retrieve_end_time=None,
+ fetch_start_time=None,
+ fetch_end_time=None,
+ download_start_time=None,
+ download_end_time=None,
+ upload_start_time=None,
+ upload_end_time=None,
+ )
+ self._update(self.status)
+
+ def skip(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Updates the manifest to mark the shards that were skipped in the current job
+ as 'upload' stage and 'success' status, indicating that they have already been downloaded.
+ """
+ old_status = self._read(location)
+ # The manifest needs to be updated for a skipped shard if its entry is not present, or
+ # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'.
+ if (
+ old_status.location != location
+ or old_status.stage != Stage.UPLOAD
+ or old_status.status != Status.SUCCESS
+ ):
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ size = get_file_size(location)
+
+ status = DownloadStatus(
+ config_name=config_name,
+ dataset=dataset if dataset else None,
+ selection=selection,
+ location=location,
+ area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)),
+ username=user,
+ stage=Stage.UPLOAD,
+ status=Status.SUCCESS,
+ error=None,
+ size=size,
+ scheduled_time=None,
+ retrieve_start_time=None,
+ retrieve_end_time=None,
+ fetch_start_time=None,
+ fetch_end_time=None,
+ download_start_time=None,
+ download_end_time=None,
+ upload_start_time=current_utc_time,
+ upload_end_time=current_utc_time,
+ )
+ self._update(status)
+ logger.info(
+ f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}."
+ )
+
+ def _set_for_transaction(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Reset Manifest state in preparation for a new transaction."""
+ self.status = dataclasses.replace(self._read(location))
+ self.status.config_name = config_name
+ self.status.dataset = dataset if dataset else None
+ self.status.selection = selection
+ self.status.location = location
+ self.status.username = user
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, exc_type, exc_inst, exc_tb) -> None:
+ """Record end status of a transaction as either 'success' or 'failure'."""
+ if exc_type is None:
+ status = Status.SUCCESS
+ error = None
+ else:
+ status = Status.FAILURE
+ # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception
+ error = "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb))
+
+ new_status = dataclasses.replace(self.status)
+ new_status.error = error
+ new_status.status = status
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ # This is necessary for setting the precise start time of the previous stage
+ # and end time of the final stage, as well as handling the case of Status.FAILURE.
+ if new_status.stage == Stage.FETCH:
+ new_status.fetch_start_time = self.prev_stage_precise_start_time
+ new_status.fetch_end_time = current_utc_time
+ elif new_status.stage == Stage.RETRIEVE:
+ new_status.retrieve_start_time = self.prev_stage_precise_start_time
+ new_status.retrieve_end_time = current_utc_time
+ elif new_status.stage == Stage.DOWNLOAD:
+ new_status.download_start_time = self.prev_stage_precise_start_time
+ new_status.download_end_time = current_utc_time
+ else:
+ new_status.upload_start_time = self.prev_stage_precise_start_time
+ new_status.upload_end_time = current_utc_time
+
+ new_status.size = get_file_size(new_status.location)
+
+ self.status = new_status
+
+ self._update(self.status)
+
+ def transact(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> "Manifest":
+ """Create a download transaction."""
+ self._set_for_transaction(config_name, dataset, selection, location, user)
+ return self
+
+ def set_stage(self, stage: Stage) -> None:
+ """Sets the current stage in manifest."""
+ prev_stage = self.status.stage
+ new_status = dataclasses.replace(self.status)
+ new_status.stage = stage
+ new_status.status = Status.IN_PROGRESS
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ if stage == Stage.FETCH:
+ new_status.fetch_start_time = current_utc_time
+ elif stage == Stage.RETRIEVE:
+ new_status.retrieve_start_time = current_utc_time
+ elif stage == Stage.DOWNLOAD:
+ new_status.fetch_start_time = self.prev_stage_precise_start_time
+ new_status.fetch_end_time = current_utc_time
+ new_status.download_start_time = current_utc_time
+ else:
+ if prev_stage == Stage.DOWNLOAD:
+ new_status.download_start_time = self.prev_stage_precise_start_time
+ new_status.download_end_time = current_utc_time
+ else:
+ new_status.retrieve_start_time = self.prev_stage_precise_start_time
+ new_status.retrieve_end_time = current_utc_time
+ new_status.upload_start_time = current_utc_time
+
+ self.status = new_status
+ self._update(self.status)
+
+ @abc.abstractmethod
+ def _read(self, location: str) -> DownloadStatus:
+ pass
+
+ @abc.abstractmethod
+ def _update(self, download_status: DownloadStatus) -> None:
+ pass
+
+
+class FirestoreManifest(Manifest, Database):
+ """A Firestore Manifest.
+ This Manifest implementation stores DownloadStatuses in a Firebase document store.
+ The document hierarchy for the manifest is as follows:
+ [manifest ]
+ ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... }
+ └── etc...
+ Where `[]` indicates a collection and ` {...}` indicates a document.
+ """
+
+ def _get_db(self) -> firestore.firestore.Client:
+ """Acquire a firestore client, initializing the firebase app if necessary.
+ Will attempt to get the db client five times. If it's still unsuccessful, a
+ `ManifestException` will be raised.
+ """
+ db = None
+ attempts = 0
+
+ while db is None:
+ try:
+ db = firestore.client()
+ except ValueError as e:
+ # The above call will fail with a value error when the firebase app is not initialized.
+ # Initialize the app here, and try again.
+ # Use the application default credentials.
+ cred = credentials.ApplicationDefault()
+
+ firebase_admin.initialize_app(cred)
+ logger.info("Initialized Firebase App.")
+
+ if attempts > 4:
+ raise ManifestException(
+ "Exceeded number of retries to get firestore client."
+ ) from e
+
+ time.sleep(get_wait_interval(attempts))
+
+ attempts += 1
+
+ return db
+
+ def _read(self, location: str) -> DownloadStatus:
+ """Reads the JSON data from a manifest."""
+
+ doc_id = generate_md5_hash(location)
+
+ # Update document with download status
+ download_doc_ref = self.root_document_for_store(doc_id)
+
+ result = download_doc_ref.get()
+ row = {}
+ if result.exists:
+ records = result.to_dict()
+ row = {n: to_json_serializable_type(v) for n, v in records.items()}
+ return DownloadStatus.from_dict(row)
+
+ def _update(self, download_status: DownloadStatus) -> None:
+ """Update or create a download status record."""
+ logger.info("Updating Firestore Manifest.")
+
+ status = DownloadStatus.to_dict(download_status)
+ doc_id = generate_md5_hash(status["location"])
+
+ # Update document with download status
+ download_doc_ref = self.root_document_for_store(doc_id)
+
+ result: WriteResult = download_doc_ref.set(status)
+
+ logger.info(
+ f"Firestore manifest updated. "
+ f"update_time={result.update_time}, "
+ f"filename={download_status.location}."
+ )
+
+ def root_document_for_store(self, store_scheme: str) -> DocumentReference:
+ """Get the root manifest document given the user's config and current document's storage location."""
+ root_collection = get_config().manifest_collection
+ return self._get_db().collection(root_collection).document(store_scheme)
diff --git a/weather_dl_v2/fastapi-server/config_processing/parsers.py b/weather_dl_v2/fastapi-server/config_processing/parsers.py
new file mode 100644
index 00000000..5f9e1f5c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/parsers.py
@@ -0,0 +1,507 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Parsers for ECMWF download configuration."""
+
+import ast
+import configparser
+import copy as cp
+import datetime
+import json
+import string
+import textwrap
+import typing as t
+import numpy as np
+from collections import OrderedDict
+from .config import Config
+
+CLIENTS = ["cds", "mars", "ecpublic"]
+
+
+def date(candidate: str) -> datetime.date:
+ """Converts ECMWF-format date strings into a `datetime.date`.
+
+ Accepted absolute date formats:
+ - YYYY-MM-DD
+ - YYYYMMDD
+ - YYYY-DDD, where DDD refers to the day of the year
+
+ For example:
+ - 2021-10-31
+ - 19700101
+ - 1950-007
+
+ See https://confluence.ecmwf.int/pages/viewpage.action?pageId=118817289 for date format spec.
+ Note: Name of month is not supported.
+ """
+ converted = None
+
+ # Parse relative day value.
+ if candidate.startswith("-"):
+ return datetime.date.today() + datetime.timedelta(days=int(candidate))
+
+ accepted_formats = ["%Y-%m-%d", "%Y%m%d", "%Y-%j"]
+
+ for fmt in accepted_formats:
+ try:
+ converted = datetime.datetime.strptime(candidate, fmt).date()
+ break
+ except ValueError:
+ pass
+
+ if converted is None:
+ raise ValueError(
+ f"Not a valid date: '{candidate}'. Please use valid relative or absolute format."
+ )
+
+ return converted
+
+
+def time(candidate: str) -> datetime.time:
+ """Converts ECMWF-format time strings into a `datetime.time`.
+
+ Accepted time formats:
+ - HH:MM
+ - HHMM
+ - HH
+
+ For example:
+ - 18:00
+ - 1820
+ - 18
+
+ Note: If MM is omitted it defaults to 00.
+ """
+ converted = None
+
+ accepted_formats = ["%H", "%H:%M", "%H%M"]
+
+ for fmt in accepted_formats:
+ try:
+ converted = datetime.datetime.strptime(candidate, fmt).time()
+ break
+ except ValueError:
+ pass
+
+ if converted is None:
+ raise ValueError(f"Not a valid time: '{candidate}'. Please use valid format.")
+
+ return converted
+
+
+def day_month_year(candidate: t.Any) -> int:
+ """Converts day, month and year strings into 'int'."""
+ try:
+ if isinstance(candidate, str) or isinstance(candidate, int):
+ return int(candidate)
+ raise ValueError("must be a str or int.")
+ except ValueError as e:
+ raise ValueError(
+ f"Not a valid day, month, or year value: {candidate}. Please use valid value."
+ ) from e
+
+
+def parse_literal(candidate: t.Any) -> t.Any:
+ try:
+ # Support parsing ints with leading zeros, e.g. '01'
+ if isinstance(candidate, str) and candidate.isdigit():
+ return int(candidate)
+ return ast.literal_eval(candidate)
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ return candidate
+
+
+def validate(key: str, value: int) -> None:
+ """Validates value based on the key."""
+ if key == "day":
+ assert 1 <= value <= 31, "Day value must be between 1 to 31."
+ if key == "month":
+ assert 1 <= value <= 12, "Month value must be between 1 to 12."
+
+
+def typecast(key: str, value: t.Any) -> t.Any:
+ """Type the value to its appropriate datatype."""
+ SWITCHER = {
+ "date": date,
+ "time": time,
+ "day": day_month_year,
+ "month": day_month_year,
+ "year": day_month_year,
+ }
+ converted = SWITCHER.get(key, parse_literal)(value)
+ validate(key, converted)
+ return converted
+
+
+def _read_config_file(file: t.IO) -> t.Dict:
+ """Reads `*.json` or `*.cfg` files."""
+ try:
+ return json.load(file)
+ except json.JSONDecodeError:
+ pass
+
+ file.seek(0)
+
+ try:
+ config = configparser.ConfigParser()
+ config.read_file(file)
+ config = {s: dict(config.items(s)) for s in config.sections()}
+ return config
+ except configparser.ParsingError:
+ return {}
+
+
+def parse_config(file: t.IO) -> t.Dict:
+ """Parses a `*.json` or `*.cfg` file into a configuration dictionary."""
+ config = _read_config_file(file)
+ config_by_section = {s: _parse_lists(v, s) for s, v in config.items()}
+ config_with_nesting = parse_subsections(config_by_section)
+ return config_with_nesting
+
+
+def _splitlines(block: str) -> t.List[str]:
+ """Converts a multi-line block into a list of strings."""
+ return [line.strip() for line in block.strip().splitlines()]
+
+
+def mars_range_value(token: str) -> t.Union[datetime.date, int, float]:
+ """Converts a range token into either a date, int, or float."""
+ try:
+ return date(token)
+ except ValueError:
+ pass
+
+ if token.isdecimal():
+ return int(token)
+
+ try:
+ return float(token)
+ except ValueError:
+ raise ValueError(
+ "Token string must be an 'int', 'float', or 'datetime.date()'."
+ )
+
+
+def mars_increment_value(token: str) -> t.Union[int, float]:
+ """Converts an increment token into either an int or a float."""
+ try:
+ return int(token)
+ except ValueError:
+ pass
+
+ try:
+ return float(token)
+ except ValueError:
+ raise ValueError("Token string must be an 'int' or a 'float'.")
+
+
+def parse_mars_syntax(block: str) -> t.List[str]:
+ """Parses MARS list or range into a list of arguments; ranges are inclusive.
+
+ Types for the range and value are inferred.
+
+ Examples:
+ >>> parse_mars_syntax("10/to/12")
+ ['10', '11', '12']
+ >>> parse_mars_syntax("12/to/10/by/-1")
+ ['12', '11', '10']
+ >>> parse_mars_syntax("0.0/to/0.5/by/0.1")
+ ['0.0', '0.1', '0.2', '0.30000000000000004', '0.4', '0.5']
+ >>> parse_mars_syntax("2020-01-07/to/2020-01-14/by/2")
+ ['2020-01-07', '2020-01-09', '2020-01-11', '2020-01-13']
+ >>> parse_mars_syntax("2020-01-14/to/2020-01-07/by/-2")
+ ['2020-01-14', '2020-01-12', '2020-01-10', '2020-01-08']
+
+ Returns:
+ A list of strings representing a range from start to finish, based on the
+ type of the values in the range.
+ If all range values are integers, it will return a list of strings of integers.
+ If range values are floats, it will return a list of strings of floats.
+ If the range values are dates, it will return a list of strings of dates in
+ YYYY-MM-DD format. (Note: here, the increment value should be an integer).
+ """
+
+ # Split into tokens, omitting empty strings.
+ tokens = [b.strip() for b in block.split("/") if b != ""]
+
+ # Return list if no range operators are present.
+ if "to" not in tokens and "by" not in tokens:
+ return tokens
+
+ # Parse range values, honoring 'to' and 'by' operators.
+ try:
+ to_idx = tokens.index("to")
+ assert to_idx != 0, "There must be a start token."
+ start_token, end_token = tokens[to_idx - 1], tokens[to_idx + 1]
+ start, end = mars_range_value(start_token), mars_range_value(end_token)
+
+ # Parse increment token, or choose default increment.
+ increment_token = "1"
+ increment = 1
+ if "by" in tokens:
+ increment_token = tokens[tokens.index("by") + 1]
+ increment = mars_increment_value(increment_token)
+ except (AssertionError, IndexError, ValueError):
+ raise SyntaxError(f"Improper range syntax in '{block}'.")
+
+ # Return a range of values with appropriate data type.
+ if isinstance(start, datetime.date) and isinstance(end, datetime.date):
+ if not isinstance(increment, int):
+ raise ValueError(
+ f"Increments on a date range must be integer number of days, '{increment_token}' is invalid."
+ )
+ return [d.strftime("%Y-%m-%d") for d in date_range(start, end, increment)]
+ elif (isinstance(start, float) or isinstance(end, float)) and not isinstance(
+ increment, datetime.date
+ ):
+ # Increment can be either an int or a float.
+ _round_places = 4
+ return [
+ str(round(x, _round_places)).zfill(len(start_token))
+ for x in np.arange(start, end + increment, increment)
+ ]
+ elif isinstance(start, int) and isinstance(end, int) and isinstance(increment, int):
+ # Honor leading zeros.
+ offset = 1 if start <= end else -1
+ return [
+ str(x).zfill(len(start_token))
+ for x in range(start, end + offset, increment)
+ ]
+ else:
+ raise ValueError(
+ f"Range tokens (start='{start_token}', end='{end_token}', increment='{increment_token}')"
+ f" are inconsistent types."
+ )
+
+
+def date_range(
+ start: datetime.date, end: datetime.date, increment: int = 1
+) -> t.Iterable[datetime.date]:
+ """Gets a range of dates, inclusive."""
+ offset = 1 if start <= end else -1
+ return (
+ start + datetime.timedelta(days=x)
+ for x in range(0, (end - start).days + offset, increment)
+ )
+
+
+def _parse_lists(config: dict, section: str = "") -> t.Dict:
+ """Parses multiline blocks in *.cfg and *.json files as lists."""
+ for key, val in config.items():
+ # Checks str type for backward compatibility since it also support "padding": 0 in json config
+ if not isinstance(val, str):
+ continue
+
+ if "/" in val and "parameters" not in section:
+ config[key] = parse_mars_syntax(val)
+ elif "\n" in val:
+ config[key] = _splitlines(val)
+
+ return config
+
+
+def _number_of_replacements(s: t.Text):
+ format_names = [v[1] for v in string.Formatter().parse(s) if v[1] is not None]
+ num_empty_names = len([empty for empty in format_names if empty == ""])
+ if num_empty_names != 0:
+ num_empty_names -= 1
+ return len(set(format_names)) + num_empty_names
+
+
+def parse_subsections(config: t.Dict) -> t.Dict:
+ """Interprets [section.subsection] as nested dictionaries in `.cfg` files."""
+ copy = cp.deepcopy(config)
+ for key, val in copy.items():
+ path = key.split(".")
+ runner = copy
+ parent = {}
+ p = None
+ for p in path:
+ if p not in runner:
+ runner[p] = {}
+ parent = runner
+ runner = runner[p]
+ parent[p] = val
+
+ for_cleanup = [key for key, _ in copy.items() if "." in key]
+ for target in for_cleanup:
+ del copy[target]
+ return copy
+
+
+def require(
+ condition: bool, message: str, error_type: t.Type[Exception] = ValueError
+) -> None:
+ """A assert-like helper that wraps text and throws an error."""
+ if not condition:
+ raise error_type(textwrap.dedent(message))
+
+
+def process_config(file: t.IO, config_name: str) -> Config:
+ """Read the config file and prompt the user if it is improperly structured."""
+ config = parse_config(file)
+
+ require(bool(config), "Unable to parse configuration file.")
+ require(
+ "parameters" in config,
+ """
+ 'parameters' section required in configuration file.
+
+ The 'parameters' section specifies the 'client', 'dataset', 'target_path', and
+ 'partition_key' for the API client.
+
+ Please consult the documentation for more information.""",
+ )
+
+ params = config.get("parameters", {})
+ require(
+ "target_template" not in params,
+ """
+ 'target_template' is deprecated, use 'target_path' instead.
+
+ Please consult the documentation for more information.""",
+ )
+ require(
+ "target_path" in params,
+ """
+ 'parameters' section requires a 'target_path' key.
+
+ The 'target_path' is used to format the name of the output files. It
+ accepts Python 3.5+ string format symbols (e.g. '{}'). The number of symbols
+ should match the length of the 'partition_keys', as the 'partition_keys' args
+ are used to create the templates.""",
+ )
+ require(
+ "client" in params,
+ """
+ 'parameters' section requires a 'client' key.
+
+ Supported clients are {}
+ """.format(
+ str(CLIENTS)
+ ),
+ )
+ require(
+ params.get("client") in CLIENTS,
+ """
+ Invalid 'client' parameter.
+
+ Supported clients are {}
+ """.format(
+ str(CLIENTS)
+ ),
+ )
+ require(
+ "append_date_dirs" not in params,
+ """
+ The current version of 'google-weather-tools' no longer supports 'append_date_dirs'!
+
+ Please refer to documentation for creating date-based directory hierarchy :
+ https://weather-tools.readthedocs.io/en/latest/Configuration.html#"""
+ """creating-a-date-based-directory-hierarchy.""",
+ NotImplementedError,
+ )
+ require(
+ "target_filename" not in params,
+ """
+ The current version of 'google-weather-tools' no longer supports 'target_filename'!
+
+ Please refer to documentation :
+ https://weather-tools.readthedocs.io/en/latest/Configuration.html#parameters-section.""",
+ NotImplementedError,
+ )
+
+ partition_keys = params.get("partition_keys", list())
+ if isinstance(partition_keys, str):
+ partition_keys = [partition_keys.strip()]
+
+ selection = config.get("selection", dict())
+ require(
+ all((key in selection for key in partition_keys)),
+ """
+ All 'partition_keys' must appear in the 'selection' section.
+
+ 'partition_keys' specify how to split data for workers. Please consult
+ documentation for more information.""",
+ )
+
+ num_template_replacements = _number_of_replacements(params["target_path"])
+ num_partition_keys = len(partition_keys)
+
+ require(
+ num_template_replacements == num_partition_keys,
+ """
+ 'target_path' has {0} replacements. Expected {1}, since there are {1}
+ partition keys.
+ """.format(
+ num_template_replacements, num_partition_keys
+ ),
+ )
+
+ if "day" in partition_keys:
+ require(
+ selection["day"] != "all",
+ """If 'all' is used for a selection value, it cannot appear as a partition key.""",
+ )
+
+ # Ensure consistent lookup.
+ config["parameters"]["partition_keys"] = partition_keys
+ # Add config file name.
+ config["parameters"]["config_name"] = config_name
+
+ # Ensure the cartesian-cross can be taken on singleton values for the partition.
+ for key in partition_keys:
+ if not isinstance(selection[key], list):
+ selection[key] = [selection[key]]
+
+ return Config.from_dict(config)
+
+
+def prepare_target_name(config: Config) -> str:
+ """Returns name of target location."""
+ partition_dict = OrderedDict(
+ (key, typecast(key, config.selection[key][0])) for key in config.partition_keys
+ )
+ target = config.target_path.format(*partition_dict.values(), **partition_dict)
+
+ return target
+
+
+def get_subsections(config: Config) -> t.List[t.Tuple[str, t.Dict]]:
+ """Collect parameter subsections from main configuration.
+
+ If the `parameters` section contains subsections (e.g. '[parameters.1]',
+ '[parameters.2]'), collect the subsection key-value pairs. Otherwise,
+ return an empty dictionary (i.e. there are no subsections).
+
+ This is useful for specifying multiple API keys for your configuration.
+ For example:
+ ```
+ [parameters.alice]
+ api_key=KKKKK1
+ api_url=UUUUU1
+ [parameters.bob]
+ api_key=KKKKK2
+ api_url=UUUUU2
+ [parameters.eve]
+ api_key=KKKKK3
+ api_url=UUUUU3
+ ```
+ """
+ return [
+ (name, params)
+ for name, params in config.kwargs.items()
+ if isinstance(params, dict)
+ ] or [("default", {})]
diff --git a/weather_dl_v2/fastapi-server/config_processing/partition.py b/weather_dl_v2/fastapi-server/config_processing/partition.py
new file mode 100644
index 00000000..a9f6a9e2
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/partition.py
@@ -0,0 +1,129 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import copy as cp
+import dataclasses
+import itertools
+import typing as t
+
+from .manifest import Manifest
+from .parsers import prepare_target_name
+from .config import Config
+from .stores import Store, FSStore
+
+logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class PartitionConfig:
+ """Partition a config into multiple data requests.
+
+ Partitioning involves four main operations: First, we fan-out shards based on
+ partition keys (a cross product of the values). Second, we filter out existing
+ downloads (unless we want to force downloads). Last, we assemble each partition
+ into a single Config.
+
+ Attributes:
+ store: A cloud storage system, used for checking the existence of downloads.
+ manifest: A download manifest to register preparation state.
+ """
+
+ config: Config
+ store: Store
+ manifest: Manifest
+
+ def _create_partition_config(self, option: t.Tuple) -> Config:
+ """Create a config for a single partition option.
+
+ Output a config dictionary, overriding the range of values for
+ each key with the partition instance in 'selection'.
+ Continuing the example from prepare_partitions, the selection section
+ would be:
+ { 'foo': ..., 'year': ['2020'], 'month': ['01'], ... }
+ { 'foo': ..., 'year': ['2020'], 'month': ['02'], ... }
+ { 'foo': ..., 'year': ['2020'], 'month': ['03'], ... }
+
+ Args:
+ option: A single item in the range of partition_keys.
+ config: The download config, including the parameters and selection sections.
+
+ Returns:
+ A configuration with that selects a single download partition.
+ """
+ copy = cp.deepcopy(self.config.selection)
+ out = cp.deepcopy(self.config)
+ for idx, key in enumerate(self.config.partition_keys):
+ copy[key] = [option[idx]]
+
+ out.selection = copy
+ return out
+
+ def skip_partition(self, config: Config) -> bool:
+ """Return true if partition should be skipped."""
+
+ if config.force_download:
+ return False
+
+ target = prepare_target_name(config)
+ if self.store.exists(target):
+ logger.info(f"file {target} found, skipping.")
+ self.manifest.skip(
+ config.config_name,
+ config.dataset,
+ config.selection,
+ target,
+ config.user_id,
+ )
+ return True
+
+ return False
+
+ def prepare_partitions(self) -> t.Iterator[Config]:
+ """Iterate over client parameters, partitioning over `partition_keys`.
+
+ This produces a Cartesian-Cross over the range of keys.
+
+ For example, if the keys were 'year' and 'month', it would produce
+ an iterable like:
+ ( ('2020', '01'), ('2020', '02'), ('2020', '03'), ...)
+
+ Returns:
+ An iterator of `Config`s.
+ """
+ for option in itertools.product(
+ *[self.config.selection[key] for key in self.config.partition_keys]
+ ):
+ yield self._create_partition_config(option)
+
+ def new_downloads_only(self, candidate: Config) -> bool:
+ """Predicate function to skip already downloaded partitions."""
+ if self.store is None:
+ self.store = FSStore()
+ should_skip = self.skip_partition(candidate)
+
+ return not should_skip
+
+ def update_manifest_collection(self, partition: Config) -> Config:
+ """Updates the DB."""
+ location = prepare_target_name(partition)
+ self.manifest.schedule(
+ partition.config_name,
+ partition.dataset,
+ partition.selection,
+ location,
+ partition.user_id,
+ )
+ logger.info(f"Created partition {location!r}.")
diff --git a/weather_dl_v2/fastapi-server/config_processing/pipeline.py b/weather_dl_v2/fastapi-server/config_processing/pipeline.py
new file mode 100644
index 00000000..175dd798
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/pipeline.py
@@ -0,0 +1,69 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import getpass
+import logging
+import os
+from .parsers import process_config
+from .partition import PartitionConfig
+from .manifest import FirestoreManifest
+from database.download_handler import get_download_handler
+from database.queue_handler import get_queue_handler
+from fastapi.concurrency import run_in_threadpool
+
+logger = logging.getLogger(__name__)
+
+download_handler = get_download_handler()
+queue_handler = get_queue_handler()
+
+
+def _do_partitions(partition_obj: PartitionConfig):
+ for partition in partition_obj.prepare_partitions():
+ # Skip existing downloads
+ if partition_obj.new_downloads_only(partition):
+ partition_obj.update_manifest_collection(partition)
+
+
+# TODO: Make partitioning faster.
+async def start_processing_config(config_file, licenses, force_download):
+ config = {}
+ manifest = FirestoreManifest()
+
+ with open(config_file, "r", encoding="utf-8") as f:
+ # configs/example.cfg -> example.cfg
+ config_name = os.path.split(config_file)[1]
+ config = process_config(f, config_name)
+
+ config.force_download = force_download
+ config.user_id = getpass.getuser()
+
+ partition_obj = PartitionConfig(config, None, manifest)
+
+ # Make entry in 'download' & 'queues' collection.
+ await download_handler._start_download(config_name, config.client)
+ await download_handler._mark_partitioning_status(
+ config_name, "Partitioning in-progress."
+ )
+ try:
+ # Prepare partitions
+ await run_in_threadpool(_do_partitions, partition_obj)
+ await download_handler._mark_partitioning_status(
+ config_name, "Partitioning completed."
+ )
+ await queue_handler._update_queues_on_start_download(config_name, licenses)
+ except Exception as e:
+ error_str = f"Partitioning failed for {config_name} due to {e}."
+ logger.error(error_str)
+ await download_handler._mark_partitioning_status(config_name, error_str)
diff --git a/weather_dl_v2/fastapi-server/config_processing/stores.py b/weather_dl_v2/fastapi-server/config_processing/stores.py
new file mode 100644
index 00000000..4f60e337
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/stores.py
@@ -0,0 +1,122 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Download destinations, or `Store`s."""
+
+import abc
+import io
+import os
+import tempfile
+import typing as t
+
+from apache_beam.io.filesystems import FileSystems
+
+
+class Store(abc.ABC):
+ """A interface to represent where downloads are stored.
+
+ Default implementation uses Apache Beam's Filesystems.
+ """
+
+ @abc.abstractmethod
+ def open(self, filename: str, mode: str = "r") -> t.IO:
+ pass
+
+ @abc.abstractmethod
+ def exists(self, filename: str) -> bool:
+ pass
+
+
+class InMemoryStore(Store):
+ """Store file data in memory."""
+
+ def __init__(self):
+ self.store = {}
+
+ def open(self, filename: str, mode: str = "r") -> t.IO:
+ """Create or read in-memory data."""
+ if "b" in mode:
+ file = io.BytesIO()
+ else:
+ file = io.StringIO()
+ self.store[filename] = file
+ return file
+
+ def exists(self, filename: str) -> bool:
+ """Return true if the 'file' exists in memory."""
+ return filename in self.store
+
+
+class TempFileStore(Store):
+ """Store data into temporary files."""
+
+ def __init__(self, directory: t.Optional[str] = None) -> None:
+ """Optionally specify the directory that contains all temporary files."""
+ self.dir = directory
+ if self.dir and not os.path.exists(self.dir):
+ os.makedirs(self.dir)
+
+ def open(self, filename: str, mode: str = "r") -> t.IO:
+ """Create a temporary file in the store directory."""
+ return tempfile.TemporaryFile(mode, dir=self.dir)
+
+ def exists(self, filename: str) -> bool:
+ """Return true if file exists."""
+ return os.path.exists(filename)
+
+
+class LocalFileStore(Store):
+ """Store data into local files."""
+
+ def __init__(self, directory: t.Optional[str] = None) -> None:
+ """Optionally specify the directory that contains all downloaded files."""
+ self.dir = directory
+ if self.dir and not os.path.exists(self.dir):
+ os.makedirs(self.dir)
+
+ def open(self, filename: str, mode: str = "r") -> t.IO:
+ """Open a local file from the store directory."""
+ return open(os.sep.join([self.dir, filename]), mode)
+
+ def exists(self, filename: str) -> bool:
+ """Returns true if local file exists."""
+ return os.path.exists(os.sep.join([self.dir, filename]))
+
+
+class FSStore(Store):
+ """Store data into any store supported by Apache Beam's FileSystems."""
+
+ def open(self, filename: str, mode: str = "r") -> t.IO:
+ """Open object in cloud bucket (or local file system) as a read or write channel.
+
+ To work with cloud storage systems, only a read or write channel can be openend
+ at one time. Data will be treated as bytes, not text (equivalent to `rb` or `wb`).
+
+ Further, append operations, or writes on existing objects, are dissallowed (the
+ error thrown will depend on the implementation of the underlying cloud provider).
+ """
+ if "r" in mode and "w" not in mode:
+ return FileSystems().open(filename)
+
+ if "w" in mode and "r" not in mode:
+ return FileSystems().create(filename)
+
+ raise ValueError(
+ f"invalid mode {mode!r}: mode must have either 'r' or 'w', but not both."
+ )
+
+ def exists(self, filename: str) -> bool:
+ """Returns true if object exists."""
+ return FileSystems().exists(filename)
diff --git a/weather_dl_v2/fastapi-server/config_processing/util.py b/weather_dl_v2/fastapi-server/config_processing/util.py
new file mode 100644
index 00000000..765a9c47
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/config_processing/util.py
@@ -0,0 +1,229 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import datetime
+import geojson
+import hashlib
+import itertools
+import os
+import socket
+import subprocess
+import sys
+import typing as t
+
+import numpy as np
+import pandas as pd
+from apache_beam.io.gcp import gcsio
+from apache_beam.utils import retry
+from xarray.core.utils import ensure_us_time_resolution
+from urllib.parse import urlparse
+from google.api_core.exceptions import BadRequest
+
+
+LATITUDE_RANGE = (-90, 90)
+LONGITUDE_RANGE = (-180, 180)
+GLOBAL_COVERAGE_AREA = [90, -180, -90, 180]
+
+logger = logging.getLogger(__name__)
+
+
+def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter(
+ exception,
+) -> bool:
+ if isinstance(exception, socket.timeout):
+ return True
+ if isinstance(exception, TimeoutError):
+ return True
+ # To handle the concurrency issue in BigQuery.
+ if isinstance(exception, BadRequest):
+ return True
+ return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception)
+
+
+class _FakeClock:
+
+ def sleep(self, value):
+ pass
+
+
+def retry_with_exponential_backoff(fun):
+ """A retry decorator that doesn't apply during test time."""
+ clock = retry.Clock()
+
+ # Use a fake clock only during test time...
+ if "unittest" in sys.modules.keys():
+ clock = _FakeClock()
+
+ return retry.with_exponential_backoff(
+ retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter,
+ clock=clock,
+ )(fun)
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]:
+ """Yield evenly-sized chunks from an iterable."""
+ input_ = iter(iterable)
+ try:
+ while True:
+ it = itertools.islice(input_, n)
+ # peek to check if 'it' has next item.
+ first = next(it)
+ yield itertools.chain([first], it)
+ except StopIteration:
+ pass
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def copy(src: str, dst: str) -> None:
+ """Copy data via `gsutil cp`."""
+ try:
+ subprocess.run(["gsutil", "cp", src, dst], check=True, capture_output=True)
+ except subprocess.CalledProcessError as e:
+ logger.info(
+ f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}.'
+ )
+ raise
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def to_json_serializable_type(value: t.Any) -> t.Any:
+ """Returns the value with a type serializable to JSON"""
+ # Note: The order of processing is significant.
+ logger.info("Serializing to JSON.")
+
+ if pd.isna(value) or value is None:
+ return None
+ elif np.issubdtype(type(value), np.floating):
+ return float(value)
+ elif isinstance(value, np.ndarray):
+ # Will return a scaler if array is of size 1, else will return a list.
+ return value.tolist()
+ elif (
+ isinstance(value, datetime.datetime)
+ or isinstance(value, str)
+ or isinstance(value, np.datetime64)
+ ):
+ # Assume strings are ISO format timestamps...
+ try:
+ value = datetime.datetime.fromisoformat(value)
+ except ValueError:
+ # ... if they are not, assume serialization is already correct.
+ return value
+ except TypeError:
+ # ... maybe value is a numpy datetime ...
+ try:
+ value = ensure_us_time_resolution(value).astype(datetime.datetime)
+ except AttributeError:
+ # ... value is a datetime object, continue.
+ pass
+
+ # We use a string timestamp representation.
+ if value.tzname():
+ return value.isoformat()
+
+ # We assume here that naive timestamps are in UTC timezone.
+ return value.replace(tzinfo=datetime.timezone.utc).isoformat()
+ elif isinstance(value, np.timedelta64):
+ # Return time delta in seconds.
+ return float(value / np.timedelta64(1, "s"))
+ # This check must happen after processing np.timedelta64 and np.datetime64.
+ elif np.issubdtype(type(value), np.integer):
+ return int(value)
+
+ return value
+
+
+def fetch_geo_polygon(area: t.Union[list, str]) -> str:
+ """Calculates a geography polygon from an input area."""
+ # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973
+ if isinstance(area, str):
+ # European area
+ if area == "E":
+ area = [73.5, -27, 33, 45]
+ # Global area
+ elif area == "G":
+ area = GLOBAL_COVERAGE_AREA
+ else:
+ raise RuntimeError(f"Not a valid value for area in config: {area}.")
+
+ n, w, s, e = [float(x) for x in area]
+ if s < LATITUDE_RANGE[0]:
+ raise ValueError(f"Invalid latitude value for south: '{s}'")
+ if n > LATITUDE_RANGE[1]:
+ raise ValueError(f"Invalid latitude value for north: '{n}'")
+ if w < LONGITUDE_RANGE[0]:
+ raise ValueError(f"Invalid longitude value for west: '{w}'")
+ if e > LONGITUDE_RANGE[1]:
+ raise ValueError(f"Invalid longitude value for east: '{e}'")
+
+ # Define the coordinates of the bounding box.
+ coords = [[w, n], [w, s], [e, s], [e, n], [w, n]]
+
+ # Create the GeoJSON polygon object.
+ polygon = geojson.dumps(geojson.Polygon([coords]))
+ return polygon
+
+
+def get_file_size(path: str) -> float:
+ parsed_gcs_path = urlparse(path)
+ if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "":
+ return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0
+ else:
+ return (
+ gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0
+ )
+
+
+def get_wait_interval(num_retries: int = 0) -> float:
+ """Returns next wait interval in seconds, using an exponential backoff algorithm."""
+ if 0 == num_retries:
+ return 0
+ return 2**num_retries
+
+
+def generate_md5_hash(input: str) -> str:
+ """Generates md5 hash for the input string."""
+ return hashlib.md5(input.encode("utf-8")).hexdigest()
+
+
+def download_with_aria2(url: str, path: str) -> None:
+ """Downloads a file from the given URL using the `aria2c` command-line utility,
+ with options set to improve download speed and reliability."""
+ dir_path, file_name = os.path.split(path)
+ try:
+ subprocess.run(
+ [
+ "aria2c",
+ "-x",
+ "16",
+ "-s",
+ "16",
+ url,
+ "-d",
+ dir_path,
+ "-o",
+ file_name,
+ "--allow-overwrite",
+ ],
+ check=True,
+ capture_output=True,
+ )
+ except subprocess.CalledProcessError as e:
+ logger.info(
+ f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.'
+ )
+ raise
diff --git a/weather_dl_v2/fastapi-server/database/__init__.py b/weather_dl_v2/fastapi-server/database/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/fastapi-server/database/download_handler.py b/weather_dl_v2/fastapi-server/database/download_handler.py
new file mode 100644
index 00000000..1377e5b4
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/download_handler.py
@@ -0,0 +1,160 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+from firebase_admin import firestore
+from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter
+from google.cloud.firestore_v1.types import WriteResult
+from database.session import get_async_client
+from server_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+def get_download_handler():
+ return DownloadHandlerFirestore(db=get_async_client())
+
+
+def get_mock_download_handler():
+ return DownloadHandlerMock()
+
+
+class DownloadHandler(abc.ABC):
+
+ @abc.abstractmethod
+ async def _start_download(self, config_name: str, client_name: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _stop_download(self, config_name: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _mark_partitioning_status(self, config_name: str, status: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _check_download_exists(self, config_name: str) -> bool:
+ pass
+
+ @abc.abstractmethod
+ async def _get_downloads(self, client_name: str) -> list:
+ pass
+
+ @abc.abstractmethod
+ async def _get_download_by_config_name(self, config_name: str):
+ pass
+
+
+class DownloadHandlerMock(DownloadHandler):
+
+ def __init__(self):
+ pass
+
+ async def _start_download(self, config_name: str, client_name: str) -> None:
+ logger.info(
+ f"Added {config_name} in 'download' collection. Update_time: 000000."
+ )
+
+ async def _stop_download(self, config_name: str) -> None:
+ logger.info(
+ f"Removed {config_name} in 'download' collection. Update_time: 000000."
+ )
+
+ async def _mark_partitioning_status(self, config_name: str, status: str) -> None:
+ logger.info(
+ f"Updated {config_name} in 'download' collection. Update_time: 000000."
+ )
+
+ async def _check_download_exists(self, config_name: str) -> bool:
+ if config_name == "not_exist":
+ return False
+ elif config_name == "not_exist.cfg":
+ return False
+ else:
+ return True
+
+ async def _get_downloads(self, client_name: str) -> list:
+ return [{"config_name": "example.cfg", "client_name": "client", "status": "partitioning completed."}]
+
+ async def _get_download_by_config_name(self, config_name: str):
+ if config_name == "not_exist":
+ return None
+ return {"config_name": "example.cfg", "client_name": "client", "status": "partitioning completed."}
+
+
+class DownloadHandlerFirestore(DownloadHandler):
+
+ def __init__(self, db: firestore.firestore.Client):
+ self.db = db
+ self.collection = get_config().download_collection
+
+ async def _start_download(self, config_name: str, client_name: str) -> None:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(config_name)
+ .set({"config_name": config_name, "client_name": client_name})
+ )
+
+ logger.info(
+ f"Added {config_name} in 'download' collection. Update_time: {result.update_time}."
+ )
+
+ async def _stop_download(self, config_name: str) -> None:
+ timestamp = (
+ await self.db.collection(self.collection).document(config_name).delete()
+ )
+ logger.info(
+ f"Removed {config_name} in 'download' collection. Update_time: {timestamp}."
+ )
+
+ async def _mark_partitioning_status(self, config_name: str, status: str) -> None:
+ timestamp = (
+ await self.db.collection(self.collection)
+ .document(config_name)
+ .update({"status": status})
+ )
+ logger.info(
+ f"Updated {config_name} in 'download' collection. Update_time: {timestamp}."
+ )
+
+ async def _check_download_exists(self, config_name: str) -> bool:
+ result: DocumentSnapshot = (
+ await self.db.collection(self.collection).document(config_name).get()
+ )
+ return result.exists
+
+ async def _get_downloads(self, client_name: str) -> list:
+ docs = []
+ if client_name:
+ docs = (
+ self.db.collection(self.collection)
+ .where(filter=FieldFilter("client_name", "==", client_name))
+ .stream()
+ )
+ else:
+ docs = self.db.collection(self.collection).stream()
+
+ return [doc.to_dict() async for doc in docs]
+
+ async def _get_download_by_config_name(self, config_name: str):
+ result: DocumentSnapshot = (
+ await self.db.collection(self.collection).document(config_name).get()
+ )
+ if result.exists:
+ return result.to_dict()
+ else:
+ return None
diff --git a/weather_dl_v2/fastapi-server/database/license_handler.py b/weather_dl_v2/fastapi-server/database/license_handler.py
new file mode 100644
index 00000000..d4878e25
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/license_handler.py
@@ -0,0 +1,200 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+from firebase_admin import firestore
+from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter
+from google.cloud.firestore_v1.types import WriteResult
+from database.session import get_async_client
+from server_config import get_config
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_license_handler():
+ return LicenseHandlerFirestore(db=get_async_client())
+
+
+def get_mock_license_handler():
+ return LicenseHandlerMock()
+
+
+class LicenseHandler(abc.ABC):
+
+ @abc.abstractmethod
+ async def _add_license(self, license_dict: dict) -> str:
+ pass
+
+ @abc.abstractmethod
+ async def _delete_license(self, license_id: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _check_license_exists(self, license_id: str) -> bool:
+ pass
+
+ @abc.abstractmethod
+ async def _get_license_by_license_id(self, license_id: str) -> dict:
+ pass
+
+ @abc.abstractmethod
+ async def _get_license_by_client_name(self, client_name: str) -> list:
+ pass
+
+ @abc.abstractmethod
+ async def _get_licenses(self) -> list:
+ pass
+
+ @abc.abstractmethod
+ async def _update_license(self, license_id: str, license_dict: dict) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _get_license_without_deployment(self) -> list:
+ pass
+
+
+class LicenseHandlerMock(LicenseHandler):
+
+ def __init__(self):
+ pass
+
+ async def _add_license(self, license_dict: dict) -> str:
+ license_id = "L1"
+ logger.info(f"Added {license_id} in 'license' collection. Update_time: 00000.")
+ return license_id
+
+ async def _delete_license(self, license_id: str) -> None:
+ logger.info(
+ f"Removed {license_id} in 'license' collection. Update_time: 00000."
+ )
+
+ async def _update_license(self, license_id: str, license_dict: dict) -> None:
+ logger.info(
+ f"Updated {license_id} in 'license' collection. Update_time: 00000."
+ )
+
+ async def _check_license_exists(self, license_id: str) -> bool:
+ if license_id == "not_exist":
+ return False
+ elif license_id == "no-exists":
+ return False
+ else:
+ return True
+
+ async def _get_license_by_license_id(self, license_id: str) -> dict:
+ if license_id == "not_exist":
+ return None
+ return {
+ "license_id": license_id,
+ "secret_id": "xxxx",
+ "client_name": "dummy_client",
+ "k8s_deployment_id": "k1",
+ "number_of_requets": 100,
+ }
+
+ async def _get_license_by_client_name(self, client_name: str) -> list:
+ return [{
+ "license_id": "L1",
+ "secret_id": "xxxx",
+ "client_name": client_name,
+ "k8s_deployment_id": "k1",
+ "number_of_requets": 100,
+ }]
+
+ async def _get_licenses(self) -> list:
+ return [{
+ "license_id": "L1",
+ "secret_id": "xxxx",
+ "client_name": "dummy_client",
+ "k8s_deployment_id": "k1",
+ "number_of_requets": 100,
+ }]
+
+ async def _get_license_without_deployment(self) -> list:
+ return []
+
+
+class LicenseHandlerFirestore(LicenseHandler):
+
+ def __init__(self, db: firestore.firestore.AsyncClient):
+ self.db = db
+ self.collection = get_config().license_collection
+
+ async def _add_license(self, license_dict: dict) -> str:
+ license_dict["license_id"] = license_dict["license_id"].lower()
+ license_id = license_dict["license_id"]
+
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license_id)
+ .set(license_dict)
+ )
+ logger.info(
+ f"Added {license_id} in 'license' collection. Update_time: {result.update_time}."
+ )
+ return license_id
+
+ async def _delete_license(self, license_id: str) -> None:
+ timestamp = (
+ await self.db.collection(self.collection).document(license_id).delete()
+ )
+ logger.info(
+ f"Removed {license_id} in 'license' collection. Update_time: {timestamp}."
+ )
+
+ async def _update_license(self, license_id: str, license_dict: dict) -> None:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license_id)
+ .update(license_dict)
+ )
+ logger.info(
+ f"Updated {license_id} in 'license' collection. Update_time: {result.update_time}."
+ )
+
+ async def _check_license_exists(self, license_id: str) -> bool:
+ result: DocumentSnapshot = (
+ await self.db.collection(self.collection).document(license_id).get()
+ )
+ return result.exists
+
+ async def _get_license_by_license_id(self, license_id: str) -> dict:
+ result: DocumentSnapshot = (
+ await self.db.collection(self.collection).document(license_id).get()
+ )
+ return result.to_dict()
+
+ async def _get_license_by_client_name(self, client_name: str) -> list:
+ docs = (
+ self.db.collection(self.collection)
+ .where(filter=FieldFilter("client_name", "==", client_name))
+ .stream()
+ )
+ return [doc.to_dict() async for doc in docs]
+
+ async def _get_licenses(self) -> list:
+ docs = self.db.collection(self.collection).stream()
+ return [doc.to_dict() async for doc in docs]
+
+ async def _get_license_without_deployment(self) -> list:
+ docs = (
+ self.db.collection(self.collection)
+ .where(filter=FieldFilter("k8s_deployment_id", "==", ""))
+ .stream()
+ )
+ return [doc.to_dict() async for doc in docs]
diff --git a/weather_dl_v2/fastapi-server/database/manifest_handler.py b/weather_dl_v2/fastapi-server/database/manifest_handler.py
new file mode 100644
index 00000000..d5facfab
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/manifest_handler.py
@@ -0,0 +1,181 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+from firebase_admin import firestore
+from google.cloud.firestore_v1.base_query import FieldFilter, Or, And
+from server_config import get_config
+from database.session import get_async_client
+
+logger = logging.getLogger(__name__)
+
+
+def get_manifest_handler():
+ return ManifestHandlerFirestore(db=get_async_client())
+
+
+def get_mock_manifest_handler():
+ return ManifestHandlerMock()
+
+
+class ManifestHandler(abc.ABC):
+
+ @abc.abstractmethod
+ async def _get_download_success_count(self, config_name: str) -> int:
+ pass
+
+ @abc.abstractmethod
+ async def _get_download_failure_count(self, config_name: str) -> int:
+ pass
+
+ @abc.abstractmethod
+ async def _get_download_scheduled_count(self, config_name: str) -> int:
+ pass
+
+ @abc.abstractmethod
+ async def _get_download_inprogress_count(self, config_name: str) -> int:
+ pass
+
+ @abc.abstractmethod
+ async def _get_download_total_count(self, config_name: str) -> int:
+ pass
+
+ @abc.abstractmethod
+ async def _get_non_successfull_downloads(self, config_name: str) -> list:
+ pass
+
+
+class ManifestHandlerMock(ManifestHandler):
+
+ async def _get_download_failure_count(self, config_name: str) -> int:
+ return 0
+
+ async def _get_download_inprogress_count(self, config_name: str) -> int:
+ return 0
+
+ async def _get_download_scheduled_count(self, config_name: str) -> int:
+ return 0
+
+ async def _get_download_success_count(self, config_name: str) -> int:
+ return 0
+
+ async def _get_download_total_count(self, config_name: str) -> int:
+ return 0
+
+ async def _get_non_successfull_downloads(self, config_name: str) -> list:
+ return []
+
+
+class ManifestHandlerFirestore(ManifestHandler):
+
+ def __init__(self, db: firestore.firestore.Client):
+ self.db = db
+ self.collection = get_config().manifest_collection
+
+ async def _get_download_success_count(self, config_name: str) -> int:
+ result = (
+ await self.db.collection(self.collection)
+ .where(filter=FieldFilter("config_name", "==", config_name))
+ .where(filter=FieldFilter("stage", "==", "upload"))
+ .where(filter=FieldFilter("status", "==", "success"))
+ .count()
+ .get()
+ )
+
+ count = result[0][0].value
+
+ return count
+
+ async def _get_download_failure_count(self, config_name: str) -> int:
+ result = (
+ await self.db.collection(self.collection)
+ .where(filter=FieldFilter("config_name", "==", config_name))
+ .where(filter=FieldFilter("status", "==", "failure"))
+ .count()
+ .get()
+ )
+
+ count = result[0][0].value
+
+ return count
+
+ async def _get_download_scheduled_count(self, config_name: str) -> int:
+ result = (
+ await self.db.collection(self.collection)
+ .where(filter=FieldFilter("config_name", "==", config_name))
+ .where(filter=FieldFilter("status", "==", "scheduled"))
+ .count()
+ .get()
+ )
+
+ count = result[0][0].value
+
+ return count
+
+ async def _get_download_inprogress_count(self, config_name: str) -> int:
+ and_filter = And(
+ filters=[
+ FieldFilter("status", "==", "success"),
+ FieldFilter("stage", "!=", "upload"),
+ ]
+ )
+ or_filter = Or(filters=[FieldFilter("status", "==", "in-progress"), and_filter])
+
+ result = (
+ await self.db.collection(self.collection)
+ .where(filter=FieldFilter("config_name", "==", config_name))
+ .where(filter=or_filter)
+ .count()
+ .get()
+ )
+
+ count = result[0][0].value
+
+ return count
+
+ async def _get_download_total_count(self, config_name: str) -> int:
+ result = (
+ await self.db.collection(self.collection)
+ .where(filter=FieldFilter("config_name", "==", config_name))
+ .count()
+ .get()
+ )
+
+ count = result[0][0].value
+
+ return count
+
+ async def _get_non_successfull_downloads(self, config_name: str) -> list:
+ or_filter = Or(
+ filters=[
+ FieldFilter("stage", "==", "fetch"),
+ FieldFilter("stage", "==", "download"),
+ And(
+ filters=[
+ FieldFilter("status", "!=", "success"),
+ FieldFilter("stage", "==", "upload"),
+ ]
+ ),
+ ]
+ )
+
+ docs = (
+ self.db.collection(self.collection)
+ .where(filter=FieldFilter("config_name", "==", config_name))
+ .where(filter=or_filter)
+ .stream()
+ )
+ return [doc.to_dict() async for doc in docs]
diff --git a/weather_dl_v2/fastapi-server/database/queue_handler.py b/weather_dl_v2/fastapi-server/database/queue_handler.py
new file mode 100644
index 00000000..1909d583
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/queue_handler.py
@@ -0,0 +1,247 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import logging
+from firebase_admin import firestore
+from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter
+from google.cloud.firestore_v1.types import WriteResult
+from database.session import get_async_client
+from server_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+def get_queue_handler():
+ return QueueHandlerFirestore(db=get_async_client())
+
+
+def get_mock_queue_handler():
+ return QueueHandlerMock()
+
+
+class QueueHandler(abc.ABC):
+
+ @abc.abstractmethod
+ async def _create_license_queue(self, license_id: str, client_name: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _remove_license_queue(self, license_id: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _get_queues(self) -> list:
+ pass
+
+ @abc.abstractmethod
+ async def _get_queue_by_license_id(self, license_id: str) -> dict:
+ pass
+
+ @abc.abstractmethod
+ async def _get_queue_by_client_name(self, client_name: str) -> list:
+ pass
+
+ @abc.abstractmethod
+ async def _update_license_queue(self, license_id: str, priority_list: list) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _update_queues_on_start_download(
+ self, config_name: str, licenses: list
+ ) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _update_queues_on_stop_download(self, config_name: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _update_config_priority_in_license(
+ self, license_id: str, config_name: str, priority: int
+ ) -> None:
+ pass
+
+ @abc.abstractmethod
+ async def _update_client_name_in_license_queue(
+ self, license_id: str, client_name: str
+ ) -> None:
+ pass
+
+
+class QueueHandlerMock(QueueHandler):
+
+ def __init__(self):
+ pass
+
+ async def _create_license_queue(self, license_id: str, client_name: str) -> None:
+ logger.info(
+ f"Added {license_id} queue in 'queues' collection. Update_time: 000000."
+ )
+
+ async def _remove_license_queue(self, license_id: str) -> None:
+ logger.info(
+ f"Removed {license_id} queue in 'queues' collection. Update_time: 000000."
+ )
+
+ async def _get_queues(self) -> list:
+ return [{"client_name": "dummy_client", "license_id": "L1", "queue": []}]
+
+ async def _get_queue_by_license_id(self, license_id: str) -> dict:
+ if license_id == "not_exist":
+ return None
+ return {"client_name": "dummy_client", "license_id": license_id, "queue": []}
+
+ async def _get_queue_by_client_name(self, client_name: str) -> list:
+ return [{"client_name": client_name, "license_id": "L1", "queue": []}]
+
+ async def _update_license_queue(self, license_id: str, priority_list: list) -> None:
+ logger.info(
+ f"Updated {license_id} queue in 'queues' collection. Update_time: 00000."
+ )
+
+ async def _update_queues_on_start_download(
+ self, config_name: str, licenses: list
+ ) -> None:
+ logger.info(
+ f"Updated {license} queue in 'queues' collection. Update_time: 00000."
+ )
+
+ async def _update_queues_on_stop_download(self, config_name: str) -> None:
+ logger.info(
+ "Updated snapshot.id queue in 'queues' collection. Update_time: 00000."
+ )
+
+ async def _update_config_priority_in_license(
+ self, license_id: str, config_name: str, priority: int
+ ) -> None:
+ logger.info(
+ "Updated snapshot.id queue in 'queues' collection. Update_time: 00000."
+ )
+
+ async def _update_client_name_in_license_queue(
+ self, license_id: str, client_name: str
+ ) -> None:
+ logger.info(
+ "Updated snapshot.id queue in 'queues' collection. Update_time: 00000."
+ )
+
+
+class QueueHandlerFirestore(QueueHandler):
+
+ def __init__(self, db: firestore.firestore.Client):
+ self.db = db
+ self.collection = get_config().queues_collection
+
+ async def _create_license_queue(self, license_id: str, client_name: str) -> None:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license_id)
+ .set({"license_id": license_id, "client_name": client_name, "queue": []})
+ )
+ logger.info(
+ f"Added {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+ async def _remove_license_queue(self, license_id: str) -> None:
+ timestamp = (
+ await self.db.collection(self.collection).document(license_id).delete()
+ )
+ logger.info(
+ f"Removed {license_id} queue in 'queues' collection. Update_time: {timestamp}."
+ )
+
+ async def _get_queues(self) -> list:
+ docs = self.db.collection(self.collection).stream()
+ return [doc.to_dict() async for doc in docs]
+
+ async def _get_queue_by_license_id(self, license_id: str) -> dict:
+ result: DocumentSnapshot = (
+ await self.db.collection(self.collection).document(license_id).get()
+ )
+ return result.to_dict()
+
+ async def _get_queue_by_client_name(self, client_name: str) -> list:
+ docs = (
+ self.db.collection(self.collection)
+ .where(filter=FieldFilter("client_name", "==", client_name))
+ .stream()
+ )
+ return [doc.to_dict() async for doc in docs]
+
+ async def _update_license_queue(self, license_id: str, priority_list: list) -> None:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license_id)
+ .update({"queue": priority_list})
+ )
+ logger.info(
+ f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+ async def _update_queues_on_start_download(
+ self, config_name: str, licenses: list
+ ) -> None:
+ for license in licenses:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license)
+ .update({"queue": firestore.ArrayUnion([config_name])})
+ )
+ logger.info(
+ f"Updated {license} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+ async def _update_queues_on_stop_download(self, config_name: str) -> None:
+ snapshot_list = await self.db.collection(self.collection).get()
+ for snapshot in snapshot_list:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(snapshot.id)
+ .update({"queue": firestore.ArrayRemove([config_name])})
+ )
+ logger.info(
+ f"Updated {snapshot.id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+ async def _update_config_priority_in_license(
+ self, license_id: str, config_name: str, priority: int
+ ) -> None:
+ snapshot: DocumentSnapshot = (
+ await self.db.collection(self.collection).document(license_id).get()
+ )
+ priority_list = snapshot.to_dict()["queue"]
+ new_priority_list = [c for c in priority_list if c != config_name]
+ new_priority_list.insert(priority, config_name)
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license_id)
+ .update({"queue": new_priority_list})
+ )
+ logger.info(
+ f"Updated {snapshot.id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+ async def _update_client_name_in_license_queue(
+ self, license_id: str, client_name: str
+ ) -> None:
+ result: WriteResult = (
+ await self.db.collection(self.collection)
+ .document(license_id)
+ .update({"client_name": client_name})
+ )
+ logger.info(
+ f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
diff --git a/weather_dl_v2/fastapi-server/database/session.py b/weather_dl_v2/fastapi-server/database/session.py
new file mode 100644
index 00000000..85dbc8be
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/session.py
@@ -0,0 +1,79 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import time
+import abc
+import logging
+import firebase_admin
+from google.cloud import firestore
+from firebase_admin import credentials
+from config_processing.util import get_wait_interval
+from server_config import get_config
+from gcloud import storage
+
+logger = logging.getLogger(__name__)
+
+
+class Database(abc.ABC):
+
+ @abc.abstractmethod
+ def _get_db(self):
+ pass
+
+
+db: firestore.AsyncClient = None
+gcs: storage.Client = None
+
+
+def get_async_client() -> firestore.AsyncClient:
+ global db
+ attempts = 0
+
+ while db is None:
+ try:
+ db = firestore.AsyncClient()
+ except ValueError as e:
+ # The above call will fail with a value error when the firebase app is not initialized.
+ # Initialize the app here, and try again.
+ # Use the application default credentials.
+ cred = credentials.ApplicationDefault()
+
+ firebase_admin.initialize_app(cred)
+ logger.info("Initialized Firebase App.")
+
+ if attempts > 4:
+ raise RuntimeError(
+ "Exceeded number of retries to get firestore client."
+ ) from e
+
+ time.sleep(get_wait_interval(attempts))
+
+ attempts += 1
+
+ return db
+
+
+def get_gcs_client() -> storage.Client:
+ global gcs
+
+ if gcs:
+ return gcs
+
+ try:
+ gcs = storage.Client(project=get_config().gcs_project)
+ except ValueError as e:
+ logger.error(f"Error initializing GCS client: {e}.")
+
+ return gcs
diff --git a/weather_dl_v2/fastapi-server/database/storage_handler.py b/weather_dl_v2/fastapi-server/database/storage_handler.py
new file mode 100644
index 00000000..fcdf6a1a
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/database/storage_handler.py
@@ -0,0 +1,77 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import os
+import logging
+import tempfile
+import contextlib
+import typing as t
+from google.cloud import storage
+from database.session import get_gcs_client
+from server_config import get_config
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_storage_handler():
+ return StorageHandlerGCS(client=get_gcs_client())
+
+
+class StorageHandler(abc.ABC):
+
+ @abc.abstractmethod
+ def _upload_file(self, file_path) -> str:
+ pass
+
+ @abc.abstractmethod
+ def _open_local(self, file_name) -> t.Iterator[str]:
+ pass
+
+
+class StorageHandlerMock(StorageHandler):
+
+ def __init__(self) -> None:
+ pass
+
+ def _upload_file(self, file_path) -> None:
+ pass
+
+ def _open_local(self, file_name) -> t.Iterator[str]:
+ pass
+
+
+class StorageHandlerGCS(StorageHandler):
+
+ def __init__(self, client: storage.Client) -> None:
+ self.client = client
+ self.bucket = self.client.get_bucket(get_config().storage_bucket)
+
+ def _upload_file(self, file_path) -> str:
+ filename = os.path.basename(file_path).split("/")[-1]
+
+ blob = self.bucket.blob(filename)
+ blob.upload_from_filename(file_path)
+
+ logger.info(f"Uploaded {filename} to {self.bucket}.")
+ return blob.public_url
+
+ @contextlib.contextmanager
+ def _open_local(self, file_name) -> t.Iterator[str]:
+ blob = self.bucket.blob(file_name)
+ with tempfile.NamedTemporaryFile() as dest_file:
+ blob.download_to_filename(dest_file.name)
+ yield dest_file.name
diff --git a/weather_dl_v2/fastapi-server/environment.yml b/weather_dl_v2/fastapi-server/environment.yml
new file mode 100644
index 00000000..a6ce07fb
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/environment.yml
@@ -0,0 +1,18 @@
+name: weather-dl-v2-server
+channels:
+ - conda-forge
+dependencies:
+ - python=3.10
+ - xarray
+ - geojson
+ - pip=22.3
+ - google-cloud-sdk=410.0.0
+ - pip:
+ - kubernetes
+ - fastapi[all]==0.97.0
+ - python-multipart
+ - numpy
+ - apache-beam[gcp]
+ - aiohttp
+ - firebase-admin
+ - gcloud
diff --git a/weather_dl_v2/fastapi-server/example.cfg b/weather_dl_v2/fastapi-server/example.cfg
new file mode 100644
index 00000000..6747012c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/example.cfg
@@ -0,0 +1,32 @@
+[parameters]
+client=mars
+
+target_path=gs:///test-weather-dl-v2/{date}T00z.gb
+partition_keys=
+ date
+ # step
+
+# API Keys & Subsections go here...
+
+[selection]
+class=od
+type=pf
+stream=enfo
+expver=0001
+levtype=pl
+levelist=100
+# params:
+# (z) Geopotential 129, (t) Temperature 130,
+# (u) U component of wind 131, (v) V component of wind 132,
+# (q) Specific humidity 133, (w) vertical velocity 135,
+# (vo) Vorticity (relative) 138, (d) Divergence 155,
+# (r) Relative humidity 157
+param=129.128
+#
+# next: 2019-01-01/to/existing
+#
+date=2019-07-18/to/2019-07-20
+time=0000
+step=0/to/2
+number=1/to/2
+grid=F640
diff --git a/weather_dl_v2/fastapi-server/license_dep/deployment_creator.py b/weather_dl_v2/fastapi-server/license_dep/deployment_creator.py
new file mode 100644
index 00000000..f79521d2
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/license_dep/deployment_creator.py
@@ -0,0 +1,67 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+from os import path
+import yaml
+from kubernetes import client, config
+from server_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+def create_license_deployment(license_id: str) -> str:
+ """Creates a kubernetes workflow of type Job for downloading the data."""
+ config.load_config()
+
+ with open(path.join(path.dirname(__file__), "license_deployment.yaml")) as f:
+ deployment_manifest = yaml.safe_load(f)
+ deployment_name = f"weather-dl-v2-license-dep-{license_id}".lower()
+
+ # Update the deployment name with a unique identifier
+ deployment_manifest["metadata"]["name"] = deployment_name
+ deployment_manifest["spec"]["template"]["spec"]["containers"][0]["args"] = [
+ "--license",
+ license_id,
+ ]
+ deployment_manifest["spec"]["template"]["spec"]["containers"][0][
+ "image"
+ ] = get_config().license_deployment_image
+
+ # Create an instance of the Kubernetes API client
+ api_instance = client.AppsV1Api()
+ # Create the deployment in the specified namespace
+ response = api_instance.create_namespaced_deployment(
+ body=deployment_manifest, namespace="default"
+ )
+
+ logger.info(f"Deployment created successfully: {response.metadata.name}.")
+ return deployment_name
+
+
+def terminate_license_deployment(license_id: str) -> None:
+ # Load Kubernetes configuration
+ config.load_config()
+
+ # Create an instance of the Kubernetes API client
+ api_instance = client.AppsV1Api()
+
+ # Specify the name and namespace of the deployment to delete
+ deployment_name = f"weather-dl-v2-license-dep-{license_id}".lower()
+
+ # Delete the deployment
+ api_instance.delete_namespaced_deployment(name=deployment_name, namespace="default")
+
+ logger.info(f"Deployment '{deployment_name}' deleted successfully.")
diff --git a/weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml b/weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml
new file mode 100644
index 00000000..707e5b91
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml
@@ -0,0 +1,35 @@
+# weather-dl-v2-license-dep Deployment
+# Defines the deployment of the app running in a pod on any worker node
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: weather-dl-v2-license-dep
+ labels:
+ app: weather-dl-v2-license-dep
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: weather-dl-v2-license-dep
+ template:
+ metadata:
+ labels:
+ app: weather-dl-v2-license-dep
+ spec:
+ containers:
+ - name: weather-dl-v2-license-dep
+ image: XXXXXXX
+ imagePullPolicy: Always
+ args: []
+ volumeMounts:
+ - name: config-volume
+ mountPath: ./config
+ volumes:
+ - name: config-volume
+ configMap:
+ name: dl-v2-config
+ # resources:
+ # # You must specify requests for CPU to autoscale
+ # # based on CPU utilization
+ # requests:
+ # cpu: "250m"
\ No newline at end of file
diff --git a/weather_dl_v2/fastapi-server/logging.conf b/weather_dl_v2/fastapi-server/logging.conf
new file mode 100644
index 00000000..ed0a5e29
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/logging.conf
@@ -0,0 +1,36 @@
+[loggers]
+keys=root,server
+
+[handlers]
+keys=consoleHandler,detailedConsoleHandler
+
+[formatters]
+keys=normalFormatter,detailedFormatter
+
+[logger_root]
+level=INFO
+handlers=consoleHandler
+
+[logger_server]
+level=DEBUG
+handlers=detailedConsoleHandler
+qualname=server
+propagate=0
+
+[handler_consoleHandler]
+class=StreamHandler
+level=DEBUG
+formatter=normalFormatter
+args=(sys.stdout,)
+
+[handler_detailedConsoleHandler]
+class=StreamHandler
+level=DEBUG
+formatter=detailedFormatter
+args=(sys.stdout,)
+
+[formatter_normalFormatter]
+format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() msg:%(message)s
+
+[formatter_detailedFormatter]
+format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() msg:%(message)s call_trace=%(pathname)s L%(lineno)-4d
\ No newline at end of file
diff --git a/weather_dl_v2/fastapi-server/main.py b/weather_dl_v2/fastapi-server/main.py
new file mode 100644
index 00000000..05124123
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/main.py
@@ -0,0 +1,70 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import logging.config
+from contextlib import asynccontextmanager
+from fastapi import FastAPI
+from routers import license, download, queues
+from database.license_handler import get_license_handler
+from routers.license import get_create_deployment
+from server_config import get_config
+
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+# set up logger.
+logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
+logger = logging.getLogger(__name__)
+
+
+async def create_pending_license_deployments():
+ """Creates license deployments for Licenses whose deployments does not exist."""
+ license_handler = get_license_handler()
+ create_deployment = get_create_deployment()
+ license_list = await license_handler._get_license_without_deployment()
+
+ for _license in license_list:
+ license_id = _license["license_id"]
+ try:
+ logger.info(f"Creating license deployment for {license_id}.")
+ await create_deployment(license_id, license_handler)
+ except Exception as e:
+ logger.error(f"License deployment failed for {license_id}. Exception: {e}.")
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ logger.info("Started FastAPI server.")
+ # Boot up
+ # Make directory to store the uploaded config files.
+ os.makedirs(os.path.join(os.getcwd(), "config_files"), exist_ok=True)
+ # Retrieve license information & create license deployment if needed.
+ await create_pending_license_deployments()
+ # TODO: Automatically create required indexes on firestore collections on server startup.
+ yield
+ # Clean up
+
+
+app = FastAPI(lifespan=lifespan)
+
+app.include_router(license.router)
+app.include_router(download.router)
+app.include_router(queues.router)
+
+
+@app.get("/")
+async def main():
+ return {"msg": get_config().welcome_message}
diff --git a/weather_dl_v2/fastapi-server/routers/download.py b/weather_dl_v2/fastapi-server/routers/download.py
new file mode 100644
index 00000000..e3de4b57
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/routers/download.py
@@ -0,0 +1,386 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+import logging
+import os
+import shutil
+import json
+
+from enum import Enum
+from config_processing.parsers import parse_config, process_config
+from config_processing.config import Config
+from fastapi import APIRouter, HTTPException, BackgroundTasks, UploadFile, Depends, Body
+from config_processing.pipeline import start_processing_config
+from database.download_handler import DownloadHandler, get_download_handler
+from database.queue_handler import QueueHandler, get_queue_handler
+from database.license_handler import LicenseHandler, get_license_handler
+from database.manifest_handler import ManifestHandler, get_manifest_handler
+from database.storage_handler import StorageHandler, get_storage_handler
+from config_processing.manifest import FirestoreManifest, Manifest
+from fastapi.concurrency import run_in_threadpool
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/download",
+ tags=["download"],
+ responses={404: {"description": "Not found"}},
+)
+
+
+async def fetch_config_stats(
+ config_name: str, client_name: str, status: str, manifest_handler: ManifestHandler
+):
+ """Get all the config stats parallely."""
+
+ success_coroutine = manifest_handler._get_download_success_count(config_name)
+ scheduled_coroutine = manifest_handler._get_download_scheduled_count(config_name)
+ failure_coroutine = manifest_handler._get_download_failure_count(config_name)
+ inprogress_coroutine = manifest_handler._get_download_inprogress_count(config_name)
+ total_coroutine = manifest_handler._get_download_total_count(config_name)
+
+ (
+ success_count,
+ scheduled_count,
+ failure_count,
+ inprogress_count,
+ total_count,
+ ) = await asyncio.gather(
+ success_coroutine,
+ scheduled_coroutine,
+ failure_coroutine,
+ inprogress_coroutine,
+ total_coroutine,
+ )
+
+ return {
+ "config_name": config_name,
+ "client_name": client_name,
+ "partitioning_status": status,
+ "downloaded_shards": success_count,
+ "scheduled_shards": scheduled_count,
+ "failed_shards": failure_count,
+ "in-progress_shards": inprogress_count,
+ "total_shards": total_count,
+ }
+
+
+def get_fetch_config_stats():
+ return fetch_config_stats
+
+
+def get_fetch_config_stats_mock():
+ async def fetch_config_stats(
+ config_name: str, client_name: str, status: str, manifest_handler: ManifestHandler
+ ):
+ return {
+ "config_name": config_name,
+ "client_name": client_name,
+ "downloaded_shards": 0,
+ "scheduled_shards": 0,
+ "failed_shards": 0,
+ "in-progress_shards": 0,
+ "total_shards": 0,
+ }
+
+ return fetch_config_stats
+
+
+def get_upload():
+ def upload(file: UploadFile):
+ dest = os.path.join(os.getcwd(), "config_files", file.filename)
+ with open(dest, "wb+") as dest_:
+ shutil.copyfileobj(file.file, dest_)
+
+ logger.info(f"Uploading {file.filename} to gcs bucket.")
+ storage_handler: StorageHandler = get_storage_handler()
+ storage_handler._upload_file(dest)
+ return dest
+
+ return upload
+
+
+def get_upload_mock():
+ def upload(file: UploadFile):
+ return f"{os.getcwd()}/tests/test_data/{file.filename}"
+
+ return upload
+
+
+def get_reschedule_partitions():
+ def invoke_manifest_schedule(
+ partition_list: list, config: Config, manifest: Manifest
+ ):
+ for partition in partition_list:
+ logger.info(f"Rescheduling partition {partition}.")
+ manifest.schedule(
+ config.config_name,
+ config.dataset,
+ json.loads(partition["selection"]),
+ partition["location"],
+ partition["username"],
+ )
+
+ async def reschedule_partitions(config_name: str, licenses: list):
+ manifest_handler: ManifestHandler = get_manifest_handler()
+ download_handler: DownloadHandler = get_download_handler()
+ queue_handler: QueueHandler = get_queue_handler()
+ storage_handler: StorageHandler = get_storage_handler()
+
+ partition_list = await manifest_handler._get_non_successfull_downloads(
+ config_name
+ )
+
+ config = None
+ manifest = FirestoreManifest()
+
+ with storage_handler._open_local(config_name) as local_path:
+ with open(local_path, "r", encoding="utf-8") as f:
+ config = process_config(f, config_name)
+
+ await download_handler._mark_partitioning_status(
+ config_name, "Partitioning in-progress."
+ )
+
+ try:
+ if config is None:
+ logger.error(
+ f"Failed reschedule_partitions. Could not open {config_name}."
+ )
+ raise FileNotFoundError(
+ f"Failed reschedule_partitions. Could not open {config_name}."
+ )
+
+ await run_in_threadpool(
+ invoke_manifest_schedule, partition_list, config, manifest
+ )
+ await download_handler._mark_partitioning_status(
+ config_name, "Partitioning completed."
+ )
+ await queue_handler._update_queues_on_start_download(config_name, licenses)
+ except Exception as e:
+ error_str = f"Partitioning failed for {config_name} due to {e}."
+ logger.error(error_str)
+ await download_handler._mark_partitioning_status(config_name, error_str)
+
+ return reschedule_partitions
+
+
+def get_reschedule_partitions_mock():
+ def reschedule_partitions(config_name: str, licenses: list):
+ pass
+
+ return reschedule_partitions
+
+
+# Can submit a config to the server.
+@router.post("/")
+async def submit_download(
+ file: UploadFile | None = None,
+ licenses: list = [],
+ force_download: bool = False,
+ background_tasks: BackgroundTasks = BackgroundTasks(),
+ download_handler: DownloadHandler = Depends(get_download_handler),
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ upload=Depends(get_upload),
+):
+ if not file:
+ logger.error("No upload file sent.")
+ raise HTTPException(status_code=404, detail="No upload file sent.")
+ else:
+ if await download_handler._check_download_exists(file.filename):
+ logger.error(
+ f"Please stop the ongoing download of the config file '{file.filename}' "
+ "before attempting to start a new download."
+ )
+ raise HTTPException(
+ status_code=400,
+ detail=f"Please stop the ongoing download of the config file '{file.filename}' "
+ "before attempting to start a new download.",
+ )
+
+ for license_id in licenses:
+ if not await license_handler._check_license_exists(license_id):
+ logger.info(f"No such license {license_id}.")
+ raise HTTPException(
+ status_code=404, detail=f"No such license {license_id}."
+ )
+ try:
+ dest = upload(file)
+ # Start processing config.
+ background_tasks.add_task(
+ start_processing_config, dest, licenses, force_download
+ )
+ return {
+ "message": f"file '{file.filename}' saved at '{dest}' successfully."
+ }
+ except Exception as e:
+ logger.error(f"Failed to save file '{file.filename} due to {e}.")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to save file '{file.filename}'."
+ )
+
+
+class DownloadStatus(str, Enum):
+ COMPLETED = "completed"
+ FAILED = "failed"
+ IN_PROGRESS = "in-progress"
+
+
+@router.get("/show/{config_name}")
+async def show_download_config(
+ config_name: str,
+ download_handler: DownloadHandler = Depends(get_download_handler),
+ storage_handler: StorageHandler = Depends(get_storage_handler),
+):
+ if not await download_handler._check_download_exists(config_name):
+ logger.error(f"No such download config {config_name} to show.")
+ raise HTTPException(
+ status_code=404,
+ detail=f"No such download config {config_name} to show.",
+ )
+
+ contents = None
+
+ with storage_handler._open_local(config_name) as local_path:
+ with open(local_path, "r", encoding="utf-8") as f:
+ contents = parse_config(f)
+ logger.info(f"Contents of {config_name}: {contents}.")
+
+ return {"config_name": config_name, "contents": contents}
+
+
+# Can check the current status of the submitted config.
+# List status for all the downloads + handle filters
+@router.get("/")
+async def get_downloads(
+ client_name: str | None = None,
+ status: DownloadStatus | None = None,
+ download_handler: DownloadHandler = Depends(get_download_handler),
+ manifest_handler: ManifestHandler = Depends(get_manifest_handler),
+ fetch_config_stats=Depends(get_fetch_config_stats),
+):
+ downloads = await download_handler._get_downloads(client_name)
+ coroutines = []
+
+ for download in downloads:
+ coroutines.append(
+ fetch_config_stats(
+ download["config_name"],
+ download["client_name"],
+ download["status"],
+ manifest_handler,
+ )
+ )
+
+ config_details = await asyncio.gather(*coroutines)
+
+ if status is None:
+ return config_details
+
+ if status.value == DownloadStatus.COMPLETED:
+ return list(
+ filter(
+ lambda detail: detail["downloaded_shards"] == detail["total_shards"],
+ config_details,
+ )
+ )
+ elif status.value == DownloadStatus.FAILED:
+ return list(filter(lambda detail: detail["failed_shards"] > 0, config_details))
+ elif status.value == DownloadStatus.IN_PROGRESS:
+ return list(
+ filter(
+ lambda detail: detail["downloaded_shards"] != detail["total_shards"],
+ config_details,
+ )
+ )
+ else:
+ return config_details
+
+
+# Get status of particular download
+@router.get("/{config_name}")
+async def get_download_by_config_name(
+ config_name: str,
+ download_handler: DownloadHandler = Depends(get_download_handler),
+ manifest_handler: ManifestHandler = Depends(get_manifest_handler),
+ fetch_config_stats=Depends(get_fetch_config_stats),
+):
+ download = await download_handler._get_download_by_config_name(config_name)
+
+ if download is None:
+ logger.error(f"Download config {config_name} not found in weather-dl v2.")
+ raise HTTPException(
+ status_code=404,
+ detail=f"Download config {config_name} not found in weather-dl v2.",
+ )
+
+ return await fetch_config_stats(
+ download["config_name"],
+ download["client_name"],
+ download["status"],
+ manifest_handler,
+ )
+
+
+# Stop & remove the execution of the config.
+@router.delete("/{config_name}")
+async def delete_download(
+ config_name: str,
+ download_handler: DownloadHandler = Depends(get_download_handler),
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+):
+ if not await download_handler._check_download_exists(config_name):
+ logger.error(f"No such download config {config_name} to stop & remove.")
+ raise HTTPException(
+ status_code=404,
+ detail=f"No such download config {config_name} to stop & remove.",
+ )
+
+ await download_handler._stop_download(config_name)
+ await queue_handler._update_queues_on_stop_download(config_name)
+ return {
+ "config_name": config_name,
+ "message": "Download config stopped & removed successfully.",
+ }
+
+
+@router.post("/retry/{config_name}")
+async def retry_config(
+ config_name: str,
+ licenses: list = Body(embed=True),
+ background_tasks: BackgroundTasks = BackgroundTasks(),
+ download_handler: DownloadHandler = Depends(get_download_handler),
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ reschedule_partitions=Depends(get_reschedule_partitions),
+):
+ if not await download_handler._check_download_exists(config_name):
+ logger.error(f"No such download config {config_name} to retry.")
+ raise HTTPException(
+ status_code=404,
+ detail=f"No such download config {config_name} to retry.",
+ )
+
+ for license_id in licenses:
+ if not await license_handler._check_license_exists(license_id):
+ logger.info(f"No such license {license_id}.")
+ raise HTTPException(
+ status_code=404, detail=f"No such license {license_id}."
+ )
+
+ background_tasks.add_task(reschedule_partitions, config_name, licenses)
+
+ return {"msg": "Refetch initiated successfully."}
diff --git a/weather_dl_v2/fastapi-server/routers/license.py b/weather_dl_v2/fastapi-server/routers/license.py
new file mode 100644
index 00000000..05ac5139
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/routers/license.py
@@ -0,0 +1,202 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import re
+from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends
+from pydantic import BaseModel
+from license_dep.deployment_creator import create_license_deployment, terminate_license_deployment
+from database.license_handler import LicenseHandler, get_license_handler
+from database.queue_handler import QueueHandler, get_queue_handler
+
+logger = logging.getLogger(__name__)
+
+
+class License(BaseModel):
+ license_id: str
+ client_name: str
+ number_of_requests: int
+ secret_id: str
+
+
+class LicenseInternal(License):
+ k8s_deployment_id: str
+
+
+# Can perform CRUD on license table -- helps in handling API KEY expiry.
+router = APIRouter(
+ prefix="/license",
+ tags=["license"],
+ responses={404: {"description": "Not found"}},
+)
+
+
+# Add/Update k8s deployment ID for existing license (intenally).
+async def update_license_internal(
+ license_id: str,
+ k8s_deployment_id: str,
+ license_handler: LicenseHandler,
+):
+ if not await license_handler._check_license_exists(license_id):
+ logger.info(f"No such license {license_id} to update.")
+ raise HTTPException(
+ status_code=404, detail=f"No such license {license_id} to update."
+ )
+ license_dict = {"k8s_deployment_id": k8s_deployment_id}
+
+ await license_handler._update_license(license_id, license_dict)
+ return {"license_id": license_id, "message": "License updated successfully."}
+
+
+def get_create_deployment():
+ async def create_deployment(license_id: str, license_handler: LicenseHandler):
+ k8s_deployment_id = create_license_deployment(license_id)
+ await update_license_internal(license_id, k8s_deployment_id, license_handler)
+
+ return create_deployment
+
+
+def get_create_deployment_mock():
+ async def create_deployment_mock(license_id: str, license_handler: LicenseHandler):
+ logger.info("create deployment mock.")
+
+ return create_deployment_mock
+
+
+def get_terminate_license_deployment():
+ return terminate_license_deployment
+
+
+def get_terminate_license_deployment_mock():
+ def get_terminate_license_deployment_mock(license_id):
+ logger.info(f"terminating license deployment for {license_id}.")
+
+ return get_terminate_license_deployment_mock
+
+
+# List all the license + handle filters of {client_name}
+@router.get("/")
+async def get_licenses(
+ client_name: str | None = None,
+ license_handler: LicenseHandler = Depends(get_license_handler),
+):
+ if client_name:
+ result = await license_handler._get_license_by_client_name(client_name)
+ else:
+ result = await license_handler._get_licenses()
+ return result
+
+
+# Get particular license
+@router.get("/{license_id}")
+async def get_license_by_license_id(
+ license_id: str, license_handler: LicenseHandler = Depends(get_license_handler)
+):
+ result = await license_handler._get_license_by_license_id(license_id)
+ if not result:
+ logger.info(f"License {license_id} not found.")
+ raise HTTPException(status_code=404, detail=f"License {license_id} not found.")
+ return result
+
+
+# Update existing license
+@router.put("/{license_id}")
+async def update_license(
+ license_id: str,
+ license: License,
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+ create_deployment=Depends(get_create_deployment),
+ terminate_license_deployment=Depends(get_terminate_license_deployment),
+):
+ if not await license_handler._check_license_exists(license_id):
+ logger.error(f"No such license {license_id} to update.")
+ raise HTTPException(
+ status_code=404, detail=f"No such license {license_id} to update."
+ )
+
+ license_dict = license.dict()
+ await license_handler._update_license(license_id, license_dict)
+ await queue_handler._update_client_name_in_license_queue(
+ license_id, license_dict["client_name"]
+ )
+
+ terminate_license_deployment(license_id)
+ await create_deployment(license_id, license_handler)
+ return {"license_id": license_id, "name": "License updated successfully."}
+
+
+# Add new license
+@router.post("/")
+async def add_license(
+ license: License,
+ background_tasks: BackgroundTasks = BackgroundTasks(),
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+ create_deployment=Depends(get_create_deployment),
+):
+ license_id = license.license_id.lower()
+
+ # Check if license id is in correct format.
+ LICENSE_REGEX = re.compile(
+ r"[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*"
+ )
+ if not bool(LICENSE_REGEX.fullmatch(license_id)):
+ logger.error(
+ """Invalid format for license_id. License id must consist of lower case alphanumeric"""
+ """ characters, '-' or '.', and must start and end with an alphanumeric character"""
+ )
+ raise HTTPException(
+ status_code=400,
+ detail="""Invalid format for license_id. License id must consist of lower case alphanumeric"""
+ """ characters, '-' or '.', and must start and end with an alphanumeric character""",
+ )
+
+ if await license_handler._check_license_exists(license_id):
+ logger.error(f"License with license_id {license_id} already exist.")
+ raise HTTPException(
+ status_code=409,
+ detail=f"License with license_id {license_id} already exist.",
+ )
+
+ license_dict = license.dict()
+ license_dict["k8s_deployment_id"] = ""
+ license_id = await license_handler._add_license(license_dict)
+ await queue_handler._create_license_queue(license_id, license_dict["client_name"])
+ background_tasks.add_task(create_deployment, license_id, license_handler)
+ return {"license_id": license_id, "message": "License added successfully."}
+
+
+# Remove license
+@router.delete("/{license_id}")
+async def delete_license(
+ license_id: str,
+ background_tasks: BackgroundTasks = BackgroundTasks(),
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+ terminate_license_deployment=Depends(get_terminate_license_deployment),
+):
+ if not await license_handler._check_license_exists(license_id):
+ logger.error(f"No such license {license_id} to delete.")
+ raise HTTPException(
+ status_code=404, detail=f"No such license {license_id} to delete."
+ )
+ await license_handler._delete_license(license_id)
+ await queue_handler._remove_license_queue(license_id)
+ background_tasks.add_task(terminate_license_deployment, license_id)
+ return {"license_id": license_id, "message": "License removed successfully."}
+
+
+# TODO: Add route to re-deploy license deployments.
diff --git a/weather_dl_v2/fastapi-server/routers/queues.py b/weather_dl_v2/fastapi-server/routers/queues.py
new file mode 100644
index 00000000..eda6a7c5
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/routers/queues.py
@@ -0,0 +1,124 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+
+from fastapi import APIRouter, HTTPException, Depends
+from database.queue_handler import QueueHandler, get_queue_handler
+from database.license_handler import LicenseHandler, get_license_handler
+from database.download_handler import DownloadHandler, get_download_handler
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/queues",
+ tags=["queues"],
+ responses={404: {"description": "Not found"}},
+)
+
+
+# Users can change the execution order of config per license basis.
+# List the licenses priority + {client_name} filter
+@router.get("/")
+async def get_all_license_queue(
+ client_name: str | None = None,
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+):
+ if client_name:
+ result = await queue_handler._get_queue_by_client_name(client_name)
+ else:
+ result = await queue_handler._get_queues()
+ return result
+
+
+# Get particular license priority
+@router.get("/{license_id}")
+async def get_license_queue(
+ license_id: str, queue_handler: QueueHandler = Depends(get_queue_handler)
+):
+ result = await queue_handler._get_queue_by_license_id(license_id)
+ if not result:
+ logger.error(f"License priority for {license_id} not found.")
+ raise HTTPException(
+ status_code=404, detail=f"License priority for {license_id} not found."
+ )
+ return result
+
+
+# Change priority queue of particular license
+@router.post("/{license_id}")
+async def modify_license_queue(
+ license_id: str,
+ priority_list: list | None = [],
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ download_handler: DownloadHandler = Depends(get_download_handler),
+):
+ if not await license_handler._check_license_exists(license_id):
+ logger.error(f"License {license_id} not found.")
+ raise HTTPException(status_code=404, detail=f"License {license_id} not found.")
+
+ for config_name in priority_list:
+ config = await download_handler._get_download_by_config_name(config_name)
+ if config is None:
+ logger.error(f"Download config {config_name} not found in weather-dl v2.")
+ raise HTTPException(
+ status_code=404,
+ detail=f"Download config {config_name} not found in weather-dl v2.",
+ )
+ try:
+ await queue_handler._update_license_queue(license_id, priority_list)
+ return {"message": f"'{license_id}' license priority updated successfully."}
+ except Exception as e:
+ logger.error(f"Failed to update '{license_id}' license priority due to {e}.")
+ raise HTTPException(
+ status_code=404, detail=f"Failed to update '{license_id}' license priority."
+ )
+
+
+# Change config's priority in particular license
+@router.put("/priority/{license_id}")
+async def modify_config_priority_in_license(
+ license_id: str,
+ config_name: str,
+ priority: int,
+ queue_handler: QueueHandler = Depends(get_queue_handler),
+ license_handler: LicenseHandler = Depends(get_license_handler),
+ download_handler: DownloadHandler = Depends(get_download_handler),
+):
+ if not await license_handler._check_license_exists(license_id):
+ logger.error(f"License {license_id} not found.")
+ raise HTTPException(status_code=404, detail=f"License {license_id} not found.")
+
+ config = await download_handler._get_download_by_config_name(config_name)
+ if config is None:
+ logger.error(f"Download config {config_name} not found in weather-dl v2.")
+ raise HTTPException(
+ status_code=404,
+ detail=f"Download config {config_name} not found in weather-dl v2.",
+ )
+
+ try:
+ await queue_handler._update_config_priority_in_license(
+ license_id, config_name, priority
+ )
+ return {
+ "message": f"'{license_id}' license -- '{config_name}' priority updated successfully."
+ }
+ except Exception as e:
+ logger.error(f"Failed to update '{license_id}' license priority due to {e}.")
+ raise HTTPException(
+ status_code=404, detail=f"Failed to update '{license_id}' license priority."
+ )
diff --git a/weather_dl_v2/fastapi-server/server.yaml b/weather_dl_v2/fastapi-server/server.yaml
new file mode 100644
index 00000000..b8a2f40d
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/server.yaml
@@ -0,0 +1,93 @@
+# Due to our org level policy we can't expose external-ip.
+# In case your project don't have any such restriction a
+# then no need to create a nginx-server on VM to access this fastapi server
+# instead create the LoadBalancer Service given below.
+#
+# # weather-dl server LoadBalancer Service
+# # Enables the pods in a deployment to be accessible from outside the cluster
+# apiVersion: v1
+# kind: Service
+# metadata:
+# name: weather-dl-v2-server-service
+# spec:
+# selector:
+# app: weather-dl-v2-server-api
+# ports:
+# - protocol: "TCP"
+# port: 8080
+# targetPort: 8080
+# type: LoadBalancer
+
+---
+# weather-dl-server-api Deployment
+# Defines the deployment of the app running in a pod on any worker node
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: weather-dl-v2-server-api
+ labels:
+ app: weather-dl-v2-server-api
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: weather-dl-v2-server-api
+ template:
+ metadata:
+ labels:
+ app: weather-dl-v2-server-api
+ spec:
+ containers:
+ - name: weather-dl-v2-server-api
+ image: XXXXXXX
+ ports:
+ - containerPort: 8080
+ imagePullPolicy: Always
+ volumeMounts:
+ - name: config-volume
+ mountPath: ./config
+ volumes:
+ - name: config-volume
+ configMap:
+ name: dl-v2-config
+ # resources:
+ # # You must specify requests for CPU to autoscale
+ # # based on CPU utilization
+ # requests:
+ # cpu: "250m"
+---
+kind: Role
+apiVersion: rbac.authorization.k8s.io/v1
+metadata:
+ name: weather-dl-v2-server-api
+rules:
+ - apiGroups:
+ - ""
+ - "apps"
+ - "batch"
+ resources:
+ - endpoints
+ - deployments
+ - pods
+ - jobs
+ verbs:
+ - get
+ - list
+ - watch
+ - create
+ - delete
+---
+kind: RoleBinding
+apiVersion: rbac.authorization.k8s.io/v1
+metadata:
+ name: weather-dl-v2-server-api
+ namespace: default
+subjects:
+ - kind: ServiceAccount
+ name: default
+ namespace: default
+roleRef:
+ apiGroup: rbac.authorization.k8s.io
+ kind: Role
+ name: weather-dl-v2-server-api
+---
\ No newline at end of file
diff --git a/weather_dl_v2/fastapi-server/server_config.py b/weather_dl_v2/fastapi-server/server_config.py
new file mode 100644
index 00000000..4ca8c21b
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/server_config.py
@@ -0,0 +1,72 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import dataclasses
+import typing as t
+import json
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet
+
+
+@dataclasses.dataclass
+class ServerConfig:
+ download_collection: str = ""
+ queues_collection: str = ""
+ license_collection: str = ""
+ manifest_collection: str = ""
+ storage_bucket: str = ""
+ gcs_project: str = ""
+ license_deployment_image: str = ""
+ welcome_message: str = ""
+ kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: t.Dict):
+ config_instance = cls()
+
+ for key, value in config.items():
+ if hasattr(config_instance, key):
+ setattr(config_instance, key, value)
+ else:
+ config_instance.kwargs[key] = value
+
+ return config_instance
+
+
+server_config = None
+
+
+def get_config():
+ global server_config
+ if server_config:
+ return server_config
+
+ server_config_json = "config/config.json"
+ if not os.path.exists(server_config_json):
+ server_config_json = os.environ.get("CONFIG_PATH", None)
+
+ if server_config_json is None:
+ logger.error("Couldn't load config file for fastAPI server.")
+ raise FileNotFoundError("Couldn't load config file for fastAPI server.")
+
+ with open(server_config_json) as file:
+ config_dict = json.load(file)
+ server_config = ServerConfig.from_dict(config_dict)
+
+ return server_config
diff --git a/weather_dl_v2/fastapi-server/tests/__init__.py b/weather_dl_v2/fastapi-server/tests/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/fastapi-server/tests/integration/__init__.py b/weather_dl_v2/fastapi-server/tests/integration/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/integration/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/fastapi-server/tests/integration/test_download.py b/weather_dl_v2/fastapi-server/tests/integration/test_download.py
new file mode 100644
index 00000000..fc707d10
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/integration/test_download.py
@@ -0,0 +1,175 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+from fastapi.testclient import TestClient
+from main import app, ROOT_DIR
+from database.download_handler import get_download_handler, get_mock_download_handler
+from database.license_handler import get_license_handler, get_mock_license_handler
+from database.queue_handler import get_queue_handler, get_mock_queue_handler
+from routers.download import get_upload, get_upload_mock, get_fetch_config_stats, get_fetch_config_stats_mock
+
+client = TestClient(app)
+
+logger = logging.getLogger(__name__)
+
+app.dependency_overrides[get_download_handler] = get_mock_download_handler
+app.dependency_overrides[get_license_handler] = get_mock_license_handler
+app.dependency_overrides[get_queue_handler] = get_mock_queue_handler
+app.dependency_overrides[get_upload] = get_upload_mock
+app.dependency_overrides[get_fetch_config_stats] = get_fetch_config_stats_mock
+
+
+def _get_download(headers, query, code, expected):
+ response = client.get("/download", headers=headers, params=query)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_get_downloads_basic():
+ headers = {}
+ query = {}
+ code = 200
+ expected = [{
+ "config_name": "example.cfg",
+ "client_name": "client",
+ "downloaded_shards": 0,
+ "scheduled_shards": 0,
+ "failed_shards": 0,
+ "in-progress_shards": 0,
+ "total_shards": 0,
+ }]
+
+ _get_download(headers, query, code, expected)
+
+
+def _submit_download(headers, file_path, licenses, code, expected):
+ file = None
+ try:
+ file = {"file": open(file_path, "rb")}
+ except FileNotFoundError:
+ logger.info("file not found.")
+
+ payload = {"licenses": licenses}
+
+ response = client.post("/download", headers=headers, files=file, data=payload)
+
+ logger.info(f"resp {response.json()}")
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_submit_download_basic():
+ header = {
+ "accept": "application/json",
+ }
+ file_path = os.path.join(ROOT_DIR, "tests/test_data/not_exist.cfg")
+ licenses = ["L1"]
+ code = 200
+ expected = {
+ "message": f"file 'not_exist.cfg' saved at '{os.getcwd()}/tests/test_data/not_exist.cfg' "
+ "successfully."
+ }
+
+ _submit_download(header, file_path, licenses, code, expected)
+
+
+def test_submit_download_file_not_uploaded():
+ header = {
+ "accept": "application/json",
+ }
+ file_path = os.path.join(ROOT_DIR, "tests/test_data/wrong_file.cfg")
+ licenses = ["L1"]
+ code = 404
+ expected = {"detail": "No upload file sent."}
+
+ _submit_download(header, file_path, licenses, code, expected)
+
+
+def test_submit_download_file_alreadys_exist():
+ header = {
+ "accept": "application/json",
+ }
+ file_path = os.path.join(ROOT_DIR, "tests/test_data/example.cfg")
+ licenses = ["L1"]
+ code = 400
+ expected = {
+ "detail": "Please stop the ongoing download of the config file 'example.cfg' before attempting to start a new download." # noqa: E501
+ }
+
+ _submit_download(header, file_path, licenses, code, expected)
+
+
+def _get_download_by_config(headers, config_name, code, expected):
+ response = client.get(f"/download/{config_name}", headers=headers)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_get_download_by_config_basic():
+ headers = {}
+ config_name = "example.cfg"
+ code = 200
+ expected = {
+ "config_name": config_name,
+ "client_name": "client",
+ "downloaded_shards": 0,
+ "scheduled_shards": 0,
+ "failed_shards": 0,
+ "in-progress_shards": 0,
+ "total_shards": 0,
+ }
+
+ _get_download_by_config(headers, config_name, code, expected)
+
+
+def test_get_download_by_config_wrong_config():
+ headers = {}
+ config_name = "not_exist"
+ code = 404
+ expected = {"detail": "Download config not_exist not found in weather-dl v2."}
+
+ _get_download_by_config(headers, config_name, code, expected)
+
+
+def _delete_download_by_config(headers, config_name, code, expected):
+ response = client.delete(f"/download/{config_name}", headers=headers)
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_delete_download_by_config_basic():
+ headers = {}
+ config_name = "dummy_config"
+ code = 200
+ expected = {
+ "config_name": "dummy_config",
+ "message": "Download config stopped & removed successfully.",
+ }
+
+ _delete_download_by_config(headers, config_name, code, expected)
+
+
+def test_delete_download_by_config_wrong_config():
+ headers = {}
+ config_name = "not_exist"
+ code = 404
+ expected = {"detail": "No such download config not_exist to stop & remove."}
+
+ _delete_download_by_config(headers, config_name, code, expected)
diff --git a/weather_dl_v2/fastapi-server/tests/integration/test_license.py b/weather_dl_v2/fastapi-server/tests/integration/test_license.py
new file mode 100644
index 00000000..f4a5dea7
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/integration/test_license.py
@@ -0,0 +1,207 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import json
+from fastapi.testclient import TestClient
+from main import app
+from database.download_handler import get_download_handler, get_mock_download_handler
+from database.license_handler import get_license_handler, get_mock_license_handler
+from routers.license import (
+ get_create_deployment,
+ get_create_deployment_mock,
+ get_terminate_license_deployment,
+ get_terminate_license_deployment_mock,
+)
+from database.queue_handler import get_queue_handler, get_mock_queue_handler
+
+client = TestClient(app)
+
+logger = logging.getLogger(__name__)
+
+app.dependency_overrides[get_download_handler] = get_mock_download_handler
+app.dependency_overrides[get_license_handler] = get_mock_license_handler
+app.dependency_overrides[get_queue_handler] = get_mock_queue_handler
+app.dependency_overrides[get_create_deployment] = get_create_deployment_mock
+app.dependency_overrides[
+ get_terminate_license_deployment
+] = get_terminate_license_deployment_mock
+
+
+def _get_license(headers, query, code, expected):
+ response = client.get("/license", headers=headers, params=query)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_get_license_basic():
+ headers = {}
+ query = {}
+ code = 200
+ expected = [{
+ "license_id": "L1",
+ "secret_id": "xxxx",
+ "client_name": "dummy_client",
+ "k8s_deployment_id": "k1",
+ "number_of_requets": 100,
+ }]
+
+ _get_license(headers, query, code, expected)
+
+
+def test_get_license_client_name():
+ headers = {}
+ client_name = "dummy_client"
+ query = {"client_name": client_name}
+ code = 200
+ expected = [{
+ "license_id": "L1",
+ "secret_id": "xxxx",
+ "client_name": client_name,
+ "k8s_deployment_id": "k1",
+ "number_of_requets": 100,
+ }]
+
+ _get_license(headers, query, code, expected)
+
+
+def _add_license(headers, payload, code, expected):
+ response = client.post(
+ "/license",
+ headers=headers,
+ data=json.dumps(payload),
+ params={"license_id": "L1"},
+ )
+
+ print(f"test add license {response.json()}")
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_add_license_basic():
+ headers = {"accept": "application/json", "Content-Type": "application/json"}
+ license = {
+ "license_id": "no-exists",
+ "client_name": "dummy_client",
+ "number_of_requests": 0,
+ "secret_id": "xxxx",
+ }
+ payload = license
+ code = 200
+ expected = {"license_id": "L1", "message": "License added successfully."}
+
+ _add_license(headers, payload, code, expected)
+
+
+def _get_license_by_license_id(headers, license_id, code, expected):
+ response = client.get(f"/license/{license_id}", headers=headers)
+
+ logger.info(f"response {response.json()}")
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_get_license_by_license_id():
+ headers = {"accept": "application/json", "Content-Type": "application/json"}
+ license_id = "L1"
+ code = 200
+ expected = {
+ "license_id": license_id,
+ "secret_id": "xxxx",
+ "client_name": "dummy_client",
+ "k8s_deployment_id": "k1",
+ "number_of_requets": 100,
+ }
+
+ _get_license_by_license_id(headers, license_id, code, expected)
+
+
+def test_get_license_wrong_license():
+ headers = {}
+ license_id = "not_exist"
+ code = 404
+ expected = {
+ "detail": "License not_exist not found.",
+ }
+
+ _get_license_by_license_id(headers, license_id, code, expected)
+
+
+def _update_license(headers, license_id, license, code, expected):
+ response = client.put(
+ f"/license/{license_id}", headers=headers, data=json.dumps(license)
+ )
+
+ print(f"_update license {response.json()}")
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_update_license_basic():
+ headers = {}
+ license_id = "L1"
+ license = {
+ "license_id": "L1",
+ "client_name": "dummy_client",
+ "number_of_requests": 0,
+ "secret_id": "xxxx",
+ }
+ code = 200
+ expected = {"license_id": license_id, "name": "License updated successfully."}
+
+ _update_license(headers, license_id, license, code, expected)
+
+
+def test_update_license_wrong_license_id():
+ headers = {}
+ license_id = "no-exists"
+ license = {
+ "license_id": "no-exists",
+ "client_name": "dummy_client",
+ "number_of_requests": 0,
+ "secret_id": "xxxx",
+ }
+ code = 404
+ expected = {"detail": "No such license no-exists to update."}
+
+ _update_license(headers, license_id, license, code, expected)
+
+
+def _delete_license(headers, license_id, code, expected):
+ response = client.delete(f"/license/{license_id}", headers=headers)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_delete_license_basic():
+ headers = {}
+ license_id = "L1"
+ code = 200
+ expected = {"license_id": license_id, "message": "License removed successfully."}
+
+ _delete_license(headers, license_id, code, expected)
+
+
+def test_delete_license_wrong_license():
+ headers = {}
+ license_id = "not_exist"
+ code = 404
+ expected = {"detail": "No such license not_exist to delete."}
+
+ _delete_license(headers, license_id, code, expected)
diff --git a/weather_dl_v2/fastapi-server/tests/integration/test_queues.py b/weather_dl_v2/fastapi-server/tests/integration/test_queues.py
new file mode 100644
index 00000000..5fa7855a
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/integration/test_queues.py
@@ -0,0 +1,148 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+from main import app
+from fastapi.testclient import TestClient
+from database.download_handler import get_download_handler, get_mock_download_handler
+from database.license_handler import get_license_handler, get_mock_license_handler
+from database.queue_handler import get_queue_handler, get_mock_queue_handler
+
+client = TestClient(app)
+
+logger = logging.getLogger(__name__)
+
+app.dependency_overrides[get_download_handler] = get_mock_download_handler
+app.dependency_overrides[get_license_handler] = get_mock_license_handler
+app.dependency_overrides[get_queue_handler] = get_mock_queue_handler
+
+
+def _get_all_queue(headers, query, code, expected):
+ response = client.get("/queues", headers=headers, params=query)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_get_all_queues():
+ headers = {}
+ query = {}
+ code = 200
+ expected = [{"client_name": "dummy_client", "license_id": "L1", "queue": []}]
+
+ _get_all_queue(headers, query, code, expected)
+
+
+def test_get_client_queues():
+ headers = {}
+ client_name = "dummy_client"
+ query = {"client_name": client_name}
+ code = 200
+ expected = [{"client_name": client_name, "license_id": "L1", "queue": []}]
+
+ _get_all_queue(headers, query, code, expected)
+
+
+def _get_queue_by_license(headers, license_id, code, expected):
+ response = client.get(f"/queues/{license_id}", headers=headers)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_get_queue_by_license_basic():
+ headers = {}
+ license_id = "L1"
+ code = 200
+ expected = {"client_name": "dummy_client", "license_id": license_id, "queue": []}
+
+ _get_queue_by_license(headers, license_id, code, expected)
+
+
+def test_get_queue_by_license_wrong_license():
+ headers = {}
+ license_id = "not_exist"
+ code = 404
+ expected = {"detail": 'License priority for not_exist not found.'}
+
+ _get_queue_by_license(headers, license_id, code, expected)
+
+
+def _modify_license_queue(headers, license_id, priority_list, code, expected):
+ response = client.post(f"/queues/{license_id}", headers=headers, data=priority_list)
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_modify_license_queue_basic():
+ headers = {}
+ license_id = "L1"
+ priority_list = []
+ code = 200
+ expected = {"message": f"'{license_id}' license priority updated successfully."}
+
+ _modify_license_queue(headers, license_id, priority_list, code, expected)
+
+
+def test_modify_license_queue_wrong_license_id():
+ headers = {}
+ license_id = "not_exist"
+ priority_list = []
+ code = 404
+ expected = {"detail": 'License not_exist not found.'}
+
+ _modify_license_queue(headers, license_id, priority_list, code, expected)
+
+
+def _modify_config_priority_in_license(headers, license_id, query, code, expected):
+ response = client.put(f"/queues/priority/{license_id}", params=query)
+
+ logger.info(f"response {response.json()}")
+
+ assert response.status_code == code
+ assert response.json() == expected
+
+
+def test_modify_config_priority_in_license_basic():
+ headers = {}
+ license_id = "L1"
+ query = {"config_name": "example.cfg", "priority": 0}
+ code = 200
+ expected = {
+ "message": f"'{license_id}' license -- 'example.cfg' priority updated successfully."
+ }
+
+ _modify_config_priority_in_license(headers, license_id, query, code, expected)
+
+
+def test_modify_config_priority_in_license_wrong_license():
+ headers = {}
+ license_id = "not_exist"
+ query = {"config_name": "example.cfg", "priority": 0}
+ code = 404
+ expected = {"detail": 'License not_exist not found.'}
+
+ _modify_config_priority_in_license(headers, license_id, query, code, expected)
+
+
+def test_modify_config_priority_in_license_wrong_config():
+ headers = {}
+ license_id = "not_exist"
+ query = {"config_name": "wrong.cfg", "priority": 0}
+ code = 404
+ expected = {"detail": 'License not_exist not found.'}
+
+ _modify_config_priority_in_license(headers, license_id, query, code, expected)
diff --git a/weather_dl_v2/fastapi-server/tests/test_data/example.cfg b/weather_dl_v2/fastapi-server/tests/test_data/example.cfg
new file mode 100644
index 00000000..6747012c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/test_data/example.cfg
@@ -0,0 +1,32 @@
+[parameters]
+client=mars
+
+target_path=gs:///test-weather-dl-v2/{date}T00z.gb
+partition_keys=
+ date
+ # step
+
+# API Keys & Subsections go here...
+
+[selection]
+class=od
+type=pf
+stream=enfo
+expver=0001
+levtype=pl
+levelist=100
+# params:
+# (z) Geopotential 129, (t) Temperature 130,
+# (u) U component of wind 131, (v) V component of wind 132,
+# (q) Specific humidity 133, (w) vertical velocity 135,
+# (vo) Vorticity (relative) 138, (d) Divergence 155,
+# (r) Relative humidity 157
+param=129.128
+#
+# next: 2019-01-01/to/existing
+#
+date=2019-07-18/to/2019-07-20
+time=0000
+step=0/to/2
+number=1/to/2
+grid=F640
diff --git a/weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg b/weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg
new file mode 100644
index 00000000..6747012c
--- /dev/null
+++ b/weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg
@@ -0,0 +1,32 @@
+[parameters]
+client=mars
+
+target_path=gs:///test-weather-dl-v2/{date}T00z.gb
+partition_keys=
+ date
+ # step
+
+# API Keys & Subsections go here...
+
+[selection]
+class=od
+type=pf
+stream=enfo
+expver=0001
+levtype=pl
+levelist=100
+# params:
+# (z) Geopotential 129, (t) Temperature 130,
+# (u) U component of wind 131, (v) V component of wind 132,
+# (q) Specific humidity 133, (w) vertical velocity 135,
+# (vo) Vorticity (relative) 138, (d) Divergence 155,
+# (r) Relative humidity 157
+param=129.128
+#
+# next: 2019-01-01/to/existing
+#
+date=2019-07-18/to/2019-07-20
+time=0000
+step=0/to/2
+number=1/to/2
+grid=F640
diff --git a/weather_dl_v2/license_deployment/Dockerfile b/weather_dl_v2/license_deployment/Dockerfile
new file mode 100644
index 00000000..68388f78
--- /dev/null
+++ b/weather_dl_v2/license_deployment/Dockerfile
@@ -0,0 +1,34 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+FROM continuumio/miniconda3:latest
+
+# Update miniconda
+RUN conda update conda -y
+
+# Add the mamba solver for faster builds
+RUN conda install -n base conda-libmamba-solver
+RUN conda config --set solver libmamba
+
+COPY . .
+# Create conda env using environment.yml
+RUN conda env create -f environment.yml --debug
+
+# Activate the conda env and update the PATH
+ARG CONDA_ENV_NAME=weather-dl-v2-license-dep
+RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc
+ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH
+
+ENTRYPOINT ["python", "-u", "fetch.py"]
diff --git a/weather_dl_v2/license_deployment/README.md b/weather_dl_v2/license_deployment/README.md
new file mode 100644
index 00000000..4c5cc6a1
--- /dev/null
+++ b/weather_dl_v2/license_deployment/README.md
@@ -0,0 +1,21 @@
+# Deployment Instructions & General Notes
+
+### How to create environment
+```
+conda env create --name weather-dl-v2-license-dep --file=environment.yml
+
+conda activate weather-dl-v2-license-dep
+```
+
+### Make changes in weather_dl_v2/config.json, if required [for running locally]
+```
+export CONFIG_PATH=/path/to/weather_dl_v2/config.json
+```
+
+### Create docker image for license deployment
+```
+export PROJECT_ID=
+export REPO= eg:weather-tools
+
+gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-license-dep" --timeout=79200 --machine-type=e2-highcpu-32
+```
\ No newline at end of file
diff --git a/weather_dl_v2/license_deployment/__init__.py b/weather_dl_v2/license_deployment/__init__.py
new file mode 100644
index 00000000..5678014c
--- /dev/null
+++ b/weather_dl_v2/license_deployment/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/weather_dl_v2/license_deployment/clients.py b/weather_dl_v2/license_deployment/clients.py
new file mode 100644
index 00000000..331888ea
--- /dev/null
+++ b/weather_dl_v2/license_deployment/clients.py
@@ -0,0 +1,417 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""ECMWF Downloader Clients."""
+
+import abc
+import collections
+import contextlib
+import datetime
+import io
+import logging
+import os
+import time
+import typing as t
+import warnings
+from urllib.parse import urljoin
+
+from cdsapi import api as cds_api
+import urllib3
+from ecmwfapi import api
+
+from config import optimize_selection_partition
+from manifest import Manifest, Stage
+from util import download_with_aria2, retry_with_exponential_backoff
+
+warnings.simplefilter("ignore", category=urllib3.connectionpool.InsecureRequestWarning)
+
+
+class Client(abc.ABC):
+ """Weather data provider client interface.
+
+ Defines methods and properties required to efficiently interact with weather
+ data providers.
+
+ Attributes:
+ config: A config that contains pipeline parameters, such as API keys.
+ level: Default log level for the client.
+ """
+
+ def __init__(self, dataset: str, level: int = logging.INFO) -> None:
+ """Clients are initialized with the general CLI configuration."""
+ self.dataset = dataset
+ self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}")
+ self.logger.setLevel(level)
+
+ @abc.abstractmethod
+ def retrieve(
+ self, dataset: str, selection: t.Dict, output: str, manifest: Manifest
+ ) -> None:
+ """Download from data source."""
+ pass
+
+ @classmethod
+ @abc.abstractmethod
+ def num_requests_per_key(cls, dataset: str) -> int:
+ """Specifies the number of workers to be used per api key for the dataset."""
+ pass
+
+ @property
+ @abc.abstractmethod
+ def license_url(self):
+ """Specifies the License URL."""
+ pass
+
+
+class SplitCDSRequest(cds_api.Client):
+ """Extended CDS class that separates fetch and download stage."""
+
+ @retry_with_exponential_backoff
+ def _download(self, url, path: str, size: int) -> None:
+ self.info("Downloading %s to %s (%s)", url, path, cds_api.bytes_to_string(size))
+ start = time.time()
+
+ download_with_aria2(url, path)
+
+ elapsed = time.time() - start
+ if elapsed:
+ self.info("Download rate %s/s", cds_api.bytes_to_string(size / elapsed))
+
+ def fetch(self, request: t.Dict, dataset: str) -> t.Dict:
+ result = self.retrieve(dataset, request)
+ return {"href": result.location, "size": result.content_length}
+
+ def download(self, result: cds_api.Result, target: t.Optional[str] = None) -> None:
+ if target:
+ if os.path.exists(target):
+ # Empty the target file, if it already exists, otherwise the
+ # transfer below might be fooled into thinking we're resuming
+ # an interrupted download.
+ open(target, "w").close()
+
+ self._download(result["href"], target, result["size"])
+
+
+class CdsClient(Client):
+ """A client to access weather data from the Cloud Data Store (CDS).
+
+ Datasets on CDS can be found at:
+ https://cds.climate.copernicus.eu/cdsapp#!/search?type=dataset
+
+ The parameters section of the input `config` requires two values: `api_url` and
+ `api_key`. Or, these values can be set as the environment variables: `CDSAPI_URL`
+ and `CDSAPI_KEY`. These can be acquired from the following URL, which requires
+ creating a free account: https://cds.climate.copernicus.eu/api-how-to
+
+ The CDS global queues for data access has dynamic rate limits. These can be viewed
+ live here: https://cds.climate.copernicus.eu/live/limits.
+
+ Attributes:
+ config: A config that contains pipeline parameters, such as API keys.
+ level: Default log level for the client.
+ """
+
+ """Name patterns of datasets that are hosted internally on CDS servers."""
+ cds_hosted_datasets = {"reanalysis-era"}
+
+ def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None:
+ c = CDSClientExtended(
+ url=os.environ.get("CLIENT_URL"),
+ key=os.environ.get("CLIENT_KEY"),
+ debug_callback=self.logger.debug,
+ info_callback=self.logger.info,
+ warning_callback=self.logger.warning,
+ error_callback=self.logger.error,
+ )
+ selection_ = optimize_selection_partition(selection)
+ with StdoutLogger(self.logger, level=logging.DEBUG):
+ manifest.set_stage(Stage.FETCH)
+ precise_fetch_start_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ manifest.prev_stage_precise_start_time = precise_fetch_start_time
+ result = c.fetch(selection_, dataset)
+ return result
+
+ @property
+ def license_url(self):
+ return "https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf"
+
+ @classmethod
+ def num_requests_per_key(cls, dataset: str) -> int:
+ """Number of requests per key from the CDS API.
+
+ CDS has dynamic, data-specific limits, defined here:
+ https://cds.climate.copernicus.eu/live/limits
+
+ Typically, the reanalysis dataset allows for 3-5 simultaneous requets.
+ For all standard CDS data (backed on disk drives), it's common that 2
+ requests are allowed, though this is dynamically set, too.
+
+ If the Beam pipeline encounters a user request limit error, please cancel
+ all outstanding requests (per each user account) at the following link:
+ https://cds.climate.copernicus.eu/cdsapp#!/yourrequests
+ """
+ # TODO(#15): Parse live CDS limits API to set data-specific limits.
+ for internal_set in cls.cds_hosted_datasets:
+ if dataset.startswith(internal_set):
+ return 5
+ return 2
+
+
+class StdoutLogger(io.StringIO):
+ """Special logger to redirect stdout to logs."""
+
+ def __init__(self, logger_: logging.Logger, level: int = logging.INFO):
+ super().__init__()
+ self.logger = logger_
+ self.level = level
+ self._redirector = contextlib.redirect_stdout(self)
+
+ def log(self, msg) -> None:
+ self.logger.log(self.level, msg)
+
+ def write(self, msg):
+ if msg and not msg.isspace():
+ self.logger.log(self.level, msg)
+
+ def __enter__(self):
+ self._redirector.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ # let contextlib do any exception handling here
+ self._redirector.__exit__(exc_type, exc_value, traceback)
+
+
+class SplitMARSRequest(api.APIRequest):
+ """Extended MARS APIRequest class that separates fetch and download stage."""
+
+ @retry_with_exponential_backoff
+ def _download(self, url, path: str, size: int) -> None:
+ self.log("Transferring %s into %s" % (self._bytename(size), path))
+ self.log("From %s" % (url,))
+
+ download_with_aria2(url, path)
+
+ def fetch(self, request: t.Dict, dataset: str) -> t.Dict:
+ status = None
+
+ self.connection.submit("%s/%s/requests" % (self.url, self.service), request)
+ self.log("Request submitted")
+ self.log("Request id: " + self.connection.last.get("name"))
+ if self.connection.status != status:
+ status = self.connection.status
+ self.log("Request is %s" % (status,))
+
+ while not self.connection.ready():
+ if self.connection.status != status:
+ status = self.connection.status
+ self.log("Request is %s" % (status,))
+ self.connection.wait()
+
+ if self.connection.status != status:
+ status = self.connection.status
+ self.log("Request is %s" % (status,))
+
+ result = self.connection.result()
+ return result
+
+ def download(self, result: t.Dict, target: t.Optional[str] = None) -> None:
+ if target:
+ if os.path.exists(target):
+ # Empty the target file, if it already exists, otherwise the
+ # transfer below might be fooled into thinking we're resuming
+ # an interrupted download.
+ open(target, "w").close()
+
+ self._download(urljoin(self.url, result["href"]), target, result["size"])
+ self.connection.cleanup()
+
+
+class SplitRequestMixin:
+ c = None
+
+ def fetch(self, req: t.Dict, dataset: t.Optional[str] = None) -> t.Dict:
+ return self.c.fetch(req, dataset)
+
+ def download(self, res: t.Dict, target: str) -> None:
+ self.c.download(res, target)
+
+
+class CDSClientExtended(SplitRequestMixin):
+ """Extended CDS Client class that separates fetch and download stage."""
+
+ def __init__(self, *args, **kwargs):
+ self.c = SplitCDSRequest(*args, **kwargs)
+
+
+class MARSECMWFServiceExtended(api.ECMWFService, SplitRequestMixin):
+ """Extended MARS ECMFService class that separates fetch and download stage."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.c = SplitMARSRequest(
+ self.url,
+ "services/%s" % (self.service,),
+ email=self.email,
+ key=self.key,
+ log=self.log,
+ verbose=self.verbose,
+ quiet=self.quiet,
+ )
+
+
+class PublicECMWFServerExtended(api.ECMWFDataServer, SplitRequestMixin):
+
+ def __init__(self, *args, dataset="", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.c = SplitMARSRequest(
+ self.url,
+ "datasets/%s" % (dataset,),
+ email=self.email,
+ key=self.key,
+ log=self.log,
+ verbose=self.verbose,
+ )
+
+
+class MarsClient(Client):
+ """A client to access data from the Meteorological Archival and Retrieval System (MARS).
+
+ See https://www.ecmwf.int/en/forecasts/datasets for a summary of datasets available
+ on MARS. Most notable, MARS provides access to ECMWF's Operational Archive
+ https://www.ecmwf.int/en/forecasts/dataset/operational-archive.
+
+ The client config must contain three parameters to autheticate access to the MARS archive:
+ `api_key`, `api_url`, and `api_email`. These can also be configued by setting the
+ commensurate environment variables: `MARSAPI_KEY`, `MARSAPI_URL`, and `MARSAPI_EMAIL`.
+ These credentials can be looked up by after registering for an ECMWF account
+ (https://apps.ecmwf.int/registration/) and visitng: https://api.ecmwf.int/v1/key/.
+
+ MARS server activity can be observed at https://apps.ecmwf.int/mars-activity/.
+
+ Attributes:
+ config: A config that contains pipeline parameters, such as API keys.
+ level: Default log level for the client.
+ """
+
+ def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None:
+ c = MARSECMWFServiceExtended(
+ "mars",
+ key=os.environ.get("CLIENT_KEY"),
+ url=os.environ.get("CLIENT_URL"),
+ email=os.environ.get("CLIENT_EMAIL"),
+ log=self.logger.debug,
+ verbose=True,
+ )
+ selection_ = optimize_selection_partition(selection)
+ with StdoutLogger(self.logger, level=logging.DEBUG):
+ manifest.set_stage(Stage.FETCH)
+ precise_fetch_start_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ manifest.prev_stage_precise_start_time = precise_fetch_start_time
+ result = c.fetch(req=selection_)
+ return result
+
+ @property
+ def license_url(self):
+ return "https://apps.ecmwf.int/datasets/licences/general/"
+
+ @classmethod
+ def num_requests_per_key(cls, dataset: str) -> int:
+ """Number of requests per key (or user) for the Mars API.
+
+ Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021.
+ To ensure we never hit a rate limit error during download, we only make use of the active
+ requests.
+ See: https://confluence.ecmwf.int/display/UDOC/Total+number+of+requests+a+user+can+submit+-+Web+API+FAQ
+
+ Queued requests can _only_ be canceled manually from a web dashboard. If the
+ `ERROR 101 (USER_QUEUED_LIMIT_EXCEEDED)` error occurs in the Beam pipeline, then go to
+ http://apps.ecmwf.int/webmars/joblist/ and cancel queued jobs.
+ """
+ return 2
+
+
+class ECMWFPublicClient(Client):
+ """A client for ECMWF's public datasets, like TIGGE."""
+
+ def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None:
+ c = PublicECMWFServerExtended(
+ url=os.environ.get("CLIENT_URL"),
+ key=os.environ.get("CLIENT_KEY"),
+ email=os.environ.get("CLIENT_EMAIL"),
+ log=self.logger.debug,
+ verbose=True,
+ dataset=dataset,
+ )
+ selection_ = optimize_selection_partition(selection)
+ with StdoutLogger(self.logger, level=logging.DEBUG):
+ manifest.set_stage(Stage.FETCH)
+ precise_fetch_start_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ manifest.prev_stage_precise_start_time = precise_fetch_start_time
+ result = c.fetch(req=selection_)
+ return result
+
+ @classmethod
+ def num_requests_per_key(cls, dataset: str) -> int:
+ # Experimentally validated request limit.
+ return 5
+
+ @property
+ def license_url(self):
+ if not self.dataset:
+ raise ValueError("must specify a dataset for this client!")
+ return f"https://apps.ecmwf.int/datasets/data/{self.dataset.lower()}/licence/"
+
+
+class FakeClient(Client):
+ """A client that writes the selection arguments to the output file."""
+
+ def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None:
+ manifest.set_stage(Stage.RETRIEVE)
+ precise_retrieve_start_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ manifest.prev_stage_precise_start_time = precise_retrieve_start_time
+ self.logger.debug(f"Downloading {dataset}.")
+
+ @property
+ def license_url(self):
+ return "lorem ipsum"
+
+ @classmethod
+ def num_requests_per_key(cls, dataset: str) -> int:
+ return 1
+
+
+CLIENTS = collections.OrderedDict(
+ cds=CdsClient,
+ mars=MarsClient,
+ ecpublic=ECMWFPublicClient,
+ fake=FakeClient,
+)
diff --git a/weather_dl_v2/license_deployment/config.py b/weather_dl_v2/license_deployment/config.py
new file mode 100644
index 00000000..fe2199b8
--- /dev/null
+++ b/weather_dl_v2/license_deployment/config.py
@@ -0,0 +1,120 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import calendar
+import copy
+import dataclasses
+import typing as t
+
+Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet
+
+
+@dataclasses.dataclass
+class Config:
+ """Contains pipeline parameters.
+
+ Attributes:
+ config_name:
+ Name of the config file.
+ client:
+ Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable.
+ dataset (optional):
+ Name of the target dataset. Allowed options are dictated by the client.
+ partition_keys (optional):
+ Choose the keys from the selection section to partition the data request.
+ This will compute a cartesian cross product of the selected keys
+ and assign each as their own download.
+ target_path:
+ Download artifact filename template. Can make use of Python's standard string formatting.
+ It can contain format symbols to be replaced by partition keys;
+ if this is used, the total number of format symbols must match the number of partition keys.
+ subsection_name:
+ Name of the particular subsection. 'default' if there is no subsection.
+ force_download:
+ Force redownload of partitions that were previously downloaded.
+ user_id:
+ Username from the environment variables.
+ kwargs (optional):
+ For representing subsections or any other parameters.
+ selection:
+ Contains parameters used to select desired data.
+ """
+
+ config_name: str = ""
+ client: str = ""
+ dataset: t.Optional[str] = ""
+ target_path: str = ""
+ partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list)
+ subsection_name: str = "default"
+ force_download: bool = False
+ user_id: str = "unknown"
+ kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict)
+ selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: t.Dict) -> "Config":
+ config_instance = cls()
+ for section_key, section_value in config.items():
+ if section_key == "parameters":
+ for key, value in section_value.items():
+ if hasattr(config_instance, key):
+ setattr(config_instance, key, value)
+ else:
+ config_instance.kwargs[key] = value
+ if section_key == "selection":
+ config_instance.selection = section_value
+ return config_instance
+
+
+def optimize_selection_partition(selection: t.Dict) -> t.Dict:
+ """Compute right-hand-side values for the selection section of a single partition.
+
+ Used to support custom syntax and optimizations, such as 'all'.
+ """
+ selection_ = copy.deepcopy(selection)
+
+ if "day" in selection_.keys() and selection_["day"] == "all":
+ year, month = selection_["year"], selection_["month"]
+
+ multiples_error = (
+ "Cannot use keyword 'all' on selections with multiple '{type}'s."
+ )
+
+ if isinstance(year, list):
+ assert len(year) == 1, multiples_error.format(type="year")
+ year = year[0]
+
+ if isinstance(month, list):
+ assert len(month) == 1, multiples_error.format(type="month")
+ month = month[0]
+
+ if isinstance(year, str):
+ assert "/" not in year, multiples_error.format(type="year")
+
+ if isinstance(month, str):
+ assert "/" not in month, multiples_error.format(type="month")
+
+ year, month = int(year), int(month)
+
+ _, n_days_in_month = calendar.monthrange(year, month)
+
+ selection_[
+ "date"
+ ] = f"{year:04d}-{month:02d}-01/to/{year:04d}-{month:02d}-{n_days_in_month:02d}"
+ del selection_["day"]
+ del selection_["month"]
+ del selection_["year"]
+
+ return selection_
diff --git a/weather_dl_v2/license_deployment/database.py b/weather_dl_v2/license_deployment/database.py
new file mode 100644
index 00000000..24206561
--- /dev/null
+++ b/weather_dl_v2/license_deployment/database.py
@@ -0,0 +1,176 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import time
+import logging
+import firebase_admin
+from firebase_admin import firestore
+from firebase_admin import credentials
+from google.cloud.firestore_v1 import DocumentSnapshot, DocumentReference
+from google.cloud.firestore_v1.types import WriteResult
+from google.cloud.firestore_v1.base_query import FieldFilter, And
+from util import get_wait_interval
+from deployment_config import get_config
+
+logger = logging.getLogger(__name__)
+
+
+class Database(abc.ABC):
+
+ @abc.abstractmethod
+ def _get_db(self):
+ pass
+
+
+class CRUDOperations(abc.ABC):
+
+ @abc.abstractmethod
+ def _initialize_license_deployment(self, license_id: str) -> dict:
+ pass
+
+ @abc.abstractmethod
+ def _get_config_from_queue_by_license_id(self, license_id: str) -> dict:
+ pass
+
+ @abc.abstractmethod
+ def _remove_config_from_license_queue(
+ self, license_id: str, config_name: str
+ ) -> None:
+ pass
+
+ @abc.abstractmethod
+ def _empty_license_queue(self, license_id: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ def _get_partition_from_manifest(self, config_name: str) -> str:
+ pass
+
+
+class FirestoreClient(Database, CRUDOperations):
+
+ def _get_db(self) -> firestore.firestore.Client:
+ """Acquire a firestore client, initializing the firebase app if necessary.
+ Will attempt to get the db client five times. If it's still unsuccessful, a
+ `ManifestException` will be raised.
+ """
+ db = None
+ attempts = 0
+
+ while db is None:
+ try:
+ db = firestore.client()
+ except ValueError as e:
+ # The above call will fail with a value error when the firebase app is not initialized.
+ # Initialize the app here, and try again.
+ # Use the application default credentials.
+ cred = credentials.ApplicationDefault()
+
+ firebase_admin.initialize_app(cred)
+ logger.info("Initialized Firebase App.")
+
+ if attempts > 4:
+ raise RuntimeError(
+ "Exceeded number of retries to get firestore client."
+ ) from e
+
+ time.sleep(get_wait_interval(attempts))
+
+ attempts += 1
+
+ return db
+
+ def _initialize_license_deployment(self, license_id: str) -> dict:
+ result: DocumentSnapshot = (
+ self._get_db()
+ .collection(get_config().license_collection)
+ .document(license_id)
+ .get()
+ )
+ return result.to_dict()
+
+ def _get_config_from_queue_by_license_id(self, license_id: str) -> str | None:
+ result: DocumentSnapshot = (
+ self._get_db()
+ .collection(get_config().queues_collection)
+ .document(license_id)
+ .get(["queue"])
+ )
+ if result.exists:
+ queue = result.to_dict()["queue"]
+ if len(queue) > 0:
+ return queue[0]
+ return None
+
+ def _get_partition_from_manifest(self, config_name: str) -> str | None:
+ transaction = self._get_db().transaction()
+ return get_partition_from_manifest(transaction, config_name)
+
+ def _remove_config_from_license_queue(
+ self, license_id: str, config_name: str
+ ) -> None:
+ result: WriteResult = (
+ self._get_db()
+ .collection(get_config().queues_collection)
+ .document(license_id)
+ .update({"queue": firestore.ArrayRemove([config_name])})
+ )
+ logger.info(
+ f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+ def _empty_license_queue(self, license_id: str) -> None:
+ result: WriteResult = (
+ self._get_db()
+ .collection(get_config().queues_collection)
+ .document(license_id)
+ .update({"queue": []})
+ )
+ logger.info(
+ f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
+ )
+
+
+# TODO: Firestore transcational fails after reading a document 20 times with roll over.
+# This happens when too many licenses try to access the same partition document.
+# Find some alternative approach to handle this.
+@firestore.transactional
+def get_partition_from_manifest(transaction, config_name: str) -> str | None:
+ db_client = FirestoreClient()
+ filter_1 = FieldFilter("config_name", "==", config_name)
+ filter_2 = FieldFilter("status", "==", "scheduled")
+ and_filter = And(filters=[filter_1, filter_2])
+
+ snapshot = (
+ db_client._get_db()
+ .collection(get_config().manifest_collection)
+ .where(filter=and_filter)
+ .limit(1)
+ .get(transaction=transaction)
+ )
+ if len(snapshot) > 0:
+ snapshot = snapshot[0]
+ else:
+ return None
+
+ ref: DocumentReference = (
+ db_client._get_db()
+ .collection(get_config().manifest_collection)
+ .document(snapshot.id)
+ )
+ transaction.update(ref, {"status": "processing"})
+
+ return snapshot.to_dict()
diff --git a/weather_dl_v2/license_deployment/deployment_config.py b/weather_dl_v2/license_deployment/deployment_config.py
new file mode 100644
index 00000000..8ae162ea
--- /dev/null
+++ b/weather_dl_v2/license_deployment/deployment_config.py
@@ -0,0 +1,69 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import dataclasses
+import typing as t
+import json
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet
+
+
+@dataclasses.dataclass
+class DeploymentConfig:
+ download_collection: str = ""
+ queues_collection: str = ""
+ license_collection: str = ""
+ manifest_collection: str = ""
+ downloader_k8_image: str = ""
+ kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, config: t.Dict):
+ config_instance = cls()
+
+ for key, value in config.items():
+ if hasattr(config_instance, key):
+ setattr(config_instance, key, value)
+ else:
+ config_instance.kwargs[key] = value
+
+ return config_instance
+
+
+deployment_config = None
+
+
+def get_config():
+ global deployment_config
+ if deployment_config:
+ return deployment_config
+
+ deployment_config_json = "config/config.json"
+ if not os.path.exists(deployment_config_json):
+ deployment_config_json = os.environ.get("CONFIG_PATH", None)
+
+ if deployment_config_json is None:
+ logger.error("Couldn't load config file for license deployment.")
+ raise FileNotFoundError("Couldn't load config file for license deployment.")
+
+ with open(deployment_config_json) as file:
+ config_dict = json.load(file)
+ deployment_config = DeploymentConfig.from_dict(config_dict)
+
+ return deployment_config
diff --git a/weather_dl_v2/license_deployment/downloader.yaml b/weather_dl_v2/license_deployment/downloader.yaml
new file mode 100644
index 00000000..361c2b36
--- /dev/null
+++ b/weather_dl_v2/license_deployment/downloader.yaml
@@ -0,0 +1,33 @@
+apiVersion: batch/v1
+kind: Job
+metadata:
+ name: downloader-with-ttl
+spec:
+ ttlSecondsAfterFinished: 0
+ template:
+ spec:
+ nodeSelector:
+ cloud.google.com/gke-nodepool: downloader-pool
+ containers:
+ - name: downloader
+ image: XXXXXXX
+ imagePullPolicy: Always
+ command: []
+ resources:
+ requests:
+ cpu: "1000m" # CPU: 1 vCPU
+ memory: "2Gi" # RAM: 2 GiB
+ ephemeral-storage: "100Gi" # Storage: 100 GiB
+ volumeMounts:
+ - name: data
+ mountPath: /data
+ - name: config-volume
+ mountPath: ./config
+ restartPolicy: Never
+ volumes:
+ - name: data
+ emptyDir:
+ sizeLimit: 100Gi
+ - name: config-volume
+ configMap:
+ name: dl-v2-config
\ No newline at end of file
diff --git a/weather_dl_v2/license_deployment/environment.yml b/weather_dl_v2/license_deployment/environment.yml
new file mode 100644
index 00000000..4848fafd
--- /dev/null
+++ b/weather_dl_v2/license_deployment/environment.yml
@@ -0,0 +1,17 @@
+name: weather-dl-v2-license-dep
+channels:
+ - conda-forge
+dependencies:
+ - python=3.10
+ - geojson
+ - cdsapi=0.5.1
+ - ecmwf-api-client=1.6.3
+ - pip=22.3
+ - pip:
+ - kubernetes
+ - google-cloud-secret-manager
+ - aiohttp
+ - numpy
+ - xarray
+ - apache-beam[gcp]
+ - firebase-admin
diff --git a/weather_dl_v2/license_deployment/fetch.py b/weather_dl_v2/license_deployment/fetch.py
new file mode 100644
index 00000000..3e69a56f
--- /dev/null
+++ b/weather_dl_v2/license_deployment/fetch.py
@@ -0,0 +1,184 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from concurrent.futures import ThreadPoolExecutor
+from google.cloud import secretmanager
+import json
+import logging
+import time
+import sys
+import os
+
+from database import FirestoreClient
+from job_creator import create_download_job
+from clients import CLIENTS
+from manifest import FirestoreManifest
+from util import exceptionit, ThreadSafeDict
+
+db_client = FirestoreClient()
+secretmanager_client = secretmanager.SecretManagerServiceClient()
+CONFIG_MAX_ERROR_COUNT = 10
+
+def create_job(request, result):
+ res = {
+ "config_name": request["config_name"],
+ "dataset": request["dataset"],
+ "selection": json.loads(request["selection"]),
+ "user_id": request["username"],
+ "url": result["href"],
+ "target_path": request["location"],
+ "license_id": license_id,
+ }
+
+ data_str = json.dumps(res)
+ logger.info(f"Creating download job for res: {data_str}")
+ create_download_job(data_str)
+
+
+@exceptionit
+def make_fetch_request(request, error_map: ThreadSafeDict):
+ client = CLIENTS[client_name](request["dataset"])
+ manifest = FirestoreManifest(license_id=license_id)
+ logger.info(
+ f"By using {client_name} datasets, "
+ f"users agree to the terms and conditions specified in {client.license_url!r}."
+ )
+
+ target = request["location"]
+ selection = json.loads(request["selection"])
+
+ logger.info(f"Fetching data for {target!r}.")
+
+ config_name = request["config_name"]
+
+ if not error_map.has_key(config_name):
+ error_map[config_name] = 0
+
+ if error_map[config_name] >= CONFIG_MAX_ERROR_COUNT:
+ logger.info(f"Error count for config {config_name} exceeded CONFIG_MAX_ERROR_COUNT ({CONFIG_MAX_ERROR_COUNT}).")
+ error_map.remove(config_name)
+ logger.info(f"Removing config {config_name} from license queue.")
+ # Remove config from this license queue.
+ db_client._remove_config_from_license_queue(license_id=license_id, config_name=config_name)
+ return
+
+ # Wait for exponential time based on error count.
+ if error_map[config_name] > 0:
+ logger.info(f"Error count for config {config_name}: {error_map[config_name]}.")
+ time = error_map.exponential_time(config_name)
+ logger.info(f"Sleeping for {time} mins.")
+ time.sleep(time)
+
+ try:
+ with manifest.transact(
+ request["config_name"],
+ request["dataset"],
+ selection,
+ target,
+ request["username"],
+ ):
+ result = client.retrieve(request["dataset"], selection, manifest)
+ except Exception as e:
+ # We are handling this as generic case as CDS client throws generic exceptions.
+
+ # License expired.
+ if "Access token expired" in str(e):
+ logger.error(f"{license_id} expired. Emptying queue! error: {e}.")
+ db_client._empty_license_queue(license_id=license_id)
+ return
+
+ # Increment error count for a config.
+ logger.error(f"Partition fetching failed. Error {e}.")
+ error_map.increment(config_name)
+ return
+
+ # If any partition in successful reset the error count.
+ error_map[config_name] = 0
+ create_job(request, result)
+
+
+def fetch_request_from_db():
+ request = None
+ config_name = db_client._get_config_from_queue_by_license_id(license_id)
+ if config_name:
+ try:
+ logger.info(f"Fetching partition for {config_name}.")
+ request = db_client._get_partition_from_manifest(config_name)
+ if not request:
+ db_client._remove_config_from_license_queue(license_id, config_name)
+ except Exception as e:
+ logger.error(
+ f"Error in fetch_request_from_db for {config_name}. error: {e}."
+ )
+ return request
+
+
+def main():
+ logger.info("Started looking at the request.")
+ error_map = ThreadSafeDict()
+ with ThreadPoolExecutor(concurrency_limit) as executor:
+ # Disclaimer: A license will pick always pick concurrency_limit + 1
+ # parition. One extra parition will be kept in threadpool task queue.
+
+ while True:
+ # Fetch a request from the database
+ request = fetch_request_from_db()
+
+ if request is not None:
+ executor.submit(make_fetch_request, request, error_map)
+ else:
+ logger.info("No request available. Waiting...")
+ time.sleep(5)
+
+ # Each license should not pick more partitions than it's
+ # concurrency_limit. We limit the threadpool queue size to just 1
+ # to prevent the license from picking more partitions than
+ # it's concurrency_limit. When an executor is freed up, the task
+ # in queue is picked and license fetches another task.
+ while executor._work_queue.qsize() >= 1:
+ logger.info("Worker busy. Waiting...")
+ time.sleep(1)
+
+
+def boot_up(license: str) -> None:
+ global license_id, client_name, concurrency_limit
+
+ result = db_client._initialize_license_deployment(license)
+ license_id = license
+ client_name = result["client_name"]
+ concurrency_limit = result["number_of_requests"]
+
+ response = secretmanager_client.access_secret_version(
+ request={"name": result["secret_id"]}
+ )
+ payload = response.payload.data.decode("UTF-8")
+ secret_dict = json.loads(payload)
+
+ os.environ.setdefault("CLIENT_URL", secret_dict.get("api_url", ""))
+ os.environ.setdefault("CLIENT_KEY", secret_dict.get("api_key", ""))
+ os.environ.setdefault("CLIENT_EMAIL", secret_dict.get("api_email", ""))
+
+
+if __name__ == "__main__":
+ license = sys.argv[2]
+ global logger
+ logging.basicConfig(
+ level=logging.INFO, format=f"[{license}] %(levelname)s - %(message)s"
+ )
+ logger = logging.getLogger(__name__)
+
+ logger.info(f"Deployment for license: {license}.")
+ boot_up(license)
+ main()
diff --git a/weather_dl_v2/license_deployment/job_creator.py b/weather_dl_v2/license_deployment/job_creator.py
new file mode 100644
index 00000000..f0acd802
--- /dev/null
+++ b/weather_dl_v2/license_deployment/job_creator.py
@@ -0,0 +1,58 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from os import path
+import yaml
+import json
+import uuid
+from kubernetes import client, config
+from deployment_config import get_config
+
+
+def create_download_job(message):
+ """Creates a kubernetes workflow of type Job for downloading the data."""
+ parsed_message = json.loads(message)
+ (
+ config_name,
+ dataset,
+ selection,
+ user_id,
+ url,
+ target_path,
+ license_id,
+ ) = parsed_message.values()
+ selection = str(selection).replace(" ", "")
+ config.load_config()
+
+ with open(path.join(path.dirname(__file__), "downloader.yaml")) as f:
+ dep = yaml.safe_load(f)
+ uid = uuid.uuid4()
+ dep["metadata"]["name"] = f"downloader-job-id-{uid}"
+ dep["spec"]["template"]["spec"]["containers"][0]["command"] = [
+ "python",
+ "downloader.py",
+ config_name,
+ dataset,
+ selection,
+ user_id,
+ url,
+ target_path,
+ license_id,
+ ]
+ dep["spec"]["template"]["spec"]["containers"][0][
+ "image"
+ ] = get_config().downloader_k8_image
+ batch_api = client.BatchV1Api()
+ batch_api.create_namespaced_job(body=dep, namespace="default")
diff --git a/weather_dl_v2/license_deployment/manifest.py b/weather_dl_v2/license_deployment/manifest.py
new file mode 100644
index 00000000..3119de9e
--- /dev/null
+++ b/weather_dl_v2/license_deployment/manifest.py
@@ -0,0 +1,522 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Client interface for connecting to a manifest."""
+
+import abc
+import logging
+import dataclasses
+import datetime
+import enum
+import json
+import pandas as pd
+import time
+import traceback
+import typing as t
+
+from util import (
+ to_json_serializable_type,
+ fetch_geo_polygon,
+ get_file_size,
+ get_wait_interval,
+ generate_md5_hash,
+ GLOBAL_COVERAGE_AREA,
+)
+
+import firebase_admin
+from firebase_admin import credentials
+from firebase_admin import firestore
+from google.cloud.firestore_v1 import DocumentReference
+from google.cloud.firestore_v1.types import WriteResult
+from deployment_config import get_config
+from database import Database
+
+logger = logging.getLogger(__name__)
+
+"""An implementation-dependent Manifest URI."""
+Location = t.NewType("Location", str)
+
+
+class ManifestException(Exception):
+ """Errors that occur in Manifest Clients."""
+
+ pass
+
+
+class Stage(enum.Enum):
+ """A request can be either in one of the following stages at a time:
+
+ fetch : This represents request is currently in fetch stage i.e. request placed on the client's server
+ & waiting for some result before starting download (eg. MARS client).
+ download : This represents request is currently in download stage i.e. data is being downloading from client's
+ server to the worker's local file system.
+ upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local
+ file system to target location (GCS path).
+ retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client),
+ request will be in the retrieve stage i.e. fetch + download.
+ """
+
+ RETRIEVE = "retrieve"
+ FETCH = "fetch"
+ DOWNLOAD = "download"
+ UPLOAD = "upload"
+
+
+class Status(enum.Enum):
+ """Depicts the request's state status:
+
+ scheduled : A request partition is created & scheduled for processing.
+ Note: Its corresponding state can be None only.
+ processing: This represents that the request picked by license deployment.
+ in-progress : This represents the request state is currently in-progress (i.e. running).
+ The next status would be "success" or "failure".
+ success : This represents the request state execution completed successfully without any error.
+ failure : This represents the request state execution failed.
+ """
+
+ PROCESSING = "processing"
+ SCHEDULED = "scheduled"
+ IN_PROGRESS = "in-progress"
+ SUCCESS = "success"
+ FAILURE = "failure"
+
+
+@dataclasses.dataclass
+class DownloadStatus:
+ """Data recorded in `Manifest`s reflecting the status of a download."""
+
+ """The name of the config file associated with the request."""
+ config_name: str = ""
+
+ """Represents the dataset field of the configuration."""
+ dataset: t.Optional[str] = ""
+
+ """Copy of selection section of the configuration."""
+ selection: t.Dict = dataclasses.field(default_factory=dict)
+
+ """Location of the downloaded data."""
+ location: str = ""
+
+ """Represents area covered by the shard."""
+ area: str = ""
+
+ """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None."""
+ stage: t.Optional[Stage] = None
+
+ """Download status: 'scheduled', 'in-progress', 'success', or 'failure'."""
+ status: t.Optional[Status] = None
+
+ """Cause of error, if any."""
+ error: t.Optional[str] = ""
+
+ """Identifier for the user running the download."""
+ username: str = ""
+
+ """Shard size in GB."""
+ size: t.Optional[float] = 0
+
+ """A UTC datetime when download was scheduled."""
+ scheduled_time: t.Optional[str] = ""
+
+ """A UTC datetime when the retrieve stage starts."""
+ retrieve_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the retrieve state ends."""
+ retrieve_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the fetch state starts."""
+ fetch_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the fetch state ends."""
+ fetch_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the download state starts."""
+ download_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the download state ends."""
+ download_end_time: t.Optional[str] = ""
+
+ """A UTC datetime when the upload state starts."""
+ upload_start_time: t.Optional[str] = ""
+
+ """A UTC datetime when the upload state ends."""
+ upload_end_time: t.Optional[str] = ""
+
+ @classmethod
+ def from_dict(cls, download_status: t.Dict) -> "DownloadStatus":
+ """Instantiate DownloadStatus dataclass from dict."""
+ download_status_instance = cls()
+ for key, value in download_status.items():
+ if key == "status":
+ setattr(download_status_instance, key, Status(value))
+ elif key == "stage" and value is not None:
+ setattr(download_status_instance, key, Stage(value))
+ else:
+ setattr(download_status_instance, key, value)
+ return download_status_instance
+
+ @classmethod
+ def to_dict(cls, instance) -> t.Dict:
+ """Return the fields of a dataclass instance as a manifest ingestible
+ dictionary mapping of field names to field values."""
+ download_status_dict = {}
+ for field in dataclasses.fields(instance):
+ key = field.name
+ value = getattr(instance, field.name)
+ if isinstance(value, Status) or isinstance(value, Stage):
+ download_status_dict[key] = value.value
+ elif isinstance(value, pd.Timestamp):
+ download_status_dict[key] = value.isoformat()
+ elif key == "selection" and value is not None:
+ download_status_dict[key] = json.dumps(value)
+ else:
+ download_status_dict[key] = value
+ return download_status_dict
+
+
+@dataclasses.dataclass
+class Manifest(abc.ABC):
+ """Abstract manifest of download statuses.
+
+ Update download statuses to some storage medium.
+
+ This class lets one indicate that a download is `scheduled` or in a transaction process.
+ In the event of a transaction, a download will be updated with an `in-progress`, `success`
+ or `failure` status (with accompanying metadata).
+
+ Example:
+ ```
+ my_manifest = parse_manifest_location(Location('fs://some-firestore-collection'))
+
+ # Schedule data for download
+ my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username')
+
+ # ...
+
+ # Initiate a transaction – it will record that the download is `in-progess`
+ with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx:
+ # download logic here
+ pass
+
+ # ...
+
+ # on error, will record the download as a `failure` before propagating the error. By default, it will
+ # record download as a `success`.
+ ```
+
+ Attributes:
+ status: The current `DownloadStatus` of the Manifest.
+ """
+
+ # To reduce the impact of _read() and _update() calls
+ # on the start time of the stage.
+ license_id: str = ""
+ prev_stage_precise_start_time: t.Optional[str] = None
+ status: t.Optional[DownloadStatus] = None
+
+ # This is overridden in subclass.
+ def __post_init__(self):
+ """Initialize the manifest."""
+ pass
+
+ def schedule(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Indicate that a job has been scheduled for download.
+
+ 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'.
+ """
+ scheduled_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+ self.status = DownloadStatus(
+ config_name=config_name,
+ dataset=dataset if dataset else None,
+ selection=selection,
+ location=location,
+ area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)),
+ username=user,
+ stage=None,
+ status=Status.SCHEDULED,
+ error=None,
+ size=None,
+ scheduled_time=scheduled_time,
+ retrieve_start_time=None,
+ retrieve_end_time=None,
+ fetch_start_time=None,
+ fetch_end_time=None,
+ download_start_time=None,
+ download_end_time=None,
+ upload_start_time=None,
+ upload_end_time=None,
+ )
+ self._update(self.status)
+
+ def skip(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Updates the manifest to mark the shards that were skipped in the current job
+ as 'upload' stage and 'success' status, indicating that they have already been downloaded.
+ """
+ old_status = self._read(location)
+ # The manifest needs to be updated for a skipped shard if its entry is not present, or
+ # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'.
+ if (
+ old_status.location != location
+ or old_status.stage != Stage.UPLOAD
+ or old_status.status != Status.SUCCESS
+ ):
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ size = get_file_size(location)
+
+ status = DownloadStatus(
+ config_name=config_name,
+ dataset=dataset if dataset else None,
+ selection=selection,
+ location=location,
+ area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)),
+ username=user,
+ stage=Stage.UPLOAD,
+ status=Status.SUCCESS,
+ error=None,
+ size=size,
+ scheduled_time=None,
+ retrieve_start_time=None,
+ retrieve_end_time=None,
+ fetch_start_time=None,
+ fetch_end_time=None,
+ download_start_time=None,
+ download_end_time=None,
+ upload_start_time=current_utc_time,
+ upload_end_time=current_utc_time,
+ )
+ self._update(status)
+ logger.info(
+ f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}."
+ )
+
+ def _set_for_transaction(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> None:
+ """Reset Manifest state in preparation for a new transaction."""
+ self.status = dataclasses.replace(self._read(location))
+ self.status.config_name = config_name
+ self.status.dataset = dataset if dataset else None
+ self.status.selection = selection
+ self.status.location = location
+ self.status.username = user
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, exc_type, exc_inst, exc_tb) -> None:
+ """Record end status of a transaction as either 'success' or 'failure'."""
+ if exc_type is None:
+ status = Status.SUCCESS
+ error = None
+ else:
+ status = Status.FAILURE
+ # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception
+ error = f"license_id: {self.license_id} "
+ error += "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb))
+
+ new_status = dataclasses.replace(self.status)
+ new_status.error = error
+ new_status.status = status
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ # This is necessary for setting the precise start time of the previous stage
+ # and end time of the final stage, as well as handling the case of Status.FAILURE.
+ if new_status.stage == Stage.FETCH:
+ new_status.fetch_start_time = self.prev_stage_precise_start_time
+ new_status.fetch_end_time = current_utc_time
+ elif new_status.stage == Stage.RETRIEVE:
+ new_status.retrieve_start_time = self.prev_stage_precise_start_time
+ new_status.retrieve_end_time = current_utc_time
+ elif new_status.stage == Stage.DOWNLOAD:
+ new_status.download_start_time = self.prev_stage_precise_start_time
+ new_status.download_end_time = current_utc_time
+ else:
+ new_status.upload_start_time = self.prev_stage_precise_start_time
+ new_status.upload_end_time = current_utc_time
+
+ new_status.size = get_file_size(new_status.location)
+
+ self.status = new_status
+
+ self._update(self.status)
+
+ def transact(
+ self,
+ config_name: str,
+ dataset: str,
+ selection: t.Dict,
+ location: str,
+ user: str,
+ ) -> "Manifest":
+ """Create a download transaction."""
+ self._set_for_transaction(config_name, dataset, selection, location, user)
+ return self
+
+ def set_stage(self, stage: Stage) -> None:
+ """Sets the current stage in manifest."""
+ prev_stage = self.status.stage
+ new_status = dataclasses.replace(self.status)
+ new_status.stage = stage
+ new_status.status = Status.IN_PROGRESS
+ current_utc_time = (
+ datetime.datetime.utcnow()
+ .replace(tzinfo=datetime.timezone.utc)
+ .isoformat(timespec="seconds")
+ )
+
+ if stage == Stage.FETCH:
+ new_status.fetch_start_time = current_utc_time
+ elif stage == Stage.RETRIEVE:
+ new_status.retrieve_start_time = current_utc_time
+ elif stage == Stage.DOWNLOAD:
+ new_status.fetch_start_time = self.prev_stage_precise_start_time
+ new_status.fetch_end_time = current_utc_time
+ new_status.download_start_time = current_utc_time
+ else:
+ if prev_stage == Stage.DOWNLOAD:
+ new_status.download_start_time = self.prev_stage_precise_start_time
+ new_status.download_end_time = current_utc_time
+ else:
+ new_status.retrieve_start_time = self.prev_stage_precise_start_time
+ new_status.retrieve_end_time = current_utc_time
+ new_status.upload_start_time = current_utc_time
+
+ self.status = new_status
+ self._update(self.status)
+
+ @abc.abstractmethod
+ def _read(self, location: str) -> DownloadStatus:
+ pass
+
+ @abc.abstractmethod
+ def _update(self, download_status: DownloadStatus) -> None:
+ pass
+
+
+class FirestoreManifest(Manifest, Database):
+ """A Firestore Manifest.
+ This Manifest implementation stores DownloadStatuses in a Firebase document store.
+ The document hierarchy for the manifest is as follows:
+ [manifest ]
+ ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... }
+ └── etc...
+ Where `[]` indicates a collection and ` {...}` indicates a document.
+ """
+
+ def _get_db(self) -> firestore.firestore.Client:
+ """Acquire a firestore client, initializing the firebase app if necessary.
+ Will attempt to get the db client five times. If it's still unsuccessful, a
+ `ManifestException` will be raised.
+ """
+ db = None
+ attempts = 0
+
+ while db is None:
+ try:
+ db = firestore.client()
+ except ValueError as e:
+ # The above call will fail with a value error when the firebase app is not initialized.
+ # Initialize the app here, and try again.
+ # Use the application default credentials.
+ cred = credentials.ApplicationDefault()
+
+ firebase_admin.initialize_app(cred)
+ logger.info("Initialized Firebase App.")
+
+ if attempts > 4:
+ raise ManifestException(
+ "Exceeded number of retries to get firestore client."
+ ) from e
+
+ time.sleep(get_wait_interval(attempts))
+
+ attempts += 1
+
+ return db
+
+ def _read(self, location: str) -> DownloadStatus:
+ """Reads the JSON data from a manifest."""
+
+ doc_id = generate_md5_hash(location)
+
+ # Update document with download status
+ download_doc_ref = self.root_document_for_store(doc_id)
+
+ result = download_doc_ref.get()
+ row = {}
+ if result.exists:
+ records = result.to_dict()
+ row = {n: to_json_serializable_type(v) for n, v in records.items()}
+ return DownloadStatus.from_dict(row)
+
+ def _update(self, download_status: DownloadStatus) -> None:
+ """Update or create a download status record."""
+ logger.info("Updating Firestore Manifest.")
+
+ status = DownloadStatus.to_dict(download_status)
+ doc_id = generate_md5_hash(status["location"])
+
+ # Update document with download status.
+ download_doc_ref = self.root_document_for_store(doc_id)
+
+ result: WriteResult = download_doc_ref.set(status)
+
+ logger.info(
+ "Firestore manifest updated. " +
+ f"update_time={result.update_time}, " +
+ f"status={status['status']} " +
+ f"stage={status['stage']} " +
+ f"filename={download_status.location}."
+ )
+
+ def root_document_for_store(self, store_scheme: str) -> DocumentReference:
+ """Get the root manifest document given the user's config and current document's storage location."""
+ return (
+ self._get_db()
+ .collection(get_config().manifest_collection)
+ .document(store_scheme)
+ )
diff --git a/weather_dl_v2/license_deployment/util.py b/weather_dl_v2/license_deployment/util.py
new file mode 100644
index 00000000..d24a1405
--- /dev/null
+++ b/weather_dl_v2/license_deployment/util.py
@@ -0,0 +1,302 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import datetime
+import logging
+import geojson
+import hashlib
+import itertools
+import os
+import socket
+import subprocess
+import sys
+import typing as t
+
+import numpy as np
+import pandas as pd
+from apache_beam.io.gcp import gcsio
+from apache_beam.utils import retry
+from xarray.core.utils import ensure_us_time_resolution
+from urllib.parse import urlparse
+from google.api_core.exceptions import BadRequest
+from threading import Lock
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+LATITUDE_RANGE = (-90, 90)
+LONGITUDE_RANGE = (-180, 180)
+GLOBAL_COVERAGE_AREA = [90, -180, -90, 180]
+
+
+def exceptionit(func):
+ def inner_function(*args, **kwargs):
+ try:
+ func(*args, **kwargs)
+ except Exception as e:
+ logger.error(f"exception in {func.__name__} {e.__class__.__name__} {e}.")
+
+ return inner_function
+
+
+def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter(
+ exception,
+) -> bool:
+ if isinstance(exception, socket.timeout):
+ return True
+ if isinstance(exception, TimeoutError):
+ return True
+ # To handle the concurrency issue in BigQuery.
+ if isinstance(exception, BadRequest):
+ return True
+ return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception)
+
+
+class _FakeClock:
+
+ def sleep(self, value):
+ pass
+
+
+def retry_with_exponential_backoff(fun):
+ """A retry decorator that doesn't apply during test time."""
+ clock = retry.Clock()
+
+ # Use a fake clock only during test time...
+ if "unittest" in sys.modules.keys():
+ clock = _FakeClock()
+
+ return retry.with_exponential_backoff(
+ retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter,
+ clock=clock,
+ )(fun)
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]:
+ """Yield evenly-sized chunks from an iterable."""
+ input_ = iter(iterable)
+ try:
+ while True:
+ it = itertools.islice(input_, n)
+ # peek to check if 'it' has next item.
+ first = next(it)
+ yield itertools.chain([first], it)
+ except StopIteration:
+ pass
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def copy(src: str, dst: str) -> None:
+ """Copy data via `gsutil cp`."""
+ try:
+ subprocess.run(["gsutil", "cp", src, dst], check=True, capture_output=True)
+ except subprocess.CalledProcessError as e:
+ logger.info(
+ f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}.'
+ )
+ raise
+
+
+# TODO(#245): Group with common utilities (duplicated)
+def to_json_serializable_type(value: t.Any) -> t.Any:
+ """Returns the value with a type serializable to JSON"""
+ # Note: The order of processing is significant.
+ logger.info("Serializing to JSON.")
+
+ if pd.isna(value) or value is None:
+ return None
+ elif np.issubdtype(type(value), np.floating):
+ return float(value)
+ elif isinstance(value, np.ndarray):
+ # Will return a scaler if array is of size 1, else will return a list.
+ return value.tolist()
+ elif (
+ isinstance(value, datetime.datetime)
+ or isinstance(value, str)
+ or isinstance(value, np.datetime64)
+ ):
+ # Assume strings are ISO format timestamps...
+ try:
+ value = datetime.datetime.fromisoformat(value)
+ except ValueError:
+ # ... if they are not, assume serialization is already correct.
+ return value
+ except TypeError:
+ # ... maybe value is a numpy datetime ...
+ try:
+ value = ensure_us_time_resolution(value).astype(datetime.datetime)
+ except AttributeError:
+ # ... value is a datetime object, continue.
+ pass
+
+ # We use a string timestamp representation.
+ if value.tzname():
+ return value.isoformat()
+
+ # We assume here that naive timestamps are in UTC timezone.
+ return value.replace(tzinfo=datetime.timezone.utc).isoformat()
+ elif isinstance(value, np.timedelta64):
+ # Return time delta in seconds.
+ return float(value / np.timedelta64(1, "s"))
+ # This check must happen after processing np.timedelta64 and np.datetime64.
+ elif np.issubdtype(type(value), np.integer):
+ return int(value)
+
+ return value
+
+
+def fetch_geo_polygon(area: t.Union[list, str]) -> str:
+ """Calculates a geography polygon from an input area."""
+ # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973
+ if isinstance(area, str):
+ # European area
+ if area == "E":
+ area = [73.5, -27, 33, 45]
+ # Global area
+ elif area == "G":
+ area = GLOBAL_COVERAGE_AREA
+ else:
+ raise RuntimeError(f"Not a valid value for area in config: {area}.")
+
+ n, w, s, e = [float(x) for x in area]
+ if s < LATITUDE_RANGE[0]:
+ raise ValueError(f"Invalid latitude value for south: '{s}'")
+ if n > LATITUDE_RANGE[1]:
+ raise ValueError(f"Invalid latitude value for north: '{n}'")
+ if w < LONGITUDE_RANGE[0]:
+ raise ValueError(f"Invalid longitude value for west: '{w}'")
+ if e > LONGITUDE_RANGE[1]:
+ raise ValueError(f"Invalid longitude value for east: '{e}'")
+
+ # Define the coordinates of the bounding box.
+ coords = [[w, n], [w, s], [e, s], [e, n], [w, n]]
+
+ # Create the GeoJSON polygon object.
+ polygon = geojson.dumps(geojson.Polygon([coords]))
+ return polygon
+
+
+def get_file_size(path: str) -> float:
+ parsed_gcs_path = urlparse(path)
+ if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "":
+ return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0
+ else:
+ return (
+ gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0
+ )
+
+
+def get_wait_interval(num_retries: int = 0) -> float:
+ """Returns next wait interval in seconds, using an exponential backoff algorithm."""
+ if 0 == num_retries:
+ return 0
+ return 2**num_retries
+
+
+def generate_md5_hash(input: str) -> str:
+ """Generates md5 hash for the input string."""
+ return hashlib.md5(input.encode("utf-8")).hexdigest()
+
+
+def download_with_aria2(url: str, path: str) -> None:
+ """Downloads a file from the given URL using the `aria2c` command-line utility,
+ with options set to improve download speed and reliability."""
+ dir_path, file_name = os.path.split(path)
+ try:
+ subprocess.run(
+ [
+ "aria2c",
+ "-x",
+ "16",
+ "-s",
+ "16",
+ url,
+ "-d",
+ dir_path,
+ "-o",
+ file_name,
+ "--allow-overwrite",
+ ],
+ check=True,
+ capture_output=True,
+ )
+ except subprocess.CalledProcessError as e:
+ logger.info(
+ f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.'
+ )
+ raise
+
+class ThreadSafeDict:
+ """A thread safe dict with crud operations."""
+
+
+ def __init__(self) -> None:
+ self._dict = {}
+ self._lock = Lock()
+ self.initial_delay = 1
+ self.factor = 0.5
+
+
+ def __getitem__(self, key):
+ val = None
+ with self._lock:
+ val = self._dict[key]
+ return val
+
+
+ def __setitem__(self, key, value):
+ with self._lock:
+ self._dict[key] = value
+
+
+ def remove(self, key):
+ with self._lock:
+ self._dict.__delitem__(key)
+
+
+ def has_key(self, key):
+ present = False
+ with self._lock:
+ present = key in self._dict
+ return present
+
+
+ def increment(self, key, delta=1):
+ with self._lock:
+ if key in self._dict:
+ self._dict[key] += delta
+
+
+ def decrement(self, key, delta=1):
+ with self._lock:
+ if key in self._dict:
+ self._dict[key] -= delta
+
+
+ def find_exponential_delay(self, n: int) -> int:
+ delay = self.initial_delay
+ for _ in range(n):
+ delay += delay*self.factor
+ return delay
+
+
+ def exponential_time(self, key):
+ """Returns exponential time based on dict value. Time in seconds."""
+ delay = 0
+ with self._lock:
+ if key in self._dict:
+ delay = self.find_exponential_delay(self._dict[key])
+ return delay * 60
diff --git a/weather_mv/README.md b/weather_mv/README.md
index ce39fbfb..a1045eab 100644
--- a/weather_mv/README.md
+++ b/weather_mv/README.md
@@ -61,7 +61,8 @@ usage: weather-mv bigquery [-h] -i URIS [--topic TOPIC] [--window_size WINDOW_SI
-o OUTPUT_TABLE [-v variables [variables ...]] [-a area [area ...]]
[--import_time IMPORT_TIME] [--infer_schema]
[--xarray_open_dataset_kwargs XARRAY_OPEN_DATASET_KWARGS]
- [--tif_metadata_for_datetime TIF_METADATA_FOR_DATETIME] [-s]
+ [--tif_metadata_for_start_time TIF_METADATA_FOR_START_TIME]
+ [--tif_metadata_for_end_time TIF_METADATA_FOR_END_TIME] [-s]
[--coordinate_chunk_size COORDINATE_CHUNK_SIZE] ['--skip_creating_polygon']
```
@@ -80,7 +81,8 @@ _Command options_:
* `--xarray_open_dataset_kwargs`: Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string.
* `--coordinate_chunk_size`: The size of the chunk of coordinates used for extracting vector data into BigQuery. Used to
tune parallel uploads.
-* `--tif_metadata_for_datetime` : Metadata that contains tif file's timestamp. Applicable only for tif files.
+* `--tif_metadata_for_start_time` : Metadata that contains tif file's start/initialization time. Applicable only for tif files.
+* `--tif_metadata_for_end_time` : Metadata that contains tif file's end/forecast time. Applicable only for tif files (optional).
* `-s, --skip-region-validation` : Skip validation of regions for data migration. Default: off.
* `--disable_grib_schema_normalization` : To disable grib's schema normalization. Default: off.
* `--skip_creating_polygon` : Not ingest grid points as polygons in BigQuery. Default: Ingest grid points as Polygon in
@@ -139,7 +141,8 @@ weather-mv bq --uris "gs://your-bucket/*.tif" \
--output_table $PROJECT.$DATASET_ID.$TABLE_ID \
--temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery
--direct_num_workers 2 \
- --tif_metadata_for_datetime start_time
+ --tif_metadata_for_start_time start_time \
+ --tif_metadata_for_end_time end_time
```
Upload only a subset of variables:
@@ -162,6 +165,39 @@ weather-mv bq --uris "gs://your-bucket/*.nc" \
--direct_num_workers 2
```
+Upload a zarr file:
+
+```bash
+weather-mv bq --uris "gs://your-bucket/*.zarr" \
+ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \
+ --temp_location "gs://$BUCKET/tmp" \
+ --use-local-code \
+ --zarr \
+ --direct_num_workers 2
+```
+
+Upload a specific date range's data from the zarr file:
+
+```bash
+weather-mv bq --uris "gs://your-bucket/*.zarr" \
+ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \
+ --temp_location "gs://$BUCKET/tmp" \
+ --use-local-code \
+ --zarr \
+ --zarr_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \
+ --direct_num_workers 2
+```
+
+Upload a specific date range's data from the file:
+
+```bash
+weather-mv bq --uris "gs://your-bucket/*.nc" \
+ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \
+ --temp_location "gs://$BUCKET/tmp" \
+ --use-local-code \
+ --xarray_open_dataset_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \
+```
+
Control how weather data is opened with XArray:
```bash
diff --git a/weather_mv/loader_pipeline/bq.py b/weather_mv/loader_pipeline/bq.py
index c23a7f44..3b2e15ba 100644
--- a/weather_mv/loader_pipeline/bq.py
+++ b/weather_mv/loader_pipeline/bq.py
@@ -24,8 +24,10 @@
import geojson
import numpy as np
import xarray as xr
+import xarray_beam as xbeam
from apache_beam.io import WriteToBigQuery, BigQueryDisposition
from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.transforms import window
from google.cloud import bigquery
from xarray.core.utils import ensure_us_time_resolution
@@ -75,8 +77,10 @@ class ToBigQuery(ToDataSink):
infer_schema: If true, this sink will attempt to read in an example data file
read all its variables, and generate a BigQuery schema.
xarray_open_dataset_kwargs: A dictionary of kwargs to pass to xr.open_dataset().
- tif_metadata_for_datetime: If the input is a .tif file, parse the tif metadata at
- this location for a timestamp.
+ tif_metadata_for_start_time: If the input is a .tif file, parse the tif metadata at
+ this location for a start time / initialization time.
+ tif_metadata_for_end_time: If the input is a .tif file, parse the tif metadata at
+ this location for a end/forecast time.
skip_region_validation: Turn off validation that checks if all Cloud resources
are in the same region.
disable_grib_schema_normalization: Turn off grib's schema normalization; Default: normalization enabled.
@@ -92,7 +96,8 @@ class ToBigQuery(ToDataSink):
import_time: t.Optional[datetime.datetime]
infer_schema: bool
xarray_open_dataset_kwargs: t.Dict
- tif_metadata_for_datetime: t.Optional[str]
+ tif_metadata_for_start_time: t.Optional[str]
+ tif_metadata_for_end_time: t.Optional[str]
skip_region_validation: bool
disable_grib_schema_normalization: bool
coordinate_chunk_size: int = 10_000
@@ -123,8 +128,11 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser):
'off')
subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}',
help='Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string.')
- subparser.add_argument('--tif_metadata_for_datetime', type=str, default=None,
- help='Metadata that contains tif file\'s timestamp. '
+ subparser.add_argument('--tif_metadata_for_start_time', type=str, default=None,
+ help='Metadata that contains tif file\'s start/initialization time. '
+ 'Applicable only for tif files.')
+ subparser.add_argument('--tif_metadata_for_end_time', type=str, default=None,
+ help='Metadata that contains tif file\'s end/forecast time. '
'Applicable only for tif files.')
subparser.add_argument('-s', '--skip-region-validation', action='store_true', default=False,
help='Skip validation of regions for data migration. Default: off')
@@ -148,10 +156,14 @@ def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.Lis
# Check that all arguments are supplied for COG input.
_, uri_extension = os.path.splitext(known_args.uris)
- if uri_extension == '.tif' and not known_args.tif_metadata_for_datetime:
- raise RuntimeError("'--tif_metadata_for_datetime' is required for tif files.")
- elif uri_extension != '.tif' and known_args.tif_metadata_for_datetime:
- raise RuntimeError("'--tif_metadata_for_datetime' can be specified only for tif files.")
+ if (uri_extension in ['.tif', '.tiff'] and not known_args.tif_metadata_for_start_time):
+ raise RuntimeError("'--tif_metadata_for_start_time' is required for tif files.")
+ elif uri_extension not in ['.tif', '.tiff'] and (
+ known_args.tif_metadata_for_start_time
+ or known_args.tif_metadata_for_end_time
+ ):
+ raise RuntimeError("'--tif_metadata_for_start_time' and "
+ "'--tif_metadata_for_end_time' can be specified only for tif files.")
# Check that Cloud resource regions are consistent.
if not (known_args.dry_run or known_args.skip_region_validation):
@@ -166,8 +178,8 @@ def __post_init__(self):
if self.zarr:
self.xarray_open_dataset_kwargs = self.zarr_kwargs
with open_dataset(self.first_uri, self.xarray_open_dataset_kwargs,
- self.disable_grib_schema_normalization, self.tif_metadata_for_datetime,
- is_zarr=self.zarr) as open_ds:
+ self.disable_grib_schema_normalization, self.tif_metadata_for_start_time,
+ self.tif_metadata_for_end_time, is_zarr=self.zarr) as open_ds:
if not self.skip_creating_polygon:
logger.warning("Assumes that equal distance between consecutive points of latitude "
@@ -219,7 +231,7 @@ def prepare_coordinates(self, uri: str) -> t.Iterator[t.Tuple[str, t.List[t.Dict
logger.info(f'Preparing coordinates for: {uri!r}.')
with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization,
- self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds:
+ self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds:
data_ds: xr.Dataset = _only_target_vars(ds, self.variables)
if self.area:
n, w, s, e = self.area
@@ -238,75 +250,100 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di
self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization,
- self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds:
+ self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds:
data_ds: xr.Dataset = _only_target_vars(ds, self.variables)
+ yield from self.to_rows(coordinates, data_ds, uri)
- first_ts_raw = data_ds.time[0].values if isinstance(data_ds.time.values,
- np.ndarray) else data_ds.time.values
- first_time_step = to_json_serializable_type(first_ts_raw)
-
- for it in coordinates:
- # Use those index values to select a Dataset containing one row of data.
- row_ds = data_ds.loc[it]
-
- # Create a Name-Value map for data columns. Result looks like:
- # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None}
- row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
- for n, v in row_ds.data_vars.items()}
-
- # Serialize coordinates.
- it = {k: to_json_serializable_type(v) for k, v in it.items()}
-
- # Add indexed coordinates.
- row.update(it)
- # Add un-indexed coordinates.
- for c in row_ds.coords:
- if c not in it and (not self.variables or c in self.variables):
- row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values))
-
- # Add import metadata.
- row[DATA_IMPORT_TIME_COLUMN] = self.import_time
- row[DATA_URI_COLUMN] = uri
- row[DATA_FIRST_STEP] = first_time_step
-
- longitude = ((row['longitude'] + 180) % 360) - 180
- row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude)
- row[GEO_POLYGON_COLUMN] = (
- fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution)
- if not self.skip_creating_polygon
- else None
- )
- # 'row' ends up looking like:
- # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812,
- # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...}
- beam.metrics.Metrics.counter('Success', 'ExtractRows').inc()
- yield row
+ def to_rows(self, coordinates: t.Iterable[t.Dict], ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]:
+ first_ts_raw = (
+ ds.time[0].values if isinstance(ds.time.values, np.ndarray)
+ else ds.time.values
+ )
+ first_time_step = to_json_serializable_type(first_ts_raw)
+ for it in coordinates:
+ # Use those index values to select a Dataset containing one row of data.
+ row_ds = ds.loc[it]
+
+ # Create a Name-Value map for data columns. Result looks like:
+ # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None}
+ row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
+ for n, v in row_ds.data_vars.items()}
+
+ # Serialize coordinates.
+ it = {k: to_json_serializable_type(v) for k, v in it.items()}
+
+ # Add indexed coordinates.
+ row.update(it)
+ # Add un-indexed coordinates.
+ for c in row_ds.coords:
+ if c not in it and (not self.variables or c in self.variables):
+ row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values))
+
+ # Add import metadata.
+ row[DATA_IMPORT_TIME_COLUMN] = self.import_time
+ row[DATA_URI_COLUMN] = uri
+ row[DATA_FIRST_STEP] = first_time_step
+
+ longitude = ((row['longitude'] + 180) % 360) - 180
+ row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude)
+ row[GEO_POLYGON_COLUMN] = (
+ fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution)
+ if not self.skip_creating_polygon
+ else None
+ )
+ # 'row' ends up looking like:
+ # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812,
+ # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...}
+ beam.metrics.Metrics.counter('Success', 'ExtractRows').inc()
+ yield row
+
+ def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]:
+ uri = ds.attrs.get(DATA_URI_COLUMN, '')
+ # Re-calculate import time for streaming extractions.
+ if not self.import_time or self.zarr:
+ self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
+ yield from self.to_rows(get_coordinates(ds, uri), ds, uri)
def expand(self, paths):
"""Extract rows of variables from data paths into a BigQuery table."""
- extracted_rows = (
+ if not self.zarr:
+ extracted_rows = (
paths
| 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates)
| beam.Reshuffle()
| 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows)
- )
-
- if not self.dry_run:
- (
- extracted_rows
- | 'WriteToBigQuery' >> WriteToBigQuery(
- project=self.table.project,
- dataset=self.table.dataset_id,
- table=self.table.table_id,
- write_disposition=BigQueryDisposition.WRITE_APPEND,
- create_disposition=BigQueryDisposition.CREATE_NEVER)
)
else:
- (
- extracted_rows
- | 'Log Extracted Rows' >> beam.Map(logger.debug)
+ xarray_open_dataset_kwargs = self.xarray_open_dataset_kwargs.copy()
+ xarray_open_dataset_kwargs.pop('chunks')
+ start_date = xarray_open_dataset_kwargs.pop('start_date', None)
+ end_date = xarray_open_dataset_kwargs.pop('end_date', None)
+ ds, chunks = xbeam.open_zarr(self.first_uri, **xarray_open_dataset_kwargs)
+
+ if start_date is not None and end_date is not None:
+ ds = ds.sel(time=slice(start_date, end_date))
+
+ ds.attrs[DATA_URI_COLUMN] = self.first_uri
+ extracted_rows = (
+ paths
+ | 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks)
+ | 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows)
+ | 'Window' >> beam.WindowInto(window.FixedWindows(60))
+ | 'AddTimestamp' >> beam.Map(timestamp_row)
)
+ if self.dry_run:
+ return extracted_rows | 'Log Rows' >> beam.Map(logger.info)
+ return (
+ extracted_rows
+ | 'WriteToBigQuery' >> WriteToBigQuery(
+ project=self.table.project,
+ dataset=self.table.dataset_id,
+ table=self.table.table_id,
+ write_disposition=BigQueryDisposition.WRITE_APPEND,
+ create_disposition=BigQueryDisposition.CREATE_NEVER)
+ )
+
def map_dtype_to_sql_type(var_type: np.dtype) -> str:
"""Maps a np.dtype to a suitable BigQuery column type."""
@@ -342,11 +379,17 @@ def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.Schem
fields.append(bigquery.SchemaField(DATA_URI_COLUMN, 'STRING', mode='NULLABLE'))
fields.append(bigquery.SchemaField(DATA_FIRST_STEP, 'TIMESTAMP', mode='NULLABLE'))
fields.append(bigquery.SchemaField(GEO_POINT_COLUMN, 'GEOGRAPHY', mode='NULLABLE'))
- fields.append(bigquery.SchemaField(GEO_POLYGON_COLUMN, 'STRING', mode='NULLABLE'))
+ fields.append(bigquery.SchemaField(GEO_POLYGON_COLUMN, 'GEOGRAPHY', mode='NULLABLE'))
return fields
+def timestamp_row(it: t.Dict) -> window.TimestampedValue:
+ """Associate an extracted row with the import_time timestamp."""
+ timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp()
+ return window.TimestampedValue(it, timestamp)
+
+
def fetch_geo_point(lat: float, long: float) -> str:
"""Calculates a geography point from an input latitude and longitude."""
if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]:
@@ -371,13 +414,13 @@ def fetch_geo_polygon(latitude: float, longitude: float, lat_grid_resolution: fl
The `get_lat_lon_range` function gives the `.` point and `bound_point` gives the `*` point.
"""
lat_lon_bound = bound_point(latitude, longitude, lat_grid_resolution, lon_grid_resolution)
- polygon = geojson.dumps(geojson.Polygon([
+ polygon = geojson.dumps(geojson.Polygon([[
(lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left
(lat_lon_bound[1][0], lat_lon_bound[1][1]), # upper_left
(lat_lon_bound[2][0], lat_lon_bound[2][1]), # upper_right
(lat_lon_bound[3][0], lat_lon_bound[3][1]), # lower_right
(lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left
- ]))
+ ]]))
return polygon
diff --git a/weather_mv/loader_pipeline/bq_test.py b/weather_mv/loader_pipeline/bq_test.py
index ed96cd9d..fae7ab31 100644
--- a/weather_mv/loader_pipeline/bq_test.py
+++ b/weather_mv/loader_pipeline/bq_test.py
@@ -15,6 +15,7 @@
import json
import logging
import os
+import tempfile
import typing as t
import unittest
@@ -23,6 +24,8 @@
import pandas as pd
import simplejson
import xarray as xr
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that, is_not_empty
from google.cloud.bigquery import SchemaField
from .bq import (
@@ -75,7 +78,7 @@ def test_schema_generation(self):
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -91,7 +94,7 @@ def test_schema_generation__with_schema_normalization(self):
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -107,7 +110,7 @@ def test_schema_generation__with_target_columns(self):
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -123,7 +126,7 @@ def test_schema_generation__with_target_columns__with_schema_normalization(self)
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -140,7 +143,7 @@ def test_schema_generation__no_targets_specified(self):
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -157,7 +160,7 @@ def test_schema_generation__no_targets_specified__with_schema_normalization(self
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -191,7 +194,7 @@ def test_schema_generation__non_index_coords(self):
SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None),
SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None),
SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None),
- SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None)
+ SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None)
]
self.assertListEqual(schema, expected_schema)
@@ -201,17 +204,18 @@ class ExtractRowsTestBase(TestDataBase):
def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=None,
import_time=DEFAULT_IMPORT_TIME, disable_grib_schema_normalization=False,
- tif_metadata_for_datetime=None, zarr: bool = False, zarr_kwargs=None,
+ tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, zarr: bool = False, zarr_kwargs=None,
skip_creating_polygon: bool = False) -> t.Iterator[t.Dict]:
if zarr_kwargs is None:
zarr_kwargs = {}
- op = ToBigQuery.from_kwargs(first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs,
- output_table='foo.bar.baz', variables=variables, area=area,
- xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time,
- infer_schema=False, tif_metadata_for_datetime=tif_metadata_for_datetime,
- skip_region_validation=True,
- disable_grib_schema_normalization=disable_grib_schema_normalization,
- coordinate_chunk_size=1000, skip_creating_polygon=skip_creating_polygon)
+ op = ToBigQuery.from_kwargs(
+ first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs,
+ output_table='foo.bar.baz', variables=variables, area=area,
+ xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False,
+ tif_metadata_for_start_time=tif_metadata_for_start_time,
+ tif_metadata_for_end_time=tif_metadata_for_end_time, skip_region_validation=True,
+ disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000,
+ skip_creating_polygon=skip_creating_polygon)
coords = op.prepare_coordinates(data_path)
for uri, chunk in coords:
yield from op.extract_rows(uri, chunk)
@@ -398,18 +402,18 @@ def test_extract_rows__with_valid_lat_long_with_polygon(self):
valid_lat_long = [[-90, 0], [-90, -180], [-45, -180], [-45, 180], [0, 0], [90, 180], [45, -180], [-90, 180],
[90, 1], [0, 180], [1, -180], [90, -180]]
actual_val = [
- '{"type": "Polygon", "coordinates": [[-1, 89], [-1, -89], [1, -89], [1, 89], [-1, 89]]}',
- '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}',
- '{"type": "Polygon", "coordinates": [[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]}',
- '{"type": "Polygon", "coordinates": [[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]}',
- '{"type": "Polygon", "coordinates": [[-1, -1], [-1, 1], [1, 1], [1, -1], [-1, -1]]}',
- '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}',
- '{"type": "Polygon", "coordinates": [[179, 44], [179, 46], [-179, 46], [-179, 44], [179, 44]]}',
- '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}',
- '{"type": "Polygon", "coordinates": [[0, 89], [0, -89], [2, -89], [2, 89], [0, 89]]}',
- '{"type": "Polygon", "coordinates": [[179, -1], [179, 1], [-179, 1], [-179, -1], [179, -1]]}',
- '{"type": "Polygon", "coordinates": [[179, 0], [179, 2], [-179, 2], [-179, 0], [179, 0]]}',
- '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}'
+ '{"type": "Polygon", "coordinates": [[[-1, 89], [-1, -89], [1, -89], [1, 89], [-1, 89]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}',
+ '{"type": "Polygon", "coordinates": [[[-1, -1], [-1, 1], [1, 1], [1, -1], [-1, -1]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, 44], [179, 46], [-179, 46], [-179, 44], [179, 44]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}',
+ '{"type": "Polygon", "coordinates": [[[0, 89], [0, -89], [2, -89], [2, 89], [0, 89]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, -1], [179, 1], [-179, 1], [-179, -1], [179, -1]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, 0], [179, 2], [-179, 2], [-179, 0], [179, 0]]]}',
+ '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}'
]
lat_grid_resolution = 1
lon_grid_resolution = 1
@@ -469,10 +473,35 @@ class ExtractRowsTifSupportTest(ExtractRowsTestBase):
def setUp(self) -> None:
super().setUp()
- self.test_data_path = f'{self.test_data_folder}/test_data_tif_start_time.tif'
+ self.test_data_path = f'{self.test_data_folder}/test_data_tif_time.tif'
- def test_extract_rows(self):
- actual = next(self.extract(self.test_data_path, tif_metadata_for_datetime='start_time'))
+ def test_extract_rows_with_end_time(self):
+ actual = next(
+ self.extract(self.test_data_path, tif_metadata_for_start_time='start_time',
+ tif_metadata_for_end_time='end_time')
+ )
+ expected = {
+ 'dewpoint_temperature_2m': 281.09349060058594,
+ 'temperature_2m': 296.8329772949219,
+ 'data_import_time': '1970-01-01T00:00:00+00:00',
+ 'data_first_step': '2020-07-01T00:00:00+00:00',
+ 'data_uri': self.test_data_path,
+ 'latitude': 42.09783344918844,
+ 'longitude': -123.66686981141397,
+ 'time': '2020-07-01T00:00:00+00:00',
+ 'valid_time': '2020-07-01T00:00:00+00:00',
+ 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))),
+ 'geo_polygon': geojson.dumps(geojson.Polygon([
+ (-123.669853, 42.095605), (-123.669853, 42.100066),
+ (-123.663885, 42.100066), (-123.663885, 42.095605),
+ (-123.669853, 42.095605)]))
+ }
+ self.assertRowsEqual(actual, expected)
+
+ def test_extract_rows_without_end_time(self):
+ actual = next(
+ self.extract(self.test_data_path, tif_metadata_for_start_time='start_time')
+ )
expected = {
'dewpoint_temperature_2m': 281.09349060058594,
'temperature_2m': 296.8329772949219,
@@ -737,5 +766,38 @@ def test_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema
self.assertRowsEqual(actual, expected)
+class ExtractRowsFromZarrTest(ExtractRowsTestBase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.tmpdir = tempfile.TemporaryDirectory()
+
+ def tearDown(self) -> None:
+ super().tearDown()
+ self.tmpdir.cleanup()
+
+ def test_extracts_rows(self):
+ input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr')
+
+ ds = (
+ xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder)
+ .isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4))
+ .rename(dict(lon='longitude', lat='latitude'))
+ )
+ ds.to_zarr(input_zarr)
+
+ op = ToBigQuery.from_kwargs(
+ first_uri=input_zarr, zarr_kwargs=dict(chunks=None, consolidated=True), dry_run=True, zarr=True,
+ output_table='foo.bar.baz',
+ variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False,
+ tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, skip_region_validation=True,
+ disable_grib_schema_normalization=False,
+ )
+
+ with TestPipeline() as p:
+ result = p | op
+ assert_that(result, is_not_empty())
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py
index 9ac18327..a5d6bca0 100644
--- a/weather_mv/loader_pipeline/ee.py
+++ b/weather_mv/loader_pipeline/ee.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
+import csv
import dataclasses
import json
import logging
+import math
import os
import re
import shutil
@@ -27,7 +29,6 @@
import apache_beam as beam
import ee
import numpy as np
-import xarray as xr
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE
from apache_beam.options.pipeline_options import PipelineOptions
@@ -36,8 +37,8 @@
from google.auth.transport import requests
from rasterio.io import MemoryFile
-from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin
-from .util import make_attrs_ee_compatible, RateLimit, validate_region
+from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload
+from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp
logger = logging.getLogger(__name__)
@@ -51,6 +52,7 @@
'IMAGE': '.tiff',
'TABLE': '.csv'
}
+ROWS_PER_WRITE = 10_000 # Number of rows per feature collection write.
def is_compute_engine() -> bool:
@@ -155,7 +157,12 @@ def setup(self):
def check_setup(self):
"""Ensures that setup has been called."""
if not self._has_setup:
- self.setup()
+ try:
+ # This throws an exception if ee is not initialized.
+ ee.data.getAlgorithms()
+ self._has_setup = True
+ except ee.EEException:
+ self.setup()
def process(self, *args, **kwargs):
"""Checks that setup has been called then call the process implementation."""
@@ -438,6 +445,8 @@ def add_to_queue(self, queue: Queue, item: t.Any):
def convert_to_asset(self, queue: Queue, uri: str):
"""Converts source data into EE asset (GeoTiff or CSV) and uploads it to the bucket."""
logger.info(f'Converting {uri!r} to COGs...')
+ job_start_time = get_utc_timestamp()
+
with open_dataset(uri,
self.open_dataset_kwargs,
self.disable_grib_schema_normalization,
@@ -457,6 +466,8 @@ def convert_to_asset(self, queue: Queue, uri: str):
('start_time', 'end_time', 'is_normalized','forecast_hour'))
dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform'])
attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool.
+ # Adding job_start_time to properites.
+ attrs["job_start_time"] = job_start_time
# Make attrs EE ingestable.
attrs = make_attrs_ee_compatible(attrs)
@@ -507,21 +518,40 @@ def convert_to_asset(self, queue: Queue, uri: str):
channel_names = []
file_name = f'{asset_name}.csv'
- df = xr.Dataset.to_dataframe(ds)
- df = df.reset_index()
- # NULL and NaN create data-type mismatch issue in ee therefore replacing all of them.
- # fillna fills in NaNs, NULLs, and NaTs but we have to exclude NaTs.
- non_nat = df.select_dtypes(exclude=['datetime', 'timedelta', 'datetimetz'])
- df[non_nat.columns] = non_nat.fillna(-9999)
+ shape = math.prod(list(ds.dims.values()))
+ # Names of dimesions, coordinates and data variables.
+ dims = list(ds.dims)
+ coords = [c for c in list(ds.coords) if c not in dims]
+ vars = list(ds.data_vars)
+ header = dims + coords + vars
+
+ # Data of dimesions, coordinates and data variables.
+ dims_data = [ds[dim].data for dim in dims]
+ coords_data = [np.full((shape,), ds[coord].data) for coord in coords]
+ vars_data = [ds[var].data.flatten() for var in vars]
+ data = coords_data + vars_data
+
+ dims_shape = [len(ds[dim].data) for dim in dims]
- # Copy in-memory dataframe to gcs.
+ def get_dims_data(index: int) -> t.List[t.Any]:
+ """Returns dimensions for the given flattened index."""
+ return [
+ dim[int(index / math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data)
+ ]
+
+ # Copy CSV to gcs.
target_path = os.path.join(self.asset_location, file_name)
- with tempfile.NamedTemporaryFile() as tmp_df:
- df.to_csv(tmp_df.name, index=False)
- tmp_df.flush()
- tmp_df.seek(0)
- with FileSystems().create(target_path) as dst:
- shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE)
+ with tempfile.NamedTemporaryFile() as temp:
+ with open(temp.name, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerows([header])
+ # Write rows in batches.
+ for i in range(0, shape, ROWS_PER_WRITE):
+ writer.writerows(
+ [get_dims_data(i) + list(row) for row in zip(*[d[i:i + ROWS_PER_WRITE] for d in data])]
+ )
+
+ upload(temp.name, target_path)
asset_data = AssetData(
name=asset_name,
target_path=target_path,
@@ -626,6 +656,8 @@ def start_ingestion(self, asset_request: t.Dict) -> str:
"""Creates COG-backed asset in earth engine. Returns the asset id."""
self.check_setup()
+ asset_request['properties']['ingestion_time'] = get_utc_timestamp()
+
try:
if self.ee_asset_type == 'IMAGE':
result = ee.data.createAsset(asset_request)
diff --git a/weather_mv/loader_pipeline/pipeline.py b/weather_mv/loader_pipeline/pipeline.py
index c12bd5f5..ef685473 100644
--- a/weather_mv/loader_pipeline/pipeline.py
+++ b/weather_mv/loader_pipeline/pipeline.py
@@ -17,6 +17,7 @@
import json
import logging
import typing as t
+import warnings
import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
@@ -27,7 +28,7 @@
from .streaming import GroupMessagesByFixedWindows, ParsePaths
logger = logging.getLogger(__name__)
-SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0'
+SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0'
def configure_logger(verbosity: int) -> None:
@@ -55,8 +56,9 @@ def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None
known_args.first_uri = next(iter(all_uris))
with beam.Pipeline(argv=pipeline_args) as p:
- if known_args.topic or known_args.subscription:
-
+ if known_args.zarr:
+ paths = p
+ elif known_args.topic or known_args.subscription:
paths = (
p
# Windowing is based on this code sample:
@@ -140,11 +142,19 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]:
# Validate Zarr arguments
if known_args.uris.endswith('.zarr'):
known_args.zarr = True
- known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None)
if known_args.zarr_kwargs and not known_args.zarr:
raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.')
+ if known_args.zarr_kwargs:
+ if not known_args.zarr_kwargs.get('start_date') or not known_args.zarr_kwargs.get('end_date'):
+ warnings.warn('`--zarr_kwargs` not contains both `start_date` and `end_date`'
+ 'so whole zarr-dataset will ingested.')
+
+ if known_args.zarr:
+ known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None)
+ known_args.zarr_kwargs['consolidated'] = known_args.zarr_kwargs.get('consolidated', True)
+
# Validate subcommand
if known_args.subcommand == 'bigquery' or known_args.subcommand == 'bq':
ToBigQuery.validate_arguments(known_args, pipeline_args)
diff --git a/weather_mv/loader_pipeline/pipeline_test.py b/weather_mv/loader_pipeline/pipeline_test.py
index 4d546192..3834b537 100644
--- a/weather_mv/loader_pipeline/pipeline_test.py
+++ b/weather_mv/loader_pipeline/pipeline_test.py
@@ -30,7 +30,7 @@ def setUp(self) -> None:
).split()
self.tif_base_cli_args = (
'weather-mv bq '
- f'-i {self.test_data_folder}/test_data_tif_start_time.tif '
+ f'-i {self.test_data_folder}/test_data_tif_time.tif '
'-o myproject.mydataset.mytable '
'--import_time 2022-02-04T22:22:12.125893 '
'-s'
@@ -62,7 +62,8 @@ def setUp(self) -> None:
'xarray_open_dataset_kwargs': {},
'coordinate_chunk_size': 10_000,
'disable_grib_schema_normalization': False,
- 'tif_metadata_for_datetime': None,
+ 'tif_metadata_for_start_time': None,
+ 'tif_metadata_for_end_time': None,
'zarr': False,
'zarr_kwargs': {},
'log_level': 2,
@@ -83,7 +84,8 @@ def test_log_level_arg(self):
def test_tif_metadata_for_datetime_raise_error_for_non_tif_file(self):
with self.assertRaisesRegex(RuntimeError, 'can be specified only for tif files.'):
- run(self.base_cli_args + '--tif_metadata_for_datetime start_time'.split())
+ run(self.base_cli_args + '--tif_metadata_for_start_time start_time '
+ '--tif_metadata_for_end_time end_time'.split())
def test_tif_metadata_for_datetime_raise_error_if_flag_is_absent(self):
with self.assertRaisesRegex(RuntimeError, 'is required for tif files.'):
diff --git a/weather_mv/loader_pipeline/regrid_test.py b/weather_mv/loader_pipeline/regrid_test.py
index 5cc5b2a1..87ffaad4 100644
--- a/weather_mv/loader_pipeline/regrid_test.py
+++ b/weather_mv/loader_pipeline/regrid_test.py
@@ -122,7 +122,7 @@ def test_zarr__coarsen(self):
self.Op,
first_uri=input_zarr,
output_path=output_zarr,
- zarr_input_chunks={"time": 5},
+ zarr_input_chunks={"time": 25},
zarr=True
)
diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py
index 8d4f5681..bc064599 100644
--- a/weather_mv/loader_pipeline/sinks.py
+++ b/weather_mv/loader_pipeline/sinks.py
@@ -141,8 +141,9 @@ def rearrange_time_list(order_list: t.List, time_list: t.List) -> t.List:
return datetime.datetime(*time_list)
-def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_datetime: str, uri: str,
- band_names_dict: t.Dict, initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset:
+def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_start_time: str,
+ tif_metadata_for_end_time: str, uri: str, band_names_dict: t.Dict,
+ initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset:
"""Transforms (y, x) coordinates into (lat, long) and adds bands data in data variables.
This also retrieves datetime from tif's metadata and stores it into dataset.
@@ -165,6 +166,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset):
ds = _replace_dataarray_names_with_long_names(ds)
end_time = None
+ start_time = None
if initialization_time_regex and forecast_time_regex:
try:
start_time = match_datetime(uri, initialization_time_regex)
@@ -177,15 +179,40 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset):
ds.attrs['start_time'] = start_time
ds.attrs['end_time'] = end_time
- datetime_value_ms = None
+ init_time = None
+ forecast_time = None
+ coords = {}
try:
- datetime_value_s = (int(end_time.timestamp()) if end_time is not None
- else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0)
- ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)})
- except KeyError:
- raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.")
+ # if start_time/end_time is in integer milliseconds
+ init_time = (int(start_time.timestamp()) if start_time is not None
+ else int(ds.attrs[tif_metadata_for_start_time]) / 1000.0)
+ coords['time'] = datetime.datetime.utcfromtimestamp(init_time)
+
+ if tif_metadata_for_end_time:
+ forecast_time = (int(end_time.timestamp()) if end_time is not None
+ else int(ds.attrs[tif_metadata_for_end_time]) / 1000.0)
+ coords['valid_time'] = datetime.datetime.utcfromtimestamp(forecast_time)
+
+ ds = ds.assign_coords(coords)
+ except KeyError as e:
+ raise RuntimeError(f"Invalid datetime metadata of tif: {e}.")
except ValueError:
- raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.")
+ try:
+ # if start_time/end_time is in UTC string format
+ init_time = (int(start_time.timestamp()) if start_time is not None
+ else datetime.datetime.strptime(ds.attrs[tif_metadata_for_start_time],
+ '%Y-%m-%dT%H:%M:%SZ'))
+ coords['time'] = init_time
+
+ if tif_metadata_for_end_time:
+ forecast_time = (int(end_time.timestamp()) if end_time is not None
+ else datetime.datetime.strptime(ds.attrs[tif_metadata_for_end_time],
+ '%Y-%m-%dT%H:%M:%SZ'))
+ coords['valid_time'] = forecast_time
+
+ ds = ds.assign_coords(coords)
+ except ValueError as e:
+ raise RuntimeError(f"Invalid datetime value in tif's metadata: {e}.")
return ds
@@ -583,6 +610,11 @@ def __open_dataset_file(filename: str,
False)
+def upload(src: str, dst: str) -> None:
+ """Uploads a file to the specified GCS bucket destination."""
+ subprocess.run(f'gsutil -m cp {src} {dst}'.split(), check=True, capture_output=True, text=True, input="n/n")
+
+
def copy(src: str, dst: str) -> None:
"""Copy data via `gcloud alpha storage` or `gsutil`."""
errors: t.List[subprocess.CalledProcessError] = []
@@ -624,7 +656,8 @@ def open_local(uri: str) -> t.Iterator[str]:
def open_dataset(uri: str,
open_dataset_kwargs: t.Optional[t.Dict] = None,
disable_grib_schema_normalization: bool = False,
- tif_metadata_for_datetime: t.Optional[str] = None,
+ tif_metadata_for_start_time: t.Optional[str] = None,
+ tif_metadata_for_end_time: t.Optional[str] = None,
band_names_dict: t.Optional[t.Dict] = None,
initialization_time_regex: t.Optional[str] = None,
forecast_time_regex: t.Optional[str] = None,
@@ -632,8 +665,17 @@ def open_dataset(uri: str,
tiff_config: t.Optional[t.Dict] = None,) -> t.Iterator[t.Union[xr.Dataset, t.List[xr.Dataset]]]:
"""Open the dataset at 'uri' and return a xarray.Dataset."""
try:
+ local_open_dataset_kwargs = start_date = end_date = None
+ if open_dataset_kwargs is not None:
+ local_open_dataset_kwargs = open_dataset_kwargs.copy()
+ start_date = local_open_dataset_kwargs.pop('start_date', None)
+ end_date = local_open_dataset_kwargs.pop('end_date', None)
+
if is_zarr:
- ds: xr.Dataset = xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs)
+ ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr',
+ **local_open_dataset_kwargs), False)
+ if start_date is not None and end_date is not None:
+ ds = ds.sel(time=slice(start_date, end_date))
beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc()
yield ds
ds.close()
@@ -641,11 +683,11 @@ def open_dataset(uri: str,
with open_local(uri) as local_path:
_, uri_extension = os.path.splitext(uri)
- xr_datasets: xr.Dataset = __open_dataset_file(local_path,
- uri_extension,
- disable_grib_schema_normalization,
- open_dataset_kwargs,
- tiff_config)
+ xr_dataset: xr.Dataset = __open_dataset_file(local_path,
+ uri_extension,
+ disable_grib_schema_normalization,
+ local_open_dataset_kwargs,
+ tiff_config)
# Extracting dtype, crs and transform from the dataset & storing them as attributes.
try:
with rasterio.open(local_path, 'r') as f:
@@ -661,14 +703,17 @@ def open_dataset(uri: str,
logger.info(f'opened dataset size: {total_size_in_bytes}')
else:
+ if start_date is not None and end_date is not None:
+ xr_dataset = xr_dataset.sel(time=slice(start_date, end_date))
if uri_extension in ['.tif', '.tiff']:
xr_dataset = _preprocess_tif(xr_datasets,
- local_path,
- tif_metadata_for_datetime,
- uri,
- band_names_dict,
- initialization_time_regex,
- forecast_time_regex)
+ local_path,
+ tif_metadata_for_start_time,
+ tif_metadata_for_end_time,
+ uri,
+ band_names_dict,
+ initialization_time_regex,
+ forecast_time_regex)
else:
xr_dataset = xr_datasets
xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform})
diff --git a/weather_mv/loader_pipeline/sinks_test.py b/weather_mv/loader_pipeline/sinks_test.py
index 01cdce9a..69150dda 100644
--- a/weather_mv/loader_pipeline/sinks_test.py
+++ b/weather_mv/loader_pipeline/sinks_test.py
@@ -84,7 +84,7 @@ def setUp(self) -> None:
super().setUp()
self.test_data_path = os.path.join(self.test_data_folder, 'test_data_20180101.nc')
self.test_grib_path = os.path.join(self.test_data_folder, 'test_data_grib_single_timestep')
- self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_start_time.tif')
+ self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_time.tif')
self.test_zarr_path = os.path.join(self.test_data_folder, 'test_data.zarr')
def test_opens_grib_files(self):
@@ -104,7 +104,8 @@ def test_accepts_xarray_kwargs(self):
self.assertDictContainsSubset({'is_normalized': False}, ds2.attrs)
def test_opens_tif_files(self):
- with open_dataset(self.test_tif_path, tif_metadata_for_datetime='start_time') as ds:
+ with open_dataset(self.test_tif_path, tif_metadata_for_start_time='start_time',
+ tif_metadata_for_end_time='end_time') as ds:
self.assertIsNotNone(ds)
self.assertDictContainsSubset({'is_normalized': False}, ds.attrs)
@@ -112,6 +113,7 @@ def test_opens_zarr(self):
with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds:
self.assertIsNotNone(ds)
self.assertEqual(list(ds.data_vars), ['cape', 'd2m'])
+
def test_open_dataset__fits_memory_bounds(self):
with write_netcdf() as test_netcdf_path:
with limit_memory(max_memory=30):
diff --git a/weather_mv/loader_pipeline/streaming.py b/weather_mv/loader_pipeline/streaming.py
index 7210b2e7..3a7a8f49 100644
--- a/weather_mv/loader_pipeline/streaming.py
+++ b/weather_mv/loader_pipeline/streaming.py
@@ -84,7 +84,7 @@ def try_parse_message(cls, message_body: t.Union[str, t.Dict]) -> t.Dict:
try:
return json.loads(message_body)
except (json.JSONDecodeError, TypeError):
- if type(message_body) is dict:
+ if isinstance(message_body, dict):
return message_body
raise
diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py
index a31a06a9..079b86de 100644
--- a/weather_mv/loader_pipeline/util.py
+++ b/weather_mv/loader_pipeline/util.py
@@ -28,7 +28,6 @@
import uuid
from functools import partial
from urllib.parse import urlparse
-
import apache_beam as beam
import numpy as np
import pandas as pd
@@ -134,6 +133,9 @@ def _check_for_coords_vars(ds_data_var: str, target_var: str) -> bool:
specified by the user."""
return ds_data_var.endswith('_'+target_var) or ds_data_var.startswith(target_var+'_')
+def get_utc_timestamp() -> float:
+ """Returns the current UTC Timestamp."""
+ return datetime.datetime.now().timestamp()
def _only_target_coordinate_vars(ds: xr.Dataset, data_vars: t.List[str]) -> t.List[str]:
"""If the user specifies target fields in the dataset, get all the matching coords & data vars."""
diff --git a/weather_mv/loader_pipeline/util_test.py b/weather_mv/loader_pipeline/util_test.py
index dae9c873..65d9169e 100644
--- a/weather_mv/loader_pipeline/util_test.py
+++ b/weather_mv/loader_pipeline/util_test.py
@@ -38,9 +38,10 @@ def test_gets_indexed_coordinates(self):
ds = xr.open_dataset(self.test_data_path)
self.assertEqual(
next(get_coordinates(ds)),
- {'latitude': 49.0,
- 'longitude':-108.0,
- 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)}
+ {
+ 'latitude': 49.0,
+ 'longitude': -108.0,
+ 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)}
)
def test_no_duplicate_coordinates(self):
@@ -91,24 +92,28 @@ def test_get_coordinates(self):
actual,
[
[
- {'longitude': -108.0,
- 'latitude': 49.0,
- 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)
+ {
+ 'longitude': -108.0,
+ 'latitude': 49.0,
+ 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)
},
- {'longitude': -108.0,
- 'latitude': 49.0,
- 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None)
+ {
+ 'longitude': -108.0,
+ 'latitude': 49.0,
+ 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None)
},
- {'longitude': -108.0,
- 'latitude': 49.0,
- 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None)
+ {
+ 'longitude': -108.0,
+ 'latitude': 49.0,
+ 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None)
},
],
[
- {'longitude': -108.0,
- 'latitude': 49.0,
- 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None)
- }
+ {
+ 'longitude': -108.0,
+ 'latitude': 49.0,
+ 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None)
+ }
]
]
)
diff --git a/weather_mv/setup.py b/weather_mv/setup.py
index bfe09713..4bdb4a0b 100644
--- a/weather_mv/setup.py
+++ b/weather_mv/setup.py
@@ -45,6 +45,7 @@
"numpy==1.22.4",
"pandas==1.5.1",
"xarray==2023.1.0",
+ "xarray-beam==0.6.2",
"cfgrib==0.9.10.2",
"netcdf4==1.6.1",
"geojson==2.5.0",
@@ -55,6 +56,8 @@
"earthengine-api>=0.1.263",
"pyproj==3.4.0", # requires separate binary installation!
"gdal==3.5.1", # requires separate binary installation!
+ "gcsfs==2022.11.0",
+ "zarr==2.15.0",
]
setup(
@@ -62,7 +65,7 @@
packages=find_packages(),
author='Anthromets',
author_email='anthromets-ecmwf@google.com',
- version='0.2.17',
+ version='0.2.19',
url='https://weather-tools.readthedocs.io/en/latest/weather_mv/',
description='A tool to load weather data into BigQuery.',
install_requires=beam_gcp_requirements + base_requirements,
diff --git a/weather_mv/test_data/test_data_tif_start_time.tif b/weather_mv/test_data/test_data_tif_start_time.tif
deleted file mode 100644
index 82f32dd7..00000000
Binary files a/weather_mv/test_data/test_data_tif_start_time.tif and /dev/null differ
diff --git a/weather_mv/test_data/test_data_tif_time.tif b/weather_mv/test_data/test_data_tif_time.tif
new file mode 100644
index 00000000..32b7f63b
Binary files /dev/null and b/weather_mv/test_data/test_data_tif_time.tif differ
diff --git a/weather_sp/setup.py b/weather_sp/setup.py
index 59786c8e..e22279cd 100644
--- a/weather_sp/setup.py
+++ b/weather_sp/setup.py
@@ -44,7 +44,7 @@
packages=find_packages(),
author='Anthromets',
author_email='anthromets-ecmwf@google.com',
- version='0.3.1',
+ version='0.3.2',
url='https://weather-tools.readthedocs.io/en/latest/weather_sp/',
description='A tool to split weather data files into per-variable files.',
install_requires=beam_gcp_requirements + base_requirements,
diff --git a/weather_sp/splitter_pipeline/file_splitters.py b/weather_sp/splitter_pipeline/file_splitters.py
index 6456c426..7a4d3c77 100644
--- a/weather_sp/splitter_pipeline/file_splitters.py
+++ b/weather_sp/splitter_pipeline/file_splitters.py
@@ -16,6 +16,7 @@
import itertools
import logging
import os
+import re
import shutil
import string
import subprocess
@@ -158,6 +159,10 @@ class GribSplitterV2(GribSplitter):
See https://confluence.ecmwf.int/display/ECC/grib_copy.
"""
+ def replace_non_numeric_bracket(self, match: re.Match) -> str:
+ value = match.group(1)
+ return f"[{value}]" if not value.isdigit() else "{" + value + "}"
+
def split_data(self) -> None:
if not self.output_info.split_dims():
raise ValueError('No splitting specified in template.')
@@ -172,7 +177,10 @@ def split_data(self) -> None:
unformatted_output_path = self.output_info.unformatted_output_path()
prefix, _ = os.path.split(next(iter(string.Formatter().parse(unformatted_output_path)))[0])
_, tail = unformatted_output_path.split(prefix)
- output_template = tail.replace('{', '[').replace('}', ']')
+
+ # Replace { with [ and } with ] only for non-numeric values inside {} of tail
+ output_str = re.sub(r'\{(\w+)\}', self.replace_non_numeric_bracket, tail)
+ output_template = output_str.format(*self.output_info.template_folders)
slash = '/'
delimiter = 'DELIMITER'