diff --git a/csrc/multi_tensor_adam.cu b/csrc/multi_tensor_adam.cu index 4f8cade40..5f0c6f271 100644 --- a/csrc/multi_tensor_adam.cu +++ b/csrc/multi_tensor_adam.cu @@ -26,7 +26,7 @@ struct AdamFunctor __device__ __forceinline__ void operator()( index_t chunk_size, volatile int* noop_gmem, - TensorListMetadata<4>& tl, + TensorListMetadata<4, index_t>& tl, const float beta1, const float beta2, const float beta1_correction, @@ -399,7 +399,7 @@ void multi_tensor_adam_cuda( // Assume single type across p,g,m1,m2 now DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0].scalar_type(), 0, "adam", - multi_tensor_apply<4>( + multi_tensor_apply64<4>( (int64_t) BLOCK_SIZE, (int64_t) chunk_size, noop_flag, diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh index 4e98bc7d9..28ccc97b1 100644 --- a/csrc/multi_tensor_apply.cuh +++ b/csrc/multi_tensor_apply.cuh @@ -14,14 +14,17 @@ // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; +constexpr int depth_to_max_tensors64[6] = {55, 32, 24, 18, 15, 12}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; -template struct TensorListMetadata +template struct TensorListMetadata { - void* addresses[n][depth_to_max_tensors[n-1]]; - int sizes[depth_to_max_tensors[n-1]]; + static constexpr int max_tensors = std::is_same::value ? depth_to_max_tensors64[n-1] : depth_to_max_tensors[n-1]; + + void* addresses[n][max_tensors]; + T sizes[max_tensors]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; - int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. + T block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. int start_tensor_this_launch; }; @@ -131,3 +134,97 @@ void multi_tensor_apply( } } } + +template +void multi_tensor_apply64( + int64_t block_size, + int64_t chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for(int t = 0; t < tensor_lists[l].size(); t++) + { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for(int t = 0; t < ntensors; t++) + { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for(int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; + + for(auto chunk = 0; chunk < chunks_this_tensor; chunk++) + { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors64[depth-1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if(tensors_full || blocks_full || last_chunk) + { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, + noop_flag.DATA_PTR(), + tl, + callable, + args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if(chunk == chunks_this_tensor - 1) + { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } + else + { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info-1]; + for(int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/tests/L0/run_optimizers/test_adam.py b/tests/L0/run_optimizers/test_adam.py index 23bf34f1a..a2422fd9a 100644 --- a/tests/L0/run_optimizers/test_adam.py +++ b/tests/L0/run_optimizers/test_adam.py @@ -233,11 +233,18 @@ def testNative(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - @largeTensorTest('60GB', 'cuda') + @largeTensorTest('60GB', 'cuda') def testLargeTensor(self): t = torch.zeros(2359332864, dtype=torch.half, device='cuda') t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda') - grad = torch.randn_like(t) + + # Instead of using torch.randn_like, we use a combination of torch.zeros_like and uniform_ + # to avoid creating gradients close to 0(like 0.01 in this case), which torch.optim.Adam could handle improperly, + # potentially updating parameters to inf values. + grad = torch.zeros_like(t).uniform_(0.1, 1000).cuda() + signs = (torch.randint(0, 2, grad.shape, dtype=grad.dtype) * 2 - 1).cuda() + grad = grad * signs + t.grad = grad t2.grad = grad params = [t] @@ -245,6 +252,8 @@ def testLargeTensor(self): optimizer = apex.optimizers.FusedAdam(params, lr=self.lr) optimizer.step() optimizer2 = torch.optim.Adam(params2, lr=self.lr) + optimizer2.step() + torch.testing.assert_close(t, t2) torch.cuda.synchronize()