diff --git a/nerfacc/cuda/csrc/scan_cub.cu b/nerfacc/cuda/csrc/scan_cub.cu index ac813b2..64af076 100644 --- a/nerfacc/cuda/csrc/scan_cub.cu +++ b/nerfacc/cuda/csrc/scan_cub.cu @@ -87,13 +87,13 @@ torch::Tensor inclusive_sum_cub( #if CUB_SUPPORTS_SCAN_BY_KEY() if (backward) { inclusive_sum_by_key( - thrust::make_reverse_iterator(indices.data_ptr() + n_edges), + thrust::make_reverse_iterator(indices.data_ptr() + n_edges), thrust::make_reverse_iterator(inputs.data_ptr() + n_edges), thrust::make_reverse_iterator(outputs.data_ptr() + n_edges), n_edges); } else { inclusive_sum_by_key( - indices.data_ptr(), + indices.data_ptr(), inputs.data_ptr(), outputs.data_ptr(), n_edges); @@ -129,13 +129,13 @@ torch::Tensor exclusive_sum_cub( #if CUB_SUPPORTS_SCAN_BY_KEY() if (backward) { exclusive_sum_by_key( - thrust::make_reverse_iterator(indices.data_ptr() + n_edges), + thrust::make_reverse_iterator(indices.data_ptr() + n_edges), thrust::make_reverse_iterator(inputs.data_ptr() + n_edges), thrust::make_reverse_iterator(outputs.data_ptr() + n_edges), n_edges); } else { exclusive_sum_by_key( - indices.data_ptr(), + indices.data_ptr(), inputs.data_ptr(), outputs.data_ptr(), n_edges); @@ -169,7 +169,7 @@ torch::Tensor inclusive_prod_cub_forward( #if CUB_SUPPORTS_SCAN_BY_KEY() inclusive_prod_by_key( - indices.data_ptr(), + indices.data_ptr(), inputs.data_ptr(), outputs.data_ptr(), n_edges); @@ -203,7 +203,7 @@ torch::Tensor inclusive_prod_cub_backward( } #if CUB_SUPPORTS_SCAN_BY_KEY() inclusive_sum_by_key( - thrust::make_reverse_iterator(indices.data_ptr() + n_edges), + thrust::make_reverse_iterator(indices.data_ptr() + n_edges), thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr() + n_edges), thrust::make_reverse_iterator(grad_inputs.data_ptr() + n_edges), n_edges); @@ -237,7 +237,7 @@ torch::Tensor exclusive_prod_cub_forward( } #if CUB_SUPPORTS_SCAN_BY_KEY() exclusive_prod_by_key( - indices.data_ptr(), + indices.data_ptr(), inputs.data_ptr(), outputs.data_ptr(), n_edges); @@ -272,7 +272,7 @@ torch::Tensor exclusive_prod_cub_backward( #if CUB_SUPPORTS_SCAN_BY_KEY() exclusive_sum_by_key( - thrust::make_reverse_iterator(indices.data_ptr() + n_edges), + thrust::make_reverse_iterator(indices.data_ptr() + n_edges), thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr() + n_edges), thrust::make_reverse_iterator(grad_inputs.data_ptr() + n_edges), n_edges);