Skip to content

Commit

Permalink
[sync BN] (NVIDIA#792)
Browse files Browse the repository at this point in the history
* [sync BN]

support non-uniform batch size across process group.

TODO: test should be added once cleaned up.

* updating unit tests

* new unit tests for different inputs

* cleaning
  • Loading branch information
jjsjann123 authored Jul 6, 2020
1 parent 43a6f9f commit 1ff54b8
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 83 deletions.
36 changes: 21 additions & 15 deletions apex/parallel/optimized_sync_batchnorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,24 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
device = mean.device
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead!
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
else:
device = mean.device
count_all = torch.cuda.IntTensor([count], device=device)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)

Expand All @@ -52,7 +60,7 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)

ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
Expand All @@ -71,7 +79,7 @@ def backward(ctx, grad_output):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
Expand All @@ -83,26 +91,24 @@ def backward(ctx, grad_output):
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()

# TODO(jie): why do I have to clone here? life time of grad_output?
# TODO: update kernel to not pre_divide by item_num
if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)

# calculate grad_input
if ctx.needs_input_grad[0]:

if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
sum_dy, ReduceOp.SUM, process_group)
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
sum_dy_xmu, ReduceOp.SUM, process_group)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)

if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
Expand Down
20 changes: 11 additions & 9 deletions csrc/syncbn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const at::Tensor numel,
const float eps);

// elementwise BN operation, returns output
Expand All @@ -24,7 +24,7 @@ at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);

// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
Expand All @@ -36,14 +36,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,

// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);

// returns {mean, biased_var}
// implemented using welford
Expand All @@ -62,7 +63,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> shift,
const bool fuse_relu);

// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
Expand All @@ -74,15 +75,16 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,

// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);

at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
Expand Down
Loading

0 comments on commit 1ff54b8

Please sign in to comment.