Skip to content

Commit

Permalink
Site torch dataset update (#82)
Browse files Browse the repository at this point in the history
* Update what Site Dataset returns

* Improve site dataset unit test

* Remove unused logic

* Update Site data loader
  • Loading branch information
Sukh-P authored Dec 18, 2024
1 parent 19404a3 commit 88b310a
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 62 deletions.
24 changes: 12 additions & 12 deletions ocf_data_sampler/load/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@
def open_site(sites_config: Site) -> xr.DataArray:

# Load site generation xr.Dataset
data_ds = xr.open_dataset(sites_config.file_path)
site_generation_ds = xr.open_dataset(sites_config.file_path)

# Load site generation data
metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")

# Add coordinates
ds = data_ds.assign_coords(
latitude=(metadata_df.latitude.to_xarray()),
longitude=(metadata_df.longitude.to_xarray()),
capacity_kwp=data_ds.capacity_kwp,
# Ensure metadata aligns with the site_id dimension in data_ds
metadata_df = metadata_df.reindex(site_generation_ds.site_id.values)

# Assign coordinates to the Dataset using the aligned metadata
site_generation_ds = site_generation_ds.assign_coords(
latitude=("site_id", metadata_df["latitude"].values),
longitude=("site_id", metadata_df["longitude"].values),
capacity_kwp=("site_id", metadata_df["capacity_kwp"].values),
)

# Sanity checks
assert np.isfinite(data_ds.capacity_kwp.values).all()
assert (data_ds.capacity_kwp.values > 0).all()
assert np.isfinite(site_generation_ds.capacity_kwp.values).all()
assert (site_generation_ds.capacity_kwp.values > 0).all()
assert metadata_df.index.is_unique

return ds.generation_kw


return site_generation_ds.generation_kw
3 changes: 2 additions & 1 deletion ocf_data_sampler/torch_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@

from .pvnet_uk_regional import PVNetUKRegionalDataset
from .site import SitesDataset
118 changes: 95 additions & 23 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from typing import Tuple

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
Expand All @@ -9,7 +10,6 @@
convert_satellite_to_numpy_batch,
convert_gsp_to_numpy_batch,
make_sun_position_numpy_batch,
convert_site_to_numpy_batch,
)
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
Expand Down Expand Up @@ -73,18 +73,6 @@ def process_and_combine_datasets(
}
)


if "site" in dataset_dict:
site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp

numpy_modalities.append(
convert_site_to_numpy_batch(
da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
)
)

