Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

64-bit indexing Adam #1786

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()(
index_t chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
TensorListMetadata<4, index_t>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
Expand Down Expand Up @@ -399,7 +399,7 @@ void multi_tensor_adam_cuda(
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
multi_tensor_apply64<4>(
(int64_t) BLOCK_SIZE,
(int64_t) chunk_size,
noop_flag,
Expand Down
105 changes: 101 additions & 4 deletions csrc/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_tensors64[6] = {55, 32, 24, 18, 15, 12};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};

template<int n> struct TensorListMetadata
template<int n, typename T = int32_t> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
static constexpr int max_tensors = std::is_same<T, int64_t>::value ? depth_to_max_tensors64[n-1] : depth_to_max_tensors[n-1];

void* addresses[n][max_tensors];
T sizes[max_tensors];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
T block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};

Expand Down Expand Up @@ -131,3 +134,97 @@ void multi_tensor_apply(
}
}
}

template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply64(
int64_t block_size,
int64_t chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}

int ntensors = tensor_lists[0].size();

TensorListMetadata<depth, int64_t> tl;

const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();

tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;

auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;

for(auto chunk = 0; chunk < chunks_this_tensor; chunk++)
{
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;

bool tensors_full = (loc_tensor_info == depth_to_max_tensors64[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size,
noop_flag.DATA_PTR<int>(),
tl,
callable,
args...);

AT_CUDA_CHECK(cudaGetLastError());

// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
13 changes: 11 additions & 2 deletions tests/L0/run_optimizers/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,18 +233,27 @@ def testNative(self):

self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

@largeTensorTest('60GB', 'cuda')
@largeTensorTest('60GB', 'cuda')
def testLargeTensor(self):
t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
grad = torch.randn_like(t)

# Instead of using torch.randn_like, we use a combination of torch.zeros_like and uniform_
# to avoid creating gradients close to 0(like 0.01 in this case), which torch.optim.Adam could handle improperly,
# potentially updating parameters to inf values.
grad = torch.zeros_like(t).uniform_(0.1, 1000).cuda()
signs = (torch.randint(0, 2, grad.shape, dtype=grad.dtype) * 2 - 1).cuda()
grad = grad * signs

t.grad = grad
t2.grad = grad
params = [t]
params2 = [t2]
optimizer = apex.optimizers.FusedAdam(params, lr=self.lr)
optimizer.step()
optimizer2 = torch.optim.Adam(params2, lr=self.lr)
optimizer2.step()

torch.testing.assert_close(t, t2)
torch.cuda.synchronize()

Expand Down