Skip to content

Commit

Permalink
fixing amp stuff again (#48)
Browse files Browse the repository at this point in the history
- doing manual type conversion in custom autograd
- fixing stride issues in backward pass by making some output tensors contiguous
  • Loading branch information
azrael417 authored Aug 28, 2024
1 parent 5d7e9b0 commit f1a965b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
24 changes: 18 additions & 6 deletions torch_harmonics/_disco_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,48 +44,60 @@

class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float32)
@custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1]
output = disco_cuda_extension.forward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
xtype = x.dtype
x = x.to(torch.float32).contiguous()
output = disco_cuda_extension.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
output = output.to(xtype)

return output

@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = disco_cuda_extension.backward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
grad_input = disco_cuda_extension.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
grad_input = grad_input.to(gtype)

return grad_input, None, None, None, None, None, None, None, None


class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float32)
@custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1]
output = disco_cuda_extension.backward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
xtype = x.dtype
x = x.to(torch.float32).contiguous()
output = disco_cuda_extension.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
output = output.to(xtype)

return output

@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
grad_input = disco_cuda_extension.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
grad_input = grad_input.to(gtype)

return grad_input, None, None, None, None, None, None, None, None

Expand Down
4 changes: 2 additions & 2 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.reshape(B, self.groups, self.groupsize, K, H, W)

# do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
out = out.reshape(B, -1, H, W)

if self.bias is not None:
Expand Down Expand Up @@ -508,7 +508,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.reshape(B, self.groups, self.groupsize, H, W)

# do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
x = x.reshape(B, -1, x.shape[-3], H, W)

if x.is_cuda and _cuda_extension_available:
Expand Down

0 comments on commit f1a965b

Please sign in to comment.