Skip to content

Commit

Permalink
update name for NERSC ice age product and also update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tdcwilliams committed Apr 19, 2024
1 parent 4815f4d commit 821527c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
4 changes: 2 additions & 2 deletions geodataset/custom_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_lonlat_arrays(self, ij_range=(None,None,None,None), **kwargs):
i0, i1, j0, j1 = ij_range
x_grd, y_grd = np.meshgrid(self['x'][j0:j1], self['y'][i0:i1])
return self.projection(x_grd, y_grd, inverse=True)


class NERSCDeformation(NERSCProductBase):
pattern = re.compile(r'arctic_2km_deformation_\d{8}T\d{6}.nc')
Expand All @@ -87,7 +87,7 @@ class NERSCIceType(NERSCProductBase):


class NERSCSeaIceAge(NERSCProductBase):
pattern = re.compile(r'arctic25km_sea_ice_age_v2p0_\d{8}.nc')
pattern = re.compile(r'arctic25km_sea_ice_age_v2p1_\d{8}.nc')


class OsisafDriftersNextsim(CustomDatasetRead):
Expand Down
51 changes: 40 additions & 11 deletions geodataset/tests/test_custom_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,53 @@ def test_get_lonlat_arrays(self):
__init__=MagicMock(return_value=None),
filepath=DEFAULT,
)
def test_datetimes_1(self, **kwargs):
""" test for older filename """
dto = dt.datetime(2017,5,1)
def test_datetimes(self, **kwargs):
dto = dt.datetime(2023,5,1,12)
kwargs['filepath'].return_value = dto.strftime('a/b/mpd_%Y%m%d.nc')
obj = UniBremenAlbedoMPF()
self.assertEqual(obj.datetimes, [dto])


@patch.multiple(UniBremenAlbedoMPF,
class NERSCProductBaseTest(BaseForTests):

@property
def x(self):
return np.linspace(0.,1.,6)

@property
def y(self):
return np.linspace(1.,2.,8)

@patch.multiple(NERSCProductBase,
__init__=MagicMock(return_value=None),
filepath=DEFAULT,
__getitem__=DEFAULT,
projection=DEFAULT,
)
def test_datetimes_2(self, **kwargs):
""" test for newer filename """
dto = dt.datetime(2021,5,1)
kwargs['filepath'].return_value = dto.strftime('a/b/mpd_%Y%m%d_NR.nc')
obj = UniBremenAlbedoMPF()
self.assertEqual(obj.datetimes, [dto])
def test_get_lonlat_arrays(self, __getitem__, projection):
""" test for older filename """
def mock_getitem(key):
if key == "x":
return self.x
return self.y

obj = NERSCProductBase()
__getitem__.side_effect = mock_getitem
projection.return_value = ('lon', 'lat')

i0 = 2
i1 = 5
j0 = 1
j1 = 6
x0, y0 = np.meshgrid(self.x[j0:j1], self.y[i0:i1])

lon, lat = obj.get_lonlat_arrays(ij_range=(i0, i1, j0, j1))
self.assertEqual(lon, 'lon')
self.assertEqual(lat, 'lat')
self.assertEqual(__getitem__.mock_calls, [call('x'), call('y')])
x, y = projection.mock_calls[0][1]
self.assertTrue(np.allclose(x, x0))
self.assertTrue(np.allclose(y, y0))
self.assertEqual(projection.mock_calls[0][2], dict(inverse=True))


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion geodataset/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_open_netcdf(self):

def test_get_lonlat_arrays(self):
for nc_file in self.nc_files:
print(nc_file)
with self.subTest(nc_file=nc_file):
with open_netcdf(nc_file) as ds:
if not ds.is_lonlat_2d:
Expand Down

0 comments on commit 821527c

Please sign in to comment.