diff --git a/examples/eigenvalue.py b/examples/eigenvalue.py index e80b9eaa..2955efd6 100644 --- a/examples/eigenvalue.py +++ b/examples/eigenvalue.py @@ -4,6 +4,7 @@ problem to the Sphere """ import torch + try: from torch.linalg import eigvalsh except ImportError: @@ -12,6 +13,7 @@ def eigvalsh(X): return symeig(X, eigenvectors=False).eigenvalues + from torch import nn import geotorch diff --git a/geotorch/lowrank.py b/geotorch/lowrank.py index b2e92f85..1ac51619 100644 --- a/geotorch/lowrank.py +++ b/geotorch/lowrank.py @@ -1,7 +1,9 @@ import torch from functools import partial + try: from torch.linalg import svd + svd = partial(svd, full_matrices=False) except ImportError: from torch import svd diff --git a/geotorch/so.py b/geotorch/so.py index 8b6e471a..e7b58b9d 100644 --- a/geotorch/so.py +++ b/geotorch/so.py @@ -1,6 +1,7 @@ import math import torch from torch import nn + try: from torch.linalg import qr except ImportError: diff --git a/geotorch/symmetric.py b/geotorch/symmetric.py index 92662c49..da07ec12 100644 --- a/geotorch/symmetric.py +++ b/geotorch/symmetric.py @@ -1,16 +1,19 @@ import torch from torch import nn from functools import partial + try: from torch.linalg import eigh from torch.linalg import eigvalsh except ImportError: from torch import symeig + eigh = partial(symeig, eigenvectors=True) def eigvalsh(X): return symeig(X, eigenvectors=False).eigenvalues + from .product import ProductManifold from .stiefel import Stiefel from .reals import Rn