Skip to content

Commit

Permalink
Tkurth/torchification (#66)
Browse files Browse the repository at this point in the history
* adding caching

* replacing many numpy calls with torch calls

* bumping up version number to 0.7.6
  • Loading branch information
azrael417 authored Feb 21, 2025
1 parent 780fd14 commit 6730e5c
Show file tree
Hide file tree
Showing 20 changed files with 569 additions and 300 deletions.
9 changes: 9 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

## Versioning

### v0.7.6

* Adding cache for precomoputed tensors such as weight tensors for DISCO and SHT
* Cache is returning copies of tensors and not references. Users are still encouraged to re-use
those tensors manually in their models because this will also save memory. However,
the cache will help with model setup speed.
* Adding test which ensures that cache is working correctly

### v0.7.5

* New normalization mode `support` for DISCO convolutions
* More efficient computation of Morlet filter basis
* Changed default for Morlet filter basis to a Hann window function
Expand Down
162 changes: 147 additions & 15 deletions notebooks/resample_sphere.ipynb

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
from parameterized import parameterized
import math
import torch


class TestCacheConsistency(unittest.TestCase):

def test_consistency(self, verbose=False):
if verbose:
print("Testing that cache values does not get modified externally")
from torch_harmonics.legendre import _precompute_legpoly

with torch.no_grad():
cost = torch.cos(torch.linspace(0.0, 2.0 * math.pi, 10, dtype=torch.float64))
leg1 = _precompute_legpoly(10, 10, cost)
# perform in-place modification of leg1
leg1 *= -1.0
leg2 = _precompute_legpoly(10, 10, cost)
self.assertFalse(torch.allclose(leg1, leg2))


if __name__ == "__main__":
unittest.main()
70 changes: 37 additions & 33 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torch.autograd import gradcheck
from torch_harmonics import *

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes


def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
Expand Down Expand Up @@ -106,22 +106,20 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape

lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)

# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1]
# compute the phi differences.
lons_in = _precompute_longitudes(nlon_in)
lons_out = _precompute_longitudes(nlon_out)

# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff

# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0

# array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
Expand Down Expand Up @@ -172,34 +170,32 @@ def setUp(self):
else:
self.device = torch.device("cpu")

self.device = torch.device("cpu")

@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3], "zernike", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3], "zernike", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
]
)
def test_disco_convolution(
Expand All @@ -216,11 +212,19 @@ def test_disco_convolution(
grid_out,
transpose,
tol,
verbose,
):

if verbose:
print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device")

nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
if isinstance(kernel_shape, int):
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)

Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
Expand Down
24 changes: 12 additions & 12 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,18 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):

@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
]
)
def test_distributed_disco_conv(
Expand Down
14 changes: 7 additions & 7 deletions tests/test_distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def _gather_helper_bwd(self, tensor, B, C, resampling_dist):

@parameterized.expand(
[
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
]
)
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
):

B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_distributed_resampling(
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
if verbose and (self.world_rank == )0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)

Expand All @@ -259,7 +259,7 @@ def test_distributed_resampling(
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)

err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
if verbose and (self.world_rank == 0):
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)

Expand Down
Loading

0 comments on commit 6730e5c

Please sign in to comment.