Skip to content

Commit

Permalink
Fix layer_normalize gradients (#3001)
Browse files Browse the repository at this point in the history
* Fix layer_normalize gradients

* fix layer_norm CPU

* attempt to fix the cuda version

* fix gamma_grad and beta_grad

* update cuda test

* use a block of size 1 to avoid race conditions

* improve the speed of CUDA path of layer_norm

* improve the speed of CUDA path of layer_norm
  • Loading branch information
arrufat authored Sep 1, 2024
1 parent 27a0135 commit 253098e
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 149 deletions.
130 changes: 73 additions & 57 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,22 +1270,19 @@ namespace dlib
const tensor& beta
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.k() == gamma.k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);

Expand All @@ -1296,43 +1293,50 @@ namespace dlib
// first compute means and invstds
means = 0;
invstds = 0;
const auto p_invstds = invstds.host();
const auto p_means = means.host();
auto p_src = src.host();
const float* p_src = src.host();
float* p_invstds = invstds.host();
float* p_means = means.host();
const long num = src.nr() * src.nc();
// compute means, and sum of squares
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
for (long k = 0; k < src.k(); ++k)
{
float val = p_src[n*num+i];
p_means[n] += val;
p_invstds[n] += val*val;
for (long i = 0; i < num; ++i)
{
p_means[n] += *p_src;
p_invstds[n] += (*p_src) * (*p_src);
++p_src;
}
}
}
means /= num;
invstds /= num;
means /= src.k() * num;
invstds /= src.k () * num;
// copy data back to host
invstds.host(); means.host();
invstds.host();
means.host();

// compute variances
for (long n = 0; n < src.num_samples(); ++n)
{
auto var = p_invstds[n] - p_means[n] * p_means[n];
p_invstds[n] = 1.0f / std::sqrt(var + eps);
p_invstds[n] = 1.0f / std::sqrt(p_invstds[n] - p_means[n] * p_means[n] + eps);
}

p_src = src.host();
auto p_dest = dest.host();
auto p_gamma = gamma.host();
auto p_beta = beta.host();
float* p_dest = dest.host();
const float* p_gamma = gamma.host();
const float* p_beta = beta.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
for (long k = 0; k < src.k(); ++k)
{
*p_dest = (*p_src - p_means[n])*p_invstds[n];
*p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
++p_src;
++p_dest;
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src - p_means[n]) * p_invstds[n];
*p_dest = (*p_dest) * p_gamma[k] + p_beta[k];
++p_src;
++p_dest;
}
}
}
}
Expand All @@ -1346,22 +1350,26 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
)
{
const long num = src.k() * src.nr() * src.nc();
const long num = src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma_grad.nr());
DLIB_CASSERT(src.nc() == beta_grad.nc());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0);

beta_grad = 0;
gamma_grad = 0;

auto p_grad = gradient_input.host();
auto p_src = src.host();
const auto p_gamma = gamma.host();
Expand All @@ -1370,7 +1378,6 @@ namespace dlib
const auto p_invstds = invstds.host();
const auto p_means = means.host();

resizable_tensor dvars, dmeans;
dvars.copy_size(invstds);
dmeans.copy_size(means);
dvars = 0;
Expand All @@ -1380,53 +1387,62 @@ namespace dlib

for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
const float invstd_pow = -0.5 * std::pow(p_invstds[n], 3.0f);
for (long k = 0; k < src.k(); ++k)
{
const float x_hat = (*p_src - p_means[n])*p_invstds[n];
p_beta_grad[i] += *p_grad;
p_gamma_grad[i] += (*p_grad)*x_hat;
for (long i = 0; i < num; ++i)
{
const float x_hat = (*p_src - p_means[n]) * p_invstds[n];
p_beta_grad[k] += *p_grad;
p_gamma_grad[k] += (*p_grad) * x_hat;

const float dx = *p_grad * p_gamma[n];
const float dx = *p_grad * p_gamma[k];

p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow;

++p_grad;
++p_src;
++p_grad;
++p_src;
}
}
}

const float invnum = 1.0f/num;
p_grad = gradient_input.host();
p_src = src.host();
const float invnum = 1.0f / (src.k() * num);
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
for (long k = 0; k < src.k(); ++k)
{
const float dx = *p_grad * p_gamma[i];
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[k];

p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum;

++p_grad;
++p_src;
++p_grad;
++p_src;
}
}
}
p_grad = gradient_input.host();
p_src = src.host();
auto p_src_grad = src_grad.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
for (long k = 0; k < src.k(); ++k)
{
const float dx = *p_grad * p_gamma[i];

*p_src_grad += dx*p_invstds[n] +
p_dvars[n] *2*(*p_src - p_means[n])*invnum +
p_dmeans[n]*invnum;
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[k];

*p_src_grad += dx * p_invstds[n] +
p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum +
p_dmeans[n] * invnum;

++p_grad;
++p_src;
++p_src_grad;
++p_grad;
++p_src;
++p_src_grad;
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
);

// -----------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 253098e

Please sign in to comment.