From c6f67538003185d6b2fe08fdfab5096aa75e9f8d Mon Sep 17 00:00:00 2001 From: Spencer Wong Date: Wed, 28 Aug 2024 14:33:53 +1000 Subject: [PATCH] Replace repeated asserts with helper functions --- test/test_um2netcdf.py | 57 +++++++++++++++++++++--------------------- umpost/um2netcdf.py | 8 +----- 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/test/test_um2netcdf.py b/test/test_um2netcdf.py index b4b1e5e..026c373 100644 --- a/test/test_um2netcdf.py +++ b/test/test_um2netcdf.py @@ -814,15 +814,26 @@ def coord(self, coordinate_name): return self.coordinate_dict[coordinate_name] -def assert_unmodified_coordinate(coord): +def assert_unmodified_coordinates(lat_coord, lon_coord): """ Helper function to check that a coordinate's attributes match those expected for a coordinate that has not yet been modified by fix_latlon_coords. """ - assert coord.points.dtype == np.dtype("float32") - assert not coord.has_bounds() - assert coord.var_name is None + for coord in [lat_coord, lon_coord]: + assert coord.points.dtype == np.dtype("float32") + assert not coord.has_bounds() + assert coord.var_name is None + + +def assert_dtype_float64(lat_coord, lon_coord): + assert lat_coord.points.dtype == np.dtype("float64") + assert lon_coord.points.dtype == np.dtype("float64") + + +def assert_has_bounds(lat_coord, lon_coord): + assert lat_coord.has_bounds() + assert lon_coord.has_bounds() # Tests of fix_latlon_coords. This function converts coordinate points @@ -848,8 +859,7 @@ def test_fix_latlon_coords_river(ua_plev_cube, cube_lon_coord = cube_with_river_coords.coord(um2nc.LONGITUDE) # Checks prior to modifications. - assert_unmodified_coordinate(cube_lat_coord) - assert_unmodified_coordinate(cube_lon_coord) + assert_unmodified_coordinates(cube_lat_coord, cube_lon_coord) um2nc.fix_latlon_coords(cube_with_river_coords, grid_type, D_LAT_N96, D_LON_N96) @@ -858,11 +868,8 @@ def test_fix_latlon_coords_river(ua_plev_cube, assert cube_lat_coord.var_name == um2nc.VAR_NAME_LAT_RIVER assert cube_lon_coord.var_name == um2nc.VAR_NAME_LON_RIVER - assert cube_lat_coord.points.dtype == np.dtype("float64") - assert cube_lon_coord.points.dtype == np.dtype("float64") - - assert cube_lat_coord.has_bounds() - assert cube_lon_coord.has_bounds() + assert_dtype_float64(cube_lat_coord, cube_lon_coord) + assert_has_bounds(cube_lat_coord, cube_lon_coord) def test_fix_latlon_coords_uv(ua_plev_cube, @@ -887,8 +894,7 @@ def test_fix_latlon_coords_uv(ua_plev_cube, ) # Checks prior to modifications. - assert_unmodified_coordinate(lat_coordinate) - assert_unmodified_coordinate(lon_coordinate) + assert_unmodified_coordinates(lat_coordinate, lon_coordinate) um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type, D_LAT_N96, D_LON_N96) @@ -896,11 +902,8 @@ def test_fix_latlon_coords_uv(ua_plev_cube, assert lat_coordinate.var_name == um2nc.VAR_NAME_LAT_V assert lon_coordinate.var_name == um2nc.VAR_NAME_LON_U - assert lat_coordinate.points.dtype == np.dtype("float64") - assert lon_coordinate.points.dtype == np.dtype("float64") - - assert lat_coordinate.has_bounds() - assert lat_coordinate.has_bounds() + assert_dtype_float64(lat_coordinate, lon_coordinate) + assert_has_bounds(lat_coordinate, lon_coordinate) def test_fix_latlon_coords_standard(ua_plev_cube, @@ -934,8 +937,7 @@ def test_fix_latlon_coords_standard(ua_plev_cube, ) # Checks prior to modifications. - assert_unmodified_coordinate(lat_coordinate) - assert_unmodified_coordinate(lon_coordinate) + assert_unmodified_coordinates(lat_coordinate, lon_coordinate) um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type, D_LAT_N96, D_LON_N96) @@ -943,11 +945,8 @@ def test_fix_latlon_coords_standard(ua_plev_cube, assert lat_coordinate.var_name == um2nc.VAR_NAME_LAT_STANDARD assert lon_coordinate.var_name == um2nc.VAR_NAME_LON_STANDARD - assert lat_coordinate.points.dtype == np.dtype("float64") - assert lon_coordinate.points.dtype == np.dtype("float64") - - assert lat_coordinate.has_bounds() - assert lat_coordinate.has_bounds() + assert_dtype_float64(lat_coordinate, lon_coordinate) + assert_has_bounds(lat_coordinate, lon_coordinate) def test_fix_latlon_coords_single_point(ua_plev_cube): @@ -977,8 +976,7 @@ def test_fix_latlon_coords_single_point(ua_plev_cube): um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type, D_LAT_N96, D_LON_N96) - assert lat_coord_single.has_bounds() - assert lon_coord_single.has_bounds() + assert_has_bounds(lat_coord_single, lon_coord_single) assert np.array_equal(lat_coord_single.bounds, expected_lat_bounds) assert np.array_equal(lon_coord_single.bounds, expected_lon_bounds) @@ -1006,8 +1004,7 @@ def test_fix_latlon_coords_has_bounds(ua_plev_cube): dummy_cube=ua_plev_cube, coords=[lat_coord, lon_coord] ) - assert lat_coord.has_bounds() - assert lon_coord.has_bounds() + assert_has_bounds(lat_coord, lon_coord) um2nc.fix_latlon_coords(cube_with_uv_coords, grid_type, D_LAT_N96, D_LON_N96) @@ -1036,6 +1033,8 @@ def _raise_CoordinateNotFoundError(coord_name): pytest.raises(um2nc.UnsupportedTimeSeriesError) ): um2nc.fix_latlon_coords(ua_plev_cube, grid_type, D_LAT_N96, D_LON_N96) + + def test_fix_cell_methods_drop_hours(): # ensure cell methods with "hour" in the interval name are translated to # empty intervals diff --git a/umpost/um2netcdf.py b/umpost/um2netcdf.py index d901c68..8d39d7b 100644 --- a/umpost/um2netcdf.py +++ b/umpost/um2netcdf.py @@ -524,13 +524,7 @@ def process(infile, outfile, args): # Interval in cell methods isn't reliable so better to remove it. c.cell_methods = fix_cell_methods(c.cell_methods) - try: - fix_latlon_coord(c, mv.grid_type, mv.d_lat, mv.d_lon) - except iris.exceptions.CoordinateNotFoundError: - print('\nMissing lat/lon coordinates for variable (possible timeseries?)\n') - print(c) - raise Exception("Variable can not be processed") - + fix_latlon_coords(c, mv.grid_type, mv.d_lat, mv.d_lon) fix_level_coord(c, mv.z_rho, mv.z_theta) if do_masking: