Skip to content

Commit

Permalink
final cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Aug 28, 2024
1 parent ba7a80c commit 603318c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
2 changes: 1 addition & 1 deletion torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resampling import UpsampleS2
from .resampling import ResampleS2
from . import quadrature
from . import random_fields
from . import examples
48 changes: 35 additions & 13 deletions torch_harmonics/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@
#

from typing import List, Tuple, Union, Optional
import math
import numpy as np

import torch
import torch.nn as nn

from torch_harmonics.quadrature import _precompute_latitudes

class UpsampleS2(nn.Module):

class ResampleS2(nn.Module):
def __init__(
self,
nlat_in: int,
Expand All @@ -61,15 +63,17 @@ def __init__(

# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False)

# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# to guarantee everything stays in bounds
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx-1, lat_idx)
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)

# compute the interpolation weights along the latitude
lat_weights = torch.from_numpy( (self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx] ).float()
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float()
lat_weights = lat_weights.unsqueeze(-1)

# convert to tensor
Expand All @@ -79,11 +83,23 @@ def __init__(
self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False)

# for the longitudes we can use the fact that points are equidistant
# TODO: add mode modes for upscaling in longitude
assert nlon_out % nlon_in == 0
self.lon_scale_factor = nlon_out // nlon_in
self.lon_shift = (self.lon_scale_factor + 1) // 2 - 1
# get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1)

# get the difference
diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
diff = np.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float()

# convert to tensor
lon_idx_left = torch.LongTensor(lon_idx_left)
lon_idx_right = torch.LongTensor(lon_idx_right)

# register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False)

def extra_repr(self):
r"""
Expand All @@ -92,17 +108,23 @@ def extra_repr(self):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"

def _upscale_longitudes(self, x: torch.Tensor):
# for artifact-free upsampling in the longitudinal direction
x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
x = torch.roll(x, - self.lon_shift, dims=-1)
# do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
return x

# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x

def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx+1, :], self.lat_weights)
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
return x

def forward(self, x: torch.Tensor):
x = self._upscale_latitudes(x)
x = self._upscale_longitudes(x)
return x
return x

0 comments on commit 603318c

Please sign in to comment.