Skip to content

Commit

Permalink
getting rid of last issues in upsampling code
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Aug 28, 2024
1 parent 699135d commit ba7a80c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torch_harmonics/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

class UpsampleS2(nn.Module):
def __init__(
self,
nlat_in: int,
nlon_in: int,
nlat_out: int,
Expand All @@ -63,10 +64,12 @@ def __init__(
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)

# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out) - 1
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# to guarantee everything stays in bounds
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx-1, lat_idx)

# compute the interpolation weights along the latitude
lat_weights = torch.from_numpy( (self.lats_out - self.lats_in[j]) / np.diff(self.lats_in)[j] )
lat_weights = torch.from_numpy( (self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx] ).float()
lat_weights = lat_weights.unsqueeze(-1)

# convert to tensor
Expand All @@ -92,6 +95,7 @@ def _upscale_longitudes(self, x: torch.Tensor):
# for artifact-free upsampling in the longitudinal direction
x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
x = torch.roll(x, - self.lon_shift, dims=-1)
return x

def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
Expand Down

0 comments on commit ba7a80c

Please sign in to comment.