Skip to content

Commit

Permalink
skebs with correct streamfunction math
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Feb 10, 2025
1 parent 7f04c04 commit 0452d9a
Showing 1 changed file with 141 additions and 34 deletions.
175 changes: 141 additions & 34 deletions credit/postblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,19 +961,21 @@ def __init__(self,
sigma_max,
perturb_frac):
super().__init__()
self.backscatter_array = Parameter(torch.full((1,levels,1,1,1), 1.0))

std = xr.open_dataset(std_path)
std_wind = np.sqrt(std.U.values ** 2 + std.V.values ** 2)
self.register_buffer("max_perturb",
(torch.tensor(std_wind[::-1] * sigma_max * perturb_frac)
.view(1, levels, 1, 1, 1)),
persistent=False)
self.backscatter_array = Parameter(torch.full((1,levels,1,1,1), 0.01))

# std = xr.open_dataset(std_path)
# std_wind = np.sqrt(std.U.values ** 2 + std.V.values ** 2)[::-1]
# self.register_buffer("max_perturb",
# (torch.tensor(std_wind * sigma_max * perturb_frac)
# .view(1, levels, 1, 1, 1)),
# persistent=False)
# self.max_perturb = 1.0

def forward(self, x):
self.backscatter_array.data = self.backscatter_array.data.clamp(0., 10.)
return self.backscatter_array * self.max_perturb ** 2 # this will be inside sqrt
# return self.backscatter_array * self.max_perturb ** 2 # this will be inside sqrt
logger.info(torch.flatten(self.backscatter_array))
return self.backscatter_array # this will be inside sqrt


class SKEBS(nn.Module):
Expand Down Expand Up @@ -1075,9 +1077,32 @@ def initialize_sht(self):
# self.sht = harmonics.RealSHT(self.nlat, self.nlon, self.lmax, self.mmax, self.grid, csphase=False)
self.isht = harmonics.InverseRealSHT(self.nlat, self.nlon, self.lmax, self.mmax, self.grid, csphase=False)
# self.vsht = harmonics.RealVectorSHT(self.nlat, self.nlon, self.lmax, self.mmax, self.grid, csphase=False)
# self.ivsht = harmonics.InverseRealVectorSHT(self.nlat, self.nlon, self.lmax, self.mmax, self.grid, csphase=False)
self.ivsht = harmonics.InverseRealVectorSHT(
self.nlat, self.nlon, self.lmax, self.mmax, self.grid, csphase=False
)
self.lmax = self.isht.lmax
self.mmax = self.isht.mmax

# Compute quadrature weights and cosine of latitudes for the grid
# cost, quad_weights = harmonics.quadrature.legendre_gauss_weights(
# self.nlat, -1, 1
# )

## equiangular grid
cost, w = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
self.lats = -torch.as_tensor(np.arcsin(cost))
self.lons = torch.linspace(0, 2 * np.pi, self.nlon + 1, dtype=torch.float64)[
: self.nlon
]

l_arr = torch.arange(0, self.lmax).reshape(self.lmax, 1).double()
l_arr = l_arr.expand(self.lmax, self.mmax)
self.register_buffer("lap", -l_arr * (l_arr + 1) / RAD_EARTH**2,
persistent=False)
self.register_buffer("invlap", -(RAD_EARTH**2) / l_arr / (l_arr + 1),
persistent=False)
self.invlap[0] = 0.0 # Adjusting the first element to avoid division by zero

logging.info(f"lmax: {self.lmax}, mmax: {self.mmax}")

def initialize_skebs_parameters(self):
Expand Down Expand Up @@ -1203,28 +1228,27 @@ def forward(self, x):

self.spec_coef = self.cycle_pattern(self.spec_coef) # cycle from prev step
# b, 1, 1, lmax, mmax
pattern_on_grid = self.isht(self.spec_coef) # b, 1, 1, lat, lon
# pattern_on_grid = self.isht(self.spec_coef) # b, 1, 1, lat, lon
# pattern_on_grid = (-1. / RAD_EARTH
# * torch.gradient(pattern_on_grid,
# spacing= PI / self.nlat,
# axis=-2)[0]
# )

