diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 8330461..e08567b 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -213,7 +213,7 @@ def __init__( out_channels: int, in_shape: Tuple[int], out_shape: Tuple[int], - kernel_shape: Union[int, List[int]], + kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basis_type: Optional[str] = "piecewise linear", basis_norm_mode: Optional[str] = "mean", groups: Optional[int] = 1, @@ -354,7 +354,7 @@ def __init__( out_channels: int, in_shape: Tuple[int], out_shape: Tuple[int], - kernel_shape: Union[int, List[int]], + kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basis_type: Optional[str] = "piecewise linear", basis_norm_mode: Optional[str] = "mean", groups: Optional[int] = 1,