Skip to content

Commit

Permalink
Replace repeated asserts with helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
blimlim committed Aug 28, 2024
1 parent 21dc1a4 commit c6f6753
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 36 deletions.
57 changes: 28 additions & 29 deletions test/test_um2netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,15 +814,26 @@ def coord(self, coordinate_name):
return self.coordinate_dict[coordinate_name]


def assert_unmodified_coordinate(coord):
def assert_unmodified_coordinates(lat_coord, lon_coord):
"""
Helper function to check that a coordinate's attributes match
those expected for a coordinate that has not yet been modified
by fix_latlon_coords.
"""
assert coord.points.dtype == np.dtype("float32")
assert not coord.has_bounds()
assert coord.var_name is None
for coord in [lat_coord, lon_coord]:
assert coord.points.dtype == np.dtype("float32")
assert not coord.has_bounds()
assert coord.var_name is None


def assert_dtype_float64(lat_coord, lon_coord):
assert lat_coord.points.dtype == np.dtype("float64")
assert lon_coord.points.dtype == np.dtype("float64")


def assert_has_bounds(lat_coord, lon_coord):
assert lat_coord.has_bounds()
assert lon_coord.has_bounds()


# Tests of fix_latlon_coords. This function converts coordinate points
Expand All @@ -848,8 +859,7 @@ def test_fix_latlon_coords_river(ua_plev_cube,
cube_lon_coord = cube_with_river_coords.coord(um2nc.LONGITUDE)

# Checks prior to modifications.
assert_unmodified_coordinate(cube_lat_coord)
assert_unmodified_coordinate(cube_lon_coord)
assert_unmodified_coordinates(cube_lat_coord, cube_lon_coord)

um2nc.fix_latlon_coords(cube_with_river_coords, grid_type,
D_LAT_N96, D_LON_N96)
Expand All @@ -858,11 +868,8 @@ def test_fix_latlon_coords_river(ua_plev_cube,
assert cube_lat_coord.var_name == um2nc.VAR_NAME_LAT_RIVER
assert cube_lon_coord.var_name == um2nc.VAR_NAME_LON_RIVER

assert cube_lat_coord.points.dtype == np.dtype("float64")
assert cube_lon_coord.points.dtype == np.dtype("float64")

assert cube_lat_coord.has_bounds()
assert cube_lon_coord.has_bounds()
assert_dtype_float64(cube_lat_coord, cube_lon_coord)
assert_has_bounds(cube_lat_coord, cube_lon_coord)


def test_fix_latlon_coords_uv(ua_plev_cube,
Expand All @@ -887,20 +894,16 @@ def test_fix_latlon_coords_uv(ua_plev_cube,
)

# Checks prior to modifications.
assert_unmodified_coordinate(lat_coordinate)
assert_unmodified_coordinate(lon_coordinate)
assert_unmodified_coordinates(lat_coordinate, lon_coordinate)

um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type,
D_LAT_N96, D_LON_N96)

assert lat_coordinate.var_name == um2nc.VAR_NAME_LAT_V
assert lon_coordinate.var_name == um2nc.VAR_NAME_LON_U

assert lat_coordinate.points.dtype == np.dtype("float64")
assert lon_coordinate.points.dtype == np.dtype("float64")

assert lat_coordinate.has_bounds()
assert lat_coordinate.has_bounds()
assert_dtype_float64(lat_coordinate, lon_coordinate)
assert_has_bounds(lat_coordinate, lon_coordinate)


def test_fix_latlon_coords_standard(ua_plev_cube,
Expand Down Expand Up @@ -934,20 +937,16 @@ def test_fix_latlon_coords_standard(ua_plev_cube,
)

# Checks prior to modifications.
assert_unmodified_coordinate(lat_coordinate)
assert_unmodified_coordinate(lon_coordinate)
assert_unmodified_coordinates(lat_coordinate, lon_coordinate)

um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type,
D_LAT_N96, D_LON_N96)

assert lat_coordinate.var_name == um2nc.VAR_NAME_LAT_STANDARD
assert lon_coordinate.var_name == um2nc.VAR_NAME_LON_STANDARD

assert lat_coordinate.points.dtype == np.dtype("float64")
assert lon_coordinate.points.dtype == np.dtype("float64")

assert lat_coordinate.has_bounds()
assert lat_coordinate.has_bounds()
assert_dtype_float64(lat_coordinate, lon_coordinate)
assert_has_bounds(lat_coordinate, lon_coordinate)


def test_fix_latlon_coords_single_point(ua_plev_cube):
Expand Down Expand Up @@ -977,8 +976,7 @@ def test_fix_latlon_coords_single_point(ua_plev_cube):
um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type,
D_LAT_N96, D_LON_N96)

assert lat_coord_single.has_bounds()
assert lon_coord_single.has_bounds()
assert_has_bounds(lat_coord_single, lon_coord_single)
assert np.array_equal(lat_coord_single.bounds, expected_lat_bounds)
assert np.array_equal(lon_coord_single.bounds, expected_lon_bounds)

Expand Down Expand Up @@ -1006,8 +1004,7 @@ def test_fix_latlon_coords_has_bounds(ua_plev_cube):
dummy_cube=ua_plev_cube,
coords=[lat_coord, lon_coord]
)
assert lat_coord.has_bounds()
assert lon_coord.has_bounds()
assert_has_bounds(lat_coord, lon_coord)

um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type,
D_LAT_N96, D_LON_N96)
Expand Down Expand Up @@ -1036,6 +1033,8 @@ def _raise_CoordinateNotFoundError(coord_name):
pytest.raises(um2nc.UnsupportedTimeSeriesError)
):
um2nc.fix_latlon_coords(ua_plev_cube, grid_type, D_LAT_N96, D_LON_N96)


def test_fix_cell_methods_drop_hours():
# ensure cell methods with "hour" in the interval name are translated to
# empty intervals
Expand Down
8 changes: 1 addition & 7 deletions umpost/um2netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,13 +524,7 @@ def process(infile, outfile, args):
# Interval in cell methods isn't reliable so better to remove it.
c.cell_methods = fix_cell_methods(c.cell_methods)

try:
fix_latlon_coord(c, mv.grid_type, mv.d_lat, mv.d_lon)
except iris.exceptions.CoordinateNotFoundError:
print('\nMissing lat/lon coordinates for variable (possible timeseries?)\n')
print(c)
raise Exception("Variable can not be processed")

fix_latlon_coords(c, mv.grid_type, mv.d_lat, mv.d_lon)
fix_level_coord(c, mv.z_rho, mv.z_theta)

if do_masking:
Expand Down

0 comments on commit c6f6753

Please sign in to comment.