spec_coef = self.spec_coef.squeeze()
u_chi, v_chi = self.getgrad(spec_coef)
u_chi, v_chi = u_chi.unsqueeze(1).unsqueeze(1), v_chi.unsqueeze(1).unsqueeze(1)
# logger.info(f"pattern max/min: {pattern_on_grid.max():.2f}, {pattern_on_grid.min():.2f}")
# debug pattern:
# logger.info("saving patterns")
# debug_save = "/glade/work/dkimpara/CREDIT_runs/test_skebs_density/debug_pattern"
# save_pattern = pattern_on_grid[0, 0, 0]
# torch.save(save_pattern, os.path.join(debug_save, f"iter_{self.iteration}"))
# save_coef = self.spec_coef[0, 0, 0]
# torch.save(save_coef, os.path.join(debug_save, f"coef_{self.iteration}"))
# torch.save(self.g_n, os.path.join(debug_save, f"g_n_{self.iteration}"))
# torch.save(self.b, os.path.join(debug_save, f"b_{self.iteration}"))


backscatter_pred = self.backscatter_network(x)

# with fixed col, we adjust the perturbations to make sense using
# min/max scaling
# forcing / forcing_max * perturb_frac * sigma_max * std
# perturb_max is the maximum perturbation fraction wrt to sigma_max*std
# where we have pre-calculated forcing_max
total_forcing = (torch.sqrt(self.r * backscatter_pred / self.dE) #taking out of sqrt so i can fix magnitude issue
* pattern_on_grid * self.spectral_adjustment )
# total_forcing = (torch.sqrt(self.r * backscatter_pred / self.dE) #taking out of sqrt so i can fix magnitude issue
# * pattern_on_grid * self.spectral_adjustment )
dissipation_term = torch.sqrt(self.r * backscatter_pred / self.dE)
# shape (b, levels, t, lat, lon)

# sp = torch.ones_like(x[:, self.sp_index : self.sp_index + 1], device = x.device) * 1013.
Expand All @@ -1236,28 +1260,29 @@ def forward(self, x):
# (b, levels, 1, lat, lon)

## compute component magnitudes of wind
u_squared, v_squared = x[:, self.U_inds] ** 2, x[:, self.V_inds] ** 2
wind_squared = u_squared + v_squared
u_frac = u_squared / wind_squared # (b, levels, 1, lat, lon)
v_frac = v_squared / wind_squared
# u_squared, v_squared = x[:, self.U_inds] ** 2, x[:, self.V_inds] ** 2
# wind_squared = u_squared + v_squared
# u_frac = u_squared / wind_squared # (b, levels, 1, lat, lon)
# v_frac = v_squared / wind_squared

# big forcing at top of atmosphere..
# skebs gives us an instantaneous forcing term, need to multiply by timestep (euler step)
# du/dt = 1 / rho * forcing
# euler step: u_1 = u_0 + dt * 1/rho * forcing
add_wind_magnitude = total_forcing * self.timestep
# add_wind_magnitude = total_forcing * self.timestep

# still debugging this part
# add_wind_magnitude = (1. / density) * total_forcing * self.timestep
add_wind_magnitude = torch.sqrt(dissipation_term ** 2 * (u_chi ** 2 + v_chi ** 2)) * self.timestep

## debug skebs, write out physical values
if self.write_debug_files:
torch.save(add_wind_magnitude, join(self.debug_save_loc, f"perturb_{self.iteration}"))
torch.save(pattern_on_grid, join(self.debug_save_loc, f"pattern_{self.iteration}"))
# torch.save(pattern_on_grid, join(self.debug_save_loc, f"pattern_{self.iteration}"))
# torch.save(x, join(self.debug_save_loc, f"x_{self.iteration}"))

