diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc.h b/apex/contrib/csrc/group_norm/group_norm_nhwc.h index dc1bd020..9fb017c9 100755 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc.h +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc.h @@ -123,7 +123,7 @@ struct Group_norm_nhwc_fwd_params { // The number of instances in the batch. int n; // The height and width of each activation map. The number of channels. - int h, w, c, hw, hwc; + int64_t h, w, c, hw, hwc; // The number of groups. int groups; // Do we apply the Swish activation function? @@ -138,7 +138,7 @@ struct Group_norm_nhwc_fwd_params { // The number of groups in each block. int groups_per_block; // The number of channels per group = c / groups. - int channels_per_group; + int channels_per_group; // The number of channels per block = groups_per_block * channels_per_group. int channels_per_block; // The inverse of hwc in floats (to compute mean/var). @@ -149,7 +149,7 @@ struct Group_norm_nhwc_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&, +void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&, size_t &red_buffer_elts); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -190,7 +190,7 @@ struct Group_norm_nhwc_bwd_params { // The number of instances in the batch. int n; // The height and width of each activation map. The number of channels. - int h, w, c, hw, hwc; + int64_t h, w, c, hw, hwc; // The number of groups. int groups; // Do we apply the Swish activation function? diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu index 1a8336e8..1fe904ce 100755 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu @@ -91,7 +91,7 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min(hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); // The gradients for gamma/beta. float2 dgamma = make_float2(0.f, 0.f), dbeta = make_float2(0.f, 0.f); @@ -212,7 +212,7 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params ¶ms, // Define the number of blocks per activation map. That's a simple heuristic. int blocks_per_act_slice = 0; - if( params.c >= 1280 ) { + if( params.c >= 1280 ) { blocks_per_act_slice = 128 / params.n; } else if( params.c >= 640 ) { blocks_per_act_slice = 256 / params.n; @@ -267,13 +267,13 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params ¶ms, // Make sure a group does not span multiple blocks. assert(params.channels_per_block % params.channels_per_group == 0); - // The number of elements in the reduction buffer (for the sums and sums of squared). + // The number of elements in the reduction buffer (for the sums and sums of squared). zeroed_red_buffer_elts = params.n * params.groups * 2 + params.c * 2; } //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params ¶ms, +void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params ¶ms, cudaStream_t stream) { // The dimension of the grid. @@ -376,7 +376,7 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min(hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); // Iterate over the activations to compute the sums. for( int hwi = hw_begin; hwi < hw_end; ++hwi ) { diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu index 83edc4cc..06f9ff67 100755 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu @@ -55,7 +55,7 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min(hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); // The sums. float sum = 0.f, sum_sq = 0.f; @@ -132,7 +132,7 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params ¶ms, // Define the number of blocks per activation map. That's a simple heuristic. int blocks_per_act_slice = 0; - if( params.c >= 1280 ) { + if( params.c >= 1280 ) { blocks_per_act_slice = 128 / params.n; } else if( params.c >= 640 ) { blocks_per_act_slice = 256 / params.n; @@ -186,13 +186,13 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params ¶ms, // Make sure a group does not span multiple blocks. assert(params.channels_per_block % params.channels_per_group == 0); - // The number of elements in the reduction buffer (for the sums and sums of squared). + // The number of elements in the reduction buffer (for the sums and sums of squared). zeroed_red_buffer_elts = params.n * params.groups * 2; } //////////////////////////////////////////////////////////////////////////////////////////////////// -void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params ¶ms, +void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params ¶ms, cudaStream_t stream) { // The dimension of the grid. @@ -285,7 +285,7 @@ __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params para // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. - int hw_end = min(hw_begin + params.acts_per_block, params.hw); + int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw); // Iterate over the activations to compute the sums. for( int hwi = hw_begin; hwi < hw_end; ++hwi ) { diff --git a/apex/contrib/test/group_norm/test_group_norm.py b/apex/contrib/test/group_norm/test_group_norm.py index f068b1bf..5675749d 100644 --- a/apex/contrib/test/group_norm/test_group_norm.py +++ b/apex/contrib/test/group_norm/test_group_norm.py @@ -89,10 +89,10 @@ def verify_group_norm(self, dx_tst, dw_tst, db_tst = [t.grad.clone() for t in [x, weight, bias]] # compare - torch.testing.assert_close(y_tst, y_ref, atol=4e-2, rtol=0) - torch.testing.assert_close(dx_tst, dx_ref, atol=4e-2, rtol=0) - torch.testing.assert_close(dw_tst, dw_ref, atol=4e-2, rtol=0) - torch.testing.assert_close(db_tst, db_ref, atol=4e-2, rtol=0) + torch.testing.assert_close(y_tst, y_ref, atol=7e-2, rtol=0) + torch.testing.assert_close(dx_tst, dx_ref, atol=7e-2, rtol=0) + torch.testing.assert_close(dw_tst, dw_ref, atol=7e-2, rtol=0) + torch.testing.assert_close(db_tst, db_ref, atol=7e-2, rtol=0) def test_fp16_one_pass_algo(self): self.verify_group_norm(cuda_group_norm_nhwc_one_pass, act="") @@ -177,6 +177,7 @@ def test_16_groups(self): [8, 1920, 32, 32], [8, 1920, 16, 16], [8, 2560, 8, 8], + [1, 128, 16128, 1200], ] for sz in sizes: n, c, h, w = sz