Skip to content

Commit

Permalink
Add an updated offline domain check for satellite radiances (#210)
Browse files Browse the repository at this point in the history
This PR adds a new version of the offline domain check for satellite
radiances. The modifications were originally done by @delippi and
@xyzemc to account for satellite radiances being 2D variables in the
IODA file (channel, location) instead of 1D.

Modifications are also made here to speed up the section of the code
that slices the original 2D array to retrieve the locations contained
within the domain. The original version took over an hour to process
large dump files. Now we use the `itemgetter` function from the
`operator` library (e.g., `g.variables[var][:,idy] =
itemgetter(*inside_indices)(invar[:,idy])`) which is much faster. The
script can now process global satellite datasets in 15-20 seconds.

Example usage: 

`python offline_domain_check_satrad.py -g invariant.nc -o
gdas.t00z.atms_all_npp.tm00.nc -s 0.1 -f`
  • Loading branch information
SamuelDegelia-NOAA authored Oct 29, 2024
1 parent 380ad1a commit a6b2aff
Showing 1 changed file with 322 additions and 0 deletions.
322 changes: 322 additions & 0 deletions rrfs-test/IODA/offline_domain_check_satrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
#!/usr/bin/env python
import netCDF4 as nc
import numpy as np
from matplotlib.path import Path
from scipy.spatial import ConvexHull
from timeit import default_timer as timer
import argparse
import warnings
import matplotlib
import os
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.ticker as mticker
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from operator import itemgetter

"""
This program determines if observations are in/outside of a convex hull
computed via a lat/lon grid file (see note below about the grid file).
A convex hull is the smallest convex shape (or polygon) that can enclose a set
of points in a plane (or in higher dimensions). Imagine stretching a rubber band
around the outermost points in a set; the shape that the rubber band forms is
the convex hull. So, if there are any concave points between vertices,
then there would be whitespace between the red and blue box. I've shrunk the
convex hull such that there wouldn't be such whitespace which, of course,
in tern means that it is going to be not an exact match of the domain grid
(e.g., near corners). This can be tuned via the "hull_shrink_factor".
"""

# Disable warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
matplotlib.use('agg')
import matplotlib.pyplot as plt

# Functions for calculating run times.
def tic():
return timer()

def toc(tic=tic, label=""):
toc = timer()
elapsed = toc-tic
hrs = int(elapsed // 3600)
mins = int((elapsed % 3600) // 60)
secs = int(elapsed % 3600 % 60)
print(f"{label}({elapsed:.2f}s), {hrs:02}:{mins:02}:{secs:02}")

tic1 = tic()

# Parse command-line arguments
# Note:
# The grid file is what contains variables grid_lat/grid_lon
# OR latCell/lonCell for FV3 and MPAS respectively.
# Examples can be found in the following rrfs-test cases:
# - rrfs-data_fv3jedi_2022052619/Data/bkg/fv3_grid_spec.nc
# - mpas_2024052700/data/restart.2024-05-27_00.00.00.nc
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--grid', type=str, help='grid file', required=True)
parser.add_argument('-o', '--obs', type=str, help='ioda observation file', required=True)
parser.add_argument('-f', '--fig', action='store_true', help='disable figure (default is False)', required=False)
parser.add_argument('-s', '--shrink', type=float, help='hull shrink factor', required=True)
args = parser.parse_args()

# Assign filenames
obs_filename = args.obs
grid_filename = args.grid # see note above.
make_fig = args.fig
hull_shrink_factor = args.shrink

print(f"Obs file: {obs_filename}")
print(f"Grid file: {grid_filename}")
print(f"Figure flag: {args.fig}")
print(f"Hull shrink factor: {hull_shrink_factor}")

# Plotting options
plot_box_width = 100. # define size of plot domain (units: lat/lon degrees)
plot_box_height = 50
cen_lat = 34.5
cen_lon = -97.5
#hull_shrink_factor = 0.10 #10% was found to work fairly well.

grid_ds = nc.Dataset(grid_filename, 'r')
obs_ds = nc.Dataset(obs_filename, 'r')

# Extract the grid latitude and longitude
if 'grid_lat' in grid_ds.variables and 'grid_lon' in grid_ds.variables: # FV3 grid
grid_lat = grid_ds.variables['grid_lat'][:, :]
grid_lon = grid_ds.variables['grid_lon'][:, :]
dycore = "FV3"
elif 'latCell' in grid_ds.variables and 'lonCell' in grid_ds.variables: # MPAS grid
grid_lat = np.degrees(grid_ds.variables['latCell'][:]) # Convert radians to degrees
grid_lon = np.degrees(grid_ds.variables['lonCell'][:]) # Convert radians to degrees
dycore = "MPAS"
else:
raise ValueError("Unrecognized grid format: 'grid_lat'/'grid_lon' or 'latCell'/'lonCell' not found.")

print(f"Max/Min grid Lat: {np.max(grid_lat)}, {np.min(grid_lat)}")
print(f"Max/Min grid Lon: {np.max(grid_lon)-360}, {np.min(grid_lon)-360}\n")

# Flatten the grid lat/lon arrays and pair them as coordinates
grid_polygon = np.vstack((grid_lon.flatten(), grid_lat.flatten())).T
grid_polygon = np.ma.filled(grid_polygon, np.nan)
grid_polygon = grid_polygon[~np.isnan(grid_polygon).any(axis=1)]

# Create convex hull from grid points
hull = ConvexHull(grid_polygon)
hull_points = grid_polygon[hull.vertices]

# Compute the centroid of the convex hull
centroid = np.mean(hull_points, axis=0)

# Function to shrink the boundary points
def shrink_boundary(points, centroid, factor=0.04):
new_points = []
for point in points:
direction = point - centroid
distance_to_centroid = np.linalg.norm(direction)
direction_normalized = direction / distance_to_centroid
new_point = point - factor * direction_normalized * distance_to_centroid
new_points.append(new_point)
return np.array(new_points)

# Shrink the hull boundary
shrunken_points = shrink_boundary(hull_points, centroid, factor=hull_shrink_factor)

# Ensure the boundary is closed
if not np.array_equal(shrunken_points[0], shrunken_points[-1]):
shrunken_points = np.vstack([shrunken_points, shrunken_points[0]])

# Create a Path object for the polygon domain
domain_path = Path(shrunken_points)

# Extract observation latitudes and longitudes
obs_lat = obs_ds.groups['MetaData'].variables['latitude'][:]
obs_lon = obs_ds.groups['MetaData'].variables['longitude'][:]
obs_lon = np.where(obs_lon < 0, obs_lon + 360, obs_lon)


#print(f"Max/Min obs Lat: {np.max(obs_lat)}, {np.min(obs_lat)}")
#print(f"Max/Min obs Lon: {np.max(obs_lon)}, {np.min(obs_lon)}\n")

# Pair the observation lat/lon as coordinates
obs_coords = np.vstack((obs_lon, obs_lat)).T

# Check if each observation is within the domain
inside_domain = domain_path.contains_points(obs_coords)

# Get indices of observations within the domain
inside_indices = np.where(inside_domain)[0]
toc(tic1,label="Time to find obs within domain: ")

tic2 = tic()
# Create a new NetCDF file to store the selected data using the more efficient method
try:
outfile = obs_filename.replace('.nc', '_dc.nc')
except:
outfile = obs_filename.replace('.nc4', '_dc.nc4')
fout = nc.Dataset(outfile, 'w')

# Create dimensions and variables in the new file
location_size = len(inside_indices)
channel_size = obs_ds.dimensions['Channel'].size if 'Channel' in obs_ds.dimensions else 0 # Use the second dimension's size if exists

# Channel variable
if '_FillValue' in obs_ds.variables['Channel'].ncattrs():
fill_value = obs_ds.variables['Channel'].getncattr('_FillValue')
else:
fill_value = 2147483647
if 'Channel' not in fout.dimensions and channel_size > 0:
fout.createDimension('Channel', channel_size)
fout.createVariable('Channel', 'int32', 'Channel', fill_value=fill_value)
fout.variables['Channel'][:] = np.arange(channel_size) + 1 #since python indicies start at 0
for attr in obs_ds.variables['Channel'].ncattrs(): # Attributes for Location variable
if attr != '_FillValue':
fout.variables['Channel'].setncattr(attr, obs_ds.variables['Channel'].getncattr(attr))

# Location variable
if '_FillValue' in obs_ds.variables['Channel'].ncattrs():
fill_value = obs_ds.variables['Channel'].getncattr('_FillValue')
else:
fill_value = 2147483647
if 'Location' not in fout.dimensions:
fout.createDimension('Location', location_size)
fout.createVariable('Location', 'int32', 'Location', fill_value=fill_value)
fout.variables['Location'][:] = 0
for attr in obs_ds.variables['Location'].ncattrs(): # Attributes for Location variable
if attr != '_FillValue':
fout.variables['Location'].setncattr(attr, obs_ds.variables['Location'].getncattr(attr))

# Copy all non-grouped attributes into the new file
for attr in obs_ds.ncattrs(): # Attributes for the main file
fout.setncattr(attr, obs_ds.getncattr(attr))

# Copy all groups and variables into the new file, keeping only the variables in range
groups = obs_ds.groups
for group in groups:
g = fout.createGroup(group)
for var in obs_ds.groups[group].variables:
invar = obs_ds.groups[group].variables[var]
vartype = invar.dtype
fill = invar.getncattr('_FillValue')
dimensions = invar.dimensions

# Create a new variable with the correct dimensions
if len(dimensions) == 1: # One-dimensional variable
try:
g.createVariable(var, vartype, dimensions, fill_value=fill)
except:
g.createVariable(var, 'str', dimensions, fill_value=fill)
g.variables[var][:] = invar[:][inside_indices]
# Copy attributes for this variable
for attr in invar.ncattrs():
if '_FillValue' in attr: continue
g.variables[var].setncattr(attr, invar.getncattr(attr))

elif len(dimensions) == 2: # Two-dimensional variable
try:
g.createVariable(var, vartype, dimensions, fill_value=fill)
except:
g.createVariable(var, 'str', dimensions, fill_value=fill)
for idy in range(0, len(invar[0,:])): # new method for slicing very large 2d arrays
g.variables[var][:,idy] = itemgetter(*inside_indices)(invar[:,idy])

# Copy attributes for this variable
for attr in invar.ncattrs():
if '_FillValue' in attr: continue
g.variables[var].setncattr(attr, invar.getncattr(attr))

else:
raise NotImplementedError("Handling for more than two dimensions not implemented.")

# Close the datasets
obs_ds.close()
fout.close()
grid_ds.close()
toc(tic2,label="Time to create new obs file: ")

tic3 = tic()

if not make_fig:
exit()

print("Generating figure...")

# Now create plot
# Set cartopy shapefile path
platform = os.getenv('HOSTNAME').upper()
if 'ORION' in platform:
cartopy.config['data_dir']='/work/noaa/fv3-cam/sdegelia/cartopy'
elif 'H' in platform: # Will need to improve this once Hercules is supported
cartopy.config['data_dir']='/home/Donald.E.Lippi/cartopy'

fig = plt.figure(figsize=(7,4))
m1 = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=0))
#m1 = fig.add_subplot(1, 1, 1, projection=ccrs.LambertConformal())
adjusted_lon = np.where(grid_lon > 180, grid_lon - 360, grid_lon)
adjusted_shrunken_points = np.copy(shrunken_points)
adjusted_shrunken_points[:, 0] = np.where(shrunken_points[:, 0] > 180, shrunken_points[:, 0] - 360, shrunken_points[:, 0])

# Determine extent for plot domain
half = plot_box_width / 2.
left = cen_lon - half
right = cen_lon + half
half = plot_box_height / 2.
bot = cen_lat - half
top = cen_lat + half

# Set extent for both plots
m1.set_extent([left, right, top, bot])

# Add features to the subplots
m1.add_feature(cfeature.COASTLINE)
m1.add_feature(cfeature.BORDERS)
m1.add_feature(cfeature.STATES)

# Gridlines for the subplots
gl1 = m1.gridlines(crs = ccrs.PlateCarree(), draw_labels = True, linewidth = 0.5, color = 'k', alpha = 0.25, linestyle = '-')
gl1.xlocator = mticker.FixedLocator([])
gl1.xlocator = mticker.FixedLocator(np.arange(-180., 181., 10.))
gl1.ylocator = mticker.FixedLocator(np.arange(-80., 91., 10.))
gl1.xformatter = LONGITUDE_FORMATTER
gl1.yformatter = LATITUDE_FORMATTER
gl1.xlabel_style = {'size': 5, 'color': 'gray'}
gl1.ylabel_style = {'size': 5, 'color': 'gray'}

# Plot the domain and the observations
#m1.fill(adjusted_lon.flatten(), grid_lat.flatten(), color='b', label='Domain Boundary', zorder=1, transform=ccrs.PlateCarree())
m1.scatter(adjusted_lon.flatten(), grid_lat.flatten(), c='b', s=1, label='Domain Boundary', zorder=2)
m1.plot(adjusted_shrunken_points[:, 0], shrunken_points[:, 1], 'r-', label='Convex Hull', zorder=10, transform=ccrs.PlateCarree())

# Plot included observations
included_lat = obs_lat[inside_indices]
included_lon = obs_lon[inside_indices]
included_count = len(included_lat)
plt.scatter(included_lon, included_lat, c='g', s=2, label=f'Included Observations ({included_count})', zorder=3, transform=ccrs.PlateCarree())

# Plot excluded observations
excluded_indices = np.setdiff1d(np.arange(len(obs_lat)), inside_indices)
excluded_lat = obs_lat[excluded_indices]
excluded_lon = obs_lon[excluded_indices]

excluded_count = len(excluded_lat)
total_count = len(obs_lat)

print(f"Ob counts:")
print(f" Excluded: {excluded_count}")
print(f" Included: {included_count}")
print(f" Total: {total_count}")
plt.scatter(excluded_lon, excluded_lat, c='r', s=2, label=f'Excluded Observations ({excluded_count})', zorder=4, transform=ccrs.PlateCarree())

plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.legend(loc='upper right')
#plt.legend(loc='upper left')
plt.title(f'{dycore} Domain and Observations ({hull_shrink_factor*100}%)')
plt.tight_layout()
plt.savefig(f'./domain_check_{dycore}.png')

toc(tic3,label="Time to create figure: ")
toc(tic1,label="Total elapsed time: ")

0 comments on commit a6b2aff

Please sign in to comment.