diff --git a/src/hf_hydrodata/gridded.py b/src/hf_hydrodata/gridded.py index e841438..109e42d 100644 --- a/src/hf_hydrodata/gridded.py +++ b/src/hf_hydrodata/gridded.py @@ -42,9 +42,10 @@ HYDRODATA = "/hydrodata" -URL = "https://hydro-dev.princeton.edu/api/data-file?" +HYDRODATA_URL = os.getenv("HYDRODATA_URL", "https://hydro-dev.princeton.edu") -def register_api_pin(email:str, pin:str): + +def register_api_pin(email: str, pin: str): """ Register the email and pin that was created with the website in the users home directory. @@ -63,7 +64,8 @@ def register_api_pin(email:str, pin:str): stream.write("}") os.chmod(pin_path, 0o700) -def get_registered_api_pin()->Tuple[str, str]: + +def get_registered_api_pin() -> Tuple[str, str]: """ Get the email and pin registered by the current user. @@ -76,7 +78,9 @@ def get_registered_api_pin()->Tuple[str, str]: pin_dir = os.path.expanduser("~/.hydrodata") pin_path = f"{pin_dir}/pin.json" if not os.path.exists(pin_path): - raise ValueError("No email/pin was registered. Use the register_api() method to register the pin you created at the website.") + raise ValueError( + "No email/pin was registered. Use the register_api() method to register the pin you created at the website." + ) try: with open(pin_path, "r") as stream: contents = stream.read() @@ -85,7 +89,10 @@ def get_registered_api_pin()->Tuple[str, str]: pin = parsed_contents.get("pin") return (email, pin) except Exception as e: - raise ValueError("No email/pin was registered. Use the register_api() method to register the pin you created at the website.") from e + raise ValueError( + "No email/pin was registered. Use the register_api() method to register the pin you created at the website." + ) from e + def get_catalog_entries(*args, **kwargs) -> List[ModelTableRow]: """ @@ -516,17 +523,18 @@ def _write_file_from_api(filepath, options): """ q_params = _construct_string_from_options(options) - url = f"{URL}{q_params}" + datafile_url = f"{HYDRODATA_URL}/api/data-file?{q_params}" try: - response = requests.get(url, timeout=180) + headers = _validate_user() + response = requests.get(datafile_url, headers=headers, timeout=180) if response.status_code != 200: raise ValueError( - f"The url {url} returned error code {response.status_code}." + f"The datafile_url {datafile_url} returned error code {response.status_code}." ) except requests.exceptions.Timeout as e: - raise ValueError(f"The url {url} has timed out.") from e + raise ValueError(f"The datafile_url {datafile_url} has timed out.") from e file_obj = io.BytesIO(response.content) with open(filepath, "wb") as output_file: @@ -749,17 +757,20 @@ def _get_ndarray_from_api(entry, options, time_values): if run_remote: options = _convert_json_to_strings(options) q_params = _construct_string_from_qparams(entry, options) - url = f"https://hydro-dev.princeton.edu/api/gridded-data?{q_params}" + gridded_data_url = f"{HYDRODATA_URL}/api/gridded-data?{q_params}" try: - response = requests.get(url, timeout=180) + headers = _validate_user() + response = requests.get(gridded_data_url, headers=headers, timeout=180) if response.status_code != 200: raise ValueError( - f"The url {url} returned error code {response.status_code}." + f"The {gridded_data_url} returned error code {response.status_code}." ) except requests.exceptions.Timeout as e: - raise ValueError(f"The url {url} has timed out.") from e + raise ValueError( + f"The gridded_data_url {gridded_data_url} has timed out." + ) from e file_obj = io.BytesIO(response.content) netcdf_dataset = xr.open_dataset(file_obj) @@ -781,6 +792,20 @@ def _get_ndarray_from_api(entry, options, time_values): return None +def _validate_user(): + email, pin = get_registered_api_pin() + url_security = f"{HYDRODATA_URL}/api/api_pins?pin={pin}&email={email}" + response = requests.get(url_security, timeout=15) + if not response.status_code == 200: + raise ValueError(f"No registered PIN for email '{email}' and PIN {pin}. See documentation to register with a URL.") + json_string = response.content.decode("utf-8") + jwt_json = json.loads(json_string) + jwt_token = jwt_json["jwt_token"] + headers = {} + headers["Authorization"] = f"Bearer {jwt_token}" + return headers + + def _adjust_dimensions(data: np.ndarray, entry: ModelTableRow) -> np.ndarray: """ Reshape the dimensions of the data array to match the conventions for the entry period and expected variable. diff --git a/tests/hf_hydrodata/test_gridded.py b/tests/hf_hydrodata/test_gridded.py index eb2ed24..1d808dd 100644 --- a/tests/hf_hydrodata/test_gridded.py +++ b/tests/hf_hydrodata/test_gridded.py @@ -21,11 +21,11 @@ class MockResponse: def __init__(self): self.headers = {} self.status_code = 200 - self.content = b"test content" + self.content = b'{"email": "dummy@email.com", "jwt_token":"foo"}' self.checksum = "" -def mock_requests_get(url, timeout): +def mock_requests_get(url, headers=None, timeout=5): """Create a mock streaming response.""" response = MockResponse()