Skip to content

Commit

Permalink
Added logic to allow for custom band naming (#426)
Browse files Browse the repository at this point in the history
* Added logic for custom band naming, code cleanup.

* Addressed comments.
  • Loading branch information
j9sh264 authored Jan 16, 2024
1 parent a0b954e commit 2cfc0d7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
11 changes: 6 additions & 5 deletions weather_mv/loader_pipeline/ee.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ def convert_to_asset(self, queue: Queue, uri: str):
with open_dataset(uri,
self.open_dataset_kwargs,
self.disable_grib_schema_normalization,
band_names_dict=self.band_names_dict,
initialization_time_regex=self.initialization_time_regex,
forecast_time_regex=self.forecast_time_regex,
group_common_hypercubes=self.group_common_hypercubes) as ds_list:
Expand All @@ -459,15 +458,17 @@ def convert_to_asset(self, queue: Queue, uri: str):
attrs = ds.attrs
data = list(ds.values())
asset_name = get_ee_safe_name(uri)
channel_names = [da.name for da in data]
start_time, end_time, is_normalized = (attrs.get(key) for key in
('start_time', 'end_time', 'is_normalized'))
channel_names = [
self.band_names_dict.get(da.name, da.name) if self.band_names_dict
else da.name for da in data
]

dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform'])
attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool.
# Adding job_start_time to properites.
attrs["job_start_time"] = job_start_time
# Make attrs EE ingestable.
attrs = make_attrs_ee_compatible(attrs)
start_time, end_time = (attrs.get(key) for key in ('start_time', 'end_time'))

if self.group_common_hypercubes:
level, height = (attrs.pop(key) for key in ['level', 'height'])
Expand Down
14 changes: 8 additions & 6 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def rearrange_time_list(order_list: t.List, time_list: t.List) -> t.List:
return datetime.datetime(*time_list)


def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_start_time: str,
tif_metadata_for_end_time: str, uri: str, band_names_dict: t.Dict,
initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset:
def _preprocess_tif(
ds: xr.Dataset,
tif_metadata_for_start_time: str,
tif_metadata_for_end_time: str,
uri: str,
initialization_time_regex: str,
forecast_time_regex: str
) -> xr.Dataset:
"""Transforms (y, x) coordinates into (lat, long) and adds bands data in data variables.
This also retrieves datetime from tif's metadata and stores it into dataset.
Expand Down Expand Up @@ -432,7 +437,6 @@ def open_dataset(uri: str,
disable_grib_schema_normalization: bool = False,
tif_metadata_for_start_time: t.Optional[str] = None,
tif_metadata_for_end_time: t.Optional[str] = None,
band_names_dict: t.Optional[t.Dict] = None,
initialization_time_regex: t.Optional[str] = None,
forecast_time_regex: t.Optional[str] = None,
group_common_hypercubes: t.Optional[bool] = False,
Expand Down Expand Up @@ -482,11 +486,9 @@ def open_dataset(uri: str,
xr_dataset = xr_datasets.sel(time=slice(start_date, end_date))
if uri_extension in ['.tif', '.tiff']:
xr_dataset = _preprocess_tif(xr_dataset,
local_path,
tif_metadata_for_start_time,
tif_metadata_for_end_time,
uri,
band_names_dict,
initialization_time_regex,
forecast_time_regex)

Expand Down

0 comments on commit 2cfc0d7

Please sign in to comment.