x_u_wind = x[:, self.U_inds] + add_wind_magnitude * u_frac
x_v_wind = x[:, self.V_inds] + add_wind_magnitude * v_frac
x_u_wind = x[:, self.U_inds] + dissipation_term * u_chi * self.timestep
x_v_wind = x[:, self.V_inds] + dissipation_term * v_chi * self.timestep

x = concat_for_inplace_ops(x, x_u_wind, min(self.U_inds), max(self.U_inds))
x = concat_for_inplace_ops(x, x_v_wind, min(self.V_inds), max(self.V_inds))
Expand All @@ -1274,12 +1299,94 @@ def forward(self, x):
if torch.is_grad_enabled(): # means we are in a training script (not always true, but good enough for now for rolling out)
if self.training and self.steps >= self.forecast_len:
self.spec_coef_is_initialized = False
logger.info(f"skebs is reset after train step {self.steps} total iter {self.iteration}")
logger.info(f"pattern is reset after train step {self.steps} total iter {self.iteration}")
elif not self.training and self.steps >= self.valid_forecast_len:
self.spec_coef_is_initialized = False
logger.info(f"skebs is reset after valid step {self.steps} total iter {self.iteration}")
logger.info(f"pattern is reset after valid step {self.steps} total iter {self.iteration}")
return x
def spec2grid(self, uspec):
"""
spatial data from spectral coefficients
"""
return self.isht(uspec)
def getuv(self, vrtdivspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
"""
return self.ivsht(self.invlap * vrtdivspec / RAD_EARTH)

def getgrad(self, chispec):
"""
compute vector gradient on grid given complex spectral coefficients.
Args:
chispec: rank 1 or 2 or 3 tensor complex array with shape
`(ntrunc+1)*(ntrunc+2)/2 or ((ntrunc+1)*(ntrunc+2)/2,nt)` containing
complex spherical harmonic coefficients (where ntrunc is the
triangular truncation limit and nt is the number of spectral arrays
to be transformed). If chispec is rank 1, nt is assumed to be 1.
Returns:
C{B{uchi, vchi}} - rank 2 or 3 numpy float32 arrays containing
gridded zonal and meridional components of the vector gradient.
Shapes are either (nlat,nlon) or (nlat,nlon,nt).
"""
idim = chispec.ndim

if (
len(chispec.shape) != 1
and len(chispec.shape) != 2
and len(chispec.shape) != 3
):
msg = "getgrad needs rank one or two arrays!"
raise ValueError(msg)

ntrunc = int(
-1.5
+ 0.5
* torch.sqrt(
9.0 - 8.0 * (1.0 - torch.tensor(self.spec2grid(chispec).shape[0]))
)
)

if len(chispec.shape) == 1:
chispec = torch.reshape(chispec, ((ntrunc + 1) * (ntrunc + 2) // 2, 1))

divspec2 = self.lap * chispec

if idim == 1:
uchi, vchi = self.getuv(
torch.stack(
(
torch.zeros([divspec2.shape[0], divspec2.shape[1]]),
divspec2,
)
).to(divspec2.device)
)
return torch.squeeze(uchi), torch.squeeze(vchi)
elif idim == 2:
uchi, vchi = self.getuv(
torch.stack(
(
torch.zeros([divspec2.shape[0], divspec2.shape[1]]),
divspec2,
)
).to(divspec2.device)
)
return uchi, vchi
elif idim == 3:
new_shape = (divspec2.shape[0], 2, *divspec2.shape[1:])
stacked_divspec = torch.zeros(
new_shape, dtype=torch.complex64
).to(divspec2.device)
# Copy the original data into the second slice of the new dimension
stacked_divspec[:, 1, :, :] = divspec2
backy = self.getuv(stacked_divspec)
uchi = backy[:, 0, :, :]
vchi = backy[:, 1, :, :]
return uchi, vchi
else:
print("nothing happening here")

# import yaml
# import os
Expand Down

0 comments on commit 0452d9a

Please sign in to comment.