if target_key == 'gsp':
# Make sun coords NumpyBatch
datetimes = pd.date_range(
Expand All @@ -95,16 +83,6 @@ def process_and_combine_datasets(

lon, lat = osgb_to_lon_lat(location.x, location.y)

elif target_key == 'site':
# Make sun coords NumpyBatch
datetimes = pd.date_range(
t0+minutes(site_config.interval_start_minutes),
t0+minutes(site_config.interval_end_minutes),
freq=minutes(site_config.time_resolution_minutes),
)

lon, lat = location.x, location.y

numpy_modalities.append(
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
)
Expand All @@ -115,6 +93,47 @@ def process_and_combine_datasets(

return combined_sample

def process_and_combine_site_sample_dict(
dataset_dict: dict,
config: Configuration,
) -> xr.Dataset:
"""
Normalize and combine data into a single xr Dataset
Args:
dataset_dict: dict containing sliced xr DataArrays
config: Configuration for the model
Returns:
xr.Dataset: A merged Dataset with nans filled in.
"""

data_arrays = []

if "nwp" in dataset_dict:
for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
data_arrays.append((f"nwp-{provider}", da_nwp))

if "sat" in dataset_dict:
# TODO add some satellite normalisation
da_sat = dataset_dict["sat"]
data_arrays.append(("satellite", da_sat))

if "site" in dataset_dict:
# site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp
data_arrays.append(("sites", da_sites))

combined_sample_dataset = merge_arrays(data_arrays)

# Fill any nan values
return combined_sample_dataset.fillna(0.0)


def merge_dicts(list_of_dicts: list[dict]) -> dict:
"""Merge a list of dictionaries into a single dictionary"""
Expand All @@ -124,6 +143,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
combined_dict.update(d)
return combined_dict

def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.
Args:
list_of_arrays: List of tuples where each tuple contains:
- A string (key name).
- An xarray.DataArray.
Returns:
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
"""
datasets = []

for key, data_array in normalised_data_arrays:
# Ensure all attributes are strings for consistency
data_array = data_array.assign_attrs(
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
)

# Convert DataArray to Dataset with the variable name as the key
dataset = data_array.to_dataset(name=key)

# Prepend key name to all dimension and coordinate names for uniqueness
dataset = dataset.rename(
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
)
dataset = dataset.rename(
{coord: f"{key}__{coord}" for coord in dataset.coords}
)

# Handle concatenation dimension if applicable
concat_dim = (
f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
else f"{key}__time_utc"
)

if f"{key}__init_time_utc" in dataset.coords:
init_coord = f"{key}__init_time_utc"
if dataset[init_coord].ndim == 0: # Check if scalar
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})

datasets.append(dataset)

# Ensure all datasets are valid xarray.Dataset objects
for ds in datasets:
assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"

# Merge all prepared datasets
combined_dataset = xr.merge(datasets)

return combined_dataset

def fill_nans_in_arrays(batch: dict) -> dict:
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
Expand Down
7 changes: 3 additions & 4 deletions ocf_data_sampler/torch_datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
slice_datasets_by_time, slice_datasets_by_space
)
from ocf_data_sampler.utils import minutes
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods

xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -152,10 +152,9 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
"""
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
sample_dict = compute(sample_dict)

sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')

sample = process_and_combine_site_sample_dict(sample_dict, self.config)
sample = sample.compute()
return sample

def get_location_from_site_id(self, site_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
interval_start = pd.Timedelta(-6, "h")
interval_end = pd.Timedelta(3, "h")
freq = pd.Timedelta("1H")
freq = pd.Timedelta("1h")
dropout_timedelta = pd.Timedelta("-2h")

t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
Expand Down
2 changes: 1 addition & 1 deletion tests/torch_datasets/test_pvnet_uk_regional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import tempfile

from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey

Expand Down
36 changes: 16 additions & 20 deletions tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pandas as pd
import pytest

from ocf_data_sampler.torch_datasets.site import SitesDataset
from ocf_data_sampler.torch_datasets import SitesDataset
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.site import SiteBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from xarray import Dataset


@pytest.fixture()
Expand Down Expand Up @@ -34,31 +35,26 @@ def test_site(site_config_filename):
# Generate a sample
sample = dataset[0]

assert isinstance(sample, dict)
assert isinstance(sample, Dataset)

for key in [
NWPBatchKey.nwp,
SatelliteBatchKey.satellite_actual,
SiteBatchKey.generation,
SiteBatchKey.site_solar_azimuth,
SiteBatchKey.site_solar_elevation,
]:
assert key in sample
# Expected dimensions and data variables
expected_dims = {'satellite__x_geostationary', 'sites__time_utc', 'nwp-ukv__target_time_utc',
'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'}
expected_data_vars = {"nwp-ukv", "satellite", "sites"}

for nwp_source in ["ukv"]:
assert nwp_source in sample[NWPBatchKey.nwp]
# Check dimensions
assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
# Check data variables
assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
assert sample["satellite"].values.shape == (7, 1, 2, 2)
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
# 3 hours of 30 minute data (inclusive)
assert sample[SiteBatchKey.generation].shape == (4,)
# Solar angles have same shape as GSP data
assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)

assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2)
# 1.5 hours of 30 minute data (inclusive)
assert sample["sites"].values.shape == (4,)

def test_site_time_filter_start(site_config_filename):

Expand Down

0 comments on commit 88b310a

Please sign in to comment.