Skip to content

Commit

Permalink
Fix dangling pointer issue in CUDA implementation of rms_normalize_gr…
Browse files Browse the repository at this point in the history
…adient
  • Loading branch information
Cydral authored Aug 29, 2024
1 parent 5a85514 commit 19a0cfe
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 13 deletions.
5 changes: 2 additions & 3 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,8 @@ namespace dlib
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad
tensor& gamma_grad,
tensor& dscale
)
{
const long num = src.k() * src.nr() * src.nc();
Expand All @@ -1519,8 +1520,6 @@ namespace dlib
const auto p_gamma_grad = gamma_grad.host();
const auto p_scale = scale.host();

resizable_tensor dscale;
dscale.copy_size(scale);
dscale = 0;
const auto p_dscale = dscale.host();

Expand Down
3 changes: 2 additions & 1 deletion dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ namespace dlib
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad
tensor& gamma_grad,
tensor& dscale
);

// -----------------------------------------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions dlib/cuda/cuda_dlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,8 @@ namespace dlib
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad
tensor& gamma_grad,
tensor& dscale
)
{
const long num = src.k() * src.nr() * src.nc();
Expand All @@ -2359,8 +2360,6 @@ namespace dlib
DLIB_CASSERT(eps > 0);

gamma_grad = 0;
resizable_tensor dscale;
dscale.copy_size(scale);
dscale = 0;
launch_kernel(_cuda_rms_normalize_gradient, max_jobs(num, src.num_samples()),
src_grad.device(), gamma_grad.device(), src.device(),
Expand Down
3 changes: 2 additions & 1 deletion dlib/cuda/cuda_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ namespace dlib
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad
tensor& gamma_grad,
tensor& dscale
);

// -----------------------------------------------------------------------------------
Expand Down
7 changes: 4 additions & 3 deletions dlib/cuda/tensor_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,13 +718,14 @@ namespace dlib { namespace tt
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad
tensor& gamma_grad,
tensor& dscale
)
{
#ifdef DLIB_USE_CUDA
cuda::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad);
cuda::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale);
#else
cpu::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad);
cpu::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale);
#endif
}

Expand Down
3 changes: 2 additions & 1 deletion dlib/cuda/tensor_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,8 @@ namespace dlib { namespace tt
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad
tensor& gamma_grad,
tensor& dscale
);
/*!
requires
Expand Down
4 changes: 3 additions & 1 deletion dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,7 @@ namespace dlib
gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc());
params.set_size(gamma.size());
gamma(params, 0) = 1;
dscale.copy_size(gamma(params, 0));
}

template <typename SUBNET>
Expand All @@ -1550,7 +1551,7 @@ namespace dlib
{
auto g = gamma(params, 0);
auto g_grad = gamma(params_grad, 0);
tt::rms_normalize_gradient(eps, gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad);
tt::rms_normalize_gradient(eps, gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale);
}

const tensor& get_layer_params() const { return params; };
Expand Down Expand Up @@ -1605,6 +1606,7 @@ namespace dlib
resizable_tensor params;
alias_tensor gamma;
resizable_tensor scale;
resizable_tensor dscale;
double learning_rate_multiplier;
double weight_decay_multiplier;
double eps;
Expand Down

0 comments on commit 19a0cfe

Please sign in to comment.