Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing channel dimension of distributed SHT from 1 to -3 #52

Merged
merged 9 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### v0.7.2

* Added resampling modules for convenience
* Changing behavior of distributed SHT to use `dim=-3` as channel dimension

### v0.7.1

Expand Down
2 changes: 1 addition & 1 deletion examples/train_sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.cuda.manual_seed(333)

# set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(device.index)

Expand Down
76 changes: 76 additions & 0 deletions tests/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/bin/bash

# Set default parameters
default_grid_size_lat=1
default_grid_size_lon=1
default_run_distributed=false

# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)

bold=$(tput bold)
normal=$(tput sgr0)

echo "Runs the torch-harmonics test suite."
echo "${bold}Arguments:${normal}"
echo " ${bold}-h | --help:${normal} Prints this text."
echo " ${bold}-d | --run_distributed:${normal} Run the distributed test suite."
echo " ${bold}-lat | --grid_size_lat:${normal} Number of ranks in latitudinal direction for distributed case."
echo " ${bold}-lon | --grid_size_lon:${normal} Number of ranks in longitudinal direction for distributed case."

shift
exit 0
;;
-lat|--grid_size_lat)
grid_size_lat="$2"
shift 2
;;
-lon|--grid_size_lon)
grid_size_lon="$2"
shift 2
;;
-d|--run_distributed)
run_distributed=true
shift
;;
*)
echo "Unknown argument: $1"
exit 1
;;
esac
done

# Use default values if arguments were not provided
grid_size_lat=${grid_size_lat:-$default_grid_size_lat}
grid_size_lon=${grid_size_lon:-$default_grid_size_lon}
run_distributed=${run_distributed:-$default_run_distributed}

echo "Running sequential tests:"
python3 -m pytest tests/test_convolution.py tests/test_sht.py

# Run distributed tests if requested
if [ "$run_distributed" = "true" ]; then

echo "Running distributed tests with the following parameters:"
echo "Grid size latitude: $grid_size_lat"
echo "Grid size longitude: $grid_size_lon"

ngpu=$(( ${grid_size_lat} * ${grid_size_lon} ))

mpirun --allow-run-as-root -np ${ngpu} bash -c "
export CUDA_LAUNCH_BLOCKING=1;
export WORLD_RANK=\${OMPI_COMM_WORLD_RANK};
export WORLD_SIZE=\${OMPI_COMM_WORLD_SIZE};
export RANK=\${OMPI_COMM_WORLD_RANK};
export MASTER_ADDR=localhost;
export MASTER_PORT=29501;
export GRID_H=${grid_size_lat};
export GRID_W=${grid_size_lon};
python3 -m pytest tests/test_distributed_sht.py
python3 -m pytest tests/test_distributed_convolution.py
"
else
echo "Skipping distributed tests."
fi
51 changes: 29 additions & 22 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes


def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
"""

kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
dr = 2 * r_cutoff / (nr + 1)

# compute the support
if nr % 2 == 1:
Expand All @@ -71,7 +72,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:

kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi

# disambiguate even and uneven cases and compute the support
Expand All @@ -87,7 +88,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) , 0.0)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else:
Expand All @@ -99,8 +100,8 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
vals = r_vals * phi_vals

# in the even case, the inner casis functions overlap into areas with a negative areas
rn = - r
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)
rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
Expand All @@ -109,6 +110,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:

return vals


def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
Expand All @@ -120,7 +122,7 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1], dim=(1, 4), keepdim=True) / scale_factor
if merge_quadrature:
psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi
else:
Expand All @@ -131,7 +133,17 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
return psi / (psi_norm + eps)


def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False):
def _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
quad_weights,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
merge_quadrature=False,
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
Expand All @@ -143,7 +155,7 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad

if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = math.ceil( kernel_shape[0] / 2)
kernel_size = math.ceil(kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
Expand Down Expand Up @@ -250,30 +262,25 @@ def test_disco_convolution(
theta_cutoff = (kernel_shape[0] + 1) / 2 * torch.pi / float(nlat_out - 1)

Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff
).to(self.device)
conv = Conv(in_channels, out_channels, in_shape, out_shape, kernel_shape, groups=1, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=theta_cutoff).to(
self.device
)

_, wgl = _precompute_latitudes(nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in

if transpose:
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
).to(self.device)

psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()

self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else:
psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
).to(self.device)

psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
thd.finalize()
dist.destroy_process_group(None)
thd.finalize()
dist.destroy_process_group(None)

def _split_helper(self, tensor):
with torch.no_grad():
Expand Down
Loading