diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0325c936..b691aadb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,12 +56,15 @@ jobs: channels: conda-forge environment-file: ci${{ matrix.python-version}}.yml activate-environment: weather-tools + miniforge-variant: Mambaforge + miniforge-version: latest + use-mamba: true - name: Check MetView's installation shell: bash -l {0} run: python -m metview selfcheck - name: Run unit tests shell: bash -l {0} - run: pytest --memray + run: pytest --memray --ignore=weather_dl_v2 # Ignoring dl-v2 as it only supports py3.10 lint: runs-on: ubuntu-latest strategy: @@ -84,13 +87,15 @@ jobs: echo "::set-output name=dir::$(pip cache dir)" - name: Install linter run: | - pip install ruff + pip install ruff==0.1.2 - name: Lint project run: ruff check . type-check: runs-on: ubuntu-latest strategy: fail-fast: false + matrix: + python-version: ["3.8"] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 @@ -98,28 +103,27 @@ jobs: access_token: ${{ github.token }} if: ${{github.ref != 'refs/head/main'}} - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 + - name: conda cache + uses: actions/cache@v2 + env: + # Increase this value to reset cache if etc/example-environment.yml has not changed + CACHE_NUMBER: 0 with: - python-version: "3.8" - - name: Setup conda - uses: s-weigand/setup-conda@v1 + path: ~/conda_pkgs_dir + key: + ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ matrix.python-version }}-${{ hashFiles('ci3.8.yml') }} + - name: Setup conda environment + uses: conda-incubator/setup-miniconda@v2 with: - update-conda: true - python-version: "3.8" - conda-channels: anaconda, conda-forge - - name: Install ecCodes - run: | - conda install -y eccodes>=2.21.0 -c conda-forge - conda install -y pyproj -c conda-forge - conda install -y gdal -c conda-forge - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "::set-output name=dir::$(pip cache dir)" - - name: Install weather-tools + python-version: ${{ matrix.python-version }} + channels: conda-forge + environment-file: ci${{ matrix.python-version}}.yml + activate-environment: weather-tools + miniforge-variant: Mambaforge + miniforge-version: latest + use-mamba: true + - name: Install weather-tools[test] run: | - pip install -e .[test] --use-deprecated=legacy-resolver + conda run -n weather-tools pip install -e .[test] --use-deprecated=legacy-resolver - name: Run type checker - run: pytype + run: conda run -n weather-tools pytype diff --git a/Dockerfile b/Dockerfile index 057442f5..30a72f57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,6 +29,7 @@ ARG weather_tools_git_rev=main RUN git clone https://github.com/google/weather-tools.git /weather WORKDIR /weather RUN git checkout "${weather_tools_git_rev}" +RUN rm -r /weather/weather_*/test_data/ RUN conda env create -f environment.yml --debug # Activate the conda env and update the PATH diff --git a/ci3.8.yml b/ci3.8.yml index d6a1e0bd..211d36be 100644 --- a/ci3.8.yml +++ b/ci3.8.yml @@ -16,7 +16,7 @@ dependencies: - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -30,9 +30,11 @@ dependencies: - pip=22.3 - pygrib=2.1.4 - xarray==2023.1.0 - - ruff==0.0.260 + - ruff==0.1.2 - google-cloud-sdk=410.0.0 - aria2=1.36.0 + - zarr=2.15.0 - pip: + - cython==0.29.34 - earthengine-api==0.1.329 - .[test] diff --git a/ci3.9.yml b/ci3.9.yml index a43cec16..86f0968d 100644 --- a/ci3.9.yml +++ b/ci3.9.yml @@ -16,7 +16,7 @@ dependencies: - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -32,7 +32,9 @@ dependencies: - google-cloud-sdk=410.0.0 - aria2=1.36.0 - xarray==2023.1.0 - - ruff==0.0.260 + - ruff==0.1.2 + - zarr=2.15.0 - pip: + - cython==0.29.34 - earthengine-api==0.1.329 - .[test] diff --git a/environment.yml b/environment.yml index eae35f9c..0b043980 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: dependencies: - python=3.8.13 - apache-beam=2.40.0 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - xarray=2023.1.0 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -25,7 +25,9 @@ dependencies: - google-cloud-sdk=410.0.0 - aria2=1.36.0 - pip=22.3 + - zarr=2.15.0 - pip: + - cython==0.29.34 - earthengine-api==0.1.329 - firebase-admin==6.0.1 - . diff --git a/pyproject.toml b/pyproject.toml index 8e782e8d..2eaadb0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,4 +42,4 @@ target-version = "py310" [tool.ruff.mccabe] # Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 \ No newline at end of file +max-complexity = 10 diff --git a/setup.py b/setup.py index af798889..233e5193 100644 --- a/setup.py +++ b/setup.py @@ -57,8 +57,9 @@ "earthengine-api>=0.1.263", "pyproj", # requires separate binary installation! "gdal", # requires separate binary installation! - "xarray-beam==0.3.1", + "xarray-beam==0.6.2", "gcsfs==2022.11.0", + "zarr==2.15.0", ] weather_sp_requirements = [ @@ -70,7 +71,7 @@ test_requirements = [ "pytype==2021.11.29", - "ruff", + "ruff==0.1.2", "pytest", "pytest-subtests", "netcdf4", @@ -82,6 +83,7 @@ "memray", "pytest-memray", "h5py", + "pooch", ] all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \ @@ -115,7 +117,7 @@ ], python_requires='>=3.8, <3.10', - install_requires=['apache-beam[gcp]==2.40.0'], + install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'], use_scm_version=True, setup_requires=['setuptools_scm'], scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'], diff --git a/weather_dl/download_pipeline/util.py b/weather_dl/download_pipeline/util.py index 1ee9e24e..dd09ba29 100644 --- a/weather_dl/download_pipeline/util.py +++ b/weather_dl/download_pipeline/util.py @@ -102,10 +102,10 @@ def to_json_serializable_type(value: t.Any) -> t.Any: return None elif np.issubdtype(type(value), np.floating): return float(value) - elif type(value) == np.ndarray: + elif isinstance(value, np.ndarray): # Will return a scaler if array is of size 1, else will return a list. return value.tolist() - elif type(value) == datetime.datetime or type(value) == str or type(value) == np.datetime64: + elif isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) @@ -126,7 +126,7 @@ def to_json_serializable_type(value: t.Any) -> t.Any: # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() - elif type(value) == np.timedelta64: + elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, 's')) # This check must happen after processing np.timedelta64 and np.datetime64. diff --git a/weather_dl_v2/README.md b/weather_dl_v2/README.md new file mode 100644 index 00000000..ea7b7bb5 --- /dev/null +++ b/weather_dl_v2/README.md @@ -0,0 +1,12 @@ +## weather-dl-v2 + + + +> **_NOTE:_** weather-dl-v2 only supports python 3.10 + +### Sequence of steps: +1) Refer to downloader_kubernetes/README.md +2) Refer to license_deployment/README.md +3) Refer to fastapi-server/README.md +4) Refer to cli/README.md + diff --git a/weather_dl_v2/__init__.py b/weather_dl_v2/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/cli/CLI-Documentation.md b/weather_dl_v2/cli/CLI-Documentation.md new file mode 100644 index 00000000..ea16bb6b --- /dev/null +++ b/weather_dl_v2/cli/CLI-Documentation.md @@ -0,0 +1,306 @@ +# CLI Documentation +The following doc provides cli commands and their various arguments and options. + +Base Command: +``` +weather-dl-v2 +``` + +## Ping +Ping the FastAPI server and check if it’s live and reachable. + + weather-dl-v2 ping + +##### Usage +``` +weather-dl-v2 ping +``` + + +
+ +## Download +Manage download configs. + + +### Add Downloads + weather-dl-v2 download add
+ Adds a new download config to specific licenses. +
+ + +##### Arguments +> `FILE_PATH` : Path to config file. + +##### Options +> `-l/--license` (Required): License ID to which this download has to be added to. +> `-f/--force-download` : Force redownload of partitions that were previously downloaded. + +##### Usage +``` +weather-dl-v2 download add /path/to/example.cfg –l L1 -l L2 [--force-download] +``` + +### List Downloads + weather-dl-v2 download list
+ List all the active downloads. +
+ +The list can also be filtered out by client_names. +Available filters: +``` +Filter Key: client_name +Values: cds, mars, ecpublic + +Filter Key: status +Values: completed, failed, in-progress +``` + +##### Options +> `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value + +##### Usage +``` +weather-dl-v2 download list +weather-dl-v2 download list --filter client_name=cds +weather-dl-v2 download list --filter status=success +weather-dl-v2 download list --filter status=failed +weather-dl-v2 download list --filter status=in-progress +weather-dl-v2 download list --filter client_name=cds --filter status=success +``` + +### Download Get + weather-dl-v2 download get
+ Get a particular download by config name. +
+ +##### Arguments +> `CONFIG_NAME` : Name of the download config. + +##### Usage +``` +weather-dl-v2 download get example.cfg +``` + +### Download Show + weather-dl-v2 download show
+ Get contents of a particular config by config name. +
+ +##### Arguments +> `CONFIG_NAME` : Name of the download config. + +##### Usage +``` +weather-dl-v2 download show example.cfg +``` + +### Download Remove + weather-dl-v2 download remove
+ Remove a download by config name. +
+ +##### Arguments +> `CONFIG_NAME` : Name of the download config. + +##### Usage +``` +weather-dl-v2 download remove example.cfg +``` + +### Download Refetch + weather-dl-v2 download refetch
+ Refetch all non-successful partitions of a config. +
+ +##### Arguments +> `CONFIG_NAME` : Name of the download config. + +##### Options +> `-l/--license` (Required): License ID to which this download has to be added to. + +##### Usage +``` +weather-dl-v2 download refetch example.cfg -l L1 -l L2 +``` + +
+ +## License +Manage licenses. + +### License Add + weather-dl-v2 license add
+ Add a new license. New licenses are added using a json file. +
+ +The json file should be in this format: +``` +{ + "license_id: , + "client_name": , + "number_of_requests": , + "secret_id": +} +``` +NOTE: `license_id` is case insensitive and has to be unique for each license. + + +##### Arguments +> `FILE_PATH` : Path to the license json. + +##### Usage +``` +weather-dl-v2 license add /path/to/new-license.json +``` + +### License Get + weather-dl-v2 license get
+ Get a particular license by license ID. +
+ +##### Arguments +> `LICENSE` : License ID of the license to be fetched. + +##### Usage +``` +weather-dl-v2 license get L1 +``` + +### License Remove + weather-dl-v2 license remove
+ Remove a particular license by license ID. +
+ +##### Arguments +> `LICENSE` : License ID of the license to be removed. + +##### Usage +``` +weather-dl-v2 license remove L1 +``` + +### License List + weather-dl-v2 license list
+ List all the licenses available. +
+ + The list can also be filtered by client name. + +##### Options +> `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value. + +##### Usage +``` +weather-dl-v2 license list +weather-dl-v2 license list --filter client_name=cds +``` + +### License Update + weather-dl-v2 license update
+ Update an existing license using License ID and a license json. +
+ + The json should be of the same format used to add a new license. + +##### Arguments +> `LICENSE` : License ID of the license to be edited. +> `FILE_PATH` : Path to the license json. + +##### Usage +``` +weather-dl-v2 license update L1 /path/to/license.json +``` + +
+ +## Queue +Manage all the license queue. + +### Queue List + weather-dl-v2 queue list
+ List all the queues. +
+ + The list can also be filtered by client name. + +##### Options +> `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value. + +##### Usage +``` +weather-dl-v2 queue list +weather-dl-v2 queue list --filter client_name=cds +``` + +### Queue Get + weather-dl-v2 queue get
+ Get a queue by license ID. +
+ + The list can also be filtered by client name. + +##### Arguments +> `LICENSE` : License ID of the queue to be fetched. + +##### Usage +``` +weather-dl-v2 queue get L1 +``` + +### Queue Edit + weather-dl-v2 queue edit
+ Edit the priority of configs inside queues using edit. +
+ +Priority can be edited in two ways: +1. The new priority queue is passed using a priority json file that should follow the following format: +``` +{ + “priority”: [“c1.cfg”, “c3.cfg”, “c2.cfg”] +} +``` +2. A config file name and its absolute priority can be passed and it updates the priority for that particular config file in the mentioned license queue. + +##### Arguments +> `LICENSE` : License ID of queue to be edited. + +##### Options +> `-f/--file` : Path of the new priority json file. +> `-c/--config` : Config name for absolute priority. +> `-p/--priority`: Absolute priority for the config in a license queue. Priority increases in ascending order with 0 having highest priority. + +##### Usage +``` +weather-dl-v2 queue edit L1 --file /path/to/priority.json +weather-dl-v2 queue edit L1 --config example.cfg --priority 0 +``` + +
+ +## Config +Configurations for cli. + +### Config Show IP + weather-dl-v2 config show-ip
+See the current server IP address. +
+ +##### Usage +``` +weather-dl-v2 config show-ip +``` + +### Config Set IP + weather-dl-v2 config set-ip
+See the current server IP address. +
+ +##### Arguments +> `NEW_IP` : New IP address. (Do not add port or protocol). + +##### Usage +``` +weather-dl-v2 config set-ip 127.0.0.1 +``` + diff --git a/weather_dl_v2/cli/Dockerfile b/weather_dl_v2/cli/Dockerfile new file mode 100644 index 00000000..ec3536be --- /dev/null +++ b/weather_dl_v2/cli/Dockerfile @@ -0,0 +1,43 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +FROM continuumio/miniconda3:latest + +COPY . . + +# Add the mamba solver for faster builds +RUN conda install -n base conda-libmamba-solver +RUN conda config --set solver libmamba + +# Create conda env using environment.yml +RUN conda update conda -y +RUN conda env create --name weather-dl-v2-cli --file=environment.yml + +# Activate the conda env and update the PATH +ARG CONDA_ENV_NAME=weather-dl-v2-cli +RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc +ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH + +RUN apt-get update -y +RUN apt-get install nano -y +RUN apt-get install vim -y +RUN apt-get install curl -y + +# Install gsutil +RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-443.0.0-linux-arm.tar.gz +RUN tar -xf google-cloud-cli-443.0.0-linux-arm.tar.gz +RUN ./google-cloud-sdk/install.sh --quiet +RUN echo "if [ -f '/google-cloud-sdk/path.bash.inc' ]; then . '/google-cloud-sdk/path.bash.inc'; fi" >> /root/.bashrc +RUN echo "if [ -f '/google-cloud-sdk/completion.bash.inc' ]; then . '/google-cloud-sdk/completion.bash.inc'; fi" >> /root/.bashrc diff --git a/weather_dl_v2/cli/README.md b/weather_dl_v2/cli/README.md new file mode 100644 index 00000000..a4f4932f --- /dev/null +++ b/weather_dl_v2/cli/README.md @@ -0,0 +1,58 @@ +# weather-dl-cli +This is a command line interface for talking to the weather-dl-v2 FastAPI server. + +- Due to our org level policy we can't expose external-ip using LoadBalancer Service +while deploying our FastAPI server. Hence we need to deploy the CLI on a VM to interact +through our fastapi server. + +Replace the FastAPI server pod's IP in cli_config.json. +``` +Please make approriate changes in cli_config.json, if required. +``` +> Note: Command to get the Pod IP : `kubectl get pods -o wide`. +> +> Though note that in case of Pod restart IP might get change. So we need to look +> for better solution for the same. + +## Create docker image for weather-dl-cli + +``` +export PROJECT_ID= +export REPO= eg:weather-tools + +gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-cli" --timeout=79200 --machine-type=e2-highcpu-32 +``` + +## Create a VM using above created docker-image +``` +export ZONE= eg: us-west1-a +export SERVICE_ACCOUNT= # Let's keep this as Compute Engine Default Service Account +export IMAGE_PATH= # The above created image-path + +gcloud compute instances create-with-container weather-dl-v2-cli \ + --project=$PROJECT_ID \ + --zone=$ZONE \ + --machine-type=e2-medium \ + --network-interface=network-tier=PREMIUM,subnet=default \ + --maintenance-policy=MIGRATE \ + --provisioning-model=STANDARD \ + --service-account=$SERVICE_ACCOUNT \ + --scopes=https://www.googleapis.com/auth/cloud-platform \ + --tags=http-server,https-server \ + --image=projects/cos-cloud/global/images/cos-stable-105-17412-101-24 \ + --boot-disk-size=10GB \ + --boot-disk-type=pd-balanced \ + --boot-disk-device-name=weather-dl-v2-cli \ + --container-image=$IMAGE_PATH \ + --container-restart-policy=on-failure \ + --container-tty \ + --no-shielded-secure-boot \ + --shielded-vtpm \ + --labels=goog-ec-src=vm_add-gcloud,container-vm=cos-stable-105-17412-101-24 \ + --metadata-from-file=startup-script=vm-startup.sh +``` + +## Use the cli after doing ssh in the above created VM +``` +weather-dl-v2 --help +``` diff --git a/weather_dl_v2/cli/app/__init__.py b/weather_dl_v2/cli/app/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/cli/app/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/cli/app/cli_config.py b/weather_dl_v2/cli/app/cli_config.py new file mode 100644 index 00000000..9bfeb1de --- /dev/null +++ b/weather_dl_v2/cli/app/cli_config.py @@ -0,0 +1,62 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import dataclasses +import typing as t +import json +import os + +Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class CliConfig: + pod_ip: str = "" + port: str = "" + + @property + def BASE_URI(self) -> str: + return f"http://{self.pod_ip}:{self.port}" + + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict): + config_instance = cls() + + for key, value in config.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + + return config_instance + + +cli_config = None + + +def get_config(): + global cli_config + # TODO: Update this so cli can work from any folder level. + # Right now it only works in folder where cli_config.json is present. + cli_config_json = os.path.join(os.getcwd(), "cli_config.json") + + if cli_config is None: + with open(cli_config_json) as file: + firestore_dict = json.load(file) + cli_config = CliConfig.from_dict(firestore_dict) + + return cli_config diff --git a/weather_dl_v2/cli/app/main.py b/weather_dl_v2/cli/app/main.py new file mode 100644 index 00000000..03a52577 --- /dev/null +++ b/weather_dl_v2/cli/app/main.py @@ -0,0 +1,48 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typer +import logging +from app.cli_config import get_config +import requests +from app.subcommands import download, queue, license, config +from app.utils import Loader + +logger = logging.getLogger(__name__) + +app = typer.Typer( + help="weather-dl-v2 is a cli tool for communicating with FastAPI server." +) + +app.add_typer(download.app, name="download", help="Manage downloads.") +app.add_typer(queue.app, name="queue", help="Manage queues.") +app.add_typer(license.app, name="license", help="Manage licenses.") +app.add_typer(config.app, name="config", help="Configurations for cli.") + + +@app.command("ping", help="Check if FastAPI server is live and rechable.") +def ping(): + uri = f"{get_config().BASE_URI}/" + try: + with Loader("Sending request..."): + x = requests.get(uri) + except Exception as e: + print(f"error {e}") + return + print(x.text) + + +if __name__ == "__main__": + app() diff --git a/weather_dl_v2/cli/app/services/__init__.py b/weather_dl_v2/cli/app/services/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/cli/app/services/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/cli/app/services/download_service.py b/weather_dl_v2/cli/app/services/download_service.py new file mode 100644 index 00000000..4d467271 --- /dev/null +++ b/weather_dl_v2/cli/app/services/download_service.py @@ -0,0 +1,128 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +import json +import typing as t +from app.services.network_service import network_service +from app.cli_config import get_config + +logger = logging.getLogger(__name__) + + +class DownloadService(abc.ABC): + + @abc.abstractmethod + def _list_all_downloads(self): + pass + + @abc.abstractmethod + def _list_all_downloads_by_filter(self, filter_dict: dict): + pass + + @abc.abstractmethod + def _get_download_by_config(self, config_name: str): + pass + + @abc.abstractmethod + def _show_config_content(self, config_name: str): + pass + + @abc.abstractmethod + def _add_new_download( + self, file_path: str, licenses: t.List[str], force_download: bool + ): + pass + + @abc.abstractmethod + def _remove_download(self, config_name: str): + pass + + @abc.abstractmethod + def _refetch_config_partitions(self, config_name: str, licenses: t.List[str]): + pass + + +class DownloadServiceNetwork(DownloadService): + + def __init__(self): + self.endpoint = f"{get_config().BASE_URI}/download" + + def _list_all_downloads(self): + return network_service.get( + uri=self.endpoint, header={"accept": "application/json"} + ) + + def _list_all_downloads_by_filter(self, filter_dict: dict): + return network_service.get( + uri=self.endpoint, + header={"accept": "application/json"}, + query=filter_dict, + ) + + def _get_download_by_config(self, config_name: str): + return network_service.get( + uri=f"{self.endpoint}/{config_name}", + header={"accept": "application/json"}, + ) + + def _show_config_content(self, config_name: str): + return network_service.get( + uri=f"{self.endpoint}/show/{config_name}", + header={"accept": "application/json"}, + ) + + def _add_new_download( + self, file_path: str, licenses: t.List[str], force_download: bool + ): + try: + file = {"file": open(file_path, "rb")} + except FileNotFoundError: + return "File not found." + + return network_service.post( + uri=self.endpoint, + header={"accept": "application/json"}, + file=file, + payload={"licenses": licenses}, + query={"force_download": force_download}, + ) + + def _remove_download(self, config_name: str): + return network_service.delete( + uri=f"{self.endpoint}/{config_name}", header={"accept": "application/json"} + ) + + def _refetch_config_partitions(self, config_name: str, licenses: t.List[str]): + return network_service.post( + uri=f"{self.endpoint}/retry/{config_name}", + header={"accept": "application/json"}, + payload=json.dumps({"licenses": licenses}), + ) + + +class DownloadServiceMock(DownloadService): + pass + + +def get_download_service(test: bool = False): + if test: + return DownloadServiceMock() + else: + return DownloadServiceNetwork() + + +download_service = get_download_service() diff --git a/weather_dl_v2/cli/app/services/license_service.py b/weather_dl_v2/cli/app/services/license_service.py new file mode 100644 index 00000000..09ff4f3c --- /dev/null +++ b/weather_dl_v2/cli/app/services/license_service.py @@ -0,0 +1,107 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +import json +from app.services.network_service import network_service +from app.cli_config import get_config + +logger = logging.getLogger(__name__) + + +class LicenseService(abc.ABC): + + @abc.abstractmethod + def _get_all_license(self): + pass + + @abc.abstractmethod + def _get_all_license_by_client_name(self, client_name: str): + pass + + @abc.abstractmethod + def _get_license_by_license_id(self, license_id: str): + pass + + @abc.abstractmethod + def _add_license(self, license_dict: dict): + pass + + @abc.abstractmethod + def _remove_license(self, license_id: str): + pass + + @abc.abstractmethod + def _update_license(self, license_id: str, license_dict: dict): + pass + + +class LicenseServiceNetwork(LicenseService): + + def __init__(self): + self.endpoint = f"{get_config().BASE_URI}/license" + + def _get_all_license(self): + return network_service.get( + uri=self.endpoint, header={"accept": "application/json"} + ) + + def _get_all_license_by_client_name(self, client_name: str): + return network_service.get( + uri=self.endpoint, + header={"accept": "application/json"}, + query={"client_name": client_name}, + ) + + def _get_license_by_license_id(self, license_id: str): + return network_service.get( + uri=f"{self.endpoint}/{license_id}", + header={"accept": "application/json"}, + ) + + def _add_license(self, license_dict: dict): + return network_service.post( + uri=self.endpoint, + header={"accept": "application/json"}, + payload=json.dumps(license_dict), + ) + + def _remove_license(self, license_id: str): + return network_service.delete( + uri=f"{self.endpoint}/{license_id}", + header={"accept": "application/json"}, + ) + + def _update_license(self, license_id: str, license_dict: dict): + return network_service.put( + uri=f"{self.endpoint}/{license_id}", + header={"accept": "application/json"}, + payload=json.dumps(license_dict), + ) + + +class LicenseServiceMock(LicenseService): + pass + + +def get_license_service(test: bool = False): + if test: + return LicenseServiceMock() + else: + return LicenseServiceNetwork() + + +license_service = get_license_service() diff --git a/weather_dl_v2/cli/app/services/network_service.py b/weather_dl_v2/cli/app/services/network_service.py new file mode 100644 index 00000000..4406d91b --- /dev/null +++ b/weather_dl_v2/cli/app/services/network_service.py @@ -0,0 +1,85 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import requests +import json +import logging +from app.utils import Loader, timeit + +logger = logging.getLogger(__name__) + + +class NetworkService: + + def parse_response(self, response: requests.Response): + try: + parsed = json.loads(response.text) + except Exception as e: + logger.info(f"Parsing error: {e}.") + logger.info(f"Status code {response.status_code}") + logger.info(f"Response {response.text}") + return + + if isinstance(parsed, list): + print(f"[Total {len(parsed)} items.]") + + return json.dumps(parsed, indent=3) + + @timeit + def get(self, uri, header, query=None, payload=None): + try: + with Loader("Sending request..."): + x = requests.get(uri, params=query, headers=header, data=payload) + return self.parse_response(x) + except requests.exceptions.RequestException as e: + logger.error(f"request error: {e}") + raise SystemExit(e) + + @timeit + def post(self, uri, header, query=None, payload=None, file=None): + try: + with Loader("Sending request..."): + x = requests.post( + uri, params=query, headers=header, data=payload, files=file + ) + return self.parse_response(x) + except requests.exceptions.RequestException as e: + logger.error(f"request error: {e}") + raise SystemExit(e) + + @timeit + def put(self, uri, header, query=None, payload=None, file=None): + try: + with Loader("Sending request..."): + x = requests.put( + uri, params=query, headers=header, data=payload, files=file + ) + return self.parse_response(x) + except requests.exceptions.RequestException as e: + logger.error(f"request error: {e}") + raise SystemExit(e) + + @timeit + def delete(self, uri, header, query=None): + try: + with Loader("Sending request..."): + x = requests.delete(uri, params=query, headers=header) + return self.parse_response(x) + except requests.exceptions.RequestException as e: + logger.error(f"request error: {e}") + raise SystemExit(e) + + +network_service = NetworkService() diff --git a/weather_dl_v2/cli/app/services/queue_service.py b/weather_dl_v2/cli/app/services/queue_service.py new file mode 100644 index 00000000..f6824934 --- /dev/null +++ b/weather_dl_v2/cli/app/services/queue_service.py @@ -0,0 +1,101 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +import json +import typing as t +from app.services.network_service import network_service +from app.cli_config import get_config + +logger = logging.getLogger(__name__) + + +class QueueService(abc.ABC): + + @abc.abstractmethod + def _get_all_license_queues(self): + pass + + @abc.abstractmethod + def _get_license_queue_by_client_name(self, client_name: str): + pass + + @abc.abstractmethod + def _get_queue_by_license(self, license_id: str): + pass + + @abc.abstractmethod + def _edit_license_queue(self, license_id: str, priority_list: t.List[str]): + pass + + @abc.abstractmethod + def _edit_config_absolute_priority( + self, license_id: str, config_name: str, priority: int + ): + pass + + +class QueueServiceNetwork(QueueService): + + def __init__(self): + self.endpoint = f"{get_config().BASE_URI}/queues" + + def _get_all_license_queues(self): + return network_service.get( + uri=self.endpoint, header={"accept": "application/json"} + ) + + def _get_license_queue_by_client_name(self, client_name: str): + return network_service.get( + uri=self.endpoint, + header={"accept": "application/json"}, + query={"client_name": client_name}, + ) + + def _get_queue_by_license(self, license_id: str): + return network_service.get( + uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json"} + ) + + def _edit_license_queue(self, license_id: str, priority_list: t.List[str]): + return network_service.post( + uri=f"{self.endpoint}/{license_id}", + header={"accept": "application/json", "Content-Type": "application/json"}, + payload=json.dumps(priority_list), + ) + + def _edit_config_absolute_priority( + self, license_id: str, config_name: str, priority: int + ): + return network_service.put( + uri=f"{self.endpoint}/priority/{license_id}", + header={"accept": "application/json"}, + query={"config_name": config_name, "priority": priority}, + ) + + +class QueueServiceMock(QueueService): + pass + + +def get_queue_service(test: bool = False): + if test: + return QueueServiceMock() + else: + return QueueServiceNetwork() + + +queue_service = get_queue_service() diff --git a/weather_dl_v2/cli/app/subcommands/__init__.py b/weather_dl_v2/cli/app/subcommands/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/cli/app/subcommands/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/cli/app/subcommands/config.py b/weather_dl_v2/cli/app/subcommands/config.py new file mode 100644 index 00000000..b2a03aaf --- /dev/null +++ b/weather_dl_v2/cli/app/subcommands/config.py @@ -0,0 +1,60 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typer +import json +import os +from typing_extensions import Annotated +from app.cli_config import get_config +from app.utils import Validator + +app = typer.Typer() + + +class ConfigValidator(Validator): + pass + + +@app.command("show-ip", help="See the current server IP address.") +def show_server_ip(): + print(f"Current pod IP: {get_config().pod_ip}") + + +@app.command("set-ip", help="Update the server IP address.") +def update_server_ip( + new_ip: Annotated[ + str, typer.Argument(help="New IP address. (Do not add port or protocol).") + ], +): + file_path = os.path.join(os.getcwd(), "cli_config.json") + cli_config = {} + with open(file_path, "r") as file: + cli_config = json.load(file) + + old_ip = cli_config["pod_ip"] + cli_config["pod_ip"] = new_ip + + with open(file_path, "w") as file: + json.dump(cli_config, file) + + validator = ConfigValidator(valid_keys=["pod_ip", "port"]) + + try: + cli_config = validator.validate_json(file_path=file_path) + except Exception as e: + print(f"payload error: {e}") + return + + print(f"Pod IP Updated {old_ip} -> {new_ip} .") diff --git a/weather_dl_v2/cli/app/subcommands/download.py b/weather_dl_v2/cli/app/subcommands/download.py new file mode 100644 index 00000000..b16a26e8 --- /dev/null +++ b/weather_dl_v2/cli/app/subcommands/download.py @@ -0,0 +1,102 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typer +from typing_extensions import Annotated +from app.services.download_service import download_service +from app.utils import Validator, as_table +from typing import List + +app = typer.Typer(rich_markup_mode="markdown") + + +class DowloadFilterValidator(Validator): + pass + + +@app.command("list", help="List out all the configs.") +def get_downloads( + filter: Annotated[ + List[str], + typer.Option( + help="""Filter by some value. Format: filter_key=filter_value. Available filters """ + """[key: client_name, values: cds, mars, ecpublic] """ + """[key: status, values: completed, failed, in-progress]""" + ), + ] = [] +): + if len(filter) > 0: + validator = DowloadFilterValidator(valid_keys=["client_name", "status"]) + + try: + filter_dict = validator.validate(filters=filter, allow_missing=True) + except Exception as e: + print(f"filter error: {e}") + return + + print(as_table(download_service._list_all_downloads_by_filter(filter_dict))) + return + + print(as_table(download_service._list_all_downloads())) + + +# TODO: Add support for submitting multiple configs using *.cfg notation. +@app.command("add", help="Submit new config to download.") +def submit_download( + file_path: Annotated[ + str, typer.Argument(help="File path of config to be uploaded.") + ], + license: Annotated[List[str], typer.Option("--license", "-l", help="License ID.")], + force_download: Annotated[ + bool, + typer.Option( + "-f", + "--force-download", + help="Force redownload of partitions that were previously downloaded.", + ), + ] = False, +): + print(download_service._add_new_download(file_path, license, force_download)) + + +@app.command("get", help="Get a particular config.") +def get_download_by_config( + config_name: Annotated[str, typer.Argument(help="Config file name.")] +): + print(as_table(download_service._get_download_by_config(config_name))) + + +@app.command("show", help="Show contents of a particular config.") +def show_config( + config_name: Annotated[str, typer.Argument(help="Config file name.")] +): + print(download_service._show_config_content(config_name)) + + +@app.command("remove", help="Remove existing config.") +def remove_download( + config_name: Annotated[str, typer.Argument(help="Config file name.")] +): + print(download_service._remove_download(config_name)) + + +@app.command( + "refetch", help="Reschedule all partitions of a config that are not successful." +) +def refetch_config( + config_name: Annotated[str, typer.Argument(help="Config file name.")], + license: Annotated[List[str], typer.Option("--license", "-l", help="License ID.")], +): + print(download_service._refetch_config_partitions(config_name, license)) diff --git a/weather_dl_v2/cli/app/subcommands/license.py b/weather_dl_v2/cli/app/subcommands/license.py new file mode 100644 index 00000000..68dccd1d --- /dev/null +++ b/weather_dl_v2/cli/app/subcommands/license.py @@ -0,0 +1,104 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typer +from typing_extensions import Annotated +from app.services.license_service import license_service +from app.utils import Validator, as_table + +app = typer.Typer() + + +class LicenseValidator(Validator): + pass + + +@app.command("list", help="List all licenses.") +def get_all_license( + filter: Annotated[ + str, typer.Option(help="Filter by some value. Format: filter_key=filter_value") + ] = None +): + if filter: + validator = LicenseValidator(valid_keys=["client_name"]) + + try: + data = validator.validate(filters=[filter]) + client_name = data["client_name"] + except Exception as e: + print(f"filter error: {e}") + return + + print(as_table(license_service._get_all_license_by_client_name(client_name))) + return + + print(as_table(license_service._get_all_license())) + + +@app.command("get", help="Get a particular license by ID.") +def get_license(license: Annotated[str, typer.Argument(help="License ID.")]): + print(as_table(license_service._get_license_by_license_id(license))) + + +@app.command("add", help="Add new license.") +def add_license( + file_path: Annotated[ + str, + typer.Argument( + help="""Input json file. Example json for new license-""" + """{"license_id" : , "client_name" : , "number_of_requests" : , "secret_id" : }""" + """\nNOTE: license_id is case insensitive and has to be unique for each license.""" + ), + ], +): + validator = LicenseValidator( + valid_keys=["license_id", "client_name", "number_of_requests", "secret_id"] + ) + + try: + license_dict = validator.validate_json(file_path=file_path) + except Exception as e: + print(f"payload error: {e}") + return + + print(license_service._add_license(license_dict)) + + +@app.command("remove", help="Remove a license.") +def remove_license(license: Annotated[str, typer.Argument(help="License ID.")]): + print(license_service._remove_license(license)) + + +@app.command("update", help="Update existing license.") +def update_license( + license: Annotated[str, typer.Argument(help="License ID.")], + file_path: Annotated[ + str, + typer.Argument( + help="""Input json file. Example json for updated license- """ + """{"client_id": , "client_name" : , "number_of_requests" : , "secret_id" : }""" + ), + ], # noqa +): + validator = LicenseValidator( + valid_keys=["client_id", "client_name", "number_of_requests", "secret_id"] + ) + try: + license_dict = validator.validate_json(file_path=file_path) + except Exception as e: + print(f"payload error: {e}") + return + + print(license_service._update_license(license, license_dict)) diff --git a/weather_dl_v2/cli/app/subcommands/queue.py b/weather_dl_v2/cli/app/subcommands/queue.py new file mode 100644 index 00000000..816564ca --- /dev/null +++ b/weather_dl_v2/cli/app/subcommands/queue.py @@ -0,0 +1,111 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typer +from typing_extensions import Annotated +from app.services.queue_service import queue_service +from app.utils import Validator, as_table + +app = typer.Typer() + + +class QueueValidator(Validator): + pass + + +@app.command("list", help="List all the license queues.") +def get_all_license_queue( + filter: Annotated[ + str, typer.Option(help="Filter by some value. Format: filter_key=filter_value") + ] = None +): + if filter: + validator = QueueValidator(valid_keys=["client_name"]) + + try: + data = validator.validate(filters=[filter]) + client_name = data["client_name"] + except Exception as e: + print(f"filter error: {e}") + return + + print(as_table(queue_service._get_license_queue_by_client_name(client_name))) + return + + print(as_table(queue_service._get_all_license_queues())) + + +@app.command("get", help="Get queue of particular license.") +def get_license_queue(license: Annotated[str, typer.Argument(help="License ID")]): + print(as_table(queue_service._get_queue_by_license(license))) + + +@app.command( + "edit", + help="Edit existing license queue. Queue can edited via a priority" + "file or my moving a single config to a given priority.", +) # noqa +def modify_license_queue( + license: Annotated[str, typer.Argument(help="License ID.")], + file: Annotated[ + str, + typer.Option( + "--file", + "-f", + help="""File path of priority json file. Example json: {"priority": ["c1.cfg", "c2.cfg",...]}""", + ), + ] = None, # noqa + config: Annotated[ + str, typer.Option("--config", "-c", help="Config name for absolute priority.") + ] = None, + priority: Annotated[ + int, + typer.Option( + "--priority", + "-p", + help="Absolute priority for the config in a license queue." + "Priority increases in ascending order with 0 having highest priority.", + ), + ] = None, # noqa +): + if file is None and (config is None and priority is None): + print("Priority file or config name with absolute priority must be passed.") + return + + if file is not None and (config is not None or priority is not None): + print("--config & --priority can't be used along with --file argument.") + return + + if file is not None: + validator = QueueValidator(valid_keys=["priority"]) + + try: + data = validator.validate_json(file_path=file) + priority_list = data["priority"] + except Exception as e: + print(f"key error: {e}") + return + print(queue_service._edit_license_queue(license, priority_list)) + return + elif config is not None and priority is not None: + if priority < 0: + print("Priority can not be negative.") + return + + print(queue_service._edit_config_absolute_priority(license, config, priority)) + return + else: + print("--config & --priority arguments should be used together.") + return diff --git a/weather_dl_v2/cli/app/utils.py b/weather_dl_v2/cli/app/utils.py new file mode 100644 index 00000000..1ced5c7b --- /dev/null +++ b/weather_dl_v2/cli/app/utils.py @@ -0,0 +1,168 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +import dataclasses +import typing as t +import json +from time import time +from itertools import cycle +from shutil import get_terminal_size +from threading import Thread +from time import sleep +from tabulate import tabulate + +logger = logging.getLogger(__name__) + + +def timeit(func): + def wrap_func(*args, **kwargs): + t1 = time() + result = func(*args, **kwargs) + t2 = time() + print(f"[executed in {(t2-t1):.4f}s.]") + return result + + return wrap_func + + +# TODO: Add a flag (may be -j/--json) to support raw response. +def as_table(response: str): + data = json.loads(response) + + if not isinstance(data, list): + # convert response to list if not a list. + data = [data] + + if len(data) == 0: + return "" + + header = data[0].keys() + # if any column has lists, convert that to a string. + rows = [ + [ + ",\n".join([f"{i} {ele}" for i, ele in enumerate(val)]) + if isinstance(val, list) + else val + for val in x.values() + ] + for x in data + ] + rows.insert(0, list(header)) + return tabulate( + rows, showindex=True, tablefmt="grid", maxcolwidths=[16] * len(header) + ) + + +class Loader: + + def __init__(self, desc="Loading...", end="", timeout=0.1): + """ + A loader-like context manager + + Args: + desc (str, optional): The loader's description. Defaults to "Loading...". + end (str, optional): Final print. Defaults to "Done!". + timeout (float, optional): Sleep time between prints. Defaults to 0.1. + """ + self.desc = desc + self.end = end + self.timeout = timeout + + self._thread = Thread(target=self._animate, daemon=True) + self.steps = ["⢿", "⣻", "⣽", "⣾", "⣷", "⣯", "⣟", "⡿"] + self.done = False + + def start(self): + self._thread.start() + return self + + def _animate(self): + for c in cycle(self.steps): + if self.done: + break + print(f"\r{self.desc} {c}", flush=True, end="") + sleep(self.timeout) + + def __enter__(self): + self.start() + + def stop(self): + self.done = True + cols = get_terminal_size((80, 20)).columns + print("\r" + " " * cols, end="", flush=True) + + def __exit__(self, exc_type, exc_value, tb): + # handle exceptions with those variables ^ + self.stop() + + +@dataclasses.dataclass +class Validator(abc.ABC): + valid_keys: t.List[str] + + def validate( + self, filters: t.List[str], show_valid_filters=True, allow_missing: bool = False + ): + filter_dict = {} + + for filter in filters: + _filter = filter.split("=") + + if len(_filter) != 2: + if show_valid_filters: + logger.info(f"valid filters are: {self.valid_keys}.") + raise ValueError("Incorrect Filter. Please Try again.") + + key, value = _filter + filter_dict[key] = value + + data_set = set(filter_dict.keys()) + valid_set = set(self.valid_keys) + + if self._validate_keys(data_set, valid_set, allow_missing): + return filter_dict + + def validate_json(self, file_path, allow_missing: bool = False): + try: + with open(file_path) as f: + data: dict = json.load(f) + data_keys = data.keys() + + data_set = set(data_keys) + valid_set = set(self.valid_keys) + + if self._validate_keys(data_set, valid_set, allow_missing): + return data + + except FileNotFoundError: + logger.info("file not found.") + raise FileNotFoundError + + def _validate_keys(self, data_set: set, valid_set: set, allow_missing: bool): + missing_keys = valid_set.difference(data_set) + invalid_keys = data_set.difference(valid_set) + + if not allow_missing and len(missing_keys) > 0: + raise ValueError(f"keys {missing_keys} are missing in file.") + + if len(invalid_keys) > 0: + raise ValueError(f"keys {invalid_keys} are invalid keys.") + + if allow_missing or data_set == valid_set: + return True + + return False diff --git a/weather_dl_v2/cli/cli_config.json b/weather_dl_v2/cli/cli_config.json new file mode 100644 index 00000000..076ed641 --- /dev/null +++ b/weather_dl_v2/cli/cli_config.json @@ -0,0 +1,4 @@ +{ + "pod_ip": "", + "port": 8080 +} \ No newline at end of file diff --git a/weather_dl_v2/cli/environment.yml b/weather_dl_v2/cli/environment.yml new file mode 100644 index 00000000..f2ffec62 --- /dev/null +++ b/weather_dl_v2/cli/environment.yml @@ -0,0 +1,14 @@ +name: weather-dl-v2-cli +channels: + - conda-forge +dependencies: + - python=3.10 + - pip=23.0.1 + - typer=0.9.0 + - tabulate=0.9.0 + - pip: + - requests + - ruff + - pytype + - pytest + - . diff --git a/weather_dl_v2/cli/setup.py b/weather_dl_v2/cli/setup.py new file mode 100644 index 00000000..509f42fc --- /dev/null +++ b/weather_dl_v2/cli/setup.py @@ -0,0 +1,30 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from setuptools import setup + +requirements = ["typer", "requests", "tabulate"] + +setup( + name="weather-dl-v2", + packages=["app", "app.subcommands", "app.services"], + install_requires=requirements, + version="0.0.1", + author="aniket", + description=( + "This cli tools helps in interacting with weather dl v2 fast API server." + ), + entry_points={"console_scripts": ["weather-dl-v2=app.main:app"]}, +) diff --git a/weather_dl_v2/cli/vm-startup.sh b/weather_dl_v2/cli/vm-startup.sh new file mode 100644 index 00000000..e36f6edc --- /dev/null +++ b/weather_dl_v2/cli/vm-startup.sh @@ -0,0 +1,4 @@ +#! /bin/bash + +command="docker exec -it \\\$(docker ps -qf name=weather-dl-v2-cli) /bin/bash" +sudo sh -c "echo \"$command\" >> /etc/profile" \ No newline at end of file diff --git a/weather_dl_v2/config.json b/weather_dl_v2/config.json new file mode 100644 index 00000000..f5afae8b --- /dev/null +++ b/weather_dl_v2/config.json @@ -0,0 +1,11 @@ +{ + "download_collection": "download", + "queues_collection": "queues", + "license_collection": "license", + "manifest_collection": "manifest", + "storage_bucket": "XXXXXXX", + "gcs_project": "XXXXXXX", + "license_deployment_image": "XXXXXXX", + "downloader_k8_image": "XXXXXXX", + "welcome_message": "Greetings from weather-dl v2!" +} \ No newline at end of file diff --git a/weather_dl_v2/downloader_kubernetes/Dockerfile b/weather_dl_v2/downloader_kubernetes/Dockerfile new file mode 100644 index 00000000..74084030 --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/Dockerfile @@ -0,0 +1,46 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +FROM continuumio/miniconda3:latest + +# Update miniconda +RUN conda update conda -y + +# Add the mamba solver for faster builds +RUN conda install -n base conda-libmamba-solver +RUN conda config --set solver libmamba + +# Create conda env using environment.yml +COPY . . +RUN conda env create -f environment.yml --debug + +# Activate the conda env and update the PATH +ARG CONDA_ENV_NAME=weather-dl-v2-downloader +RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc +ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH diff --git a/weather_dl_v2/downloader_kubernetes/README.md b/weather_dl_v2/downloader_kubernetes/README.md new file mode 100644 index 00000000..b0d865f8 --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/README.md @@ -0,0 +1,23 @@ +# Deployment / Usage Instruction + +### User authorization required to set up the environment: +* roles/container.admin + +### Authorization needed for the tool to operate: +We are not configuring any service account here hence make sure that compute engine default service account have roles: +* roles/storage.admin +* roles/bigquery.dataEditor +* roles/bigquery.jobUser + +### Make changes in weather_dl_v2/config.json, if required [for running locally] +``` +export CONFIG_PATH=/path/to/weather_dl_v2/config.json +``` + +### Create docker image for downloader: +``` +export PROJECT_ID= +export REPO= eg:weather-tools + +gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-downloader" --timeout=79200 --machine-type=e2-highcpu-32 +``` diff --git a/weather_dl_v2/downloader_kubernetes/downloader.py b/weather_dl_v2/downloader_kubernetes/downloader.py new file mode 100644 index 00000000..c8a5c7dc --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/downloader.py @@ -0,0 +1,73 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This program downloads ECMWF data & upload it into GCS. +""" +import tempfile +import os +import sys +from manifest import FirestoreManifest, Stage +from util import copy, download_with_aria2 +import datetime + + +def download(url: str, path: str) -> None: + """Download data from client, with retries.""" + if path: + if os.path.exists(path): + # Empty the target file, if it already exists, otherwise the + # transfer below might be fooled into thinking we're resuming + # an interrupted download. + open(path, "w").close() + download_with_aria2(url, path) + + +def main( + config_name, dataset, selection, user_id, url, target_path, license_id +) -> None: + """Download data from a client to a temp file.""" + + manifest = FirestoreManifest(license_id=license_id) + temp_name = "" + with manifest.transact(config_name, dataset, selection, target_path, user_id): + with tempfile.NamedTemporaryFile(delete=False) as temp: + temp_name = temp.name + manifest.set_stage(Stage.DOWNLOAD) + precise_download_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + manifest.prev_stage_precise_start_time = precise_download_start_time + print(f"Downloading data for {target_path!r}.") + download(url, temp_name) + print(f"Download completed for {target_path!r}.") + + manifest.set_stage(Stage.UPLOAD) + precise_upload_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + manifest.prev_stage_precise_start_time = precise_upload_start_time + print(f"Uploading to store for {target_path!r}.") + copy(temp_name, target_path) + print(f"Upload to store complete for {target_path!r}.") + os.unlink(temp_name) + + +if __name__ == "__main__": + main(*sys.argv[1:]) diff --git a/weather_dl_v2/downloader_kubernetes/downloader_config.py b/weather_dl_v2/downloader_kubernetes/downloader_config.py new file mode 100644 index 00000000..247ae664 --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/downloader_config.py @@ -0,0 +1,65 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import dataclasses +import typing as t +import json +import os +import logging + +logger = logging.getLogger(__name__) + +Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class DownloaderConfig: + manifest_collection: str = "" + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict): + config_instance = cls() + + for key, value in config.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + + return config_instance + + +downloader_config = None + + +def get_config(): + global downloader_config + if downloader_config: + return downloader_config + + downloader_config_json = "config/config.json" + if not os.path.exists(downloader_config_json): + downloader_config_json = os.environ.get("CONFIG_PATH", None) + + if downloader_config_json is None: + logger.error("Couldn't load config file for downloader.") + raise FileNotFoundError("Couldn't load config file for downloader.") + + with open(downloader_config_json) as file: + config_dict = json.load(file) + downloader_config = DownloaderConfig.from_dict(config_dict) + + return downloader_config diff --git a/weather_dl_v2/downloader_kubernetes/environment.yml b/weather_dl_v2/downloader_kubernetes/environment.yml new file mode 100644 index 00000000..79e75565 --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/environment.yml @@ -0,0 +1,17 @@ +name: weather-dl-v2-downloader +channels: + - conda-forge +dependencies: + - python=3.10 + - google-cloud-sdk=410.0.0 + - aria2=1.36.0 + - geojson=2.5.0=py_0 + - xarray=2022.11.0 + - google-apitools + - pip=22.3 + - pip: + - apache_beam[gcp]==2.40.0 + - firebase-admin + - google-cloud-pubsub + - kubernetes + - psutil diff --git a/weather_dl_v2/downloader_kubernetes/manifest.py b/weather_dl_v2/downloader_kubernetes/manifest.py new file mode 100644 index 00000000..0bc82264 --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/manifest.py @@ -0,0 +1,503 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Client interface for connecting to a manifest.""" + +import abc +import dataclasses +import datetime +import enum +import json +import pandas as pd +import time +import traceback +import typing as t + +from util import ( + to_json_serializable_type, + fetch_geo_polygon, + get_file_size, + get_wait_interval, + generate_md5_hash, + GLOBAL_COVERAGE_AREA, +) + +import firebase_admin +from firebase_admin import credentials +from firebase_admin import firestore +from google.cloud.firestore_v1 import DocumentReference +from google.cloud.firestore_v1.types import WriteResult +from downloader_config import get_config + +"""An implementation-dependent Manifest URI.""" +Location = t.NewType("Location", str) + + +class ManifestException(Exception): + """Errors that occur in Manifest Clients.""" + + pass + + +class Stage(enum.Enum): + """A request can be either in one of the following stages at a time: + + fetch : This represents request is currently in fetch stage i.e. request placed on the client's server + & waiting for some result before starting download (eg. MARS client). + download : This represents request is currently in download stage i.e. data is being downloading from client's + server to the worker's local file system. + upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local + file system to target location (GCS path). + retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), + request will be in the retrieve stage i.e. fetch + download. + """ + + RETRIEVE = "retrieve" + FETCH = "fetch" + DOWNLOAD = "download" + UPLOAD = "upload" + + +class Status(enum.Enum): + """Depicts the request's state status: + + scheduled : A request partition is created & scheduled for processing. + Note: Its corresponding state can be None only. + in-progress : This represents the request state is currently in-progress (i.e. running). + The next status would be "success" or "failure". + success : This represents the request state execution completed successfully without any error. + failure : This represents the request state execution failed. + """ + + SCHEDULED = "scheduled" + IN_PROGRESS = "in-progress" + SUCCESS = "success" + FAILURE = "failure" + + +@dataclasses.dataclass +class DownloadStatus: + """Data recorded in `Manifest`s reflecting the status of a download.""" + + """The name of the config file associated with the request.""" + config_name: str = "" + + """Represents the dataset field of the configuration.""" + dataset: t.Optional[str] = "" + + """Copy of selection section of the configuration.""" + selection: t.Dict = dataclasses.field(default_factory=dict) + + """Location of the downloaded data.""" + location: str = "" + + """Represents area covered by the shard.""" + area: str = "" + + """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" + stage: t.Optional[Stage] = None + + """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" + status: t.Optional[Status] = None + + """Cause of error, if any.""" + error: t.Optional[str] = "" + + """Identifier for the user running the download.""" + username: str = "" + + """Shard size in GB.""" + size: t.Optional[float] = 0 + + """A UTC datetime when download was scheduled.""" + scheduled_time: t.Optional[str] = "" + + """A UTC datetime when the retrieve stage starts.""" + retrieve_start_time: t.Optional[str] = "" + + """A UTC datetime when the retrieve state ends.""" + retrieve_end_time: t.Optional[str] = "" + + """A UTC datetime when the fetch state starts.""" + fetch_start_time: t.Optional[str] = "" + + """A UTC datetime when the fetch state ends.""" + fetch_end_time: t.Optional[str] = "" + + """A UTC datetime when the download state starts.""" + download_start_time: t.Optional[str] = "" + + """A UTC datetime when the download state ends.""" + download_end_time: t.Optional[str] = "" + + """A UTC datetime when the upload state starts.""" + upload_start_time: t.Optional[str] = "" + + """A UTC datetime when the upload state ends.""" + upload_end_time: t.Optional[str] = "" + + @classmethod + def from_dict(cls, download_status: t.Dict) -> "DownloadStatus": + """Instantiate DownloadStatus dataclass from dict.""" + download_status_instance = cls() + for key, value in download_status.items(): + if key == "status": + setattr(download_status_instance, key, Status(value)) + elif key == "stage" and value is not None: + setattr(download_status_instance, key, Stage(value)) + else: + setattr(download_status_instance, key, value) + return download_status_instance + + @classmethod + def to_dict(cls, instance) -> t.Dict: + """Return the fields of a dataclass instance as a manifest ingestible + dictionary mapping of field names to field values.""" + download_status_dict = {} + for field in dataclasses.fields(instance): + key = field.name + value = getattr(instance, field.name) + if isinstance(value, Status) or isinstance(value, Stage): + download_status_dict[key] = value.value + elif isinstance(value, pd.Timestamp): + download_status_dict[key] = value.isoformat() + elif key == "selection" and value is not None: + download_status_dict[key] = json.dumps(value) + else: + download_status_dict[key] = value + return download_status_dict + + +@dataclasses.dataclass +class Manifest(abc.ABC): + """Abstract manifest of download statuses. + + Update download statuses to some storage medium. + + This class lets one indicate that a download is `scheduled` or in a transaction process. + In the event of a transaction, a download will be updated with an `in-progress`, `success` + or `failure` status (with accompanying metadata). + + Example: + ``` + my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) + + # Schedule data for download + my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') + + # ... + + # Initiate a transaction – it will record that the download is `in-progess` + with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: + # download logic here + pass + + # ... + + # on error, will record the download as a `failure` before propagating the error. By default, it will + # record download as a `success`. + ``` + + Attributes: + status: The current `DownloadStatus` of the Manifest. + """ + + # To reduce the impact of _read() and _update() calls + # on the start time of the stage. + license_id: str = "" + prev_stage_precise_start_time: t.Optional[str] = None + status: t.Optional[DownloadStatus] = None + + # This is overridden in subclass. + def __post_init__(self): + """Initialize the manifest.""" + pass + + def schedule( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Indicate that a job has been scheduled for download. + + 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. + """ + scheduled_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + self.status = DownloadStatus( + config_name=config_name, + dataset=dataset if dataset else None, + selection=selection, + location=location, + area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), + username=user, + stage=None, + status=Status.SCHEDULED, + error=None, + size=None, + scheduled_time=scheduled_time, + retrieve_start_time=None, + retrieve_end_time=None, + fetch_start_time=None, + fetch_end_time=None, + download_start_time=None, + download_end_time=None, + upload_start_time=None, + upload_end_time=None, + ) + self._update(self.status) + + def skip( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Updates the manifest to mark the shards that were skipped in the current job + as 'upload' stage and 'success' status, indicating that they have already been downloaded. + """ + old_status = self._read(location) + # The manifest needs to be updated for a skipped shard if its entry is not present, or + # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. + if ( + old_status.location != location + or old_status.stage != Stage.UPLOAD + or old_status.status != Status.SUCCESS + ): + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + size = get_file_size(location) + + status = DownloadStatus( + config_name=config_name, + dataset=dataset if dataset else None, + selection=selection, + location=location, + area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), + username=user, + stage=Stage.UPLOAD, + status=Status.SUCCESS, + error=None, + size=size, + scheduled_time=None, + retrieve_start_time=None, + retrieve_end_time=None, + fetch_start_time=None, + fetch_end_time=None, + download_start_time=None, + download_end_time=None, + upload_start_time=current_utc_time, + upload_end_time=current_utc_time, + ) + self._update(status) + print( + f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}." + ) + + def _set_for_transaction( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Reset Manifest state in preparation for a new transaction.""" + self.status = dataclasses.replace(self._read(location)) + self.status.config_name = config_name + self.status.dataset = dataset if dataset else None + self.status.selection = selection + self.status.location = location + self.status.username = user + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type, exc_inst, exc_tb) -> None: + """Record end status of a transaction as either 'success' or 'failure'.""" + if exc_type is None: + status = Status.SUCCESS + error = None + else: + status = Status.FAILURE + # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception + error = f"license_id: {self.license_id} " + error += "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb)) + + new_status = dataclasses.replace(self.status) + new_status.error = error + new_status.status = status + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + # This is necessary for setting the precise start time of the previous stage + # and end time of the final stage, as well as handling the case of Status.FAILURE. + if new_status.stage == Stage.FETCH: + new_status.fetch_start_time = self.prev_stage_precise_start_time + new_status.fetch_end_time = current_utc_time + elif new_status.stage == Stage.RETRIEVE: + new_status.retrieve_start_time = self.prev_stage_precise_start_time + new_status.retrieve_end_time = current_utc_time + elif new_status.stage == Stage.DOWNLOAD: + new_status.download_start_time = self.prev_stage_precise_start_time + new_status.download_end_time = current_utc_time + else: + new_status.upload_start_time = self.prev_stage_precise_start_time + new_status.upload_end_time = current_utc_time + + new_status.size = get_file_size(new_status.location) + + self.status = new_status + + self._update(self.status) + + def transact( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> "Manifest": + """Create a download transaction.""" + self._set_for_transaction(config_name, dataset, selection, location, user) + return self + + def set_stage(self, stage: Stage) -> None: + """Sets the current stage in manifest.""" + new_status = dataclasses.replace(self.status) + new_status.stage = stage + new_status.status = Status.IN_PROGRESS + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + if stage == Stage.DOWNLOAD: + new_status.download_start_time = current_utc_time + else: + new_status.download_start_time = self.prev_stage_precise_start_time + new_status.download_end_time = current_utc_time + new_status.upload_start_time = current_utc_time + + self.status = new_status + self._update(self.status) + + @abc.abstractmethod + def _read(self, location: str) -> DownloadStatus: + pass + + @abc.abstractmethod + def _update(self, download_status: DownloadStatus) -> None: + pass + + +class FirestoreManifest(Manifest): + """A Firestore Manifest. + This Manifest implementation stores DownloadStatuses in a Firebase document store. + The document hierarchy for the manifest is as follows: + [manifest ] + ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } + └── etc... + Where `[]` indicates a collection and ` {...}` indicates a document. + """ + + def _get_db(self) -> firestore.firestore.Client: + """Acquire a firestore client, initializing the firebase app if necessary. + Will attempt to get the db client five times. If it's still unsuccessful, a + `ManifestException` will be raised. + """ + db = None + attempts = 0 + + while db is None: + try: + db = firestore.client() + except ValueError as e: + # The above call will fail with a value error when the firebase app is not initialized. + # Initialize the app here, and try again. + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + print("Initialized Firebase App.") + + if attempts > 4: + raise ManifestException( + "Exceeded number of retries to get firestore client." + ) from e + + time.sleep(get_wait_interval(attempts)) + + attempts += 1 + + return db + + def _read(self, location: str) -> DownloadStatus: + """Reads the JSON data from a manifest.""" + + doc_id = generate_md5_hash(location) + + # Update document with download status + download_doc_ref = self.root_document_for_store(doc_id) + + result = download_doc_ref.get() + row = {} + if result.exists: + records = result.to_dict() + row = {n: to_json_serializable_type(v) for n, v in records.items()} + return DownloadStatus.from_dict(row) + + def _update(self, download_status: DownloadStatus) -> None: + """Update or create a download status record.""" + print("Updating Firestore Manifest.") + + status = DownloadStatus.to_dict(download_status) + doc_id = generate_md5_hash(status["location"]) + + # Update document with download status + download_doc_ref = self.root_document_for_store(doc_id) + + result: WriteResult = download_doc_ref.set(status) + + print( + f"Firestore manifest updated. " + f"update_time={result.update_time}, " + f"filename={download_status.location}." + ) + + def root_document_for_store(self, store_scheme: str) -> DocumentReference: + """Get the root manifest document given the user's config and current document's storage location.""" + return ( + self._get_db() + .collection(get_config().manifest_collection) + .document(store_scheme) + ) diff --git a/weather_dl_v2/downloader_kubernetes/util.py b/weather_dl_v2/downloader_kubernetes/util.py new file mode 100644 index 00000000..5777234f --- /dev/null +++ b/weather_dl_v2/downloader_kubernetes/util.py @@ -0,0 +1,226 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime +import geojson +import hashlib +import itertools +import os +import socket +import subprocess +import sys +import typing as t + +import numpy as np +import pandas as pd +from apache_beam.io.gcp import gcsio +from apache_beam.utils import retry +from xarray.core.utils import ensure_us_time_resolution +from urllib.parse import urlparse +from google.api_core.exceptions import BadRequest + + +LATITUDE_RANGE = (-90, 90) +LONGITUDE_RANGE = (-180, 180) +GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] + + +def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter( + exception, +) -> bool: + if isinstance(exception, socket.timeout): + return True + if isinstance(exception, TimeoutError): + return True + # To handle the concurrency issue in BigQuery. + if isinstance(exception, BadRequest): + return True + return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) + + +class _FakeClock: + + def sleep(self, value): + pass + + +def retry_with_exponential_backoff(fun): + """A retry decorator that doesn't apply during test time.""" + clock = retry.Clock() + + # Use a fake clock only during test time... + if "unittest" in sys.modules.keys(): + clock = _FakeClock() + + return retry.with_exponential_backoff( + retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, + clock=clock, + )(fun) + + +# TODO(#245): Group with common utilities (duplicated) +def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: + """Yield evenly-sized chunks from an iterable.""" + input_ = iter(iterable) + try: + while True: + it = itertools.islice(input_, n) + # peek to check if 'it' has next item. + first = next(it) + yield itertools.chain([first], it) + except StopIteration: + pass + + +# TODO(#245): Group with common utilities (duplicated) +def copy(src: str, dst: str) -> None: + """Copy data via `gsutil cp`.""" + try: + subprocess.run(["gsutil", "cp", src, dst], check=True, capture_output=True) + except subprocess.CalledProcessError as e: + print( + f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}' + ) + raise + + +# TODO(#245): Group with common utilities (duplicated) +def to_json_serializable_type(value: t.Any) -> t.Any: + """Returns the value with a type serializable to JSON""" + # Note: The order of processing is significant. + print("Serializing to JSON") + + if pd.isna(value) or value is None: + return None + elif np.issubdtype(type(value), np.floating): + return float(value) + elif isinstance(value, np.ndarray): + # Will return a scaler if array is of size 1, else will return a list. + return value.tolist() + elif ( + isinstance(value, datetime.datetime) + or isinstance(value, str) + or isinstance(value, np.datetime64) + ): + # Assume strings are ISO format timestamps... + try: + value = datetime.datetime.fromisoformat(value) + except ValueError: + # ... if they are not, assume serialization is already correct. + return value + except TypeError: + # ... maybe value is a numpy datetime ... + try: + value = ensure_us_time_resolution(value).astype(datetime.datetime) + except AttributeError: + # ... value is a datetime object, continue. + pass + + # We use a string timestamp representation. + if value.tzname(): + return value.isoformat() + + # We assume here that naive timestamps are in UTC timezone. + return value.replace(tzinfo=datetime.timezone.utc).isoformat() + elif isinstance(value, np.timedelta64): + # Return time delta in seconds. + return float(value / np.timedelta64(1, "s")) + # This check must happen after processing np.timedelta64 and np.datetime64. + elif np.issubdtype(type(value), np.integer): + return int(value) + + return value + + +def fetch_geo_polygon(area: t.Union[list, str]) -> str: + """Calculates a geography polygon from an input area.""" + # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 + if isinstance(area, str): + # European area + if area == "E": + area = [73.5, -27, 33, 45] + # Global area + elif area == "G": + area = GLOBAL_COVERAGE_AREA + else: + raise RuntimeError(f"Not a valid value for area in config: {area}.") + + n, w, s, e = [float(x) for x in area] + if s < LATITUDE_RANGE[0]: + raise ValueError(f"Invalid latitude value for south: '{s}'") + if n > LATITUDE_RANGE[1]: + raise ValueError(f"Invalid latitude value for north: '{n}'") + if w < LONGITUDE_RANGE[0]: + raise ValueError(f"Invalid longitude value for west: '{w}'") + if e > LONGITUDE_RANGE[1]: + raise ValueError(f"Invalid longitude value for east: '{e}'") + + # Define the coordinates of the bounding box. + coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] + + # Create the GeoJSON polygon object. + polygon = geojson.dumps(geojson.Polygon([coords])) + return polygon + + +def get_file_size(path: str) -> float: + parsed_gcs_path = urlparse(path) + if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "": + return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0 + else: + return ( + gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0 + ) + + +def get_wait_interval(num_retries: int = 0) -> float: + """Returns next wait interval in seconds, using an exponential backoff algorithm.""" + if 0 == num_retries: + return 0 + return 2**num_retries + + +def generate_md5_hash(input: str) -> str: + """Generates md5 hash for the input string.""" + return hashlib.md5(input.encode("utf-8")).hexdigest() + + +def download_with_aria2(url: str, path: str) -> None: + """Downloads a file from the given URL using the `aria2c` command-line utility, + with options set to improve download speed and reliability.""" + dir_path, file_name = os.path.split(path) + try: + subprocess.run( + [ + "aria2c", + "-x", + "16", + "-s", + "16", + url, + "-d", + dir_path, + "-o", + file_name, + "--allow-overwrite", + ], + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + print( + f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}' + ) + raise diff --git a/weather_dl_v2/fastapi-server/API-Interactions.md b/weather_dl_v2/fastapi-server/API-Interactions.md new file mode 100644 index 00000000..3ea4eece --- /dev/null +++ b/weather_dl_v2/fastapi-server/API-Interactions.md @@ -0,0 +1,25 @@ +# API Interactions +| Command | Type | Endpoint | +|---|---|---| +| `weather-dl-v2 ping` | `get` | `/` +| Download | | | +| `weather-dl-v2 download add –l [--force-download]` | `post` | `/download?force_download={value}` | +| `weather-dl-v2 download list` | `get` | `/download/` | +| `weather-dl-v2 download list --filter client_name=` | `get` | `/download?client_name={name}` | +| `weather-dl-v2 download get ` | `get` | `/download/{config_name}` | +| `weather-dl-v2 download show ` | `get` | `/download/show/{config_name}` | +| `weather-dl-v2 download remove ` | `delete` | `/download/{config_name}` | +| `weather-dl-v2 download refetch -l ` | `post` | `/download/refetch/{config_name}` | +| License | | | +| `weather-dl-v2 license add ` | `post` | `/license/` | +| `weather-dl-v2 license get ` | `get` | `/license/{license_id}` | +| `weather-dl-v2 license remove ` | `delete` | `/license/{license_id}` | +| `weather-dl-v2 license list` | `get` | `/license/` | +| `weather-dl-v2 license list --filter client_name=` | `get` | `/license?client_name={name}` | +| `weather-dl-v2 license edit ` | `put` | `/license/{license_id}` | +| Queue | | | +| `weather-dl-v2 queue list` | `get` | `/queues/` | +| `weather-dl-v2 queue list --filter client_name=` | `get` | `/queues?client_name={name}` | +| `weather-dl-v2 queue get ` | `get` | `/queues/{license_id}` | +| `queue edit --config --priority ` | `post` | `/queues/{license_id}` | +| `queue edit --file ` | `put` | `/queues/priority/{license_id}` | \ No newline at end of file diff --git a/weather_dl_v2/fastapi-server/Dockerfile b/weather_dl_v2/fastapi-server/Dockerfile new file mode 100644 index 00000000..b54e41c0 --- /dev/null +++ b/weather_dl_v2/fastapi-server/Dockerfile @@ -0,0 +1,40 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +FROM continuumio/miniconda3:latest + +EXPOSE 8080 + +# Update miniconda +RUN conda update conda -y + +# Add the mamba solver for faster builds +RUN conda install -n base conda-libmamba-solver +RUN conda config --set solver libmamba + +COPY . . +# Create conda env using environment.yml +RUN conda env create -f environment.yml --debug + +# Activate the conda env and update the PATH +ARG CONDA_ENV_NAME=weather-dl-v2-server +RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc +ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH + +# Use the ping endpoint as a healthcheck, +# so Docker knows if the API is still running ok or needs to be restarted +HEALTHCHECK --interval=21s --timeout=3s --start-period=10s CMD curl --fail http://localhost:8080/ping || exit 1 + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"] diff --git a/weather_dl_v2/fastapi-server/README.md b/weather_dl_v2/fastapi-server/README.md new file mode 100644 index 00000000..2debb563 --- /dev/null +++ b/weather_dl_v2/fastapi-server/README.md @@ -0,0 +1,91 @@ +# Deployment Instructions & General Notes + +### User authorization required to set up the environment: +* roles/container.admin + +### Authorization needed for the tool to operate: +We are not configuring any service account here hence make sure that compute engine default service account have roles: +* roles/pubsub.subscriber +* roles/storage.admin +* roles/bigquery.dataEditor +* roles/bigquery.jobUser + +### Install kubectl: +``` +apt-get update + +apt-get install -y kubectl +``` + +### Create cluster: +``` +export PROJECT_ID= +export REGION= eg: us-west1 +export ZONE= eg: us-west1-a +export CLUSTER_NAME= eg: weather-dl-v2-cluster +export DOWNLOAD_NODE_POOL=downloader-pool + +gcloud beta container --project $PROJECT_ID clusters create $CLUSTER_NAME --zone $ZONE --no-enable-basic-auth --cluster-version "1.27.2-gke.1200" --release-channel "regular" --machine-type "e2-standard-8" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "1100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/cloud-platform" --max-pods-per-node "16" --num-nodes "4" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM --enable-ip-alias --network "projects/$PROJECT_ID/global/networks/default" --subnetwork "projects/$PROJECT_ID/regions/$REGION/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "16" --enable-autoscaling --min-nodes "4" --max-nodes "100" --location-policy "BALANCED" --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --enable-managed-prometheus --enable-shielded-nodes --node-locations $ZONE --node-labels preemptible=false && gcloud beta container --project $PROJECT_ID node-pools create $DOWNLOAD_NODE_POOL --cluster $CLUSTER_NAME --zone $ZONE --machine-type "e2-standard-8" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "1100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/cloud-platform" --max-pods-per-node "16" --num-nodes "1" --enable-autoscaling --min-nodes "1" --max-nodes "100" --location-policy "BALANCED" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations $ZONE --node-labels preemptible=false +``` + +### Connect to Cluster: +``` +gcloud container clusters get-credentials $CLUSTER_NAME --zone $ZONE --project $PROJECT_ID +``` + +### How to create environment: +``` +conda env create --name weather-dl-v2-server --file=environment.yml + +conda activate weather-dl-v2-server +``` + +### Make changes in weather_dl_v2/config.json, if required [for running locally] +``` +export CONFIG_PATH=/path/to/weather_dl_v2/config.json +``` + +### To run fastapi server: +``` +uvicorn main:app --reload +``` + +* Open your browser at http://127.0.0.1:8000. + + +### Create docker image for server: +``` +export PROJECT_ID= +export REPO= eg:weather-tools + +gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-server" --timeout=79200 --machine-type=e2-highcpu-32 +``` + +### Add path of created server image in server.yaml: +``` +Please write down the fastAPI server's docker image path at Line 42 of server.yaml. +``` + +### Create ConfigMap of common configurations for services: +Make necessary changes to weather_dl_v2/config.json and run following command. +ConfigMap is used for: +- Having a common configuration file for all services. +- Decoupling docker image and config files. +``` +kubectl create configmap dl-v2-config --from-file=/path/to/weather_dl_v2/config.json +``` + +### Deploy fastapi server on kubernetes: +``` +kubectl apply -f server.yaml --force +``` + +## General Commands +### For viewing the current pods: +``` +kubectl get pods +``` + +### For deleting existing deployment: +``` +kubectl delete -f server.yaml --force \ No newline at end of file diff --git a/weather_dl_v2/fastapi-server/__init__.py b/weather_dl_v2/fastapi-server/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/fastapi-server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/fastapi-server/config_processing/config.py b/weather_dl_v2/fastapi-server/config_processing/config.py new file mode 100644 index 00000000..fe2199b8 --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/config.py @@ -0,0 +1,120 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import calendar +import copy +import dataclasses +import typing as t + +Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class Config: + """Contains pipeline parameters. + + Attributes: + config_name: + Name of the config file. + client: + Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable. + dataset (optional): + Name of the target dataset. Allowed options are dictated by the client. + partition_keys (optional): + Choose the keys from the selection section to partition the data request. + This will compute a cartesian cross product of the selected keys + and assign each as their own download. + target_path: + Download artifact filename template. Can make use of Python's standard string formatting. + It can contain format symbols to be replaced by partition keys; + if this is used, the total number of format symbols must match the number of partition keys. + subsection_name: + Name of the particular subsection. 'default' if there is no subsection. + force_download: + Force redownload of partitions that were previously downloaded. + user_id: + Username from the environment variables. + kwargs (optional): + For representing subsections or any other parameters. + selection: + Contains parameters used to select desired data. + """ + + config_name: str = "" + client: str = "" + dataset: t.Optional[str] = "" + target_path: str = "" + partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list) + subsection_name: str = "default" + force_download: bool = False + user_id: str = "unknown" + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict) -> "Config": + config_instance = cls() + for section_key, section_value in config.items(): + if section_key == "parameters": + for key, value in section_value.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + if section_key == "selection": + config_instance.selection = section_value + return config_instance + + +def optimize_selection_partition(selection: t.Dict) -> t.Dict: + """Compute right-hand-side values for the selection section of a single partition. + + Used to support custom syntax and optimizations, such as 'all'. + """ + selection_ = copy.deepcopy(selection) + + if "day" in selection_.keys() and selection_["day"] == "all": + year, month = selection_["year"], selection_["month"] + + multiples_error = ( + "Cannot use keyword 'all' on selections with multiple '{type}'s." + ) + + if isinstance(year, list): + assert len(year) == 1, multiples_error.format(type="year") + year = year[0] + + if isinstance(month, list): + assert len(month) == 1, multiples_error.format(type="month") + month = month[0] + + if isinstance(year, str): + assert "/" not in year, multiples_error.format(type="year") + + if isinstance(month, str): + assert "/" not in month, multiples_error.format(type="month") + + year, month = int(year), int(month) + + _, n_days_in_month = calendar.monthrange(year, month) + + selection_[ + "date" + ] = f"{year:04d}-{month:02d}-01/to/{year:04d}-{month:02d}-{n_days_in_month:02d}" + del selection_["day"] + del selection_["month"] + del selection_["year"] + + return selection_ diff --git a/weather_dl_v2/fastapi-server/config_processing/manifest.py b/weather_dl_v2/fastapi-server/config_processing/manifest.py new file mode 100644 index 00000000..35a8bf7b --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/manifest.py @@ -0,0 +1,513 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Client interface for connecting to a manifest.""" + +import abc +import dataclasses +import logging +import datetime +import enum +import json +import pandas as pd +import time +import traceback +import typing as t + +from .util import ( + to_json_serializable_type, + fetch_geo_polygon, + get_file_size, + get_wait_interval, + generate_md5_hash, + GLOBAL_COVERAGE_AREA, +) + +import firebase_admin +from firebase_admin import credentials +from firebase_admin import firestore +from google.cloud.firestore_v1 import DocumentReference +from google.cloud.firestore_v1.types import WriteResult +from server_config import get_config +from database.session import Database + +"""An implementation-dependent Manifest URI.""" +Location = t.NewType("Location", str) + +logger = logging.getLogger(__name__) + + +class ManifestException(Exception): + """Errors that occur in Manifest Clients.""" + + pass + + +class Stage(enum.Enum): + """A request can be either in one of the following stages at a time: + + fetch : This represents request is currently in fetch stage i.e. request placed on the client's server + & waiting for some result before starting download (eg. MARS client). + download : This represents request is currently in download stage i.e. data is being downloading from client's + server to the worker's local file system. + upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local + file system to target location (GCS path). + retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), + request will be in the retrieve stage i.e. fetch + download. + """ + + RETRIEVE = "retrieve" + FETCH = "fetch" + DOWNLOAD = "download" + UPLOAD = "upload" + + +class Status(enum.Enum): + """Depicts the request's state status: + + scheduled : A request partition is created & scheduled for processing. + Note: Its corresponding state can be None only. + in-progress : This represents the request state is currently in-progress (i.e. running). + The next status would be "success" or "failure". + success : This represents the request state execution completed successfully without any error. + failure : This represents the request state execution failed. + """ + + SCHEDULED = "scheduled" + IN_PROGRESS = "in-progress" + SUCCESS = "success" + FAILURE = "failure" + + +@dataclasses.dataclass +class DownloadStatus: + """Data recorded in `Manifest`s reflecting the status of a download.""" + + """The name of the config file associated with the request.""" + config_name: str = "" + + """Represents the dataset field of the configuration.""" + dataset: t.Optional[str] = "" + + """Copy of selection section of the configuration.""" + selection: t.Dict = dataclasses.field(default_factory=dict) + + """Location of the downloaded data.""" + location: str = "" + + """Represents area covered by the shard.""" + area: str = "" + + """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" + stage: t.Optional[Stage] = None + + """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" + status: t.Optional[Status] = None + + """Cause of error, if any.""" + error: t.Optional[str] = "" + + """Identifier for the user running the download.""" + username: str = "" + + """Shard size in GB.""" + size: t.Optional[float] = 0 + + """A UTC datetime when download was scheduled.""" + scheduled_time: t.Optional[str] = "" + + """A UTC datetime when the retrieve stage starts.""" + retrieve_start_time: t.Optional[str] = "" + + """A UTC datetime when the retrieve state ends.""" + retrieve_end_time: t.Optional[str] = "" + + """A UTC datetime when the fetch state starts.""" + fetch_start_time: t.Optional[str] = "" + + """A UTC datetime when the fetch state ends.""" + fetch_end_time: t.Optional[str] = "" + + """A UTC datetime when the download state starts.""" + download_start_time: t.Optional[str] = "" + + """A UTC datetime when the download state ends.""" + download_end_time: t.Optional[str] = "" + + """A UTC datetime when the upload state starts.""" + upload_start_time: t.Optional[str] = "" + + """A UTC datetime when the upload state ends.""" + upload_end_time: t.Optional[str] = "" + + @classmethod + def from_dict(cls, download_status: t.Dict) -> "DownloadStatus": + """Instantiate DownloadStatus dataclass from dict.""" + download_status_instance = cls() + for key, value in download_status.items(): + if key == "status": + setattr(download_status_instance, key, Status(value)) + elif key == "stage" and value is not None: + setattr(download_status_instance, key, Stage(value)) + else: + setattr(download_status_instance, key, value) + return download_status_instance + + @classmethod + def to_dict(cls, instance) -> t.Dict: + """Return the fields of a dataclass instance as a manifest ingestible + dictionary mapping of field names to field values.""" + download_status_dict = {} + for field in dataclasses.fields(instance): + key = field.name + value = getattr(instance, field.name) + if isinstance(value, Status) or isinstance(value, Stage): + download_status_dict[key] = value.value + elif isinstance(value, pd.Timestamp): + download_status_dict[key] = value.isoformat() + elif key == "selection" and value is not None: + download_status_dict[key] = json.dumps(value) + else: + download_status_dict[key] = value + return download_status_dict + + +@dataclasses.dataclass +class Manifest(abc.ABC): + """Abstract manifest of download statuses. + + Update download statuses to some storage medium. + + This class lets one indicate that a download is `scheduled` or in a transaction process. + In the event of a transaction, a download will be updated with an `in-progress`, `success` + or `failure` status (with accompanying metadata). + + Example: + ``` + my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) + + # Schedule data for download + my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') + + # ... + + # Initiate a transaction – it will record that the download is `in-progess` + with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: + # download logic here + pass + + # ... + + # on error, will record the download as a `failure` before propagating the error. By default, it will + # record download as a `success`. + ``` + + Attributes: + status: The current `DownloadStatus` of the Manifest. + """ + + # To reduce the impact of _read() and _update() calls + # on the start time of the stage. + prev_stage_precise_start_time: t.Optional[str] = None + status: t.Optional[DownloadStatus] = None + + # This is overridden in subclass. + def __post_init__(self): + """Initialize the manifest.""" + pass + + def schedule( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Indicate that a job has been scheduled for download. + + 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. + """ + scheduled_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + self.status = DownloadStatus( + config_name=config_name, + dataset=dataset if dataset else None, + selection=selection, + location=location, + area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), + username=user, + stage=None, + status=Status.SCHEDULED, + error=None, + size=None, + scheduled_time=scheduled_time, + retrieve_start_time=None, + retrieve_end_time=None, + fetch_start_time=None, + fetch_end_time=None, + download_start_time=None, + download_end_time=None, + upload_start_time=None, + upload_end_time=None, + ) + self._update(self.status) + + def skip( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Updates the manifest to mark the shards that were skipped in the current job + as 'upload' stage and 'success' status, indicating that they have already been downloaded. + """ + old_status = self._read(location) + # The manifest needs to be updated for a skipped shard if its entry is not present, or + # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. + if ( + old_status.location != location + or old_status.stage != Stage.UPLOAD + or old_status.status != Status.SUCCESS + ): + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + size = get_file_size(location) + + status = DownloadStatus( + config_name=config_name, + dataset=dataset if dataset else None, + selection=selection, + location=location, + area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), + username=user, + stage=Stage.UPLOAD, + status=Status.SUCCESS, + error=None, + size=size, + scheduled_time=None, + retrieve_start_time=None, + retrieve_end_time=None, + fetch_start_time=None, + fetch_end_time=None, + download_start_time=None, + download_end_time=None, + upload_start_time=current_utc_time, + upload_end_time=current_utc_time, + ) + self._update(status) + logger.info( + f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}." + ) + + def _set_for_transaction( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Reset Manifest state in preparation for a new transaction.""" + self.status = dataclasses.replace(self._read(location)) + self.status.config_name = config_name + self.status.dataset = dataset if dataset else None + self.status.selection = selection + self.status.location = location + self.status.username = user + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type, exc_inst, exc_tb) -> None: + """Record end status of a transaction as either 'success' or 'failure'.""" + if exc_type is None: + status = Status.SUCCESS + error = None + else: + status = Status.FAILURE + # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception + error = "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb)) + + new_status = dataclasses.replace(self.status) + new_status.error = error + new_status.status = status + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + # This is necessary for setting the precise start time of the previous stage + # and end time of the final stage, as well as handling the case of Status.FAILURE. + if new_status.stage == Stage.FETCH: + new_status.fetch_start_time = self.prev_stage_precise_start_time + new_status.fetch_end_time = current_utc_time + elif new_status.stage == Stage.RETRIEVE: + new_status.retrieve_start_time = self.prev_stage_precise_start_time + new_status.retrieve_end_time = current_utc_time + elif new_status.stage == Stage.DOWNLOAD: + new_status.download_start_time = self.prev_stage_precise_start_time + new_status.download_end_time = current_utc_time + else: + new_status.upload_start_time = self.prev_stage_precise_start_time + new_status.upload_end_time = current_utc_time + + new_status.size = get_file_size(new_status.location) + + self.status = new_status + + self._update(self.status) + + def transact( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> "Manifest": + """Create a download transaction.""" + self._set_for_transaction(config_name, dataset, selection, location, user) + return self + + def set_stage(self, stage: Stage) -> None: + """Sets the current stage in manifest.""" + prev_stage = self.status.stage + new_status = dataclasses.replace(self.status) + new_status.stage = stage + new_status.status = Status.IN_PROGRESS + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + if stage == Stage.FETCH: + new_status.fetch_start_time = current_utc_time + elif stage == Stage.RETRIEVE: + new_status.retrieve_start_time = current_utc_time + elif stage == Stage.DOWNLOAD: + new_status.fetch_start_time = self.prev_stage_precise_start_time + new_status.fetch_end_time = current_utc_time + new_status.download_start_time = current_utc_time + else: + if prev_stage == Stage.DOWNLOAD: + new_status.download_start_time = self.prev_stage_precise_start_time + new_status.download_end_time = current_utc_time + else: + new_status.retrieve_start_time = self.prev_stage_precise_start_time + new_status.retrieve_end_time = current_utc_time + new_status.upload_start_time = current_utc_time + + self.status = new_status + self._update(self.status) + + @abc.abstractmethod + def _read(self, location: str) -> DownloadStatus: + pass + + @abc.abstractmethod + def _update(self, download_status: DownloadStatus) -> None: + pass + + +class FirestoreManifest(Manifest, Database): + """A Firestore Manifest. + This Manifest implementation stores DownloadStatuses in a Firebase document store. + The document hierarchy for the manifest is as follows: + [manifest ] + ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } + └── etc... + Where `[]` indicates a collection and ` {...}` indicates a document. + """ + + def _get_db(self) -> firestore.firestore.Client: + """Acquire a firestore client, initializing the firebase app if necessary. + Will attempt to get the db client five times. If it's still unsuccessful, a + `ManifestException` will be raised. + """ + db = None + attempts = 0 + + while db is None: + try: + db = firestore.client() + except ValueError as e: + # The above call will fail with a value error when the firebase app is not initialized. + # Initialize the app here, and try again. + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + logger.info("Initialized Firebase App.") + + if attempts > 4: + raise ManifestException( + "Exceeded number of retries to get firestore client." + ) from e + + time.sleep(get_wait_interval(attempts)) + + attempts += 1 + + return db + + def _read(self, location: str) -> DownloadStatus: + """Reads the JSON data from a manifest.""" + + doc_id = generate_md5_hash(location) + + # Update document with download status + download_doc_ref = self.root_document_for_store(doc_id) + + result = download_doc_ref.get() + row = {} + if result.exists: + records = result.to_dict() + row = {n: to_json_serializable_type(v) for n, v in records.items()} + return DownloadStatus.from_dict(row) + + def _update(self, download_status: DownloadStatus) -> None: + """Update or create a download status record.""" + logger.info("Updating Firestore Manifest.") + + status = DownloadStatus.to_dict(download_status) + doc_id = generate_md5_hash(status["location"]) + + # Update document with download status + download_doc_ref = self.root_document_for_store(doc_id) + + result: WriteResult = download_doc_ref.set(status) + + logger.info( + f"Firestore manifest updated. " + f"update_time={result.update_time}, " + f"filename={download_status.location}." + ) + + def root_document_for_store(self, store_scheme: str) -> DocumentReference: + """Get the root manifest document given the user's config and current document's storage location.""" + root_collection = get_config().manifest_collection + return self._get_db().collection(root_collection).document(store_scheme) diff --git a/weather_dl_v2/fastapi-server/config_processing/parsers.py b/weather_dl_v2/fastapi-server/config_processing/parsers.py new file mode 100644 index 00000000..5f9e1f5c --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/parsers.py @@ -0,0 +1,507 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Parsers for ECMWF download configuration.""" + +import ast +import configparser +import copy as cp +import datetime +import json +import string +import textwrap +import typing as t +import numpy as np +from collections import OrderedDict +from .config import Config + +CLIENTS = ["cds", "mars", "ecpublic"] + + +def date(candidate: str) -> datetime.date: + """Converts ECMWF-format date strings into a `datetime.date`. + + Accepted absolute date formats: + - YYYY-MM-DD + - YYYYMMDD + - YYYY-DDD, where DDD refers to the day of the year + + For example: + - 2021-10-31 + - 19700101 + - 1950-007 + + See https://confluence.ecmwf.int/pages/viewpage.action?pageId=118817289 for date format spec. + Note: Name of month is not supported. + """ + converted = None + + # Parse relative day value. + if candidate.startswith("-"): + return datetime.date.today() + datetime.timedelta(days=int(candidate)) + + accepted_formats = ["%Y-%m-%d", "%Y%m%d", "%Y-%j"] + + for fmt in accepted_formats: + try: + converted = datetime.datetime.strptime(candidate, fmt).date() + break + except ValueError: + pass + + if converted is None: + raise ValueError( + f"Not a valid date: '{candidate}'. Please use valid relative or absolute format." + ) + + return converted + + +def time(candidate: str) -> datetime.time: + """Converts ECMWF-format time strings into a `datetime.time`. + + Accepted time formats: + - HH:MM + - HHMM + - HH + + For example: + - 18:00 + - 1820 + - 18 + + Note: If MM is omitted it defaults to 00. + """ + converted = None + + accepted_formats = ["%H", "%H:%M", "%H%M"] + + for fmt in accepted_formats: + try: + converted = datetime.datetime.strptime(candidate, fmt).time() + break + except ValueError: + pass + + if converted is None: + raise ValueError(f"Not a valid time: '{candidate}'. Please use valid format.") + + return converted + + +def day_month_year(candidate: t.Any) -> int: + """Converts day, month and year strings into 'int'.""" + try: + if isinstance(candidate, str) or isinstance(candidate, int): + return int(candidate) + raise ValueError("must be a str or int.") + except ValueError as e: + raise ValueError( + f"Not a valid day, month, or year value: {candidate}. Please use valid value." + ) from e + + +def parse_literal(candidate: t.Any) -> t.Any: + try: + # Support parsing ints with leading zeros, e.g. '01' + if isinstance(candidate, str) and candidate.isdigit(): + return int(candidate) + return ast.literal_eval(candidate) + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + return candidate + + +def validate(key: str, value: int) -> None: + """Validates value based on the key.""" + if key == "day": + assert 1 <= value <= 31, "Day value must be between 1 to 31." + if key == "month": + assert 1 <= value <= 12, "Month value must be between 1 to 12." + + +def typecast(key: str, value: t.Any) -> t.Any: + """Type the value to its appropriate datatype.""" + SWITCHER = { + "date": date, + "time": time, + "day": day_month_year, + "month": day_month_year, + "year": day_month_year, + } + converted = SWITCHER.get(key, parse_literal)(value) + validate(key, converted) + return converted + + +def _read_config_file(file: t.IO) -> t.Dict: + """Reads `*.json` or `*.cfg` files.""" + try: + return json.load(file) + except json.JSONDecodeError: + pass + + file.seek(0) + + try: + config = configparser.ConfigParser() + config.read_file(file) + config = {s: dict(config.items(s)) for s in config.sections()} + return config + except configparser.ParsingError: + return {} + + +def parse_config(file: t.IO) -> t.Dict: + """Parses a `*.json` or `*.cfg` file into a configuration dictionary.""" + config = _read_config_file(file) + config_by_section = {s: _parse_lists(v, s) for s, v in config.items()} + config_with_nesting = parse_subsections(config_by_section) + return config_with_nesting + + +def _splitlines(block: str) -> t.List[str]: + """Converts a multi-line block into a list of strings.""" + return [line.strip() for line in block.strip().splitlines()] + + +def mars_range_value(token: str) -> t.Union[datetime.date, int, float]: + """Converts a range token into either a date, int, or float.""" + try: + return date(token) + except ValueError: + pass + + if token.isdecimal(): + return int(token) + + try: + return float(token) + except ValueError: + raise ValueError( + "Token string must be an 'int', 'float', or 'datetime.date()'." + ) + + +def mars_increment_value(token: str) -> t.Union[int, float]: + """Converts an increment token into either an int or a float.""" + try: + return int(token) + except ValueError: + pass + + try: + return float(token) + except ValueError: + raise ValueError("Token string must be an 'int' or a 'float'.") + + +def parse_mars_syntax(block: str) -> t.List[str]: + """Parses MARS list or range into a list of arguments; ranges are inclusive. + + Types for the range and value are inferred. + + Examples: + >>> parse_mars_syntax("10/to/12") + ['10', '11', '12'] + >>> parse_mars_syntax("12/to/10/by/-1") + ['12', '11', '10'] + >>> parse_mars_syntax("0.0/to/0.5/by/0.1") + ['0.0', '0.1', '0.2', '0.30000000000000004', '0.4', '0.5'] + >>> parse_mars_syntax("2020-01-07/to/2020-01-14/by/2") + ['2020-01-07', '2020-01-09', '2020-01-11', '2020-01-13'] + >>> parse_mars_syntax("2020-01-14/to/2020-01-07/by/-2") + ['2020-01-14', '2020-01-12', '2020-01-10', '2020-01-08'] + + Returns: + A list of strings representing a range from start to finish, based on the + type of the values in the range. + If all range values are integers, it will return a list of strings of integers. + If range values are floats, it will return a list of strings of floats. + If the range values are dates, it will return a list of strings of dates in + YYYY-MM-DD format. (Note: here, the increment value should be an integer). + """ + + # Split into tokens, omitting empty strings. + tokens = [b.strip() for b in block.split("/") if b != ""] + + # Return list if no range operators are present. + if "to" not in tokens and "by" not in tokens: + return tokens + + # Parse range values, honoring 'to' and 'by' operators. + try: + to_idx = tokens.index("to") + assert to_idx != 0, "There must be a start token." + start_token, end_token = tokens[to_idx - 1], tokens[to_idx + 1] + start, end = mars_range_value(start_token), mars_range_value(end_token) + + # Parse increment token, or choose default increment. + increment_token = "1" + increment = 1 + if "by" in tokens: + increment_token = tokens[tokens.index("by") + 1] + increment = mars_increment_value(increment_token) + except (AssertionError, IndexError, ValueError): + raise SyntaxError(f"Improper range syntax in '{block}'.") + + # Return a range of values with appropriate data type. + if isinstance(start, datetime.date) and isinstance(end, datetime.date): + if not isinstance(increment, int): + raise ValueError( + f"Increments on a date range must be integer number of days, '{increment_token}' is invalid." + ) + return [d.strftime("%Y-%m-%d") for d in date_range(start, end, increment)] + elif (isinstance(start, float) or isinstance(end, float)) and not isinstance( + increment, datetime.date + ): + # Increment can be either an int or a float. + _round_places = 4 + return [ + str(round(x, _round_places)).zfill(len(start_token)) + for x in np.arange(start, end + increment, increment) + ] + elif isinstance(start, int) and isinstance(end, int) and isinstance(increment, int): + # Honor leading zeros. + offset = 1 if start <= end else -1 + return [ + str(x).zfill(len(start_token)) + for x in range(start, end + offset, increment) + ] + else: + raise ValueError( + f"Range tokens (start='{start_token}', end='{end_token}', increment='{increment_token}')" + f" are inconsistent types." + ) + + +def date_range( + start: datetime.date, end: datetime.date, increment: int = 1 +) -> t.Iterable[datetime.date]: + """Gets a range of dates, inclusive.""" + offset = 1 if start <= end else -1 + return ( + start + datetime.timedelta(days=x) + for x in range(0, (end - start).days + offset, increment) + ) + + +def _parse_lists(config: dict, section: str = "") -> t.Dict: + """Parses multiline blocks in *.cfg and *.json files as lists.""" + for key, val in config.items(): + # Checks str type for backward compatibility since it also support "padding": 0 in json config + if not isinstance(val, str): + continue + + if "/" in val and "parameters" not in section: + config[key] = parse_mars_syntax(val) + elif "\n" in val: + config[key] = _splitlines(val) + + return config + + +def _number_of_replacements(s: t.Text): + format_names = [v[1] for v in string.Formatter().parse(s) if v[1] is not None] + num_empty_names = len([empty for empty in format_names if empty == ""]) + if num_empty_names != 0: + num_empty_names -= 1 + return len(set(format_names)) + num_empty_names + + +def parse_subsections(config: t.Dict) -> t.Dict: + """Interprets [section.subsection] as nested dictionaries in `.cfg` files.""" + copy = cp.deepcopy(config) + for key, val in copy.items(): + path = key.split(".") + runner = copy + parent = {} + p = None + for p in path: + if p not in runner: + runner[p] = {} + parent = runner + runner = runner[p] + parent[p] = val + + for_cleanup = [key for key, _ in copy.items() if "." in key] + for target in for_cleanup: + del copy[target] + return copy + + +def require( + condition: bool, message: str, error_type: t.Type[Exception] = ValueError +) -> None: + """A assert-like helper that wraps text and throws an error.""" + if not condition: + raise error_type(textwrap.dedent(message)) + + +def process_config(file: t.IO, config_name: str) -> Config: + """Read the config file and prompt the user if it is improperly structured.""" + config = parse_config(file) + + require(bool(config), "Unable to parse configuration file.") + require( + "parameters" in config, + """ + 'parameters' section required in configuration file. + + The 'parameters' section specifies the 'client', 'dataset', 'target_path', and + 'partition_key' for the API client. + + Please consult the documentation for more information.""", + ) + + params = config.get("parameters", {}) + require( + "target_template" not in params, + """ + 'target_template' is deprecated, use 'target_path' instead. + + Please consult the documentation for more information.""", + ) + require( + "target_path" in params, + """ + 'parameters' section requires a 'target_path' key. + + The 'target_path' is used to format the name of the output files. It + accepts Python 3.5+ string format symbols (e.g. '{}'). The number of symbols + should match the length of the 'partition_keys', as the 'partition_keys' args + are used to create the templates.""", + ) + require( + "client" in params, + """ + 'parameters' section requires a 'client' key. + + Supported clients are {} + """.format( + str(CLIENTS) + ), + ) + require( + params.get("client") in CLIENTS, + """ + Invalid 'client' parameter. + + Supported clients are {} + """.format( + str(CLIENTS) + ), + ) + require( + "append_date_dirs" not in params, + """ + The current version of 'google-weather-tools' no longer supports 'append_date_dirs'! + + Please refer to documentation for creating date-based directory hierarchy : + https://weather-tools.readthedocs.io/en/latest/Configuration.html#""" + """creating-a-date-based-directory-hierarchy.""", + NotImplementedError, + ) + require( + "target_filename" not in params, + """ + The current version of 'google-weather-tools' no longer supports 'target_filename'! + + Please refer to documentation : + https://weather-tools.readthedocs.io/en/latest/Configuration.html#parameters-section.""", + NotImplementedError, + ) + + partition_keys = params.get("partition_keys", list()) + if isinstance(partition_keys, str): + partition_keys = [partition_keys.strip()] + + selection = config.get("selection", dict()) + require( + all((key in selection for key in partition_keys)), + """ + All 'partition_keys' must appear in the 'selection' section. + + 'partition_keys' specify how to split data for workers. Please consult + documentation for more information.""", + ) + + num_template_replacements = _number_of_replacements(params["target_path"]) + num_partition_keys = len(partition_keys) + + require( + num_template_replacements == num_partition_keys, + """ + 'target_path' has {0} replacements. Expected {1}, since there are {1} + partition keys. + """.format( + num_template_replacements, num_partition_keys + ), + ) + + if "day" in partition_keys: + require( + selection["day"] != "all", + """If 'all' is used for a selection value, it cannot appear as a partition key.""", + ) + + # Ensure consistent lookup. + config["parameters"]["partition_keys"] = partition_keys + # Add config file name. + config["parameters"]["config_name"] = config_name + + # Ensure the cartesian-cross can be taken on singleton values for the partition. + for key in partition_keys: + if not isinstance(selection[key], list): + selection[key] = [selection[key]] + + return Config.from_dict(config) + + +def prepare_target_name(config: Config) -> str: + """Returns name of target location.""" + partition_dict = OrderedDict( + (key, typecast(key, config.selection[key][0])) for key in config.partition_keys + ) + target = config.target_path.format(*partition_dict.values(), **partition_dict) + + return target + + +def get_subsections(config: Config) -> t.List[t.Tuple[str, t.Dict]]: + """Collect parameter subsections from main configuration. + + If the `parameters` section contains subsections (e.g. '[parameters.1]', + '[parameters.2]'), collect the subsection key-value pairs. Otherwise, + return an empty dictionary (i.e. there are no subsections). + + This is useful for specifying multiple API keys for your configuration. + For example: + ``` + [parameters.alice] + api_key=KKKKK1 + api_url=UUUUU1 + [parameters.bob] + api_key=KKKKK2 + api_url=UUUUU2 + [parameters.eve] + api_key=KKKKK3 + api_url=UUUUU3 + ``` + """ + return [ + (name, params) + for name, params in config.kwargs.items() + if isinstance(params, dict) + ] or [("default", {})] diff --git a/weather_dl_v2/fastapi-server/config_processing/partition.py b/weather_dl_v2/fastapi-server/config_processing/partition.py new file mode 100644 index 00000000..a9f6a9e2 --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/partition.py @@ -0,0 +1,129 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import copy as cp +import dataclasses +import itertools +import typing as t + +from .manifest import Manifest +from .parsers import prepare_target_name +from .config import Config +from .stores import Store, FSStore + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class PartitionConfig: + """Partition a config into multiple data requests. + + Partitioning involves four main operations: First, we fan-out shards based on + partition keys (a cross product of the values). Second, we filter out existing + downloads (unless we want to force downloads). Last, we assemble each partition + into a single Config. + + Attributes: + store: A cloud storage system, used for checking the existence of downloads. + manifest: A download manifest to register preparation state. + """ + + config: Config + store: Store + manifest: Manifest + + def _create_partition_config(self, option: t.Tuple) -> Config: + """Create a config for a single partition option. + + Output a config dictionary, overriding the range of values for + each key with the partition instance in 'selection'. + Continuing the example from prepare_partitions, the selection section + would be: + { 'foo': ..., 'year': ['2020'], 'month': ['01'], ... } + { 'foo': ..., 'year': ['2020'], 'month': ['02'], ... } + { 'foo': ..., 'year': ['2020'], 'month': ['03'], ... } + + Args: + option: A single item in the range of partition_keys. + config: The download config, including the parameters and selection sections. + + Returns: + A configuration with that selects a single download partition. + """ + copy = cp.deepcopy(self.config.selection) + out = cp.deepcopy(self.config) + for idx, key in enumerate(self.config.partition_keys): + copy[key] = [option[idx]] + + out.selection = copy + return out + + def skip_partition(self, config: Config) -> bool: + """Return true if partition should be skipped.""" + + if config.force_download: + return False + + target = prepare_target_name(config) + if self.store.exists(target): + logger.info(f"file {target} found, skipping.") + self.manifest.skip( + config.config_name, + config.dataset, + config.selection, + target, + config.user_id, + ) + return True + + return False + + def prepare_partitions(self) -> t.Iterator[Config]: + """Iterate over client parameters, partitioning over `partition_keys`. + + This produces a Cartesian-Cross over the range of keys. + + For example, if the keys were 'year' and 'month', it would produce + an iterable like: + ( ('2020', '01'), ('2020', '02'), ('2020', '03'), ...) + + Returns: + An iterator of `Config`s. + """ + for option in itertools.product( + *[self.config.selection[key] for key in self.config.partition_keys] + ): + yield self._create_partition_config(option) + + def new_downloads_only(self, candidate: Config) -> bool: + """Predicate function to skip already downloaded partitions.""" + if self.store is None: + self.store = FSStore() + should_skip = self.skip_partition(candidate) + + return not should_skip + + def update_manifest_collection(self, partition: Config) -> Config: + """Updates the DB.""" + location = prepare_target_name(partition) + self.manifest.schedule( + partition.config_name, + partition.dataset, + partition.selection, + location, + partition.user_id, + ) + logger.info(f"Created partition {location!r}.") diff --git a/weather_dl_v2/fastapi-server/config_processing/pipeline.py b/weather_dl_v2/fastapi-server/config_processing/pipeline.py new file mode 100644 index 00000000..175dd798 --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/pipeline.py @@ -0,0 +1,69 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import getpass +import logging +import os +from .parsers import process_config +from .partition import PartitionConfig +from .manifest import FirestoreManifest +from database.download_handler import get_download_handler +from database.queue_handler import get_queue_handler +from fastapi.concurrency import run_in_threadpool + +logger = logging.getLogger(__name__) + +download_handler = get_download_handler() +queue_handler = get_queue_handler() + + +def _do_partitions(partition_obj: PartitionConfig): + for partition in partition_obj.prepare_partitions(): + # Skip existing downloads + if partition_obj.new_downloads_only(partition): + partition_obj.update_manifest_collection(partition) + + +# TODO: Make partitioning faster. +async def start_processing_config(config_file, licenses, force_download): + config = {} + manifest = FirestoreManifest() + + with open(config_file, "r", encoding="utf-8") as f: + # configs/example.cfg -> example.cfg + config_name = os.path.split(config_file)[1] + config = process_config(f, config_name) + + config.force_download = force_download + config.user_id = getpass.getuser() + + partition_obj = PartitionConfig(config, None, manifest) + + # Make entry in 'download' & 'queues' collection. + await download_handler._start_download(config_name, config.client) + await download_handler._mark_partitioning_status( + config_name, "Partitioning in-progress." + ) + try: + # Prepare partitions + await run_in_threadpool(_do_partitions, partition_obj) + await download_handler._mark_partitioning_status( + config_name, "Partitioning completed." + ) + await queue_handler._update_queues_on_start_download(config_name, licenses) + except Exception as e: + error_str = f"Partitioning failed for {config_name} due to {e}." + logger.error(error_str) + await download_handler._mark_partitioning_status(config_name, error_str) diff --git a/weather_dl_v2/fastapi-server/config_processing/stores.py b/weather_dl_v2/fastapi-server/config_processing/stores.py new file mode 100644 index 00000000..4f60e337 --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/stores.py @@ -0,0 +1,122 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Download destinations, or `Store`s.""" + +import abc +import io +import os +import tempfile +import typing as t + +from apache_beam.io.filesystems import FileSystems + + +class Store(abc.ABC): + """A interface to represent where downloads are stored. + + Default implementation uses Apache Beam's Filesystems. + """ + + @abc.abstractmethod + def open(self, filename: str, mode: str = "r") -> t.IO: + pass + + @abc.abstractmethod + def exists(self, filename: str) -> bool: + pass + + +class InMemoryStore(Store): + """Store file data in memory.""" + + def __init__(self): + self.store = {} + + def open(self, filename: str, mode: str = "r") -> t.IO: + """Create or read in-memory data.""" + if "b" in mode: + file = io.BytesIO() + else: + file = io.StringIO() + self.store[filename] = file + return file + + def exists(self, filename: str) -> bool: + """Return true if the 'file' exists in memory.""" + return filename in self.store + + +class TempFileStore(Store): + """Store data into temporary files.""" + + def __init__(self, directory: t.Optional[str] = None) -> None: + """Optionally specify the directory that contains all temporary files.""" + self.dir = directory + if self.dir and not os.path.exists(self.dir): + os.makedirs(self.dir) + + def open(self, filename: str, mode: str = "r") -> t.IO: + """Create a temporary file in the store directory.""" + return tempfile.TemporaryFile(mode, dir=self.dir) + + def exists(self, filename: str) -> bool: + """Return true if file exists.""" + return os.path.exists(filename) + + +class LocalFileStore(Store): + """Store data into local files.""" + + def __init__(self, directory: t.Optional[str] = None) -> None: + """Optionally specify the directory that contains all downloaded files.""" + self.dir = directory + if self.dir and not os.path.exists(self.dir): + os.makedirs(self.dir) + + def open(self, filename: str, mode: str = "r") -> t.IO: + """Open a local file from the store directory.""" + return open(os.sep.join([self.dir, filename]), mode) + + def exists(self, filename: str) -> bool: + """Returns true if local file exists.""" + return os.path.exists(os.sep.join([self.dir, filename])) + + +class FSStore(Store): + """Store data into any store supported by Apache Beam's FileSystems.""" + + def open(self, filename: str, mode: str = "r") -> t.IO: + """Open object in cloud bucket (or local file system) as a read or write channel. + + To work with cloud storage systems, only a read or write channel can be openend + at one time. Data will be treated as bytes, not text (equivalent to `rb` or `wb`). + + Further, append operations, or writes on existing objects, are dissallowed (the + error thrown will depend on the implementation of the underlying cloud provider). + """ + if "r" in mode and "w" not in mode: + return FileSystems().open(filename) + + if "w" in mode and "r" not in mode: + return FileSystems().create(filename) + + raise ValueError( + f"invalid mode {mode!r}: mode must have either 'r' or 'w', but not both." + ) + + def exists(self, filename: str) -> bool: + """Returns true if object exists.""" + return FileSystems().exists(filename) diff --git a/weather_dl_v2/fastapi-server/config_processing/util.py b/weather_dl_v2/fastapi-server/config_processing/util.py new file mode 100644 index 00000000..765a9c47 --- /dev/null +++ b/weather_dl_v2/fastapi-server/config_processing/util.py @@ -0,0 +1,229 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import datetime +import geojson +import hashlib +import itertools +import os +import socket +import subprocess +import sys +import typing as t + +import numpy as np +import pandas as pd +from apache_beam.io.gcp import gcsio +from apache_beam.utils import retry +from xarray.core.utils import ensure_us_time_resolution +from urllib.parse import urlparse +from google.api_core.exceptions import BadRequest + + +LATITUDE_RANGE = (-90, 90) +LONGITUDE_RANGE = (-180, 180) +GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] + +logger = logging.getLogger(__name__) + + +def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter( + exception, +) -> bool: + if isinstance(exception, socket.timeout): + return True + if isinstance(exception, TimeoutError): + return True + # To handle the concurrency issue in BigQuery. + if isinstance(exception, BadRequest): + return True + return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) + + +class _FakeClock: + + def sleep(self, value): + pass + + +def retry_with_exponential_backoff(fun): + """A retry decorator that doesn't apply during test time.""" + clock = retry.Clock() + + # Use a fake clock only during test time... + if "unittest" in sys.modules.keys(): + clock = _FakeClock() + + return retry.with_exponential_backoff( + retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, + clock=clock, + )(fun) + + +# TODO(#245): Group with common utilities (duplicated) +def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: + """Yield evenly-sized chunks from an iterable.""" + input_ = iter(iterable) + try: + while True: + it = itertools.islice(input_, n) + # peek to check if 'it' has next item. + first = next(it) + yield itertools.chain([first], it) + except StopIteration: + pass + + +# TODO(#245): Group with common utilities (duplicated) +def copy(src: str, dst: str) -> None: + """Copy data via `gsutil cp`.""" + try: + subprocess.run(["gsutil", "cp", src, dst], check=True, capture_output=True) + except subprocess.CalledProcessError as e: + logger.info( + f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}.' + ) + raise + + +# TODO(#245): Group with common utilities (duplicated) +def to_json_serializable_type(value: t.Any) -> t.Any: + """Returns the value with a type serializable to JSON""" + # Note: The order of processing is significant. + logger.info("Serializing to JSON.") + + if pd.isna(value) or value is None: + return None + elif np.issubdtype(type(value), np.floating): + return float(value) + elif isinstance(value, np.ndarray): + # Will return a scaler if array is of size 1, else will return a list. + return value.tolist() + elif ( + isinstance(value, datetime.datetime) + or isinstance(value, str) + or isinstance(value, np.datetime64) + ): + # Assume strings are ISO format timestamps... + try: + value = datetime.datetime.fromisoformat(value) + except ValueError: + # ... if they are not, assume serialization is already correct. + return value + except TypeError: + # ... maybe value is a numpy datetime ... + try: + value = ensure_us_time_resolution(value).astype(datetime.datetime) + except AttributeError: + # ... value is a datetime object, continue. + pass + + # We use a string timestamp representation. + if value.tzname(): + return value.isoformat() + + # We assume here that naive timestamps are in UTC timezone. + return value.replace(tzinfo=datetime.timezone.utc).isoformat() + elif isinstance(value, np.timedelta64): + # Return time delta in seconds. + return float(value / np.timedelta64(1, "s")) + # This check must happen after processing np.timedelta64 and np.datetime64. + elif np.issubdtype(type(value), np.integer): + return int(value) + + return value + + +def fetch_geo_polygon(area: t.Union[list, str]) -> str: + """Calculates a geography polygon from an input area.""" + # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 + if isinstance(area, str): + # European area + if area == "E": + area = [73.5, -27, 33, 45] + # Global area + elif area == "G": + area = GLOBAL_COVERAGE_AREA + else: + raise RuntimeError(f"Not a valid value for area in config: {area}.") + + n, w, s, e = [float(x) for x in area] + if s < LATITUDE_RANGE[0]: + raise ValueError(f"Invalid latitude value for south: '{s}'") + if n > LATITUDE_RANGE[1]: + raise ValueError(f"Invalid latitude value for north: '{n}'") + if w < LONGITUDE_RANGE[0]: + raise ValueError(f"Invalid longitude value for west: '{w}'") + if e > LONGITUDE_RANGE[1]: + raise ValueError(f"Invalid longitude value for east: '{e}'") + + # Define the coordinates of the bounding box. + coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] + + # Create the GeoJSON polygon object. + polygon = geojson.dumps(geojson.Polygon([coords])) + return polygon + + +def get_file_size(path: str) -> float: + parsed_gcs_path = urlparse(path) + if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "": + return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0 + else: + return ( + gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0 + ) + + +def get_wait_interval(num_retries: int = 0) -> float: + """Returns next wait interval in seconds, using an exponential backoff algorithm.""" + if 0 == num_retries: + return 0 + return 2**num_retries + + +def generate_md5_hash(input: str) -> str: + """Generates md5 hash for the input string.""" + return hashlib.md5(input.encode("utf-8")).hexdigest() + + +def download_with_aria2(url: str, path: str) -> None: + """Downloads a file from the given URL using the `aria2c` command-line utility, + with options set to improve download speed and reliability.""" + dir_path, file_name = os.path.split(path) + try: + subprocess.run( + [ + "aria2c", + "-x", + "16", + "-s", + "16", + url, + "-d", + dir_path, + "-o", + file_name, + "--allow-overwrite", + ], + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + logger.info( + f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.' + ) + raise diff --git a/weather_dl_v2/fastapi-server/database/__init__.py b/weather_dl_v2/fastapi-server/database/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/fastapi-server/database/download_handler.py b/weather_dl_v2/fastapi-server/database/download_handler.py new file mode 100644 index 00000000..1377e5b4 --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/download_handler.py @@ -0,0 +1,160 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +from firebase_admin import firestore +from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter +from google.cloud.firestore_v1.types import WriteResult +from database.session import get_async_client +from server_config import get_config + +logger = logging.getLogger(__name__) + + +def get_download_handler(): + return DownloadHandlerFirestore(db=get_async_client()) + + +def get_mock_download_handler(): + return DownloadHandlerMock() + + +class DownloadHandler(abc.ABC): + + @abc.abstractmethod + async def _start_download(self, config_name: str, client_name: str) -> None: + pass + + @abc.abstractmethod + async def _stop_download(self, config_name: str) -> None: + pass + + @abc.abstractmethod + async def _mark_partitioning_status(self, config_name: str, status: str) -> None: + pass + + @abc.abstractmethod + async def _check_download_exists(self, config_name: str) -> bool: + pass + + @abc.abstractmethod + async def _get_downloads(self, client_name: str) -> list: + pass + + @abc.abstractmethod + async def _get_download_by_config_name(self, config_name: str): + pass + + +class DownloadHandlerMock(DownloadHandler): + + def __init__(self): + pass + + async def _start_download(self, config_name: str, client_name: str) -> None: + logger.info( + f"Added {config_name} in 'download' collection. Update_time: 000000." + ) + + async def _stop_download(self, config_name: str) -> None: + logger.info( + f"Removed {config_name} in 'download' collection. Update_time: 000000." + ) + + async def _mark_partitioning_status(self, config_name: str, status: str) -> None: + logger.info( + f"Updated {config_name} in 'download' collection. Update_time: 000000." + ) + + async def _check_download_exists(self, config_name: str) -> bool: + if config_name == "not_exist": + return False + elif config_name == "not_exist.cfg": + return False + else: + return True + + async def _get_downloads(self, client_name: str) -> list: + return [{"config_name": "example.cfg", "client_name": "client", "status": "partitioning completed."}] + + async def _get_download_by_config_name(self, config_name: str): + if config_name == "not_exist": + return None + return {"config_name": "example.cfg", "client_name": "client", "status": "partitioning completed."} + + +class DownloadHandlerFirestore(DownloadHandler): + + def __init__(self, db: firestore.firestore.Client): + self.db = db + self.collection = get_config().download_collection + + async def _start_download(self, config_name: str, client_name: str) -> None: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(config_name) + .set({"config_name": config_name, "client_name": client_name}) + ) + + logger.info( + f"Added {config_name} in 'download' collection. Update_time: {result.update_time}." + ) + + async def _stop_download(self, config_name: str) -> None: + timestamp = ( + await self.db.collection(self.collection).document(config_name).delete() + ) + logger.info( + f"Removed {config_name} in 'download' collection. Update_time: {timestamp}." + ) + + async def _mark_partitioning_status(self, config_name: str, status: str) -> None: + timestamp = ( + await self.db.collection(self.collection) + .document(config_name) + .update({"status": status}) + ) + logger.info( + f"Updated {config_name} in 'download' collection. Update_time: {timestamp}." + ) + + async def _check_download_exists(self, config_name: str) -> bool: + result: DocumentSnapshot = ( + await self.db.collection(self.collection).document(config_name).get() + ) + return result.exists + + async def _get_downloads(self, client_name: str) -> list: + docs = [] + if client_name: + docs = ( + self.db.collection(self.collection) + .where(filter=FieldFilter("client_name", "==", client_name)) + .stream() + ) + else: + docs = self.db.collection(self.collection).stream() + + return [doc.to_dict() async for doc in docs] + + async def _get_download_by_config_name(self, config_name: str): + result: DocumentSnapshot = ( + await self.db.collection(self.collection).document(config_name).get() + ) + if result.exists: + return result.to_dict() + else: + return None diff --git a/weather_dl_v2/fastapi-server/database/license_handler.py b/weather_dl_v2/fastapi-server/database/license_handler.py new file mode 100644 index 00000000..d4878e25 --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/license_handler.py @@ -0,0 +1,200 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +from firebase_admin import firestore +from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter +from google.cloud.firestore_v1.types import WriteResult +from database.session import get_async_client +from server_config import get_config + + +logger = logging.getLogger(__name__) + + +def get_license_handler(): + return LicenseHandlerFirestore(db=get_async_client()) + + +def get_mock_license_handler(): + return LicenseHandlerMock() + + +class LicenseHandler(abc.ABC): + + @abc.abstractmethod + async def _add_license(self, license_dict: dict) -> str: + pass + + @abc.abstractmethod + async def _delete_license(self, license_id: str) -> None: + pass + + @abc.abstractmethod + async def _check_license_exists(self, license_id: str) -> bool: + pass + + @abc.abstractmethod + async def _get_license_by_license_id(self, license_id: str) -> dict: + pass + + @abc.abstractmethod + async def _get_license_by_client_name(self, client_name: str) -> list: + pass + + @abc.abstractmethod + async def _get_licenses(self) -> list: + pass + + @abc.abstractmethod + async def _update_license(self, license_id: str, license_dict: dict) -> None: + pass + + @abc.abstractmethod + async def _get_license_without_deployment(self) -> list: + pass + + +class LicenseHandlerMock(LicenseHandler): + + def __init__(self): + pass + + async def _add_license(self, license_dict: dict) -> str: + license_id = "L1" + logger.info(f"Added {license_id} in 'license' collection. Update_time: 00000.") + return license_id + + async def _delete_license(self, license_id: str) -> None: + logger.info( + f"Removed {license_id} in 'license' collection. Update_time: 00000." + ) + + async def _update_license(self, license_id: str, license_dict: dict) -> None: + logger.info( + f"Updated {license_id} in 'license' collection. Update_time: 00000." + ) + + async def _check_license_exists(self, license_id: str) -> bool: + if license_id == "not_exist": + return False + elif license_id == "no-exists": + return False + else: + return True + + async def _get_license_by_license_id(self, license_id: str) -> dict: + if license_id == "not_exist": + return None + return { + "license_id": license_id, + "secret_id": "xxxx", + "client_name": "dummy_client", + "k8s_deployment_id": "k1", + "number_of_requets": 100, + } + + async def _get_license_by_client_name(self, client_name: str) -> list: + return [{ + "license_id": "L1", + "secret_id": "xxxx", + "client_name": client_name, + "k8s_deployment_id": "k1", + "number_of_requets": 100, + }] + + async def _get_licenses(self) -> list: + return [{ + "license_id": "L1", + "secret_id": "xxxx", + "client_name": "dummy_client", + "k8s_deployment_id": "k1", + "number_of_requets": 100, + }] + + async def _get_license_without_deployment(self) -> list: + return [] + + +class LicenseHandlerFirestore(LicenseHandler): + + def __init__(self, db: firestore.firestore.AsyncClient): + self.db = db + self.collection = get_config().license_collection + + async def _add_license(self, license_dict: dict) -> str: + license_dict["license_id"] = license_dict["license_id"].lower() + license_id = license_dict["license_id"] + + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license_id) + .set(license_dict) + ) + logger.info( + f"Added {license_id} in 'license' collection. Update_time: {result.update_time}." + ) + return license_id + + async def _delete_license(self, license_id: str) -> None: + timestamp = ( + await self.db.collection(self.collection).document(license_id).delete() + ) + logger.info( + f"Removed {license_id} in 'license' collection. Update_time: {timestamp}." + ) + + async def _update_license(self, license_id: str, license_dict: dict) -> None: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license_id) + .update(license_dict) + ) + logger.info( + f"Updated {license_id} in 'license' collection. Update_time: {result.update_time}." + ) + + async def _check_license_exists(self, license_id: str) -> bool: + result: DocumentSnapshot = ( + await self.db.collection(self.collection).document(license_id).get() + ) + return result.exists + + async def _get_license_by_license_id(self, license_id: str) -> dict: + result: DocumentSnapshot = ( + await self.db.collection(self.collection).document(license_id).get() + ) + return result.to_dict() + + async def _get_license_by_client_name(self, client_name: str) -> list: + docs = ( + self.db.collection(self.collection) + .where(filter=FieldFilter("client_name", "==", client_name)) + .stream() + ) + return [doc.to_dict() async for doc in docs] + + async def _get_licenses(self) -> list: + docs = self.db.collection(self.collection).stream() + return [doc.to_dict() async for doc in docs] + + async def _get_license_without_deployment(self) -> list: + docs = ( + self.db.collection(self.collection) + .where(filter=FieldFilter("k8s_deployment_id", "==", "")) + .stream() + ) + return [doc.to_dict() async for doc in docs] diff --git a/weather_dl_v2/fastapi-server/database/manifest_handler.py b/weather_dl_v2/fastapi-server/database/manifest_handler.py new file mode 100644 index 00000000..d5facfab --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/manifest_handler.py @@ -0,0 +1,181 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +from firebase_admin import firestore +from google.cloud.firestore_v1.base_query import FieldFilter, Or, And +from server_config import get_config +from database.session import get_async_client + +logger = logging.getLogger(__name__) + + +def get_manifest_handler(): + return ManifestHandlerFirestore(db=get_async_client()) + + +def get_mock_manifest_handler(): + return ManifestHandlerMock() + + +class ManifestHandler(abc.ABC): + + @abc.abstractmethod + async def _get_download_success_count(self, config_name: str) -> int: + pass + + @abc.abstractmethod + async def _get_download_failure_count(self, config_name: str) -> int: + pass + + @abc.abstractmethod + async def _get_download_scheduled_count(self, config_name: str) -> int: + pass + + @abc.abstractmethod + async def _get_download_inprogress_count(self, config_name: str) -> int: + pass + + @abc.abstractmethod + async def _get_download_total_count(self, config_name: str) -> int: + pass + + @abc.abstractmethod + async def _get_non_successfull_downloads(self, config_name: str) -> list: + pass + + +class ManifestHandlerMock(ManifestHandler): + + async def _get_download_failure_count(self, config_name: str) -> int: + return 0 + + async def _get_download_inprogress_count(self, config_name: str) -> int: + return 0 + + async def _get_download_scheduled_count(self, config_name: str) -> int: + return 0 + + async def _get_download_success_count(self, config_name: str) -> int: + return 0 + + async def _get_download_total_count(self, config_name: str) -> int: + return 0 + + async def _get_non_successfull_downloads(self, config_name: str) -> list: + return [] + + +class ManifestHandlerFirestore(ManifestHandler): + + def __init__(self, db: firestore.firestore.Client): + self.db = db + self.collection = get_config().manifest_collection + + async def _get_download_success_count(self, config_name: str) -> int: + result = ( + await self.db.collection(self.collection) + .where(filter=FieldFilter("config_name", "==", config_name)) + .where(filter=FieldFilter("stage", "==", "upload")) + .where(filter=FieldFilter("status", "==", "success")) + .count() + .get() + ) + + count = result[0][0].value + + return count + + async def _get_download_failure_count(self, config_name: str) -> int: + result = ( + await self.db.collection(self.collection) + .where(filter=FieldFilter("config_name", "==", config_name)) + .where(filter=FieldFilter("status", "==", "failure")) + .count() + .get() + ) + + count = result[0][0].value + + return count + + async def _get_download_scheduled_count(self, config_name: str) -> int: + result = ( + await self.db.collection(self.collection) + .where(filter=FieldFilter("config_name", "==", config_name)) + .where(filter=FieldFilter("status", "==", "scheduled")) + .count() + .get() + ) + + count = result[0][0].value + + return count + + async def _get_download_inprogress_count(self, config_name: str) -> int: + and_filter = And( + filters=[ + FieldFilter("status", "==", "success"), + FieldFilter("stage", "!=", "upload"), + ] + ) + or_filter = Or(filters=[FieldFilter("status", "==", "in-progress"), and_filter]) + + result = ( + await self.db.collection(self.collection) + .where(filter=FieldFilter("config_name", "==", config_name)) + .where(filter=or_filter) + .count() + .get() + ) + + count = result[0][0].value + + return count + + async def _get_download_total_count(self, config_name: str) -> int: + result = ( + await self.db.collection(self.collection) + .where(filter=FieldFilter("config_name", "==", config_name)) + .count() + .get() + ) + + count = result[0][0].value + + return count + + async def _get_non_successfull_downloads(self, config_name: str) -> list: + or_filter = Or( + filters=[ + FieldFilter("stage", "==", "fetch"), + FieldFilter("stage", "==", "download"), + And( + filters=[ + FieldFilter("status", "!=", "success"), + FieldFilter("stage", "==", "upload"), + ] + ), + ] + ) + + docs = ( + self.db.collection(self.collection) + .where(filter=FieldFilter("config_name", "==", config_name)) + .where(filter=or_filter) + .stream() + ) + return [doc.to_dict() async for doc in docs] diff --git a/weather_dl_v2/fastapi-server/database/queue_handler.py b/weather_dl_v2/fastapi-server/database/queue_handler.py new file mode 100644 index 00000000..1909d583 --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/queue_handler.py @@ -0,0 +1,247 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import logging +from firebase_admin import firestore +from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter +from google.cloud.firestore_v1.types import WriteResult +from database.session import get_async_client +from server_config import get_config + +logger = logging.getLogger(__name__) + + +def get_queue_handler(): + return QueueHandlerFirestore(db=get_async_client()) + + +def get_mock_queue_handler(): + return QueueHandlerMock() + + +class QueueHandler(abc.ABC): + + @abc.abstractmethod + async def _create_license_queue(self, license_id: str, client_name: str) -> None: + pass + + @abc.abstractmethod + async def _remove_license_queue(self, license_id: str) -> None: + pass + + @abc.abstractmethod + async def _get_queues(self) -> list: + pass + + @abc.abstractmethod + async def _get_queue_by_license_id(self, license_id: str) -> dict: + pass + + @abc.abstractmethod + async def _get_queue_by_client_name(self, client_name: str) -> list: + pass + + @abc.abstractmethod + async def _update_license_queue(self, license_id: str, priority_list: list) -> None: + pass + + @abc.abstractmethod + async def _update_queues_on_start_download( + self, config_name: str, licenses: list + ) -> None: + pass + + @abc.abstractmethod + async def _update_queues_on_stop_download(self, config_name: str) -> None: + pass + + @abc.abstractmethod + async def _update_config_priority_in_license( + self, license_id: str, config_name: str, priority: int + ) -> None: + pass + + @abc.abstractmethod + async def _update_client_name_in_license_queue( + self, license_id: str, client_name: str + ) -> None: + pass + + +class QueueHandlerMock(QueueHandler): + + def __init__(self): + pass + + async def _create_license_queue(self, license_id: str, client_name: str) -> None: + logger.info( + f"Added {license_id} queue in 'queues' collection. Update_time: 000000." + ) + + async def _remove_license_queue(self, license_id: str) -> None: + logger.info( + f"Removed {license_id} queue in 'queues' collection. Update_time: 000000." + ) + + async def _get_queues(self) -> list: + return [{"client_name": "dummy_client", "license_id": "L1", "queue": []}] + + async def _get_queue_by_license_id(self, license_id: str) -> dict: + if license_id == "not_exist": + return None + return {"client_name": "dummy_client", "license_id": license_id, "queue": []} + + async def _get_queue_by_client_name(self, client_name: str) -> list: + return [{"client_name": client_name, "license_id": "L1", "queue": []}] + + async def _update_license_queue(self, license_id: str, priority_list: list) -> None: + logger.info( + f"Updated {license_id} queue in 'queues' collection. Update_time: 00000." + ) + + async def _update_queues_on_start_download( + self, config_name: str, licenses: list + ) -> None: + logger.info( + f"Updated {license} queue in 'queues' collection. Update_time: 00000." + ) + + async def _update_queues_on_stop_download(self, config_name: str) -> None: + logger.info( + "Updated snapshot.id queue in 'queues' collection. Update_time: 00000." + ) + + async def _update_config_priority_in_license( + self, license_id: str, config_name: str, priority: int + ) -> None: + logger.info( + "Updated snapshot.id queue in 'queues' collection. Update_time: 00000." + ) + + async def _update_client_name_in_license_queue( + self, license_id: str, client_name: str + ) -> None: + logger.info( + "Updated snapshot.id queue in 'queues' collection. Update_time: 00000." + ) + + +class QueueHandlerFirestore(QueueHandler): + + def __init__(self, db: firestore.firestore.Client): + self.db = db + self.collection = get_config().queues_collection + + async def _create_license_queue(self, license_id: str, client_name: str) -> None: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license_id) + .set({"license_id": license_id, "client_name": client_name, "queue": []}) + ) + logger.info( + f"Added {license_id} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + async def _remove_license_queue(self, license_id: str) -> None: + timestamp = ( + await self.db.collection(self.collection).document(license_id).delete() + ) + logger.info( + f"Removed {license_id} queue in 'queues' collection. Update_time: {timestamp}." + ) + + async def _get_queues(self) -> list: + docs = self.db.collection(self.collection).stream() + return [doc.to_dict() async for doc in docs] + + async def _get_queue_by_license_id(self, license_id: str) -> dict: + result: DocumentSnapshot = ( + await self.db.collection(self.collection).document(license_id).get() + ) + return result.to_dict() + + async def _get_queue_by_client_name(self, client_name: str) -> list: + docs = ( + self.db.collection(self.collection) + .where(filter=FieldFilter("client_name", "==", client_name)) + .stream() + ) + return [doc.to_dict() async for doc in docs] + + async def _update_license_queue(self, license_id: str, priority_list: list) -> None: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license_id) + .update({"queue": priority_list}) + ) + logger.info( + f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + async def _update_queues_on_start_download( + self, config_name: str, licenses: list + ) -> None: + for license in licenses: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license) + .update({"queue": firestore.ArrayUnion([config_name])}) + ) + logger.info( + f"Updated {license} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + async def _update_queues_on_stop_download(self, config_name: str) -> None: + snapshot_list = await self.db.collection(self.collection).get() + for snapshot in snapshot_list: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(snapshot.id) + .update({"queue": firestore.ArrayRemove([config_name])}) + ) + logger.info( + f"Updated {snapshot.id} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + async def _update_config_priority_in_license( + self, license_id: str, config_name: str, priority: int + ) -> None: + snapshot: DocumentSnapshot = ( + await self.db.collection(self.collection).document(license_id).get() + ) + priority_list = snapshot.to_dict()["queue"] + new_priority_list = [c for c in priority_list if c != config_name] + new_priority_list.insert(priority, config_name) + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license_id) + .update({"queue": new_priority_list}) + ) + logger.info( + f"Updated {snapshot.id} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + async def _update_client_name_in_license_queue( + self, license_id: str, client_name: str + ) -> None: + result: WriteResult = ( + await self.db.collection(self.collection) + .document(license_id) + .update({"client_name": client_name}) + ) + logger.info( + f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." + ) diff --git a/weather_dl_v2/fastapi-server/database/session.py b/weather_dl_v2/fastapi-server/database/session.py new file mode 100644 index 00000000..85dbc8be --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/session.py @@ -0,0 +1,79 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import abc +import logging +import firebase_admin +from google.cloud import firestore +from firebase_admin import credentials +from config_processing.util import get_wait_interval +from server_config import get_config +from gcloud import storage + +logger = logging.getLogger(__name__) + + +class Database(abc.ABC): + + @abc.abstractmethod + def _get_db(self): + pass + + +db: firestore.AsyncClient = None +gcs: storage.Client = None + + +def get_async_client() -> firestore.AsyncClient: + global db + attempts = 0 + + while db is None: + try: + db = firestore.AsyncClient() + except ValueError as e: + # The above call will fail with a value error when the firebase app is not initialized. + # Initialize the app here, and try again. + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + logger.info("Initialized Firebase App.") + + if attempts > 4: + raise RuntimeError( + "Exceeded number of retries to get firestore client." + ) from e + + time.sleep(get_wait_interval(attempts)) + + attempts += 1 + + return db + + +def get_gcs_client() -> storage.Client: + global gcs + + if gcs: + return gcs + + try: + gcs = storage.Client(project=get_config().gcs_project) + except ValueError as e: + logger.error(f"Error initializing GCS client: {e}.") + + return gcs diff --git a/weather_dl_v2/fastapi-server/database/storage_handler.py b/weather_dl_v2/fastapi-server/database/storage_handler.py new file mode 100644 index 00000000..fcdf6a1a --- /dev/null +++ b/weather_dl_v2/fastapi-server/database/storage_handler.py @@ -0,0 +1,77 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import os +import logging +import tempfile +import contextlib +import typing as t +from google.cloud import storage +from database.session import get_gcs_client +from server_config import get_config + + +logger = logging.getLogger(__name__) + + +def get_storage_handler(): + return StorageHandlerGCS(client=get_gcs_client()) + + +class StorageHandler(abc.ABC): + + @abc.abstractmethod + def _upload_file(self, file_path) -> str: + pass + + @abc.abstractmethod + def _open_local(self, file_name) -> t.Iterator[str]: + pass + + +class StorageHandlerMock(StorageHandler): + + def __init__(self) -> None: + pass + + def _upload_file(self, file_path) -> None: + pass + + def _open_local(self, file_name) -> t.Iterator[str]: + pass + + +class StorageHandlerGCS(StorageHandler): + + def __init__(self, client: storage.Client) -> None: + self.client = client + self.bucket = self.client.get_bucket(get_config().storage_bucket) + + def _upload_file(self, file_path) -> str: + filename = os.path.basename(file_path).split("/")[-1] + + blob = self.bucket.blob(filename) + blob.upload_from_filename(file_path) + + logger.info(f"Uploaded {filename} to {self.bucket}.") + return blob.public_url + + @contextlib.contextmanager + def _open_local(self, file_name) -> t.Iterator[str]: + blob = self.bucket.blob(file_name) + with tempfile.NamedTemporaryFile() as dest_file: + blob.download_to_filename(dest_file.name) + yield dest_file.name diff --git a/weather_dl_v2/fastapi-server/environment.yml b/weather_dl_v2/fastapi-server/environment.yml new file mode 100644 index 00000000..a6ce07fb --- /dev/null +++ b/weather_dl_v2/fastapi-server/environment.yml @@ -0,0 +1,18 @@ +name: weather-dl-v2-server +channels: + - conda-forge +dependencies: + - python=3.10 + - xarray + - geojson + - pip=22.3 + - google-cloud-sdk=410.0.0 + - pip: + - kubernetes + - fastapi[all]==0.97.0 + - python-multipart + - numpy + - apache-beam[gcp] + - aiohttp + - firebase-admin + - gcloud diff --git a/weather_dl_v2/fastapi-server/example.cfg b/weather_dl_v2/fastapi-server/example.cfg new file mode 100644 index 00000000..6747012c --- /dev/null +++ b/weather_dl_v2/fastapi-server/example.cfg @@ -0,0 +1,32 @@ +[parameters] +client=mars + +target_path=gs:///test-weather-dl-v2/{date}T00z.gb +partition_keys= + date + # step + +# API Keys & Subsections go here... + +[selection] +class=od +type=pf +stream=enfo +expver=0001 +levtype=pl +levelist=100 +# params: +# (z) Geopotential 129, (t) Temperature 130, +# (u) U component of wind 131, (v) V component of wind 132, +# (q) Specific humidity 133, (w) vertical velocity 135, +# (vo) Vorticity (relative) 138, (d) Divergence 155, +# (r) Relative humidity 157 +param=129.128 +# +# next: 2019-01-01/to/existing +# +date=2019-07-18/to/2019-07-20 +time=0000 +step=0/to/2 +number=1/to/2 +grid=F640 diff --git a/weather_dl_v2/fastapi-server/license_dep/deployment_creator.py b/weather_dl_v2/fastapi-server/license_dep/deployment_creator.py new file mode 100644 index 00000000..f79521d2 --- /dev/null +++ b/weather_dl_v2/fastapi-server/license_dep/deployment_creator.py @@ -0,0 +1,67 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from os import path +import yaml +from kubernetes import client, config +from server_config import get_config + +logger = logging.getLogger(__name__) + + +def create_license_deployment(license_id: str) -> str: + """Creates a kubernetes workflow of type Job for downloading the data.""" + config.load_config() + + with open(path.join(path.dirname(__file__), "license_deployment.yaml")) as f: + deployment_manifest = yaml.safe_load(f) + deployment_name = f"weather-dl-v2-license-dep-{license_id}".lower() + + # Update the deployment name with a unique identifier + deployment_manifest["metadata"]["name"] = deployment_name + deployment_manifest["spec"]["template"]["spec"]["containers"][0]["args"] = [ + "--license", + license_id, + ] + deployment_manifest["spec"]["template"]["spec"]["containers"][0][ + "image" + ] = get_config().license_deployment_image + + # Create an instance of the Kubernetes API client + api_instance = client.AppsV1Api() + # Create the deployment in the specified namespace + response = api_instance.create_namespaced_deployment( + body=deployment_manifest, namespace="default" + ) + + logger.info(f"Deployment created successfully: {response.metadata.name}.") + return deployment_name + + +def terminate_license_deployment(license_id: str) -> None: + # Load Kubernetes configuration + config.load_config() + + # Create an instance of the Kubernetes API client + api_instance = client.AppsV1Api() + + # Specify the name and namespace of the deployment to delete + deployment_name = f"weather-dl-v2-license-dep-{license_id}".lower() + + # Delete the deployment + api_instance.delete_namespaced_deployment(name=deployment_name, namespace="default") + + logger.info(f"Deployment '{deployment_name}' deleted successfully.") diff --git a/weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml b/weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml new file mode 100644 index 00000000..707e5b91 --- /dev/null +++ b/weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml @@ -0,0 +1,35 @@ +# weather-dl-v2-license-dep Deployment +# Defines the deployment of the app running in a pod on any worker node +apiVersion: apps/v1 +kind: Deployment +metadata: + name: weather-dl-v2-license-dep + labels: + app: weather-dl-v2-license-dep +spec: + replicas: 1 + selector: + matchLabels: + app: weather-dl-v2-license-dep + template: + metadata: + labels: + app: weather-dl-v2-license-dep + spec: + containers: + - name: weather-dl-v2-license-dep + image: XXXXXXX + imagePullPolicy: Always + args: [] + volumeMounts: + - name: config-volume + mountPath: ./config + volumes: + - name: config-volume + configMap: + name: dl-v2-config + # resources: + # # You must specify requests for CPU to autoscale + # # based on CPU utilization + # requests: + # cpu: "250m" \ No newline at end of file diff --git a/weather_dl_v2/fastapi-server/logging.conf b/weather_dl_v2/fastapi-server/logging.conf new file mode 100644 index 00000000..ed0a5e29 --- /dev/null +++ b/weather_dl_v2/fastapi-server/logging.conf @@ -0,0 +1,36 @@ +[loggers] +keys=root,server + +[handlers] +keys=consoleHandler,detailedConsoleHandler + +[formatters] +keys=normalFormatter,detailedFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[logger_server] +level=DEBUG +handlers=detailedConsoleHandler +qualname=server +propagate=0 + +[handler_consoleHandler] +class=StreamHandler +level=DEBUG +formatter=normalFormatter +args=(sys.stdout,) + +[handler_detailedConsoleHandler] +class=StreamHandler +level=DEBUG +formatter=detailedFormatter +args=(sys.stdout,) + +[formatter_normalFormatter] +format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() msg:%(message)s + +[formatter_detailedFormatter] +format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() msg:%(message)s call_trace=%(pathname)s L%(lineno)-4d \ No newline at end of file diff --git a/weather_dl_v2/fastapi-server/main.py b/weather_dl_v2/fastapi-server/main.py new file mode 100644 index 00000000..05124123 --- /dev/null +++ b/weather_dl_v2/fastapi-server/main.py @@ -0,0 +1,70 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import logging.config +from contextlib import asynccontextmanager +from fastapi import FastAPI +from routers import license, download, queues +from database.license_handler import get_license_handler +from routers.license import get_create_deployment +from server_config import get_config + +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# set up logger. +logging.config.fileConfig("logging.conf", disable_existing_loggers=False) +logger = logging.getLogger(__name__) + + +async def create_pending_license_deployments(): + """Creates license deployments for Licenses whose deployments does not exist.""" + license_handler = get_license_handler() + create_deployment = get_create_deployment() + license_list = await license_handler._get_license_without_deployment() + + for _license in license_list: + license_id = _license["license_id"] + try: + logger.info(f"Creating license deployment for {license_id}.") + await create_deployment(license_id, license_handler) + except Exception as e: + logger.error(f"License deployment failed for {license_id}. Exception: {e}.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Started FastAPI server.") + # Boot up + # Make directory to store the uploaded config files. + os.makedirs(os.path.join(os.getcwd(), "config_files"), exist_ok=True) + # Retrieve license information & create license deployment if needed. + await create_pending_license_deployments() + # TODO: Automatically create required indexes on firestore collections on server startup. + yield + # Clean up + + +app = FastAPI(lifespan=lifespan) + +app.include_router(license.router) +app.include_router(download.router) +app.include_router(queues.router) + + +@app.get("/") +async def main(): + return {"msg": get_config().welcome_message} diff --git a/weather_dl_v2/fastapi-server/routers/download.py b/weather_dl_v2/fastapi-server/routers/download.py new file mode 100644 index 00000000..e3de4b57 --- /dev/null +++ b/weather_dl_v2/fastapi-server/routers/download.py @@ -0,0 +1,386 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import os +import shutil +import json + +from enum import Enum +from config_processing.parsers import parse_config, process_config +from config_processing.config import Config +from fastapi import APIRouter, HTTPException, BackgroundTasks, UploadFile, Depends, Body +from config_processing.pipeline import start_processing_config +from database.download_handler import DownloadHandler, get_download_handler +from database.queue_handler import QueueHandler, get_queue_handler +from database.license_handler import LicenseHandler, get_license_handler +from database.manifest_handler import ManifestHandler, get_manifest_handler +from database.storage_handler import StorageHandler, get_storage_handler +from config_processing.manifest import FirestoreManifest, Manifest +from fastapi.concurrency import run_in_threadpool + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/download", + tags=["download"], + responses={404: {"description": "Not found"}}, +) + + +async def fetch_config_stats( + config_name: str, client_name: str, status: str, manifest_handler: ManifestHandler +): + """Get all the config stats parallely.""" + + success_coroutine = manifest_handler._get_download_success_count(config_name) + scheduled_coroutine = manifest_handler._get_download_scheduled_count(config_name) + failure_coroutine = manifest_handler._get_download_failure_count(config_name) + inprogress_coroutine = manifest_handler._get_download_inprogress_count(config_name) + total_coroutine = manifest_handler._get_download_total_count(config_name) + + ( + success_count, + scheduled_count, + failure_count, + inprogress_count, + total_count, + ) = await asyncio.gather( + success_coroutine, + scheduled_coroutine, + failure_coroutine, + inprogress_coroutine, + total_coroutine, + ) + + return { + "config_name": config_name, + "client_name": client_name, + "partitioning_status": status, + "downloaded_shards": success_count, + "scheduled_shards": scheduled_count, + "failed_shards": failure_count, + "in-progress_shards": inprogress_count, + "total_shards": total_count, + } + + +def get_fetch_config_stats(): + return fetch_config_stats + + +def get_fetch_config_stats_mock(): + async def fetch_config_stats( + config_name: str, client_name: str, status: str, manifest_handler: ManifestHandler + ): + return { + "config_name": config_name, + "client_name": client_name, + "downloaded_shards": 0, + "scheduled_shards": 0, + "failed_shards": 0, + "in-progress_shards": 0, + "total_shards": 0, + } + + return fetch_config_stats + + +def get_upload(): + def upload(file: UploadFile): + dest = os.path.join(os.getcwd(), "config_files", file.filename) + with open(dest, "wb+") as dest_: + shutil.copyfileobj(file.file, dest_) + + logger.info(f"Uploading {file.filename} to gcs bucket.") + storage_handler: StorageHandler = get_storage_handler() + storage_handler._upload_file(dest) + return dest + + return upload + + +def get_upload_mock(): + def upload(file: UploadFile): + return f"{os.getcwd()}/tests/test_data/{file.filename}" + + return upload + + +def get_reschedule_partitions(): + def invoke_manifest_schedule( + partition_list: list, config: Config, manifest: Manifest + ): + for partition in partition_list: + logger.info(f"Rescheduling partition {partition}.") + manifest.schedule( + config.config_name, + config.dataset, + json.loads(partition["selection"]), + partition["location"], + partition["username"], + ) + + async def reschedule_partitions(config_name: str, licenses: list): + manifest_handler: ManifestHandler = get_manifest_handler() + download_handler: DownloadHandler = get_download_handler() + queue_handler: QueueHandler = get_queue_handler() + storage_handler: StorageHandler = get_storage_handler() + + partition_list = await manifest_handler._get_non_successfull_downloads( + config_name + ) + + config = None + manifest = FirestoreManifest() + + with storage_handler._open_local(config_name) as local_path: + with open(local_path, "r", encoding="utf-8") as f: + config = process_config(f, config_name) + + await download_handler._mark_partitioning_status( + config_name, "Partitioning in-progress." + ) + + try: + if config is None: + logger.error( + f"Failed reschedule_partitions. Could not open {config_name}." + ) + raise FileNotFoundError( + f"Failed reschedule_partitions. Could not open {config_name}." + ) + + await run_in_threadpool( + invoke_manifest_schedule, partition_list, config, manifest + ) + await download_handler._mark_partitioning_status( + config_name, "Partitioning completed." + ) + await queue_handler._update_queues_on_start_download(config_name, licenses) + except Exception as e: + error_str = f"Partitioning failed for {config_name} due to {e}." + logger.error(error_str) + await download_handler._mark_partitioning_status(config_name, error_str) + + return reschedule_partitions + + +def get_reschedule_partitions_mock(): + def reschedule_partitions(config_name: str, licenses: list): + pass + + return reschedule_partitions + + +# Can submit a config to the server. +@router.post("/") +async def submit_download( + file: UploadFile | None = None, + licenses: list = [], + force_download: bool = False, + background_tasks: BackgroundTasks = BackgroundTasks(), + download_handler: DownloadHandler = Depends(get_download_handler), + license_handler: LicenseHandler = Depends(get_license_handler), + upload=Depends(get_upload), +): + if not file: + logger.error("No upload file sent.") + raise HTTPException(status_code=404, detail="No upload file sent.") + else: + if await download_handler._check_download_exists(file.filename): + logger.error( + f"Please stop the ongoing download of the config file '{file.filename}' " + "before attempting to start a new download." + ) + raise HTTPException( + status_code=400, + detail=f"Please stop the ongoing download of the config file '{file.filename}' " + "before attempting to start a new download.", + ) + + for license_id in licenses: + if not await license_handler._check_license_exists(license_id): + logger.info(f"No such license {license_id}.") + raise HTTPException( + status_code=404, detail=f"No such license {license_id}." + ) + try: + dest = upload(file) + # Start processing config. + background_tasks.add_task( + start_processing_config, dest, licenses, force_download + ) + return { + "message": f"file '{file.filename}' saved at '{dest}' successfully." + } + except Exception as e: + logger.error(f"Failed to save file '{file.filename} due to {e}.") + raise HTTPException( + status_code=500, detail=f"Failed to save file '{file.filename}'." + ) + + +class DownloadStatus(str, Enum): + COMPLETED = "completed" + FAILED = "failed" + IN_PROGRESS = "in-progress" + + +@router.get("/show/{config_name}") +async def show_download_config( + config_name: str, + download_handler: DownloadHandler = Depends(get_download_handler), + storage_handler: StorageHandler = Depends(get_storage_handler), +): + if not await download_handler._check_download_exists(config_name): + logger.error(f"No such download config {config_name} to show.") + raise HTTPException( + status_code=404, + detail=f"No such download config {config_name} to show.", + ) + + contents = None + + with storage_handler._open_local(config_name) as local_path: + with open(local_path, "r", encoding="utf-8") as f: + contents = parse_config(f) + logger.info(f"Contents of {config_name}: {contents}.") + + return {"config_name": config_name, "contents": contents} + + +# Can check the current status of the submitted config. +# List status for all the downloads + handle filters +@router.get("/") +async def get_downloads( + client_name: str | None = None, + status: DownloadStatus | None = None, + download_handler: DownloadHandler = Depends(get_download_handler), + manifest_handler: ManifestHandler = Depends(get_manifest_handler), + fetch_config_stats=Depends(get_fetch_config_stats), +): + downloads = await download_handler._get_downloads(client_name) + coroutines = [] + + for download in downloads: + coroutines.append( + fetch_config_stats( + download["config_name"], + download["client_name"], + download["status"], + manifest_handler, + ) + ) + + config_details = await asyncio.gather(*coroutines) + + if status is None: + return config_details + + if status.value == DownloadStatus.COMPLETED: + return list( + filter( + lambda detail: detail["downloaded_shards"] == detail["total_shards"], + config_details, + ) + ) + elif status.value == DownloadStatus.FAILED: + return list(filter(lambda detail: detail["failed_shards"] > 0, config_details)) + elif status.value == DownloadStatus.IN_PROGRESS: + return list( + filter( + lambda detail: detail["downloaded_shards"] != detail["total_shards"], + config_details, + ) + ) + else: + return config_details + + +# Get status of particular download +@router.get("/{config_name}") +async def get_download_by_config_name( + config_name: str, + download_handler: DownloadHandler = Depends(get_download_handler), + manifest_handler: ManifestHandler = Depends(get_manifest_handler), + fetch_config_stats=Depends(get_fetch_config_stats), +): + download = await download_handler._get_download_by_config_name(config_name) + + if download is None: + logger.error(f"Download config {config_name} not found in weather-dl v2.") + raise HTTPException( + status_code=404, + detail=f"Download config {config_name} not found in weather-dl v2.", + ) + + return await fetch_config_stats( + download["config_name"], + download["client_name"], + download["status"], + manifest_handler, + ) + + +# Stop & remove the execution of the config. +@router.delete("/{config_name}") +async def delete_download( + config_name: str, + download_handler: DownloadHandler = Depends(get_download_handler), + queue_handler: QueueHandler = Depends(get_queue_handler), +): + if not await download_handler._check_download_exists(config_name): + logger.error(f"No such download config {config_name} to stop & remove.") + raise HTTPException( + status_code=404, + detail=f"No such download config {config_name} to stop & remove.", + ) + + await download_handler._stop_download(config_name) + await queue_handler._update_queues_on_stop_download(config_name) + return { + "config_name": config_name, + "message": "Download config stopped & removed successfully.", + } + + +@router.post("/retry/{config_name}") +async def retry_config( + config_name: str, + licenses: list = Body(embed=True), + background_tasks: BackgroundTasks = BackgroundTasks(), + download_handler: DownloadHandler = Depends(get_download_handler), + license_handler: LicenseHandler = Depends(get_license_handler), + reschedule_partitions=Depends(get_reschedule_partitions), +): + if not await download_handler._check_download_exists(config_name): + logger.error(f"No such download config {config_name} to retry.") + raise HTTPException( + status_code=404, + detail=f"No such download config {config_name} to retry.", + ) + + for license_id in licenses: + if not await license_handler._check_license_exists(license_id): + logger.info(f"No such license {license_id}.") + raise HTTPException( + status_code=404, detail=f"No such license {license_id}." + ) + + background_tasks.add_task(reschedule_partitions, config_name, licenses) + + return {"msg": "Refetch initiated successfully."} diff --git a/weather_dl_v2/fastapi-server/routers/license.py b/weather_dl_v2/fastapi-server/routers/license.py new file mode 100644 index 00000000..05ac5139 --- /dev/null +++ b/weather_dl_v2/fastapi-server/routers/license.py @@ -0,0 +1,202 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import re +from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends +from pydantic import BaseModel +from license_dep.deployment_creator import create_license_deployment, terminate_license_deployment +from database.license_handler import LicenseHandler, get_license_handler +from database.queue_handler import QueueHandler, get_queue_handler + +logger = logging.getLogger(__name__) + + +class License(BaseModel): + license_id: str + client_name: str + number_of_requests: int + secret_id: str + + +class LicenseInternal(License): + k8s_deployment_id: str + + +# Can perform CRUD on license table -- helps in handling API KEY expiry. +router = APIRouter( + prefix="/license", + tags=["license"], + responses={404: {"description": "Not found"}}, +) + + +# Add/Update k8s deployment ID for existing license (intenally). +async def update_license_internal( + license_id: str, + k8s_deployment_id: str, + license_handler: LicenseHandler, +): + if not await license_handler._check_license_exists(license_id): + logger.info(f"No such license {license_id} to update.") + raise HTTPException( + status_code=404, detail=f"No such license {license_id} to update." + ) + license_dict = {"k8s_deployment_id": k8s_deployment_id} + + await license_handler._update_license(license_id, license_dict) + return {"license_id": license_id, "message": "License updated successfully."} + + +def get_create_deployment(): + async def create_deployment(license_id: str, license_handler: LicenseHandler): + k8s_deployment_id = create_license_deployment(license_id) + await update_license_internal(license_id, k8s_deployment_id, license_handler) + + return create_deployment + + +def get_create_deployment_mock(): + async def create_deployment_mock(license_id: str, license_handler: LicenseHandler): + logger.info("create deployment mock.") + + return create_deployment_mock + + +def get_terminate_license_deployment(): + return terminate_license_deployment + + +def get_terminate_license_deployment_mock(): + def get_terminate_license_deployment_mock(license_id): + logger.info(f"terminating license deployment for {license_id}.") + + return get_terminate_license_deployment_mock + + +# List all the license + handle filters of {client_name} +@router.get("/") +async def get_licenses( + client_name: str | None = None, + license_handler: LicenseHandler = Depends(get_license_handler), +): + if client_name: + result = await license_handler._get_license_by_client_name(client_name) + else: + result = await license_handler._get_licenses() + return result + + +# Get particular license +@router.get("/{license_id}") +async def get_license_by_license_id( + license_id: str, license_handler: LicenseHandler = Depends(get_license_handler) +): + result = await license_handler._get_license_by_license_id(license_id) + if not result: + logger.info(f"License {license_id} not found.") + raise HTTPException(status_code=404, detail=f"License {license_id} not found.") + return result + + +# Update existing license +@router.put("/{license_id}") +async def update_license( + license_id: str, + license: License, + license_handler: LicenseHandler = Depends(get_license_handler), + queue_handler: QueueHandler = Depends(get_queue_handler), + create_deployment=Depends(get_create_deployment), + terminate_license_deployment=Depends(get_terminate_license_deployment), +): + if not await license_handler._check_license_exists(license_id): + logger.error(f"No such license {license_id} to update.") + raise HTTPException( + status_code=404, detail=f"No such license {license_id} to update." + ) + + license_dict = license.dict() + await license_handler._update_license(license_id, license_dict) + await queue_handler._update_client_name_in_license_queue( + license_id, license_dict["client_name"] + ) + + terminate_license_deployment(license_id) + await create_deployment(license_id, license_handler) + return {"license_id": license_id, "name": "License updated successfully."} + + +# Add new license +@router.post("/") +async def add_license( + license: License, + background_tasks: BackgroundTasks = BackgroundTasks(), + license_handler: LicenseHandler = Depends(get_license_handler), + queue_handler: QueueHandler = Depends(get_queue_handler), + create_deployment=Depends(get_create_deployment), +): + license_id = license.license_id.lower() + + # Check if license id is in correct format. + LICENSE_REGEX = re.compile( + r"[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*" + ) + if not bool(LICENSE_REGEX.fullmatch(license_id)): + logger.error( + """Invalid format for license_id. License id must consist of lower case alphanumeric""" + """ characters, '-' or '.', and must start and end with an alphanumeric character""" + ) + raise HTTPException( + status_code=400, + detail="""Invalid format for license_id. License id must consist of lower case alphanumeric""" + """ characters, '-' or '.', and must start and end with an alphanumeric character""", + ) + + if await license_handler._check_license_exists(license_id): + logger.error(f"License with license_id {license_id} already exist.") + raise HTTPException( + status_code=409, + detail=f"License with license_id {license_id} already exist.", + ) + + license_dict = license.dict() + license_dict["k8s_deployment_id"] = "" + license_id = await license_handler._add_license(license_dict) + await queue_handler._create_license_queue(license_id, license_dict["client_name"]) + background_tasks.add_task(create_deployment, license_id, license_handler) + return {"license_id": license_id, "message": "License added successfully."} + + +# Remove license +@router.delete("/{license_id}") +async def delete_license( + license_id: str, + background_tasks: BackgroundTasks = BackgroundTasks(), + license_handler: LicenseHandler = Depends(get_license_handler), + queue_handler: QueueHandler = Depends(get_queue_handler), + terminate_license_deployment=Depends(get_terminate_license_deployment), +): + if not await license_handler._check_license_exists(license_id): + logger.error(f"No such license {license_id} to delete.") + raise HTTPException( + status_code=404, detail=f"No such license {license_id} to delete." + ) + await license_handler._delete_license(license_id) + await queue_handler._remove_license_queue(license_id) + background_tasks.add_task(terminate_license_deployment, license_id) + return {"license_id": license_id, "message": "License removed successfully."} + + +# TODO: Add route to re-deploy license deployments. diff --git a/weather_dl_v2/fastapi-server/routers/queues.py b/weather_dl_v2/fastapi-server/routers/queues.py new file mode 100644 index 00000000..eda6a7c5 --- /dev/null +++ b/weather_dl_v2/fastapi-server/routers/queues.py @@ -0,0 +1,124 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from fastapi import APIRouter, HTTPException, Depends +from database.queue_handler import QueueHandler, get_queue_handler +from database.license_handler import LicenseHandler, get_license_handler +from database.download_handler import DownloadHandler, get_download_handler + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/queues", + tags=["queues"], + responses={404: {"description": "Not found"}}, +) + + +# Users can change the execution order of config per license basis. +# List the licenses priority + {client_name} filter +@router.get("/") +async def get_all_license_queue( + client_name: str | None = None, + queue_handler: QueueHandler = Depends(get_queue_handler), +): + if client_name: + result = await queue_handler._get_queue_by_client_name(client_name) + else: + result = await queue_handler._get_queues() + return result + + +# Get particular license priority +@router.get("/{license_id}") +async def get_license_queue( + license_id: str, queue_handler: QueueHandler = Depends(get_queue_handler) +): + result = await queue_handler._get_queue_by_license_id(license_id) + if not result: + logger.error(f"License priority for {license_id} not found.") + raise HTTPException( + status_code=404, detail=f"License priority for {license_id} not found." + ) + return result + + +# Change priority queue of particular license +@router.post("/{license_id}") +async def modify_license_queue( + license_id: str, + priority_list: list | None = [], + queue_handler: QueueHandler = Depends(get_queue_handler), + license_handler: LicenseHandler = Depends(get_license_handler), + download_handler: DownloadHandler = Depends(get_download_handler), +): + if not await license_handler._check_license_exists(license_id): + logger.error(f"License {license_id} not found.") + raise HTTPException(status_code=404, detail=f"License {license_id} not found.") + + for config_name in priority_list: + config = await download_handler._get_download_by_config_name(config_name) + if config is None: + logger.error(f"Download config {config_name} not found in weather-dl v2.") + raise HTTPException( + status_code=404, + detail=f"Download config {config_name} not found in weather-dl v2.", + ) + try: + await queue_handler._update_license_queue(license_id, priority_list) + return {"message": f"'{license_id}' license priority updated successfully."} + except Exception as e: + logger.error(f"Failed to update '{license_id}' license priority due to {e}.") + raise HTTPException( + status_code=404, detail=f"Failed to update '{license_id}' license priority." + ) + + +# Change config's priority in particular license +@router.put("/priority/{license_id}") +async def modify_config_priority_in_license( + license_id: str, + config_name: str, + priority: int, + queue_handler: QueueHandler = Depends(get_queue_handler), + license_handler: LicenseHandler = Depends(get_license_handler), + download_handler: DownloadHandler = Depends(get_download_handler), +): + if not await license_handler._check_license_exists(license_id): + logger.error(f"License {license_id} not found.") + raise HTTPException(status_code=404, detail=f"License {license_id} not found.") + + config = await download_handler._get_download_by_config_name(config_name) + if config is None: + logger.error(f"Download config {config_name} not found in weather-dl v2.") + raise HTTPException( + status_code=404, + detail=f"Download config {config_name} not found in weather-dl v2.", + ) + + try: + await queue_handler._update_config_priority_in_license( + license_id, config_name, priority + ) + return { + "message": f"'{license_id}' license -- '{config_name}' priority updated successfully." + } + except Exception as e: + logger.error(f"Failed to update '{license_id}' license priority due to {e}.") + raise HTTPException( + status_code=404, detail=f"Failed to update '{license_id}' license priority." + ) diff --git a/weather_dl_v2/fastapi-server/server.yaml b/weather_dl_v2/fastapi-server/server.yaml new file mode 100644 index 00000000..b8a2f40d --- /dev/null +++ b/weather_dl_v2/fastapi-server/server.yaml @@ -0,0 +1,93 @@ +# Due to our org level policy we can't expose external-ip. +# In case your project don't have any such restriction a +# then no need to create a nginx-server on VM to access this fastapi server +# instead create the LoadBalancer Service given below. +# +# # weather-dl server LoadBalancer Service +# # Enables the pods in a deployment to be accessible from outside the cluster +# apiVersion: v1 +# kind: Service +# metadata: +# name: weather-dl-v2-server-service +# spec: +# selector: +# app: weather-dl-v2-server-api +# ports: +# - protocol: "TCP" +# port: 8080 +# targetPort: 8080 +# type: LoadBalancer + +--- +# weather-dl-server-api Deployment +# Defines the deployment of the app running in a pod on any worker node +apiVersion: apps/v1 +kind: Deployment +metadata: + name: weather-dl-v2-server-api + labels: + app: weather-dl-v2-server-api +spec: + replicas: 1 + selector: + matchLabels: + app: weather-dl-v2-server-api + template: + metadata: + labels: + app: weather-dl-v2-server-api + spec: + containers: + - name: weather-dl-v2-server-api + image: XXXXXXX + ports: + - containerPort: 8080 + imagePullPolicy: Always + volumeMounts: + - name: config-volume + mountPath: ./config + volumes: + - name: config-volume + configMap: + name: dl-v2-config + # resources: + # # You must specify requests for CPU to autoscale + # # based on CPU utilization + # requests: + # cpu: "250m" +--- +kind: Role +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: weather-dl-v2-server-api +rules: + - apiGroups: + - "" + - "apps" + - "batch" + resources: + - endpoints + - deployments + - pods + - jobs + verbs: + - get + - list + - watch + - create + - delete +--- +kind: RoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: weather-dl-v2-server-api + namespace: default +subjects: + - kind: ServiceAccount + name: default + namespace: default +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: weather-dl-v2-server-api +--- \ No newline at end of file diff --git a/weather_dl_v2/fastapi-server/server_config.py b/weather_dl_v2/fastapi-server/server_config.py new file mode 100644 index 00000000..4ca8c21b --- /dev/null +++ b/weather_dl_v2/fastapi-server/server_config.py @@ -0,0 +1,72 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import dataclasses +import typing as t +import json +import os +import logging + +logger = logging.getLogger(__name__) + +Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class ServerConfig: + download_collection: str = "" + queues_collection: str = "" + license_collection: str = "" + manifest_collection: str = "" + storage_bucket: str = "" + gcs_project: str = "" + license_deployment_image: str = "" + welcome_message: str = "" + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict): + config_instance = cls() + + for key, value in config.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + + return config_instance + + +server_config = None + + +def get_config(): + global server_config + if server_config: + return server_config + + server_config_json = "config/config.json" + if not os.path.exists(server_config_json): + server_config_json = os.environ.get("CONFIG_PATH", None) + + if server_config_json is None: + logger.error("Couldn't load config file for fastAPI server.") + raise FileNotFoundError("Couldn't load config file for fastAPI server.") + + with open(server_config_json) as file: + config_dict = json.load(file) + server_config = ServerConfig.from_dict(config_dict) + + return server_config diff --git a/weather_dl_v2/fastapi-server/tests/__init__.py b/weather_dl_v2/fastapi-server/tests/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/fastapi-server/tests/integration/__init__.py b/weather_dl_v2/fastapi-server/tests/integration/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/integration/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/fastapi-server/tests/integration/test_download.py b/weather_dl_v2/fastapi-server/tests/integration/test_download.py new file mode 100644 index 00000000..fc707d10 --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/integration/test_download.py @@ -0,0 +1,175 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +from fastapi.testclient import TestClient +from main import app, ROOT_DIR +from database.download_handler import get_download_handler, get_mock_download_handler +from database.license_handler import get_license_handler, get_mock_license_handler +from database.queue_handler import get_queue_handler, get_mock_queue_handler +from routers.download import get_upload, get_upload_mock, get_fetch_config_stats, get_fetch_config_stats_mock + +client = TestClient(app) + +logger = logging.getLogger(__name__) + +app.dependency_overrides[get_download_handler] = get_mock_download_handler +app.dependency_overrides[get_license_handler] = get_mock_license_handler +app.dependency_overrides[get_queue_handler] = get_mock_queue_handler +app.dependency_overrides[get_upload] = get_upload_mock +app.dependency_overrides[get_fetch_config_stats] = get_fetch_config_stats_mock + + +def _get_download(headers, query, code, expected): + response = client.get("/download", headers=headers, params=query) + + assert response.status_code == code + assert response.json() == expected + + +def test_get_downloads_basic(): + headers = {} + query = {} + code = 200 + expected = [{ + "config_name": "example.cfg", + "client_name": "client", + "downloaded_shards": 0, + "scheduled_shards": 0, + "failed_shards": 0, + "in-progress_shards": 0, + "total_shards": 0, + }] + + _get_download(headers, query, code, expected) + + +def _submit_download(headers, file_path, licenses, code, expected): + file = None + try: + file = {"file": open(file_path, "rb")} + except FileNotFoundError: + logger.info("file not found.") + + payload = {"licenses": licenses} + + response = client.post("/download", headers=headers, files=file, data=payload) + + logger.info(f"resp {response.json()}") + + assert response.status_code == code + assert response.json() == expected + + +def test_submit_download_basic(): + header = { + "accept": "application/json", + } + file_path = os.path.join(ROOT_DIR, "tests/test_data/not_exist.cfg") + licenses = ["L1"] + code = 200 + expected = { + "message": f"file 'not_exist.cfg' saved at '{os.getcwd()}/tests/test_data/not_exist.cfg' " + "successfully." + } + + _submit_download(header, file_path, licenses, code, expected) + + +def test_submit_download_file_not_uploaded(): + header = { + "accept": "application/json", + } + file_path = os.path.join(ROOT_DIR, "tests/test_data/wrong_file.cfg") + licenses = ["L1"] + code = 404 + expected = {"detail": "No upload file sent."} + + _submit_download(header, file_path, licenses, code, expected) + + +def test_submit_download_file_alreadys_exist(): + header = { + "accept": "application/json", + } + file_path = os.path.join(ROOT_DIR, "tests/test_data/example.cfg") + licenses = ["L1"] + code = 400 + expected = { + "detail": "Please stop the ongoing download of the config file 'example.cfg' before attempting to start a new download." # noqa: E501 + } + + _submit_download(header, file_path, licenses, code, expected) + + +def _get_download_by_config(headers, config_name, code, expected): + response = client.get(f"/download/{config_name}", headers=headers) + + assert response.status_code == code + assert response.json() == expected + + +def test_get_download_by_config_basic(): + headers = {} + config_name = "example.cfg" + code = 200 + expected = { + "config_name": config_name, + "client_name": "client", + "downloaded_shards": 0, + "scheduled_shards": 0, + "failed_shards": 0, + "in-progress_shards": 0, + "total_shards": 0, + } + + _get_download_by_config(headers, config_name, code, expected) + + +def test_get_download_by_config_wrong_config(): + headers = {} + config_name = "not_exist" + code = 404 + expected = {"detail": "Download config not_exist not found in weather-dl v2."} + + _get_download_by_config(headers, config_name, code, expected) + + +def _delete_download_by_config(headers, config_name, code, expected): + response = client.delete(f"/download/{config_name}", headers=headers) + assert response.status_code == code + assert response.json() == expected + + +def test_delete_download_by_config_basic(): + headers = {} + config_name = "dummy_config" + code = 200 + expected = { + "config_name": "dummy_config", + "message": "Download config stopped & removed successfully.", + } + + _delete_download_by_config(headers, config_name, code, expected) + + +def test_delete_download_by_config_wrong_config(): + headers = {} + config_name = "not_exist" + code = 404 + expected = {"detail": "No such download config not_exist to stop & remove."} + + _delete_download_by_config(headers, config_name, code, expected) diff --git a/weather_dl_v2/fastapi-server/tests/integration/test_license.py b/weather_dl_v2/fastapi-server/tests/integration/test_license.py new file mode 100644 index 00000000..f4a5dea7 --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/integration/test_license.py @@ -0,0 +1,207 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import json +from fastapi.testclient import TestClient +from main import app +from database.download_handler import get_download_handler, get_mock_download_handler +from database.license_handler import get_license_handler, get_mock_license_handler +from routers.license import ( + get_create_deployment, + get_create_deployment_mock, + get_terminate_license_deployment, + get_terminate_license_deployment_mock, +) +from database.queue_handler import get_queue_handler, get_mock_queue_handler + +client = TestClient(app) + +logger = logging.getLogger(__name__) + +app.dependency_overrides[get_download_handler] = get_mock_download_handler +app.dependency_overrides[get_license_handler] = get_mock_license_handler +app.dependency_overrides[get_queue_handler] = get_mock_queue_handler +app.dependency_overrides[get_create_deployment] = get_create_deployment_mock +app.dependency_overrides[ + get_terminate_license_deployment +] = get_terminate_license_deployment_mock + + +def _get_license(headers, query, code, expected): + response = client.get("/license", headers=headers, params=query) + + assert response.status_code == code + assert response.json() == expected + + +def test_get_license_basic(): + headers = {} + query = {} + code = 200 + expected = [{ + "license_id": "L1", + "secret_id": "xxxx", + "client_name": "dummy_client", + "k8s_deployment_id": "k1", + "number_of_requets": 100, + }] + + _get_license(headers, query, code, expected) + + +def test_get_license_client_name(): + headers = {} + client_name = "dummy_client" + query = {"client_name": client_name} + code = 200 + expected = [{ + "license_id": "L1", + "secret_id": "xxxx", + "client_name": client_name, + "k8s_deployment_id": "k1", + "number_of_requets": 100, + }] + + _get_license(headers, query, code, expected) + + +def _add_license(headers, payload, code, expected): + response = client.post( + "/license", + headers=headers, + data=json.dumps(payload), + params={"license_id": "L1"}, + ) + + print(f"test add license {response.json()}") + + assert response.status_code == code + assert response.json() == expected + + +def test_add_license_basic(): + headers = {"accept": "application/json", "Content-Type": "application/json"} + license = { + "license_id": "no-exists", + "client_name": "dummy_client", + "number_of_requests": 0, + "secret_id": "xxxx", + } + payload = license + code = 200 + expected = {"license_id": "L1", "message": "License added successfully."} + + _add_license(headers, payload, code, expected) + + +def _get_license_by_license_id(headers, license_id, code, expected): + response = client.get(f"/license/{license_id}", headers=headers) + + logger.info(f"response {response.json()}") + assert response.status_code == code + assert response.json() == expected + + +def test_get_license_by_license_id(): + headers = {"accept": "application/json", "Content-Type": "application/json"} + license_id = "L1" + code = 200 + expected = { + "license_id": license_id, + "secret_id": "xxxx", + "client_name": "dummy_client", + "k8s_deployment_id": "k1", + "number_of_requets": 100, + } + + _get_license_by_license_id(headers, license_id, code, expected) + + +def test_get_license_wrong_license(): + headers = {} + license_id = "not_exist" + code = 404 + expected = { + "detail": "License not_exist not found.", + } + + _get_license_by_license_id(headers, license_id, code, expected) + + +def _update_license(headers, license_id, license, code, expected): + response = client.put( + f"/license/{license_id}", headers=headers, data=json.dumps(license) + ) + + print(f"_update license {response.json()}") + + assert response.status_code == code + assert response.json() == expected + + +def test_update_license_basic(): + headers = {} + license_id = "L1" + license = { + "license_id": "L1", + "client_name": "dummy_client", + "number_of_requests": 0, + "secret_id": "xxxx", + } + code = 200 + expected = {"license_id": license_id, "name": "License updated successfully."} + + _update_license(headers, license_id, license, code, expected) + + +def test_update_license_wrong_license_id(): + headers = {} + license_id = "no-exists" + license = { + "license_id": "no-exists", + "client_name": "dummy_client", + "number_of_requests": 0, + "secret_id": "xxxx", + } + code = 404 + expected = {"detail": "No such license no-exists to update."} + + _update_license(headers, license_id, license, code, expected) + + +def _delete_license(headers, license_id, code, expected): + response = client.delete(f"/license/{license_id}", headers=headers) + + assert response.status_code == code + assert response.json() == expected + + +def test_delete_license_basic(): + headers = {} + license_id = "L1" + code = 200 + expected = {"license_id": license_id, "message": "License removed successfully."} + + _delete_license(headers, license_id, code, expected) + + +def test_delete_license_wrong_license(): + headers = {} + license_id = "not_exist" + code = 404 + expected = {"detail": "No such license not_exist to delete."} + + _delete_license(headers, license_id, code, expected) diff --git a/weather_dl_v2/fastapi-server/tests/integration/test_queues.py b/weather_dl_v2/fastapi-server/tests/integration/test_queues.py new file mode 100644 index 00000000..5fa7855a --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/integration/test_queues.py @@ -0,0 +1,148 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from main import app +from fastapi.testclient import TestClient +from database.download_handler import get_download_handler, get_mock_download_handler +from database.license_handler import get_license_handler, get_mock_license_handler +from database.queue_handler import get_queue_handler, get_mock_queue_handler + +client = TestClient(app) + +logger = logging.getLogger(__name__) + +app.dependency_overrides[get_download_handler] = get_mock_download_handler +app.dependency_overrides[get_license_handler] = get_mock_license_handler +app.dependency_overrides[get_queue_handler] = get_mock_queue_handler + + +def _get_all_queue(headers, query, code, expected): + response = client.get("/queues", headers=headers, params=query) + + assert response.status_code == code + assert response.json() == expected + + +def test_get_all_queues(): + headers = {} + query = {} + code = 200 + expected = [{"client_name": "dummy_client", "license_id": "L1", "queue": []}] + + _get_all_queue(headers, query, code, expected) + + +def test_get_client_queues(): + headers = {} + client_name = "dummy_client" + query = {"client_name": client_name} + code = 200 + expected = [{"client_name": client_name, "license_id": "L1", "queue": []}] + + _get_all_queue(headers, query, code, expected) + + +def _get_queue_by_license(headers, license_id, code, expected): + response = client.get(f"/queues/{license_id}", headers=headers) + + assert response.status_code == code + assert response.json() == expected + + +def test_get_queue_by_license_basic(): + headers = {} + license_id = "L1" + code = 200 + expected = {"client_name": "dummy_client", "license_id": license_id, "queue": []} + + _get_queue_by_license(headers, license_id, code, expected) + + +def test_get_queue_by_license_wrong_license(): + headers = {} + license_id = "not_exist" + code = 404 + expected = {"detail": 'License priority for not_exist not found.'} + + _get_queue_by_license(headers, license_id, code, expected) + + +def _modify_license_queue(headers, license_id, priority_list, code, expected): + response = client.post(f"/queues/{license_id}", headers=headers, data=priority_list) + + assert response.status_code == code + assert response.json() == expected + + +def test_modify_license_queue_basic(): + headers = {} + license_id = "L1" + priority_list = [] + code = 200 + expected = {"message": f"'{license_id}' license priority updated successfully."} + + _modify_license_queue(headers, license_id, priority_list, code, expected) + + +def test_modify_license_queue_wrong_license_id(): + headers = {} + license_id = "not_exist" + priority_list = [] + code = 404 + expected = {"detail": 'License not_exist not found.'} + + _modify_license_queue(headers, license_id, priority_list, code, expected) + + +def _modify_config_priority_in_license(headers, license_id, query, code, expected): + response = client.put(f"/queues/priority/{license_id}", params=query) + + logger.info(f"response {response.json()}") + + assert response.status_code == code + assert response.json() == expected + + +def test_modify_config_priority_in_license_basic(): + headers = {} + license_id = "L1" + query = {"config_name": "example.cfg", "priority": 0} + code = 200 + expected = { + "message": f"'{license_id}' license -- 'example.cfg' priority updated successfully." + } + + _modify_config_priority_in_license(headers, license_id, query, code, expected) + + +def test_modify_config_priority_in_license_wrong_license(): + headers = {} + license_id = "not_exist" + query = {"config_name": "example.cfg", "priority": 0} + code = 404 + expected = {"detail": 'License not_exist not found.'} + + _modify_config_priority_in_license(headers, license_id, query, code, expected) + + +def test_modify_config_priority_in_license_wrong_config(): + headers = {} + license_id = "not_exist" + query = {"config_name": "wrong.cfg", "priority": 0} + code = 404 + expected = {"detail": 'License not_exist not found.'} + + _modify_config_priority_in_license(headers, license_id, query, code, expected) diff --git a/weather_dl_v2/fastapi-server/tests/test_data/example.cfg b/weather_dl_v2/fastapi-server/tests/test_data/example.cfg new file mode 100644 index 00000000..6747012c --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/test_data/example.cfg @@ -0,0 +1,32 @@ +[parameters] +client=mars + +target_path=gs:///test-weather-dl-v2/{date}T00z.gb +partition_keys= + date + # step + +# API Keys & Subsections go here... + +[selection] +class=od +type=pf +stream=enfo +expver=0001 +levtype=pl +levelist=100 +# params: +# (z) Geopotential 129, (t) Temperature 130, +# (u) U component of wind 131, (v) V component of wind 132, +# (q) Specific humidity 133, (w) vertical velocity 135, +# (vo) Vorticity (relative) 138, (d) Divergence 155, +# (r) Relative humidity 157 +param=129.128 +# +# next: 2019-01-01/to/existing +# +date=2019-07-18/to/2019-07-20 +time=0000 +step=0/to/2 +number=1/to/2 +grid=F640 diff --git a/weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg b/weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg new file mode 100644 index 00000000..6747012c --- /dev/null +++ b/weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg @@ -0,0 +1,32 @@ +[parameters] +client=mars + +target_path=gs:///test-weather-dl-v2/{date}T00z.gb +partition_keys= + date + # step + +# API Keys & Subsections go here... + +[selection] +class=od +type=pf +stream=enfo +expver=0001 +levtype=pl +levelist=100 +# params: +# (z) Geopotential 129, (t) Temperature 130, +# (u) U component of wind 131, (v) V component of wind 132, +# (q) Specific humidity 133, (w) vertical velocity 135, +# (vo) Vorticity (relative) 138, (d) Divergence 155, +# (r) Relative humidity 157 +param=129.128 +# +# next: 2019-01-01/to/existing +# +date=2019-07-18/to/2019-07-20 +time=0000 +step=0/to/2 +number=1/to/2 +grid=F640 diff --git a/weather_dl_v2/license_deployment/Dockerfile b/weather_dl_v2/license_deployment/Dockerfile new file mode 100644 index 00000000..68388f78 --- /dev/null +++ b/weather_dl_v2/license_deployment/Dockerfile @@ -0,0 +1,34 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +FROM continuumio/miniconda3:latest + +# Update miniconda +RUN conda update conda -y + +# Add the mamba solver for faster builds +RUN conda install -n base conda-libmamba-solver +RUN conda config --set solver libmamba + +COPY . . +# Create conda env using environment.yml +RUN conda env create -f environment.yml --debug + +# Activate the conda env and update the PATH +ARG CONDA_ENV_NAME=weather-dl-v2-license-dep +RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc +ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH + +ENTRYPOINT ["python", "-u", "fetch.py"] diff --git a/weather_dl_v2/license_deployment/README.md b/weather_dl_v2/license_deployment/README.md new file mode 100644 index 00000000..4c5cc6a1 --- /dev/null +++ b/weather_dl_v2/license_deployment/README.md @@ -0,0 +1,21 @@ +# Deployment Instructions & General Notes + +### How to create environment +``` +conda env create --name weather-dl-v2-license-dep --file=environment.yml + +conda activate weather-dl-v2-license-dep +``` + +### Make changes in weather_dl_v2/config.json, if required [for running locally] +``` +export CONFIG_PATH=/path/to/weather_dl_v2/config.json +``` + +### Create docker image for license deployment +``` +export PROJECT_ID= +export REPO= eg:weather-tools + +gcloud builds submit . --tag "gcr.io/$PROJECT_ID/$REPO:weather-dl-v2-license-dep" --timeout=79200 --machine-type=e2-highcpu-32 +``` \ No newline at end of file diff --git a/weather_dl_v2/license_deployment/__init__.py b/weather_dl_v2/license_deployment/__init__.py new file mode 100644 index 00000000..5678014c --- /dev/null +++ b/weather_dl_v2/license_deployment/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/weather_dl_v2/license_deployment/clients.py b/weather_dl_v2/license_deployment/clients.py new file mode 100644 index 00000000..331888ea --- /dev/null +++ b/weather_dl_v2/license_deployment/clients.py @@ -0,0 +1,417 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""ECMWF Downloader Clients.""" + +import abc +import collections +import contextlib +import datetime +import io +import logging +import os +import time +import typing as t +import warnings +from urllib.parse import urljoin + +from cdsapi import api as cds_api +import urllib3 +from ecmwfapi import api + +from config import optimize_selection_partition +from manifest import Manifest, Stage +from util import download_with_aria2, retry_with_exponential_backoff + +warnings.simplefilter("ignore", category=urllib3.connectionpool.InsecureRequestWarning) + + +class Client(abc.ABC): + """Weather data provider client interface. + + Defines methods and properties required to efficiently interact with weather + data providers. + + Attributes: + config: A config that contains pipeline parameters, such as API keys. + level: Default log level for the client. + """ + + def __init__(self, dataset: str, level: int = logging.INFO) -> None: + """Clients are initialized with the general CLI configuration.""" + self.dataset = dataset + self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + self.logger.setLevel(level) + + @abc.abstractmethod + def retrieve( + self, dataset: str, selection: t.Dict, output: str, manifest: Manifest + ) -> None: + """Download from data source.""" + pass + + @classmethod + @abc.abstractmethod + def num_requests_per_key(cls, dataset: str) -> int: + """Specifies the number of workers to be used per api key for the dataset.""" + pass + + @property + @abc.abstractmethod + def license_url(self): + """Specifies the License URL.""" + pass + + +class SplitCDSRequest(cds_api.Client): + """Extended CDS class that separates fetch and download stage.""" + + @retry_with_exponential_backoff + def _download(self, url, path: str, size: int) -> None: + self.info("Downloading %s to %s (%s)", url, path, cds_api.bytes_to_string(size)) + start = time.time() + + download_with_aria2(url, path) + + elapsed = time.time() - start + if elapsed: + self.info("Download rate %s/s", cds_api.bytes_to_string(size / elapsed)) + + def fetch(self, request: t.Dict, dataset: str) -> t.Dict: + result = self.retrieve(dataset, request) + return {"href": result.location, "size": result.content_length} + + def download(self, result: cds_api.Result, target: t.Optional[str] = None) -> None: + if target: + if os.path.exists(target): + # Empty the target file, if it already exists, otherwise the + # transfer below might be fooled into thinking we're resuming + # an interrupted download. + open(target, "w").close() + + self._download(result["href"], target, result["size"]) + + +class CdsClient(Client): + """A client to access weather data from the Cloud Data Store (CDS). + + Datasets on CDS can be found at: + https://cds.climate.copernicus.eu/cdsapp#!/search?type=dataset + + The parameters section of the input `config` requires two values: `api_url` and + `api_key`. Or, these values can be set as the environment variables: `CDSAPI_URL` + and `CDSAPI_KEY`. These can be acquired from the following URL, which requires + creating a free account: https://cds.climate.copernicus.eu/api-how-to + + The CDS global queues for data access has dynamic rate limits. These can be viewed + live here: https://cds.climate.copernicus.eu/live/limits. + + Attributes: + config: A config that contains pipeline parameters, such as API keys. + level: Default log level for the client. + """ + + """Name patterns of datasets that are hosted internally on CDS servers.""" + cds_hosted_datasets = {"reanalysis-era"} + + def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: + c = CDSClientExtended( + url=os.environ.get("CLIENT_URL"), + key=os.environ.get("CLIENT_KEY"), + debug_callback=self.logger.debug, + info_callback=self.logger.info, + warning_callback=self.logger.warning, + error_callback=self.logger.error, + ) + selection_ = optimize_selection_partition(selection) + with StdoutLogger(self.logger, level=logging.DEBUG): + manifest.set_stage(Stage.FETCH) + precise_fetch_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + manifest.prev_stage_precise_start_time = precise_fetch_start_time + result = c.fetch(selection_, dataset) + return result + + @property + def license_url(self): + return "https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf" + + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: + """Number of requests per key from the CDS API. + + CDS has dynamic, data-specific limits, defined here: + https://cds.climate.copernicus.eu/live/limits + + Typically, the reanalysis dataset allows for 3-5 simultaneous requets. + For all standard CDS data (backed on disk drives), it's common that 2 + requests are allowed, though this is dynamically set, too. + + If the Beam pipeline encounters a user request limit error, please cancel + all outstanding requests (per each user account) at the following link: + https://cds.climate.copernicus.eu/cdsapp#!/yourrequests + """ + # TODO(#15): Parse live CDS limits API to set data-specific limits. + for internal_set in cls.cds_hosted_datasets: + if dataset.startswith(internal_set): + return 5 + return 2 + + +class StdoutLogger(io.StringIO): + """Special logger to redirect stdout to logs.""" + + def __init__(self, logger_: logging.Logger, level: int = logging.INFO): + super().__init__() + self.logger = logger_ + self.level = level + self._redirector = contextlib.redirect_stdout(self) + + def log(self, msg) -> None: + self.logger.log(self.level, msg) + + def write(self, msg): + if msg and not msg.isspace(): + self.logger.log(self.level, msg) + + def __enter__(self): + self._redirector.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + # let contextlib do any exception handling here + self._redirector.__exit__(exc_type, exc_value, traceback) + + +class SplitMARSRequest(api.APIRequest): + """Extended MARS APIRequest class that separates fetch and download stage.""" + + @retry_with_exponential_backoff + def _download(self, url, path: str, size: int) -> None: + self.log("Transferring %s into %s" % (self._bytename(size), path)) + self.log("From %s" % (url,)) + + download_with_aria2(url, path) + + def fetch(self, request: t.Dict, dataset: str) -> t.Dict: + status = None + + self.connection.submit("%s/%s/requests" % (self.url, self.service), request) + self.log("Request submitted") + self.log("Request id: " + self.connection.last.get("name")) + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + + while not self.connection.ready(): + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + self.connection.wait() + + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + + result = self.connection.result() + return result + + def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: + if target: + if os.path.exists(target): + # Empty the target file, if it already exists, otherwise the + # transfer below might be fooled into thinking we're resuming + # an interrupted download. + open(target, "w").close() + + self._download(urljoin(self.url, result["href"]), target, result["size"]) + self.connection.cleanup() + + +class SplitRequestMixin: + c = None + + def fetch(self, req: t.Dict, dataset: t.Optional[str] = None) -> t.Dict: + return self.c.fetch(req, dataset) + + def download(self, res: t.Dict, target: str) -> None: + self.c.download(res, target) + + +class CDSClientExtended(SplitRequestMixin): + """Extended CDS Client class that separates fetch and download stage.""" + + def __init__(self, *args, **kwargs): + self.c = SplitCDSRequest(*args, **kwargs) + + +class MARSECMWFServiceExtended(api.ECMWFService, SplitRequestMixin): + """Extended MARS ECMFService class that separates fetch and download stage.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.c = SplitMARSRequest( + self.url, + "services/%s" % (self.service,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + quiet=self.quiet, + ) + + +class PublicECMWFServerExtended(api.ECMWFDataServer, SplitRequestMixin): + + def __init__(self, *args, dataset="", **kwargs): + super().__init__(*args, **kwargs) + self.c = SplitMARSRequest( + self.url, + "datasets/%s" % (dataset,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + ) + + +class MarsClient(Client): + """A client to access data from the Meteorological Archival and Retrieval System (MARS). + + See https://www.ecmwf.int/en/forecasts/datasets for a summary of datasets available + on MARS. Most notable, MARS provides access to ECMWF's Operational Archive + https://www.ecmwf.int/en/forecasts/dataset/operational-archive. + + The client config must contain three parameters to autheticate access to the MARS archive: + `api_key`, `api_url`, and `api_email`. These can also be configued by setting the + commensurate environment variables: `MARSAPI_KEY`, `MARSAPI_URL`, and `MARSAPI_EMAIL`. + These credentials can be looked up by after registering for an ECMWF account + (https://apps.ecmwf.int/registration/) and visitng: https://api.ecmwf.int/v1/key/. + + MARS server activity can be observed at https://apps.ecmwf.int/mars-activity/. + + Attributes: + config: A config that contains pipeline parameters, such as API keys. + level: Default log level for the client. + """ + + def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: + c = MARSECMWFServiceExtended( + "mars", + key=os.environ.get("CLIENT_KEY"), + url=os.environ.get("CLIENT_URL"), + email=os.environ.get("CLIENT_EMAIL"), + log=self.logger.debug, + verbose=True, + ) + selection_ = optimize_selection_partition(selection) + with StdoutLogger(self.logger, level=logging.DEBUG): + manifest.set_stage(Stage.FETCH) + precise_fetch_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + manifest.prev_stage_precise_start_time = precise_fetch_start_time + result = c.fetch(req=selection_) + return result + + @property + def license_url(self): + return "https://apps.ecmwf.int/datasets/licences/general/" + + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: + """Number of requests per key (or user) for the Mars API. + + Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021. + To ensure we never hit a rate limit error during download, we only make use of the active + requests. + See: https://confluence.ecmwf.int/display/UDOC/Total+number+of+requests+a+user+can+submit+-+Web+API+FAQ + + Queued requests can _only_ be canceled manually from a web dashboard. If the + `ERROR 101 (USER_QUEUED_LIMIT_EXCEEDED)` error occurs in the Beam pipeline, then go to + http://apps.ecmwf.int/webmars/joblist/ and cancel queued jobs. + """ + return 2 + + +class ECMWFPublicClient(Client): + """A client for ECMWF's public datasets, like TIGGE.""" + + def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: + c = PublicECMWFServerExtended( + url=os.environ.get("CLIENT_URL"), + key=os.environ.get("CLIENT_KEY"), + email=os.environ.get("CLIENT_EMAIL"), + log=self.logger.debug, + verbose=True, + dataset=dataset, + ) + selection_ = optimize_selection_partition(selection) + with StdoutLogger(self.logger, level=logging.DEBUG): + manifest.set_stage(Stage.FETCH) + precise_fetch_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + manifest.prev_stage_precise_start_time = precise_fetch_start_time + result = c.fetch(req=selection_) + return result + + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: + # Experimentally validated request limit. + return 5 + + @property + def license_url(self): + if not self.dataset: + raise ValueError("must specify a dataset for this client!") + return f"https://apps.ecmwf.int/datasets/data/{self.dataset.lower()}/licence/" + + +class FakeClient(Client): + """A client that writes the selection arguments to the output file.""" + + def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: + manifest.set_stage(Stage.RETRIEVE) + precise_retrieve_start_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + manifest.prev_stage_precise_start_time = precise_retrieve_start_time + self.logger.debug(f"Downloading {dataset}.") + + @property + def license_url(self): + return "lorem ipsum" + + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: + return 1 + + +CLIENTS = collections.OrderedDict( + cds=CdsClient, + mars=MarsClient, + ecpublic=ECMWFPublicClient, + fake=FakeClient, +) diff --git a/weather_dl_v2/license_deployment/config.py b/weather_dl_v2/license_deployment/config.py new file mode 100644 index 00000000..fe2199b8 --- /dev/null +++ b/weather_dl_v2/license_deployment/config.py @@ -0,0 +1,120 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import calendar +import copy +import dataclasses +import typing as t + +Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class Config: + """Contains pipeline parameters. + + Attributes: + config_name: + Name of the config file. + client: + Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable. + dataset (optional): + Name of the target dataset. Allowed options are dictated by the client. + partition_keys (optional): + Choose the keys from the selection section to partition the data request. + This will compute a cartesian cross product of the selected keys + and assign each as their own download. + target_path: + Download artifact filename template. Can make use of Python's standard string formatting. + It can contain format symbols to be replaced by partition keys; + if this is used, the total number of format symbols must match the number of partition keys. + subsection_name: + Name of the particular subsection. 'default' if there is no subsection. + force_download: + Force redownload of partitions that were previously downloaded. + user_id: + Username from the environment variables. + kwargs (optional): + For representing subsections or any other parameters. + selection: + Contains parameters used to select desired data. + """ + + config_name: str = "" + client: str = "" + dataset: t.Optional[str] = "" + target_path: str = "" + partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list) + subsection_name: str = "default" + force_download: bool = False + user_id: str = "unknown" + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict) -> "Config": + config_instance = cls() + for section_key, section_value in config.items(): + if section_key == "parameters": + for key, value in section_value.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + if section_key == "selection": + config_instance.selection = section_value + return config_instance + + +def optimize_selection_partition(selection: t.Dict) -> t.Dict: + """Compute right-hand-side values for the selection section of a single partition. + + Used to support custom syntax and optimizations, such as 'all'. + """ + selection_ = copy.deepcopy(selection) + + if "day" in selection_.keys() and selection_["day"] == "all": + year, month = selection_["year"], selection_["month"] + + multiples_error = ( + "Cannot use keyword 'all' on selections with multiple '{type}'s." + ) + + if isinstance(year, list): + assert len(year) == 1, multiples_error.format(type="year") + year = year[0] + + if isinstance(month, list): + assert len(month) == 1, multiples_error.format(type="month") + month = month[0] + + if isinstance(year, str): + assert "/" not in year, multiples_error.format(type="year") + + if isinstance(month, str): + assert "/" not in month, multiples_error.format(type="month") + + year, month = int(year), int(month) + + _, n_days_in_month = calendar.monthrange(year, month) + + selection_[ + "date" + ] = f"{year:04d}-{month:02d}-01/to/{year:04d}-{month:02d}-{n_days_in_month:02d}" + del selection_["day"] + del selection_["month"] + del selection_["year"] + + return selection_ diff --git a/weather_dl_v2/license_deployment/database.py b/weather_dl_v2/license_deployment/database.py new file mode 100644 index 00000000..24206561 --- /dev/null +++ b/weather_dl_v2/license_deployment/database.py @@ -0,0 +1,176 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import time +import logging +import firebase_admin +from firebase_admin import firestore +from firebase_admin import credentials +from google.cloud.firestore_v1 import DocumentSnapshot, DocumentReference +from google.cloud.firestore_v1.types import WriteResult +from google.cloud.firestore_v1.base_query import FieldFilter, And +from util import get_wait_interval +from deployment_config import get_config + +logger = logging.getLogger(__name__) + + +class Database(abc.ABC): + + @abc.abstractmethod + def _get_db(self): + pass + + +class CRUDOperations(abc.ABC): + + @abc.abstractmethod + def _initialize_license_deployment(self, license_id: str) -> dict: + pass + + @abc.abstractmethod + def _get_config_from_queue_by_license_id(self, license_id: str) -> dict: + pass + + @abc.abstractmethod + def _remove_config_from_license_queue( + self, license_id: str, config_name: str + ) -> None: + pass + + @abc.abstractmethod + def _empty_license_queue(self, license_id: str) -> None: + pass + + @abc.abstractmethod + def _get_partition_from_manifest(self, config_name: str) -> str: + pass + + +class FirestoreClient(Database, CRUDOperations): + + def _get_db(self) -> firestore.firestore.Client: + """Acquire a firestore client, initializing the firebase app if necessary. + Will attempt to get the db client five times. If it's still unsuccessful, a + `ManifestException` will be raised. + """ + db = None + attempts = 0 + + while db is None: + try: + db = firestore.client() + except ValueError as e: + # The above call will fail with a value error when the firebase app is not initialized. + # Initialize the app here, and try again. + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + logger.info("Initialized Firebase App.") + + if attempts > 4: + raise RuntimeError( + "Exceeded number of retries to get firestore client." + ) from e + + time.sleep(get_wait_interval(attempts)) + + attempts += 1 + + return db + + def _initialize_license_deployment(self, license_id: str) -> dict: + result: DocumentSnapshot = ( + self._get_db() + .collection(get_config().license_collection) + .document(license_id) + .get() + ) + return result.to_dict() + + def _get_config_from_queue_by_license_id(self, license_id: str) -> str | None: + result: DocumentSnapshot = ( + self._get_db() + .collection(get_config().queues_collection) + .document(license_id) + .get(["queue"]) + ) + if result.exists: + queue = result.to_dict()["queue"] + if len(queue) > 0: + return queue[0] + return None + + def _get_partition_from_manifest(self, config_name: str) -> str | None: + transaction = self._get_db().transaction() + return get_partition_from_manifest(transaction, config_name) + + def _remove_config_from_license_queue( + self, license_id: str, config_name: str + ) -> None: + result: WriteResult = ( + self._get_db() + .collection(get_config().queues_collection) + .document(license_id) + .update({"queue": firestore.ArrayRemove([config_name])}) + ) + logger.info( + f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + def _empty_license_queue(self, license_id: str) -> None: + result: WriteResult = ( + self._get_db() + .collection(get_config().queues_collection) + .document(license_id) + .update({"queue": []}) + ) + logger.info( + f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." + ) + + +# TODO: Firestore transcational fails after reading a document 20 times with roll over. +# This happens when too many licenses try to access the same partition document. +# Find some alternative approach to handle this. +@firestore.transactional +def get_partition_from_manifest(transaction, config_name: str) -> str | None: + db_client = FirestoreClient() + filter_1 = FieldFilter("config_name", "==", config_name) + filter_2 = FieldFilter("status", "==", "scheduled") + and_filter = And(filters=[filter_1, filter_2]) + + snapshot = ( + db_client._get_db() + .collection(get_config().manifest_collection) + .where(filter=and_filter) + .limit(1) + .get(transaction=transaction) + ) + if len(snapshot) > 0: + snapshot = snapshot[0] + else: + return None + + ref: DocumentReference = ( + db_client._get_db() + .collection(get_config().manifest_collection) + .document(snapshot.id) + ) + transaction.update(ref, {"status": "processing"}) + + return snapshot.to_dict() diff --git a/weather_dl_v2/license_deployment/deployment_config.py b/weather_dl_v2/license_deployment/deployment_config.py new file mode 100644 index 00000000..8ae162ea --- /dev/null +++ b/weather_dl_v2/license_deployment/deployment_config.py @@ -0,0 +1,69 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import dataclasses +import typing as t +import json +import os +import logging + +logger = logging.getLogger(__name__) + +Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class DeploymentConfig: + download_collection: str = "" + queues_collection: str = "" + license_collection: str = "" + manifest_collection: str = "" + downloader_k8_image: str = "" + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict): + config_instance = cls() + + for key, value in config.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + + return config_instance + + +deployment_config = None + + +def get_config(): + global deployment_config + if deployment_config: + return deployment_config + + deployment_config_json = "config/config.json" + if not os.path.exists(deployment_config_json): + deployment_config_json = os.environ.get("CONFIG_PATH", None) + + if deployment_config_json is None: + logger.error("Couldn't load config file for license deployment.") + raise FileNotFoundError("Couldn't load config file for license deployment.") + + with open(deployment_config_json) as file: + config_dict = json.load(file) + deployment_config = DeploymentConfig.from_dict(config_dict) + + return deployment_config diff --git a/weather_dl_v2/license_deployment/downloader.yaml b/weather_dl_v2/license_deployment/downloader.yaml new file mode 100644 index 00000000..361c2b36 --- /dev/null +++ b/weather_dl_v2/license_deployment/downloader.yaml @@ -0,0 +1,33 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: downloader-with-ttl +spec: + ttlSecondsAfterFinished: 0 + template: + spec: + nodeSelector: + cloud.google.com/gke-nodepool: downloader-pool + containers: + - name: downloader + image: XXXXXXX + imagePullPolicy: Always + command: [] + resources: + requests: + cpu: "1000m" # CPU: 1 vCPU + memory: "2Gi" # RAM: 2 GiB + ephemeral-storage: "100Gi" # Storage: 100 GiB + volumeMounts: + - name: data + mountPath: /data + - name: config-volume + mountPath: ./config + restartPolicy: Never + volumes: + - name: data + emptyDir: + sizeLimit: 100Gi + - name: config-volume + configMap: + name: dl-v2-config \ No newline at end of file diff --git a/weather_dl_v2/license_deployment/environment.yml b/weather_dl_v2/license_deployment/environment.yml new file mode 100644 index 00000000..4848fafd --- /dev/null +++ b/weather_dl_v2/license_deployment/environment.yml @@ -0,0 +1,17 @@ +name: weather-dl-v2-license-dep +channels: + - conda-forge +dependencies: + - python=3.10 + - geojson + - cdsapi=0.5.1 + - ecmwf-api-client=1.6.3 + - pip=22.3 + - pip: + - kubernetes + - google-cloud-secret-manager + - aiohttp + - numpy + - xarray + - apache-beam[gcp] + - firebase-admin diff --git a/weather_dl_v2/license_deployment/fetch.py b/weather_dl_v2/license_deployment/fetch.py new file mode 100644 index 00000000..3e69a56f --- /dev/null +++ b/weather_dl_v2/license_deployment/fetch.py @@ -0,0 +1,184 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from concurrent.futures import ThreadPoolExecutor +from google.cloud import secretmanager +import json +import logging +import time +import sys +import os + +from database import FirestoreClient +from job_creator import create_download_job +from clients import CLIENTS +from manifest import FirestoreManifest +from util import exceptionit, ThreadSafeDict + +db_client = FirestoreClient() +secretmanager_client = secretmanager.SecretManagerServiceClient() +CONFIG_MAX_ERROR_COUNT = 10 + +def create_job(request, result): + res = { + "config_name": request["config_name"], + "dataset": request["dataset"], + "selection": json.loads(request["selection"]), + "user_id": request["username"], + "url": result["href"], + "target_path": request["location"], + "license_id": license_id, + } + + data_str = json.dumps(res) + logger.info(f"Creating download job for res: {data_str}") + create_download_job(data_str) + + +@exceptionit +def make_fetch_request(request, error_map: ThreadSafeDict): + client = CLIENTS[client_name](request["dataset"]) + manifest = FirestoreManifest(license_id=license_id) + logger.info( + f"By using {client_name} datasets, " + f"users agree to the terms and conditions specified in {client.license_url!r}." + ) + + target = request["location"] + selection = json.loads(request["selection"]) + + logger.info(f"Fetching data for {target!r}.") + + config_name = request["config_name"] + + if not error_map.has_key(config_name): + error_map[config_name] = 0 + + if error_map[config_name] >= CONFIG_MAX_ERROR_COUNT: + logger.info(f"Error count for config {config_name} exceeded CONFIG_MAX_ERROR_COUNT ({CONFIG_MAX_ERROR_COUNT}).") + error_map.remove(config_name) + logger.info(f"Removing config {config_name} from license queue.") + # Remove config from this license queue. + db_client._remove_config_from_license_queue(license_id=license_id, config_name=config_name) + return + + # Wait for exponential time based on error count. + if error_map[config_name] > 0: + logger.info(f"Error count for config {config_name}: {error_map[config_name]}.") + time = error_map.exponential_time(config_name) + logger.info(f"Sleeping for {time} mins.") + time.sleep(time) + + try: + with manifest.transact( + request["config_name"], + request["dataset"], + selection, + target, + request["username"], + ): + result = client.retrieve(request["dataset"], selection, manifest) + except Exception as e: + # We are handling this as generic case as CDS client throws generic exceptions. + + # License expired. + if "Access token expired" in str(e): + logger.error(f"{license_id} expired. Emptying queue! error: {e}.") + db_client._empty_license_queue(license_id=license_id) + return + + # Increment error count for a config. + logger.error(f"Partition fetching failed. Error {e}.") + error_map.increment(config_name) + return + + # If any partition in successful reset the error count. + error_map[config_name] = 0 + create_job(request, result) + + +def fetch_request_from_db(): + request = None + config_name = db_client._get_config_from_queue_by_license_id(license_id) + if config_name: + try: + logger.info(f"Fetching partition for {config_name}.") + request = db_client._get_partition_from_manifest(config_name) + if not request: + db_client._remove_config_from_license_queue(license_id, config_name) + except Exception as e: + logger.error( + f"Error in fetch_request_from_db for {config_name}. error: {e}." + ) + return request + + +def main(): + logger.info("Started looking at the request.") + error_map = ThreadSafeDict() + with ThreadPoolExecutor(concurrency_limit) as executor: + # Disclaimer: A license will pick always pick concurrency_limit + 1 + # parition. One extra parition will be kept in threadpool task queue. + + while True: + # Fetch a request from the database + request = fetch_request_from_db() + + if request is not None: + executor.submit(make_fetch_request, request, error_map) + else: + logger.info("No request available. Waiting...") + time.sleep(5) + + # Each license should not pick more partitions than it's + # concurrency_limit. We limit the threadpool queue size to just 1 + # to prevent the license from picking more partitions than + # it's concurrency_limit. When an executor is freed up, the task + # in queue is picked and license fetches another task. + while executor._work_queue.qsize() >= 1: + logger.info("Worker busy. Waiting...") + time.sleep(1) + + +def boot_up(license: str) -> None: + global license_id, client_name, concurrency_limit + + result = db_client._initialize_license_deployment(license) + license_id = license + client_name = result["client_name"] + concurrency_limit = result["number_of_requests"] + + response = secretmanager_client.access_secret_version( + request={"name": result["secret_id"]} + ) + payload = response.payload.data.decode("UTF-8") + secret_dict = json.loads(payload) + + os.environ.setdefault("CLIENT_URL", secret_dict.get("api_url", "")) + os.environ.setdefault("CLIENT_KEY", secret_dict.get("api_key", "")) + os.environ.setdefault("CLIENT_EMAIL", secret_dict.get("api_email", "")) + + +if __name__ == "__main__": + license = sys.argv[2] + global logger + logging.basicConfig( + level=logging.INFO, format=f"[{license}] %(levelname)s - %(message)s" + ) + logger = logging.getLogger(__name__) + + logger.info(f"Deployment for license: {license}.") + boot_up(license) + main() diff --git a/weather_dl_v2/license_deployment/job_creator.py b/weather_dl_v2/license_deployment/job_creator.py new file mode 100644 index 00000000..f0acd802 --- /dev/null +++ b/weather_dl_v2/license_deployment/job_creator.py @@ -0,0 +1,58 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from os import path +import yaml +import json +import uuid +from kubernetes import client, config +from deployment_config import get_config + + +def create_download_job(message): + """Creates a kubernetes workflow of type Job for downloading the data.""" + parsed_message = json.loads(message) + ( + config_name, + dataset, + selection, + user_id, + url, + target_path, + license_id, + ) = parsed_message.values() + selection = str(selection).replace(" ", "") + config.load_config() + + with open(path.join(path.dirname(__file__), "downloader.yaml")) as f: + dep = yaml.safe_load(f) + uid = uuid.uuid4() + dep["metadata"]["name"] = f"downloader-job-id-{uid}" + dep["spec"]["template"]["spec"]["containers"][0]["command"] = [ + "python", + "downloader.py", + config_name, + dataset, + selection, + user_id, + url, + target_path, + license_id, + ] + dep["spec"]["template"]["spec"]["containers"][0][ + "image" + ] = get_config().downloader_k8_image + batch_api = client.BatchV1Api() + batch_api.create_namespaced_job(body=dep, namespace="default") diff --git a/weather_dl_v2/license_deployment/manifest.py b/weather_dl_v2/license_deployment/manifest.py new file mode 100644 index 00000000..3119de9e --- /dev/null +++ b/weather_dl_v2/license_deployment/manifest.py @@ -0,0 +1,522 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Client interface for connecting to a manifest.""" + +import abc +import logging +import dataclasses +import datetime +import enum +import json +import pandas as pd +import time +import traceback +import typing as t + +from util import ( + to_json_serializable_type, + fetch_geo_polygon, + get_file_size, + get_wait_interval, + generate_md5_hash, + GLOBAL_COVERAGE_AREA, +) + +import firebase_admin +from firebase_admin import credentials +from firebase_admin import firestore +from google.cloud.firestore_v1 import DocumentReference +from google.cloud.firestore_v1.types import WriteResult +from deployment_config import get_config +from database import Database + +logger = logging.getLogger(__name__) + +"""An implementation-dependent Manifest URI.""" +Location = t.NewType("Location", str) + + +class ManifestException(Exception): + """Errors that occur in Manifest Clients.""" + + pass + + +class Stage(enum.Enum): + """A request can be either in one of the following stages at a time: + + fetch : This represents request is currently in fetch stage i.e. request placed on the client's server + & waiting for some result before starting download (eg. MARS client). + download : This represents request is currently in download stage i.e. data is being downloading from client's + server to the worker's local file system. + upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local + file system to target location (GCS path). + retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), + request will be in the retrieve stage i.e. fetch + download. + """ + + RETRIEVE = "retrieve" + FETCH = "fetch" + DOWNLOAD = "download" + UPLOAD = "upload" + + +class Status(enum.Enum): + """Depicts the request's state status: + + scheduled : A request partition is created & scheduled for processing. + Note: Its corresponding state can be None only. + processing: This represents that the request picked by license deployment. + in-progress : This represents the request state is currently in-progress (i.e. running). + The next status would be "success" or "failure". + success : This represents the request state execution completed successfully without any error. + failure : This represents the request state execution failed. + """ + + PROCESSING = "processing" + SCHEDULED = "scheduled" + IN_PROGRESS = "in-progress" + SUCCESS = "success" + FAILURE = "failure" + + +@dataclasses.dataclass +class DownloadStatus: + """Data recorded in `Manifest`s reflecting the status of a download.""" + + """The name of the config file associated with the request.""" + config_name: str = "" + + """Represents the dataset field of the configuration.""" + dataset: t.Optional[str] = "" + + """Copy of selection section of the configuration.""" + selection: t.Dict = dataclasses.field(default_factory=dict) + + """Location of the downloaded data.""" + location: str = "" + + """Represents area covered by the shard.""" + area: str = "" + + """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" + stage: t.Optional[Stage] = None + + """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" + status: t.Optional[Status] = None + + """Cause of error, if any.""" + error: t.Optional[str] = "" + + """Identifier for the user running the download.""" + username: str = "" + + """Shard size in GB.""" + size: t.Optional[float] = 0 + + """A UTC datetime when download was scheduled.""" + scheduled_time: t.Optional[str] = "" + + """A UTC datetime when the retrieve stage starts.""" + retrieve_start_time: t.Optional[str] = "" + + """A UTC datetime when the retrieve state ends.""" + retrieve_end_time: t.Optional[str] = "" + + """A UTC datetime when the fetch state starts.""" + fetch_start_time: t.Optional[str] = "" + + """A UTC datetime when the fetch state ends.""" + fetch_end_time: t.Optional[str] = "" + + """A UTC datetime when the download state starts.""" + download_start_time: t.Optional[str] = "" + + """A UTC datetime when the download state ends.""" + download_end_time: t.Optional[str] = "" + + """A UTC datetime when the upload state starts.""" + upload_start_time: t.Optional[str] = "" + + """A UTC datetime when the upload state ends.""" + upload_end_time: t.Optional[str] = "" + + @classmethod + def from_dict(cls, download_status: t.Dict) -> "DownloadStatus": + """Instantiate DownloadStatus dataclass from dict.""" + download_status_instance = cls() + for key, value in download_status.items(): + if key == "status": + setattr(download_status_instance, key, Status(value)) + elif key == "stage" and value is not None: + setattr(download_status_instance, key, Stage(value)) + else: + setattr(download_status_instance, key, value) + return download_status_instance + + @classmethod + def to_dict(cls, instance) -> t.Dict: + """Return the fields of a dataclass instance as a manifest ingestible + dictionary mapping of field names to field values.""" + download_status_dict = {} + for field in dataclasses.fields(instance): + key = field.name + value = getattr(instance, field.name) + if isinstance(value, Status) or isinstance(value, Stage): + download_status_dict[key] = value.value + elif isinstance(value, pd.Timestamp): + download_status_dict[key] = value.isoformat() + elif key == "selection" and value is not None: + download_status_dict[key] = json.dumps(value) + else: + download_status_dict[key] = value + return download_status_dict + + +@dataclasses.dataclass +class Manifest(abc.ABC): + """Abstract manifest of download statuses. + + Update download statuses to some storage medium. + + This class lets one indicate that a download is `scheduled` or in a transaction process. + In the event of a transaction, a download will be updated with an `in-progress`, `success` + or `failure` status (with accompanying metadata). + + Example: + ``` + my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) + + # Schedule data for download + my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') + + # ... + + # Initiate a transaction – it will record that the download is `in-progess` + with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: + # download logic here + pass + + # ... + + # on error, will record the download as a `failure` before propagating the error. By default, it will + # record download as a `success`. + ``` + + Attributes: + status: The current `DownloadStatus` of the Manifest. + """ + + # To reduce the impact of _read() and _update() calls + # on the start time of the stage. + license_id: str = "" + prev_stage_precise_start_time: t.Optional[str] = None + status: t.Optional[DownloadStatus] = None + + # This is overridden in subclass. + def __post_init__(self): + """Initialize the manifest.""" + pass + + def schedule( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Indicate that a job has been scheduled for download. + + 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. + """ + scheduled_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + self.status = DownloadStatus( + config_name=config_name, + dataset=dataset if dataset else None, + selection=selection, + location=location, + area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), + username=user, + stage=None, + status=Status.SCHEDULED, + error=None, + size=None, + scheduled_time=scheduled_time, + retrieve_start_time=None, + retrieve_end_time=None, + fetch_start_time=None, + fetch_end_time=None, + download_start_time=None, + download_end_time=None, + upload_start_time=None, + upload_end_time=None, + ) + self._update(self.status) + + def skip( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Updates the manifest to mark the shards that were skipped in the current job + as 'upload' stage and 'success' status, indicating that they have already been downloaded. + """ + old_status = self._read(location) + # The manifest needs to be updated for a skipped shard if its entry is not present, or + # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. + if ( + old_status.location != location + or old_status.stage != Stage.UPLOAD + or old_status.status != Status.SUCCESS + ): + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + size = get_file_size(location) + + status = DownloadStatus( + config_name=config_name, + dataset=dataset if dataset else None, + selection=selection, + location=location, + area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), + username=user, + stage=Stage.UPLOAD, + status=Status.SUCCESS, + error=None, + size=size, + scheduled_time=None, + retrieve_start_time=None, + retrieve_end_time=None, + fetch_start_time=None, + fetch_end_time=None, + download_start_time=None, + download_end_time=None, + upload_start_time=current_utc_time, + upload_end_time=current_utc_time, + ) + self._update(status) + logger.info( + f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}." + ) + + def _set_for_transaction( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> None: + """Reset Manifest state in preparation for a new transaction.""" + self.status = dataclasses.replace(self._read(location)) + self.status.config_name = config_name + self.status.dataset = dataset if dataset else None + self.status.selection = selection + self.status.location = location + self.status.username = user + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type, exc_inst, exc_tb) -> None: + """Record end status of a transaction as either 'success' or 'failure'.""" + if exc_type is None: + status = Status.SUCCESS + error = None + else: + status = Status.FAILURE + # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception + error = f"license_id: {self.license_id} " + error += "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb)) + + new_status = dataclasses.replace(self.status) + new_status.error = error + new_status.status = status + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + # This is necessary for setting the precise start time of the previous stage + # and end time of the final stage, as well as handling the case of Status.FAILURE. + if new_status.stage == Stage.FETCH: + new_status.fetch_start_time = self.prev_stage_precise_start_time + new_status.fetch_end_time = current_utc_time + elif new_status.stage == Stage.RETRIEVE: + new_status.retrieve_start_time = self.prev_stage_precise_start_time + new_status.retrieve_end_time = current_utc_time + elif new_status.stage == Stage.DOWNLOAD: + new_status.download_start_time = self.prev_stage_precise_start_time + new_status.download_end_time = current_utc_time + else: + new_status.upload_start_time = self.prev_stage_precise_start_time + new_status.upload_end_time = current_utc_time + + new_status.size = get_file_size(new_status.location) + + self.status = new_status + + self._update(self.status) + + def transact( + self, + config_name: str, + dataset: str, + selection: t.Dict, + location: str, + user: str, + ) -> "Manifest": + """Create a download transaction.""" + self._set_for_transaction(config_name, dataset, selection, location, user) + return self + + def set_stage(self, stage: Stage) -> None: + """Sets the current stage in manifest.""" + prev_stage = self.status.stage + new_status = dataclasses.replace(self.status) + new_status.stage = stage + new_status.status = Status.IN_PROGRESS + current_utc_time = ( + datetime.datetime.utcnow() + .replace(tzinfo=datetime.timezone.utc) + .isoformat(timespec="seconds") + ) + + if stage == Stage.FETCH: + new_status.fetch_start_time = current_utc_time + elif stage == Stage.RETRIEVE: + new_status.retrieve_start_time = current_utc_time + elif stage == Stage.DOWNLOAD: + new_status.fetch_start_time = self.prev_stage_precise_start_time + new_status.fetch_end_time = current_utc_time + new_status.download_start_time = current_utc_time + else: + if prev_stage == Stage.DOWNLOAD: + new_status.download_start_time = self.prev_stage_precise_start_time + new_status.download_end_time = current_utc_time + else: + new_status.retrieve_start_time = self.prev_stage_precise_start_time + new_status.retrieve_end_time = current_utc_time + new_status.upload_start_time = current_utc_time + + self.status = new_status + self._update(self.status) + + @abc.abstractmethod + def _read(self, location: str) -> DownloadStatus: + pass + + @abc.abstractmethod + def _update(self, download_status: DownloadStatus) -> None: + pass + + +class FirestoreManifest(Manifest, Database): + """A Firestore Manifest. + This Manifest implementation stores DownloadStatuses in a Firebase document store. + The document hierarchy for the manifest is as follows: + [manifest ] + ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } + └── etc... + Where `[]` indicates a collection and ` {...}` indicates a document. + """ + + def _get_db(self) -> firestore.firestore.Client: + """Acquire a firestore client, initializing the firebase app if necessary. + Will attempt to get the db client five times. If it's still unsuccessful, a + `ManifestException` will be raised. + """ + db = None + attempts = 0 + + while db is None: + try: + db = firestore.client() + except ValueError as e: + # The above call will fail with a value error when the firebase app is not initialized. + # Initialize the app here, and try again. + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + logger.info("Initialized Firebase App.") + + if attempts > 4: + raise ManifestException( + "Exceeded number of retries to get firestore client." + ) from e + + time.sleep(get_wait_interval(attempts)) + + attempts += 1 + + return db + + def _read(self, location: str) -> DownloadStatus: + """Reads the JSON data from a manifest.""" + + doc_id = generate_md5_hash(location) + + # Update document with download status + download_doc_ref = self.root_document_for_store(doc_id) + + result = download_doc_ref.get() + row = {} + if result.exists: + records = result.to_dict() + row = {n: to_json_serializable_type(v) for n, v in records.items()} + return DownloadStatus.from_dict(row) + + def _update(self, download_status: DownloadStatus) -> None: + """Update or create a download status record.""" + logger.info("Updating Firestore Manifest.") + + status = DownloadStatus.to_dict(download_status) + doc_id = generate_md5_hash(status["location"]) + + # Update document with download status. + download_doc_ref = self.root_document_for_store(doc_id) + + result: WriteResult = download_doc_ref.set(status) + + logger.info( + "Firestore manifest updated. " + + f"update_time={result.update_time}, " + + f"status={status['status']} " + + f"stage={status['stage']} " + + f"filename={download_status.location}." + ) + + def root_document_for_store(self, store_scheme: str) -> DocumentReference: + """Get the root manifest document given the user's config and current document's storage location.""" + return ( + self._get_db() + .collection(get_config().manifest_collection) + .document(store_scheme) + ) diff --git a/weather_dl_v2/license_deployment/util.py b/weather_dl_v2/license_deployment/util.py new file mode 100644 index 00000000..d24a1405 --- /dev/null +++ b/weather_dl_v2/license_deployment/util.py @@ -0,0 +1,302 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime +import logging +import geojson +import hashlib +import itertools +import os +import socket +import subprocess +import sys +import typing as t + +import numpy as np +import pandas as pd +from apache_beam.io.gcp import gcsio +from apache_beam.utils import retry +from xarray.core.utils import ensure_us_time_resolution +from urllib.parse import urlparse +from google.api_core.exceptions import BadRequest +from threading import Lock + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +LATITUDE_RANGE = (-90, 90) +LONGITUDE_RANGE = (-180, 180) +GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] + + +def exceptionit(func): + def inner_function(*args, **kwargs): + try: + func(*args, **kwargs) + except Exception as e: + logger.error(f"exception in {func.__name__} {e.__class__.__name__} {e}.") + + return inner_function + + +def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter( + exception, +) -> bool: + if isinstance(exception, socket.timeout): + return True + if isinstance(exception, TimeoutError): + return True + # To handle the concurrency issue in BigQuery. + if isinstance(exception, BadRequest): + return True + return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) + + +class _FakeClock: + + def sleep(self, value): + pass + + +def retry_with_exponential_backoff(fun): + """A retry decorator that doesn't apply during test time.""" + clock = retry.Clock() + + # Use a fake clock only during test time... + if "unittest" in sys.modules.keys(): + clock = _FakeClock() + + return retry.with_exponential_backoff( + retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, + clock=clock, + )(fun) + + +# TODO(#245): Group with common utilities (duplicated) +def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: + """Yield evenly-sized chunks from an iterable.""" + input_ = iter(iterable) + try: + while True: + it = itertools.islice(input_, n) + # peek to check if 'it' has next item. + first = next(it) + yield itertools.chain([first], it) + except StopIteration: + pass + + +# TODO(#245): Group with common utilities (duplicated) +def copy(src: str, dst: str) -> None: + """Copy data via `gsutil cp`.""" + try: + subprocess.run(["gsutil", "cp", src, dst], check=True, capture_output=True) + except subprocess.CalledProcessError as e: + logger.info( + f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}.' + ) + raise + + +# TODO(#245): Group with common utilities (duplicated) +def to_json_serializable_type(value: t.Any) -> t.Any: + """Returns the value with a type serializable to JSON""" + # Note: The order of processing is significant. + logger.info("Serializing to JSON.") + + if pd.isna(value) or value is None: + return None + elif np.issubdtype(type(value), np.floating): + return float(value) + elif isinstance(value, np.ndarray): + # Will return a scaler if array is of size 1, else will return a list. + return value.tolist() + elif ( + isinstance(value, datetime.datetime) + or isinstance(value, str) + or isinstance(value, np.datetime64) + ): + # Assume strings are ISO format timestamps... + try: + value = datetime.datetime.fromisoformat(value) + except ValueError: + # ... if they are not, assume serialization is already correct. + return value + except TypeError: + # ... maybe value is a numpy datetime ... + try: + value = ensure_us_time_resolution(value).astype(datetime.datetime) + except AttributeError: + # ... value is a datetime object, continue. + pass + + # We use a string timestamp representation. + if value.tzname(): + return value.isoformat() + + # We assume here that naive timestamps are in UTC timezone. + return value.replace(tzinfo=datetime.timezone.utc).isoformat() + elif isinstance(value, np.timedelta64): + # Return time delta in seconds. + return float(value / np.timedelta64(1, "s")) + # This check must happen after processing np.timedelta64 and np.datetime64. + elif np.issubdtype(type(value), np.integer): + return int(value) + + return value + + +def fetch_geo_polygon(area: t.Union[list, str]) -> str: + """Calculates a geography polygon from an input area.""" + # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 + if isinstance(area, str): + # European area + if area == "E": + area = [73.5, -27, 33, 45] + # Global area + elif area == "G": + area = GLOBAL_COVERAGE_AREA + else: + raise RuntimeError(f"Not a valid value for area in config: {area}.") + + n, w, s, e = [float(x) for x in area] + if s < LATITUDE_RANGE[0]: + raise ValueError(f"Invalid latitude value for south: '{s}'") + if n > LATITUDE_RANGE[1]: + raise ValueError(f"Invalid latitude value for north: '{n}'") + if w < LONGITUDE_RANGE[0]: + raise ValueError(f"Invalid longitude value for west: '{w}'") + if e > LONGITUDE_RANGE[1]: + raise ValueError(f"Invalid longitude value for east: '{e}'") + + # Define the coordinates of the bounding box. + coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] + + # Create the GeoJSON polygon object. + polygon = geojson.dumps(geojson.Polygon([coords])) + return polygon + + +def get_file_size(path: str) -> float: + parsed_gcs_path = urlparse(path) + if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "": + return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0 + else: + return ( + gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0 + ) + + +def get_wait_interval(num_retries: int = 0) -> float: + """Returns next wait interval in seconds, using an exponential backoff algorithm.""" + if 0 == num_retries: + return 0 + return 2**num_retries + + +def generate_md5_hash(input: str) -> str: + """Generates md5 hash for the input string.""" + return hashlib.md5(input.encode("utf-8")).hexdigest() + + +def download_with_aria2(url: str, path: str) -> None: + """Downloads a file from the given URL using the `aria2c` command-line utility, + with options set to improve download speed and reliability.""" + dir_path, file_name = os.path.split(path) + try: + subprocess.run( + [ + "aria2c", + "-x", + "16", + "-s", + "16", + url, + "-d", + dir_path, + "-o", + file_name, + "--allow-overwrite", + ], + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + logger.info( + f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.' + ) + raise + +class ThreadSafeDict: + """A thread safe dict with crud operations.""" + + + def __init__(self) -> None: + self._dict = {} + self._lock = Lock() + self.initial_delay = 1 + self.factor = 0.5 + + + def __getitem__(self, key): + val = None + with self._lock: + val = self._dict[key] + return val + + + def __setitem__(self, key, value): + with self._lock: + self._dict[key] = value + + + def remove(self, key): + with self._lock: + self._dict.__delitem__(key) + + + def has_key(self, key): + present = False + with self._lock: + present = key in self._dict + return present + + + def increment(self, key, delta=1): + with self._lock: + if key in self._dict: + self._dict[key] += delta + + + def decrement(self, key, delta=1): + with self._lock: + if key in self._dict: + self._dict[key] -= delta + + + def find_exponential_delay(self, n: int) -> int: + delay = self.initial_delay + for _ in range(n): + delay += delay*self.factor + return delay + + + def exponential_time(self, key): + """Returns exponential time based on dict value. Time in seconds.""" + delay = 0 + with self._lock: + if key in self._dict: + delay = self.find_exponential_delay(self._dict[key]) + return delay * 60 diff --git a/weather_mv/README.md b/weather_mv/README.md index ce39fbfb..a1045eab 100644 --- a/weather_mv/README.md +++ b/weather_mv/README.md @@ -61,7 +61,8 @@ usage: weather-mv bigquery [-h] -i URIS [--topic TOPIC] [--window_size WINDOW_SI -o OUTPUT_TABLE [-v variables [variables ...]] [-a area [area ...]] [--import_time IMPORT_TIME] [--infer_schema] [--xarray_open_dataset_kwargs XARRAY_OPEN_DATASET_KWARGS] - [--tif_metadata_for_datetime TIF_METADATA_FOR_DATETIME] [-s] + [--tif_metadata_for_start_time TIF_METADATA_FOR_START_TIME] + [--tif_metadata_for_end_time TIF_METADATA_FOR_END_TIME] [-s] [--coordinate_chunk_size COORDINATE_CHUNK_SIZE] ['--skip_creating_polygon'] ``` @@ -80,7 +81,8 @@ _Command options_: * `--xarray_open_dataset_kwargs`: Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string. * `--coordinate_chunk_size`: The size of the chunk of coordinates used for extracting vector data into BigQuery. Used to tune parallel uploads. -* `--tif_metadata_for_datetime` : Metadata that contains tif file's timestamp. Applicable only for tif files. +* `--tif_metadata_for_start_time` : Metadata that contains tif file's start/initialization time. Applicable only for tif files. +* `--tif_metadata_for_end_time` : Metadata that contains tif file's end/forecast time. Applicable only for tif files (optional). * `-s, --skip-region-validation` : Skip validation of regions for data migration. Default: off. * `--disable_grib_schema_normalization` : To disable grib's schema normalization. Default: off. * `--skip_creating_polygon` : Not ingest grid points as polygons in BigQuery. Default: Ingest grid points as Polygon in @@ -139,7 +141,8 @@ weather-mv bq --uris "gs://your-bucket/*.tif" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 \ - --tif_metadata_for_datetime start_time + --tif_metadata_for_start_time start_time \ + --tif_metadata_for_end_time end_time ``` Upload only a subset of variables: @@ -162,6 +165,39 @@ weather-mv bq --uris "gs://your-bucket/*.nc" \ --direct_num_workers 2 ``` +Upload a zarr file: + +```bash +weather-mv bq --uris "gs://your-bucket/*.zarr" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ + --use-local-code \ + --zarr \ + --direct_num_workers 2 +``` + +Upload a specific date range's data from the zarr file: + +```bash +weather-mv bq --uris "gs://your-bucket/*.zarr" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ + --use-local-code \ + --zarr \ + --zarr_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \ + --direct_num_workers 2 +``` + +Upload a specific date range's data from the file: + +```bash +weather-mv bq --uris "gs://your-bucket/*.nc" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ + --use-local-code \ + --xarray_open_dataset_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \ +``` + Control how weather data is opened with XArray: ```bash diff --git a/weather_mv/loader_pipeline/bq.py b/weather_mv/loader_pipeline/bq.py index c23a7f44..3b2e15ba 100644 --- a/weather_mv/loader_pipeline/bq.py +++ b/weather_mv/loader_pipeline/bq.py @@ -24,8 +24,10 @@ import geojson import numpy as np import xarray as xr +import xarray_beam as xbeam from apache_beam.io import WriteToBigQuery, BigQueryDisposition from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms import window from google.cloud import bigquery from xarray.core.utils import ensure_us_time_resolution @@ -75,8 +77,10 @@ class ToBigQuery(ToDataSink): infer_schema: If true, this sink will attempt to read in an example data file read all its variables, and generate a BigQuery schema. xarray_open_dataset_kwargs: A dictionary of kwargs to pass to xr.open_dataset(). - tif_metadata_for_datetime: If the input is a .tif file, parse the tif metadata at - this location for a timestamp. + tif_metadata_for_start_time: If the input is a .tif file, parse the tif metadata at + this location for a start time / initialization time. + tif_metadata_for_end_time: If the input is a .tif file, parse the tif metadata at + this location for a end/forecast time. skip_region_validation: Turn off validation that checks if all Cloud resources are in the same region. disable_grib_schema_normalization: Turn off grib's schema normalization; Default: normalization enabled. @@ -92,7 +96,8 @@ class ToBigQuery(ToDataSink): import_time: t.Optional[datetime.datetime] infer_schema: bool xarray_open_dataset_kwargs: t.Dict - tif_metadata_for_datetime: t.Optional[str] + tif_metadata_for_start_time: t.Optional[str] + tif_metadata_for_end_time: t.Optional[str] skip_region_validation: bool disable_grib_schema_normalization: bool coordinate_chunk_size: int = 10_000 @@ -123,8 +128,11 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser): 'off') subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}', help='Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string.') - subparser.add_argument('--tif_metadata_for_datetime', type=str, default=None, - help='Metadata that contains tif file\'s timestamp. ' + subparser.add_argument('--tif_metadata_for_start_time', type=str, default=None, + help='Metadata that contains tif file\'s start/initialization time. ' + 'Applicable only for tif files.') + subparser.add_argument('--tif_metadata_for_end_time', type=str, default=None, + help='Metadata that contains tif file\'s end/forecast time. ' 'Applicable only for tif files.') subparser.add_argument('-s', '--skip-region-validation', action='store_true', default=False, help='Skip validation of regions for data migration. Default: off') @@ -148,10 +156,14 @@ def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.Lis # Check that all arguments are supplied for COG input. _, uri_extension = os.path.splitext(known_args.uris) - if uri_extension == '.tif' and not known_args.tif_metadata_for_datetime: - raise RuntimeError("'--tif_metadata_for_datetime' is required for tif files.") - elif uri_extension != '.tif' and known_args.tif_metadata_for_datetime: - raise RuntimeError("'--tif_metadata_for_datetime' can be specified only for tif files.") + if (uri_extension in ['.tif', '.tiff'] and not known_args.tif_metadata_for_start_time): + raise RuntimeError("'--tif_metadata_for_start_time' is required for tif files.") + elif uri_extension not in ['.tif', '.tiff'] and ( + known_args.tif_metadata_for_start_time + or known_args.tif_metadata_for_end_time + ): + raise RuntimeError("'--tif_metadata_for_start_time' and " + "'--tif_metadata_for_end_time' can be specified only for tif files.") # Check that Cloud resource regions are consistent. if not (known_args.dry_run or known_args.skip_region_validation): @@ -166,8 +178,8 @@ def __post_init__(self): if self.zarr: self.xarray_open_dataset_kwargs = self.zarr_kwargs with open_dataset(self.first_uri, self.xarray_open_dataset_kwargs, - self.disable_grib_schema_normalization, self.tif_metadata_for_datetime, - is_zarr=self.zarr) as open_ds: + self.disable_grib_schema_normalization, self.tif_metadata_for_start_time, + self.tif_metadata_for_end_time, is_zarr=self.zarr) as open_ds: if not self.skip_creating_polygon: logger.warning("Assumes that equal distance between consecutive points of latitude " @@ -219,7 +231,7 @@ def prepare_coordinates(self, uri: str) -> t.Iterator[t.Tuple[str, t.List[t.Dict logger.info(f'Preparing coordinates for: {uri!r}.') with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, - self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds: + self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) if self.area: n, w, s, e = self.area @@ -238,75 +250,100 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, - self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds: + self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) + yield from self.to_rows(coordinates, data_ds, uri) - first_ts_raw = data_ds.time[0].values if isinstance(data_ds.time.values, - np.ndarray) else data_ds.time.values - first_time_step = to_json_serializable_type(first_ts_raw) - - for it in coordinates: - # Use those index values to select a Dataset containing one row of data. - row_ds = data_ds.loc[it] - - # Create a Name-Value map for data columns. Result looks like: - # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None} - row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values)) - for n, v in row_ds.data_vars.items()} - - # Serialize coordinates. - it = {k: to_json_serializable_type(v) for k, v in it.items()} - - # Add indexed coordinates. - row.update(it) - # Add un-indexed coordinates. - for c in row_ds.coords: - if c not in it and (not self.variables or c in self.variables): - row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values)) - - # Add import metadata. - row[DATA_IMPORT_TIME_COLUMN] = self.import_time - row[DATA_URI_COLUMN] = uri - row[DATA_FIRST_STEP] = first_time_step - - longitude = ((row['longitude'] + 180) % 360) - 180 - row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude) - row[GEO_POLYGON_COLUMN] = ( - fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution) - if not self.skip_creating_polygon - else None - ) - # 'row' ends up looking like: - # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812, - # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...} - beam.metrics.Metrics.counter('Success', 'ExtractRows').inc() - yield row + def to_rows(self, coordinates: t.Iterable[t.Dict], ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]: + first_ts_raw = ( + ds.time[0].values if isinstance(ds.time.values, np.ndarray) + else ds.time.values + ) + first_time_step = to_json_serializable_type(first_ts_raw) + for it in coordinates: + # Use those index values to select a Dataset containing one row of data. + row_ds = ds.loc[it] + + # Create a Name-Value map for data columns. Result looks like: + # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None} + row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values)) + for n, v in row_ds.data_vars.items()} + + # Serialize coordinates. + it = {k: to_json_serializable_type(v) for k, v in it.items()} + + # Add indexed coordinates. + row.update(it) + # Add un-indexed coordinates. + for c in row_ds.coords: + if c not in it and (not self.variables or c in self.variables): + row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values)) + + # Add import metadata. + row[DATA_IMPORT_TIME_COLUMN] = self.import_time + row[DATA_URI_COLUMN] = uri + row[DATA_FIRST_STEP] = first_time_step + + longitude = ((row['longitude'] + 180) % 360) - 180 + row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude) + row[GEO_POLYGON_COLUMN] = ( + fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution) + if not self.skip_creating_polygon + else None + ) + # 'row' ends up looking like: + # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812, + # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...} + beam.metrics.Metrics.counter('Success', 'ExtractRows').inc() + yield row + + def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]: + uri = ds.attrs.get(DATA_URI_COLUMN, '') + # Re-calculate import time for streaming extractions. + if not self.import_time or self.zarr: + self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + yield from self.to_rows(get_coordinates(ds, uri), ds, uri) def expand(self, paths): """Extract rows of variables from data paths into a BigQuery table.""" - extracted_rows = ( + if not self.zarr: + extracted_rows = ( paths | 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates) | beam.Reshuffle() | 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows) - ) - - if not self.dry_run: - ( - extracted_rows - | 'WriteToBigQuery' >> WriteToBigQuery( - project=self.table.project, - dataset=self.table.dataset_id, - table=self.table.table_id, - write_disposition=BigQueryDisposition.WRITE_APPEND, - create_disposition=BigQueryDisposition.CREATE_NEVER) ) else: - ( - extracted_rows - | 'Log Extracted Rows' >> beam.Map(logger.debug) + xarray_open_dataset_kwargs = self.xarray_open_dataset_kwargs.copy() + xarray_open_dataset_kwargs.pop('chunks') + start_date = xarray_open_dataset_kwargs.pop('start_date', None) + end_date = xarray_open_dataset_kwargs.pop('end_date', None) + ds, chunks = xbeam.open_zarr(self.first_uri, **xarray_open_dataset_kwargs) + + if start_date is not None and end_date is not None: + ds = ds.sel(time=slice(start_date, end_date)) + + ds.attrs[DATA_URI_COLUMN] = self.first_uri + extracted_rows = ( + paths + | 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks) + | 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows) + | 'Window' >> beam.WindowInto(window.FixedWindows(60)) + | 'AddTimestamp' >> beam.Map(timestamp_row) ) + if self.dry_run: + return extracted_rows | 'Log Rows' >> beam.Map(logger.info) + return ( + extracted_rows + | 'WriteToBigQuery' >> WriteToBigQuery( + project=self.table.project, + dataset=self.table.dataset_id, + table=self.table.table_id, + write_disposition=BigQueryDisposition.WRITE_APPEND, + create_disposition=BigQueryDisposition.CREATE_NEVER) + ) + def map_dtype_to_sql_type(var_type: np.dtype) -> str: """Maps a np.dtype to a suitable BigQuery column type.""" @@ -342,11 +379,17 @@ def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.Schem fields.append(bigquery.SchemaField(DATA_URI_COLUMN, 'STRING', mode='NULLABLE')) fields.append(bigquery.SchemaField(DATA_FIRST_STEP, 'TIMESTAMP', mode='NULLABLE')) fields.append(bigquery.SchemaField(GEO_POINT_COLUMN, 'GEOGRAPHY', mode='NULLABLE')) - fields.append(bigquery.SchemaField(GEO_POLYGON_COLUMN, 'STRING', mode='NULLABLE')) + fields.append(bigquery.SchemaField(GEO_POLYGON_COLUMN, 'GEOGRAPHY', mode='NULLABLE')) return fields +def timestamp_row(it: t.Dict) -> window.TimestampedValue: + """Associate an extracted row with the import_time timestamp.""" + timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp() + return window.TimestampedValue(it, timestamp) + + def fetch_geo_point(lat: float, long: float) -> str: """Calculates a geography point from an input latitude and longitude.""" if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]: @@ -371,13 +414,13 @@ def fetch_geo_polygon(latitude: float, longitude: float, lat_grid_resolution: fl The `get_lat_lon_range` function gives the `.` point and `bound_point` gives the `*` point. """ lat_lon_bound = bound_point(latitude, longitude, lat_grid_resolution, lon_grid_resolution) - polygon = geojson.dumps(geojson.Polygon([ + polygon = geojson.dumps(geojson.Polygon([[ (lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left (lat_lon_bound[1][0], lat_lon_bound[1][1]), # upper_left (lat_lon_bound[2][0], lat_lon_bound[2][1]), # upper_right (lat_lon_bound[3][0], lat_lon_bound[3][1]), # lower_right (lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left - ])) + ]])) return polygon diff --git a/weather_mv/loader_pipeline/bq_test.py b/weather_mv/loader_pipeline/bq_test.py index ed96cd9d..fae7ab31 100644 --- a/weather_mv/loader_pipeline/bq_test.py +++ b/weather_mv/loader_pipeline/bq_test.py @@ -15,6 +15,7 @@ import json import logging import os +import tempfile import typing as t import unittest @@ -23,6 +24,8 @@ import pandas as pd import simplejson import xarray as xr +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, is_not_empty from google.cloud.bigquery import SchemaField from .bq import ( @@ -75,7 +78,7 @@ def test_schema_generation(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -91,7 +94,7 @@ def test_schema_generation__with_schema_normalization(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -107,7 +110,7 @@ def test_schema_generation__with_target_columns(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -123,7 +126,7 @@ def test_schema_generation__with_target_columns__with_schema_normalization(self) SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -140,7 +143,7 @@ def test_schema_generation__no_targets_specified(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -157,7 +160,7 @@ def test_schema_generation__no_targets_specified__with_schema_normalization(self SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -191,7 +194,7 @@ def test_schema_generation__non_index_coords(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), - SchemaField('geo_polygon', 'STRING', 'NULLABLE', None, (), None) + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -201,17 +204,18 @@ class ExtractRowsTestBase(TestDataBase): def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=None, import_time=DEFAULT_IMPORT_TIME, disable_grib_schema_normalization=False, - tif_metadata_for_datetime=None, zarr: bool = False, zarr_kwargs=None, + tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, zarr: bool = False, zarr_kwargs=None, skip_creating_polygon: bool = False) -> t.Iterator[t.Dict]: if zarr_kwargs is None: zarr_kwargs = {} - op = ToBigQuery.from_kwargs(first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, - output_table='foo.bar.baz', variables=variables, area=area, - xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, - infer_schema=False, tif_metadata_for_datetime=tif_metadata_for_datetime, - skip_region_validation=True, - disable_grib_schema_normalization=disable_grib_schema_normalization, - coordinate_chunk_size=1000, skip_creating_polygon=skip_creating_polygon) + op = ToBigQuery.from_kwargs( + first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, + output_table='foo.bar.baz', variables=variables, area=area, + xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False, + tif_metadata_for_start_time=tif_metadata_for_start_time, + tif_metadata_for_end_time=tif_metadata_for_end_time, skip_region_validation=True, + disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000, + skip_creating_polygon=skip_creating_polygon) coords = op.prepare_coordinates(data_path) for uri, chunk in coords: yield from op.extract_rows(uri, chunk) @@ -398,18 +402,18 @@ def test_extract_rows__with_valid_lat_long_with_polygon(self): valid_lat_long = [[-90, 0], [-90, -180], [-45, -180], [-45, 180], [0, 0], [90, 180], [45, -180], [-90, 180], [90, 1], [0, 180], [1, -180], [90, -180]] actual_val = [ - '{"type": "Polygon", "coordinates": [[-1, 89], [-1, -89], [1, -89], [1, 89], [-1, 89]]}', - '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}', - '{"type": "Polygon", "coordinates": [[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]}', - '{"type": "Polygon", "coordinates": [[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]}', - '{"type": "Polygon", "coordinates": [[-1, -1], [-1, 1], [1, 1], [1, -1], [-1, -1]]}', - '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}', - '{"type": "Polygon", "coordinates": [[179, 44], [179, 46], [-179, 46], [-179, 44], [179, 44]]}', - '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}', - '{"type": "Polygon", "coordinates": [[0, 89], [0, -89], [2, -89], [2, 89], [0, 89]]}', - '{"type": "Polygon", "coordinates": [[179, -1], [179, 1], [-179, 1], [-179, -1], [179, -1]]}', - '{"type": "Polygon", "coordinates": [[179, 0], [179, 2], [-179, 2], [-179, 0], [179, 0]]}', - '{"type": "Polygon", "coordinates": [[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]}' + '{"type": "Polygon", "coordinates": [[[-1, 89], [-1, -89], [1, -89], [1, 89], [-1, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}', + '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}', + '{"type": "Polygon", "coordinates": [[[-1, -1], [-1, 1], [1, 1], [1, -1], [-1, -1]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, 44], [179, 46], [-179, 46], [-179, 44], [179, 44]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', + '{"type": "Polygon", "coordinates": [[[0, 89], [0, -89], [2, -89], [2, 89], [0, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, -1], [179, 1], [-179, 1], [-179, -1], [179, -1]]]}', + '{"type": "Polygon", "coordinates": [[[179, 0], [179, 2], [-179, 2], [-179, 0], [179, 0]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}' ] lat_grid_resolution = 1 lon_grid_resolution = 1 @@ -469,10 +473,35 @@ class ExtractRowsTifSupportTest(ExtractRowsTestBase): def setUp(self) -> None: super().setUp() - self.test_data_path = f'{self.test_data_folder}/test_data_tif_start_time.tif' + self.test_data_path = f'{self.test_data_folder}/test_data_tif_time.tif' - def test_extract_rows(self): - actual = next(self.extract(self.test_data_path, tif_metadata_for_datetime='start_time')) + def test_extract_rows_with_end_time(self): + actual = next( + self.extract(self.test_data_path, tif_metadata_for_start_time='start_time', + tif_metadata_for_end_time='end_time') + ) + expected = { + 'dewpoint_temperature_2m': 281.09349060058594, + 'temperature_2m': 296.8329772949219, + 'data_import_time': '1970-01-01T00:00:00+00:00', + 'data_first_step': '2020-07-01T00:00:00+00:00', + 'data_uri': self.test_data_path, + 'latitude': 42.09783344918844, + 'longitude': -123.66686981141397, + 'time': '2020-07-01T00:00:00+00:00', + 'valid_time': '2020-07-01T00:00:00+00:00', + 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-123.669853, 42.095605), (-123.669853, 42.100066), + (-123.663885, 42.100066), (-123.663885, 42.095605), + (-123.669853, 42.095605)])) + } + self.assertRowsEqual(actual, expected) + + def test_extract_rows_without_end_time(self): + actual = next( + self.extract(self.test_data_path, tif_metadata_for_start_time='start_time') + ) expected = { 'dewpoint_temperature_2m': 281.09349060058594, 'temperature_2m': 296.8329772949219, @@ -737,5 +766,38 @@ def test_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema self.assertRowsEqual(actual, expected) +class ExtractRowsFromZarrTest(ExtractRowsTestBase): + + def setUp(self) -> None: + super().setUp() + self.tmpdir = tempfile.TemporaryDirectory() + + def tearDown(self) -> None: + super().tearDown() + self.tmpdir.cleanup() + + def test_extracts_rows(self): + input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr') + + ds = ( + xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder) + .isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4)) + .rename(dict(lon='longitude', lat='latitude')) + ) + ds.to_zarr(input_zarr) + + op = ToBigQuery.from_kwargs( + first_uri=input_zarr, zarr_kwargs=dict(chunks=None, consolidated=True), dry_run=True, zarr=True, + output_table='foo.bar.baz', + variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False, + tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, skip_region_validation=True, + disable_grib_schema_normalization=False, + ) + + with TestPipeline() as p: + result = p | op + assert_that(result, is_not_empty()) + + if __name__ == '__main__': unittest.main() diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index 9ac18327..a5d6bca0 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import csv import dataclasses import json import logging +import math import os import re import shutil @@ -27,7 +29,6 @@ import apache_beam as beam import ee import numpy as np -import xarray as xr from apache_beam.io.filesystems import FileSystems from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE from apache_beam.options.pipeline_options import PipelineOptions @@ -36,8 +37,8 @@ from google.auth.transport import requests from rasterio.io import MemoryFile -from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin -from .util import make_attrs_ee_compatible, RateLimit, validate_region +from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload +from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp logger = logging.getLogger(__name__) @@ -51,6 +52,7 @@ 'IMAGE': '.tiff', 'TABLE': '.csv' } +ROWS_PER_WRITE = 10_000 # Number of rows per feature collection write. def is_compute_engine() -> bool: @@ -155,7 +157,12 @@ def setup(self): def check_setup(self): """Ensures that setup has been called.""" if not self._has_setup: - self.setup() + try: + # This throws an exception if ee is not initialized. + ee.data.getAlgorithms() + self._has_setup = True + except ee.EEException: + self.setup() def process(self, *args, **kwargs): """Checks that setup has been called then call the process implementation.""" @@ -438,6 +445,8 @@ def add_to_queue(self, queue: Queue, item: t.Any): def convert_to_asset(self, queue: Queue, uri: str): """Converts source data into EE asset (GeoTiff or CSV) and uploads it to the bucket.""" logger.info(f'Converting {uri!r} to COGs...') + job_start_time = get_utc_timestamp() + with open_dataset(uri, self.open_dataset_kwargs, self.disable_grib_schema_normalization, @@ -457,6 +466,8 @@ def convert_to_asset(self, queue: Queue, uri: str): ('start_time', 'end_time', 'is_normalized','forecast_hour')) dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform']) attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool. + # Adding job_start_time to properites. + attrs["job_start_time"] = job_start_time # Make attrs EE ingestable. attrs = make_attrs_ee_compatible(attrs) @@ -507,21 +518,40 @@ def convert_to_asset(self, queue: Queue, uri: str): channel_names = [] file_name = f'{asset_name}.csv' - df = xr.Dataset.to_dataframe(ds) - df = df.reset_index() - # NULL and NaN create data-type mismatch issue in ee therefore replacing all of them. - # fillna fills in NaNs, NULLs, and NaTs but we have to exclude NaTs. - non_nat = df.select_dtypes(exclude=['datetime', 'timedelta', 'datetimetz']) - df[non_nat.columns] = non_nat.fillna(-9999) + shape = math.prod(list(ds.dims.values())) + # Names of dimesions, coordinates and data variables. + dims = list(ds.dims) + coords = [c for c in list(ds.coords) if c not in dims] + vars = list(ds.data_vars) + header = dims + coords + vars + + # Data of dimesions, coordinates and data variables. + dims_data = [ds[dim].data for dim in dims] + coords_data = [np.full((shape,), ds[coord].data) for coord in coords] + vars_data = [ds[var].data.flatten() for var in vars] + data = coords_data + vars_data + + dims_shape = [len(ds[dim].data) for dim in dims] - # Copy in-memory dataframe to gcs. + def get_dims_data(index: int) -> t.List[t.Any]: + """Returns dimensions for the given flattened index.""" + return [ + dim[int(index / math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data) + ] + + # Copy CSV to gcs. target_path = os.path.join(self.asset_location, file_name) - with tempfile.NamedTemporaryFile() as tmp_df: - df.to_csv(tmp_df.name, index=False) - tmp_df.flush() - tmp_df.seek(0) - with FileSystems().create(target_path) as dst: - shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE) + with tempfile.NamedTemporaryFile() as temp: + with open(temp.name, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerows([header]) + # Write rows in batches. + for i in range(0, shape, ROWS_PER_WRITE): + writer.writerows( + [get_dims_data(i) + list(row) for row in zip(*[d[i:i + ROWS_PER_WRITE] for d in data])] + ) + + upload(temp.name, target_path) asset_data = AssetData( name=asset_name, target_path=target_path, @@ -626,6 +656,8 @@ def start_ingestion(self, asset_request: t.Dict) -> str: """Creates COG-backed asset in earth engine. Returns the asset id.""" self.check_setup() + asset_request['properties']['ingestion_time'] = get_utc_timestamp() + try: if self.ee_asset_type == 'IMAGE': result = ee.data.createAsset(asset_request) diff --git a/weather_mv/loader_pipeline/pipeline.py b/weather_mv/loader_pipeline/pipeline.py index c12bd5f5..ef685473 100644 --- a/weather_mv/loader_pipeline/pipeline.py +++ b/weather_mv/loader_pipeline/pipeline.py @@ -17,6 +17,7 @@ import json import logging import typing as t +import warnings import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -27,7 +28,7 @@ from .streaming import GroupMessagesByFixedWindows, ParsePaths logger = logging.getLogger(__name__) -SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' +SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0' def configure_logger(verbosity: int) -> None: @@ -55,8 +56,9 @@ def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None known_args.first_uri = next(iter(all_uris)) with beam.Pipeline(argv=pipeline_args) as p: - if known_args.topic or known_args.subscription: - + if known_args.zarr: + paths = p + elif known_args.topic or known_args.subscription: paths = ( p # Windowing is based on this code sample: @@ -140,11 +142,19 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: # Validate Zarr arguments if known_args.uris.endswith('.zarr'): known_args.zarr = True - known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None) if known_args.zarr_kwargs and not known_args.zarr: raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.') + if known_args.zarr_kwargs: + if not known_args.zarr_kwargs.get('start_date') or not known_args.zarr_kwargs.get('end_date'): + warnings.warn('`--zarr_kwargs` not contains both `start_date` and `end_date`' + 'so whole zarr-dataset will ingested.') + + if known_args.zarr: + known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None) + known_args.zarr_kwargs['consolidated'] = known_args.zarr_kwargs.get('consolidated', True) + # Validate subcommand if known_args.subcommand == 'bigquery' or known_args.subcommand == 'bq': ToBigQuery.validate_arguments(known_args, pipeline_args) diff --git a/weather_mv/loader_pipeline/pipeline_test.py b/weather_mv/loader_pipeline/pipeline_test.py index 4d546192..3834b537 100644 --- a/weather_mv/loader_pipeline/pipeline_test.py +++ b/weather_mv/loader_pipeline/pipeline_test.py @@ -30,7 +30,7 @@ def setUp(self) -> None: ).split() self.tif_base_cli_args = ( 'weather-mv bq ' - f'-i {self.test_data_folder}/test_data_tif_start_time.tif ' + f'-i {self.test_data_folder}/test_data_tif_time.tif ' '-o myproject.mydataset.mytable ' '--import_time 2022-02-04T22:22:12.125893 ' '-s' @@ -62,7 +62,8 @@ def setUp(self) -> None: 'xarray_open_dataset_kwargs': {}, 'coordinate_chunk_size': 10_000, 'disable_grib_schema_normalization': False, - 'tif_metadata_for_datetime': None, + 'tif_metadata_for_start_time': None, + 'tif_metadata_for_end_time': None, 'zarr': False, 'zarr_kwargs': {}, 'log_level': 2, @@ -83,7 +84,8 @@ def test_log_level_arg(self): def test_tif_metadata_for_datetime_raise_error_for_non_tif_file(self): with self.assertRaisesRegex(RuntimeError, 'can be specified only for tif files.'): - run(self.base_cli_args + '--tif_metadata_for_datetime start_time'.split()) + run(self.base_cli_args + '--tif_metadata_for_start_time start_time ' + '--tif_metadata_for_end_time end_time'.split()) def test_tif_metadata_for_datetime_raise_error_if_flag_is_absent(self): with self.assertRaisesRegex(RuntimeError, 'is required for tif files.'): diff --git a/weather_mv/loader_pipeline/regrid_test.py b/weather_mv/loader_pipeline/regrid_test.py index 5cc5b2a1..87ffaad4 100644 --- a/weather_mv/loader_pipeline/regrid_test.py +++ b/weather_mv/loader_pipeline/regrid_test.py @@ -122,7 +122,7 @@ def test_zarr__coarsen(self): self.Op, first_uri=input_zarr, output_path=output_zarr, - zarr_input_chunks={"time": 5}, + zarr_input_chunks={"time": 25}, zarr=True ) diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 8d4f5681..bc064599 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -141,8 +141,9 @@ def rearrange_time_list(order_list: t.List, time_list: t.List) -> t.List: return datetime.datetime(*time_list) -def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_datetime: str, uri: str, - band_names_dict: t.Dict, initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset: +def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_start_time: str, + tif_metadata_for_end_time: str, uri: str, band_names_dict: t.Dict, + initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset: """Transforms (y, x) coordinates into (lat, long) and adds bands data in data variables. This also retrieves datetime from tif's metadata and stores it into dataset. @@ -165,6 +166,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): ds = _replace_dataarray_names_with_long_names(ds) end_time = None + start_time = None if initialization_time_regex and forecast_time_regex: try: start_time = match_datetime(uri, initialization_time_regex) @@ -177,15 +179,40 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): ds.attrs['start_time'] = start_time ds.attrs['end_time'] = end_time - datetime_value_ms = None + init_time = None + forecast_time = None + coords = {} try: - datetime_value_s = (int(end_time.timestamp()) if end_time is not None - else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0) - ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)}) - except KeyError: - raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.") + # if start_time/end_time is in integer milliseconds + init_time = (int(start_time.timestamp()) if start_time is not None + else int(ds.attrs[tif_metadata_for_start_time]) / 1000.0) + coords['time'] = datetime.datetime.utcfromtimestamp(init_time) + + if tif_metadata_for_end_time: + forecast_time = (int(end_time.timestamp()) if end_time is not None + else int(ds.attrs[tif_metadata_for_end_time]) / 1000.0) + coords['valid_time'] = datetime.datetime.utcfromtimestamp(forecast_time) + + ds = ds.assign_coords(coords) + except KeyError as e: + raise RuntimeError(f"Invalid datetime metadata of tif: {e}.") except ValueError: - raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.") + try: + # if start_time/end_time is in UTC string format + init_time = (int(start_time.timestamp()) if start_time is not None + else datetime.datetime.strptime(ds.attrs[tif_metadata_for_start_time], + '%Y-%m-%dT%H:%M:%SZ')) + coords['time'] = init_time + + if tif_metadata_for_end_time: + forecast_time = (int(end_time.timestamp()) if end_time is not None + else datetime.datetime.strptime(ds.attrs[tif_metadata_for_end_time], + '%Y-%m-%dT%H:%M:%SZ')) + coords['valid_time'] = forecast_time + + ds = ds.assign_coords(coords) + except ValueError as e: + raise RuntimeError(f"Invalid datetime value in tif's metadata: {e}.") return ds @@ -583,6 +610,11 @@ def __open_dataset_file(filename: str, False) +def upload(src: str, dst: str) -> None: + """Uploads a file to the specified GCS bucket destination.""" + subprocess.run(f'gsutil -m cp {src} {dst}'.split(), check=True, capture_output=True, text=True, input="n/n") + + def copy(src: str, dst: str) -> None: """Copy data via `gcloud alpha storage` or `gsutil`.""" errors: t.List[subprocess.CalledProcessError] = [] @@ -624,7 +656,8 @@ def open_local(uri: str) -> t.Iterator[str]: def open_dataset(uri: str, open_dataset_kwargs: t.Optional[t.Dict] = None, disable_grib_schema_normalization: bool = False, - tif_metadata_for_datetime: t.Optional[str] = None, + tif_metadata_for_start_time: t.Optional[str] = None, + tif_metadata_for_end_time: t.Optional[str] = None, band_names_dict: t.Optional[t.Dict] = None, initialization_time_regex: t.Optional[str] = None, forecast_time_regex: t.Optional[str] = None, @@ -632,8 +665,17 @@ def open_dataset(uri: str, tiff_config: t.Optional[t.Dict] = None,) -> t.Iterator[t.Union[xr.Dataset, t.List[xr.Dataset]]]: """Open the dataset at 'uri' and return a xarray.Dataset.""" try: + local_open_dataset_kwargs = start_date = end_date = None + if open_dataset_kwargs is not None: + local_open_dataset_kwargs = open_dataset_kwargs.copy() + start_date = local_open_dataset_kwargs.pop('start_date', None) + end_date = local_open_dataset_kwargs.pop('end_date', None) + if is_zarr: - ds: xr.Dataset = xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs) + ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr', + **local_open_dataset_kwargs), False) + if start_date is not None and end_date is not None: + ds = ds.sel(time=slice(start_date, end_date)) beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc() yield ds ds.close() @@ -641,11 +683,11 @@ def open_dataset(uri: str, with open_local(uri) as local_path: _, uri_extension = os.path.splitext(uri) - xr_datasets: xr.Dataset = __open_dataset_file(local_path, - uri_extension, - disable_grib_schema_normalization, - open_dataset_kwargs, - tiff_config) + xr_dataset: xr.Dataset = __open_dataset_file(local_path, + uri_extension, + disable_grib_schema_normalization, + local_open_dataset_kwargs, + tiff_config) # Extracting dtype, crs and transform from the dataset & storing them as attributes. try: with rasterio.open(local_path, 'r') as f: @@ -661,14 +703,17 @@ def open_dataset(uri: str, logger.info(f'opened dataset size: {total_size_in_bytes}') else: + if start_date is not None and end_date is not None: + xr_dataset = xr_dataset.sel(time=slice(start_date, end_date)) if uri_extension in ['.tif', '.tiff']: xr_dataset = _preprocess_tif(xr_datasets, - local_path, - tif_metadata_for_datetime, - uri, - band_names_dict, - initialization_time_regex, - forecast_time_regex) + local_path, + tif_metadata_for_start_time, + tif_metadata_for_end_time, + uri, + band_names_dict, + initialization_time_regex, + forecast_time_regex) else: xr_dataset = xr_datasets xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) diff --git a/weather_mv/loader_pipeline/sinks_test.py b/weather_mv/loader_pipeline/sinks_test.py index 01cdce9a..69150dda 100644 --- a/weather_mv/loader_pipeline/sinks_test.py +++ b/weather_mv/loader_pipeline/sinks_test.py @@ -84,7 +84,7 @@ def setUp(self) -> None: super().setUp() self.test_data_path = os.path.join(self.test_data_folder, 'test_data_20180101.nc') self.test_grib_path = os.path.join(self.test_data_folder, 'test_data_grib_single_timestep') - self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_start_time.tif') + self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_time.tif') self.test_zarr_path = os.path.join(self.test_data_folder, 'test_data.zarr') def test_opens_grib_files(self): @@ -104,7 +104,8 @@ def test_accepts_xarray_kwargs(self): self.assertDictContainsSubset({'is_normalized': False}, ds2.attrs) def test_opens_tif_files(self): - with open_dataset(self.test_tif_path, tif_metadata_for_datetime='start_time') as ds: + with open_dataset(self.test_tif_path, tif_metadata_for_start_time='start_time', + tif_metadata_for_end_time='end_time') as ds: self.assertIsNotNone(ds) self.assertDictContainsSubset({'is_normalized': False}, ds.attrs) @@ -112,6 +113,7 @@ def test_opens_zarr(self): with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds: self.assertIsNotNone(ds) self.assertEqual(list(ds.data_vars), ['cape', 'd2m']) + def test_open_dataset__fits_memory_bounds(self): with write_netcdf() as test_netcdf_path: with limit_memory(max_memory=30): diff --git a/weather_mv/loader_pipeline/streaming.py b/weather_mv/loader_pipeline/streaming.py index 7210b2e7..3a7a8f49 100644 --- a/weather_mv/loader_pipeline/streaming.py +++ b/weather_mv/loader_pipeline/streaming.py @@ -84,7 +84,7 @@ def try_parse_message(cls, message_body: t.Union[str, t.Dict]) -> t.Dict: try: return json.loads(message_body) except (json.JSONDecodeError, TypeError): - if type(message_body) is dict: + if isinstance(message_body, dict): return message_body raise diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py index a31a06a9..079b86de 100644 --- a/weather_mv/loader_pipeline/util.py +++ b/weather_mv/loader_pipeline/util.py @@ -28,7 +28,6 @@ import uuid from functools import partial from urllib.parse import urlparse - import apache_beam as beam import numpy as np import pandas as pd @@ -134,6 +133,9 @@ def _check_for_coords_vars(ds_data_var: str, target_var: str) -> bool: specified by the user.""" return ds_data_var.endswith('_'+target_var) or ds_data_var.startswith(target_var+'_') +def get_utc_timestamp() -> float: + """Returns the current UTC Timestamp.""" + return datetime.datetime.now().timestamp() def _only_target_coordinate_vars(ds: xr.Dataset, data_vars: t.List[str]) -> t.List[str]: """If the user specifies target fields in the dataset, get all the matching coords & data vars.""" diff --git a/weather_mv/loader_pipeline/util_test.py b/weather_mv/loader_pipeline/util_test.py index dae9c873..65d9169e 100644 --- a/weather_mv/loader_pipeline/util_test.py +++ b/weather_mv/loader_pipeline/util_test.py @@ -38,9 +38,10 @@ def test_gets_indexed_coordinates(self): ds = xr.open_dataset(self.test_data_path) self.assertEqual( next(get_coordinates(ds)), - {'latitude': 49.0, - 'longitude':-108.0, - 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} + { + 'latitude': 49.0, + 'longitude': -108.0, + 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} ) def test_no_duplicate_coordinates(self): @@ -91,24 +92,28 @@ def test_get_coordinates(self): actual, [ [ - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) }, - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) }, - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) }, ], [ - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) - } + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) + } ] ] ) diff --git a/weather_mv/setup.py b/weather_mv/setup.py index bfe09713..4bdb4a0b 100644 --- a/weather_mv/setup.py +++ b/weather_mv/setup.py @@ -45,6 +45,7 @@ "numpy==1.22.4", "pandas==1.5.1", "xarray==2023.1.0", + "xarray-beam==0.6.2", "cfgrib==0.9.10.2", "netcdf4==1.6.1", "geojson==2.5.0", @@ -55,6 +56,8 @@ "earthengine-api>=0.1.263", "pyproj==3.4.0", # requires separate binary installation! "gdal==3.5.1", # requires separate binary installation! + "gcsfs==2022.11.0", + "zarr==2.15.0", ] setup( @@ -62,7 +65,7 @@ packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', - version='0.2.17', + version='0.2.19', url='https://weather-tools.readthedocs.io/en/latest/weather_mv/', description='A tool to load weather data into BigQuery.', install_requires=beam_gcp_requirements + base_requirements, diff --git a/weather_mv/test_data/test_data_tif_start_time.tif b/weather_mv/test_data/test_data_tif_start_time.tif deleted file mode 100644 index 82f32dd7..00000000 Binary files a/weather_mv/test_data/test_data_tif_start_time.tif and /dev/null differ diff --git a/weather_mv/test_data/test_data_tif_time.tif b/weather_mv/test_data/test_data_tif_time.tif new file mode 100644 index 00000000..32b7f63b Binary files /dev/null and b/weather_mv/test_data/test_data_tif_time.tif differ diff --git a/weather_sp/setup.py b/weather_sp/setup.py index 59786c8e..e22279cd 100644 --- a/weather_sp/setup.py +++ b/weather_sp/setup.py @@ -44,7 +44,7 @@ packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', - version='0.3.1', + version='0.3.2', url='https://weather-tools.readthedocs.io/en/latest/weather_sp/', description='A tool to split weather data files into per-variable files.', install_requires=beam_gcp_requirements + base_requirements, diff --git a/weather_sp/splitter_pipeline/file_splitters.py b/weather_sp/splitter_pipeline/file_splitters.py index 6456c426..7a4d3c77 100644 --- a/weather_sp/splitter_pipeline/file_splitters.py +++ b/weather_sp/splitter_pipeline/file_splitters.py @@ -16,6 +16,7 @@ import itertools import logging import os +import re import shutil import string import subprocess @@ -158,6 +159,10 @@ class GribSplitterV2(GribSplitter): See https://confluence.ecmwf.int/display/ECC/grib_copy. """ + def replace_non_numeric_bracket(self, match: re.Match) -> str: + value = match.group(1) + return f"[{value}]" if not value.isdigit() else "{" + value + "}" + def split_data(self) -> None: if not self.output_info.split_dims(): raise ValueError('No splitting specified in template.') @@ -172,7 +177,10 @@ def split_data(self) -> None: unformatted_output_path = self.output_info.unformatted_output_path() prefix, _ = os.path.split(next(iter(string.Formatter().parse(unformatted_output_path)))[0]) _, tail = unformatted_output_path.split(prefix) - output_template = tail.replace('{', '[').replace('}', ']') + + # Replace { with [ and } with ] only for non-numeric values inside {} of tail + output_str = re.sub(r'\{(\w+)\}', self.replace_non_numeric_bracket, tail) + output_template = output_str.format(*self.output_info.template_folders) slash = '/' delimiter = 'DELIMITER'