diff --git a/torch_harmonics/csrc/disco/disco_cuda_bwd.cu b/torch_harmonics/csrc/disco/disco_cuda_bwd.cu index 9252f99..d31fc28 100644 --- a/torch_harmonics/csrc/disco/disco_cuda_bwd.cu +++ b/torch_harmonics/csrc/disco/disco_cuda_bwd.cu @@ -286,7 +286,7 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor out = torch::zeros(out_dims, options); // get stream - auto stream = at::cuda::getCurrentCUDAStream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); // assert static_assert(0 == (ELXTH_MAX%2)); diff --git a/torch_harmonics/csrc/disco/disco_cuda_fwd.cu b/torch_harmonics/csrc/disco/disco_cuda_fwd.cu index 3163421..51bff72 100644 --- a/torch_harmonics/csrc/disco/disco_cuda_fwd.cu +++ b/torch_harmonics/csrc/disco/disco_cuda_fwd.cu @@ -264,7 +264,7 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor out = torch::zeros(out_dims, options); // get stream - auto stream = at::cuda::getCurrentCUDAStream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); // assert static_assert(0 == (ELXTH_MAX%2));