diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index f56f5a02e7..e8cbf9869e 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1499,7 +1499,8 @@ namespace dlib const tensor& src, const tensor& gamma, tensor& src_grad, - tensor& gamma_grad + tensor& gamma_grad, + tensor& dscale ) { const long num = src.k() * src.nr() * src.nc(); @@ -1519,8 +1520,6 @@ namespace dlib const auto p_gamma_grad = gamma_grad.host(); const auto p_scale = scale.host(); - resizable_tensor dscale; - dscale.copy_size(scale); dscale = 0; const auto p_dscale = dscale.host(); diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index 3cca447e75..8605a071db 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -270,7 +270,8 @@ namespace dlib const tensor& src, const tensor& gamma, tensor& src_grad, - tensor& gamma_grad + tensor& gamma_grad, + tensor& dscale ); // ----------------------------------------------------------------------------------- diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 82b440c24f..9c2f3ec9ee 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2345,7 +2345,8 @@ namespace dlib const tensor& src, const tensor& gamma, tensor& src_grad, - tensor& gamma_grad + tensor& gamma_grad, + tensor& dscale ) { const long num = src.k() * src.nr() * src.nc(); @@ -2359,8 +2360,6 @@ namespace dlib DLIB_CASSERT(eps > 0); gamma_grad = 0; - resizable_tensor dscale; - dscale.copy_size(scale); dscale = 0; launch_kernel(_cuda_rms_normalize_gradient, max_jobs(num, src.num_samples()), src_grad.device(), gamma_grad.device(), src.device(), diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index d4753bef6f..9086b43d8c 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -377,7 +377,8 @@ namespace dlib const tensor& src, const tensor& gamma, tensor& src_grad, - tensor& gamma_grad + tensor& gamma_grad, + tensor& dscale ); // ----------------------------------------------------------------------------------- diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index 2cdf0c97bb..6e79b348e9 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -718,13 +718,14 @@ namespace dlib { namespace tt const tensor& src, const tensor& gamma, tensor& src_grad, - tensor& gamma_grad + tensor& gamma_grad, + tensor& dscale ) { #ifdef DLIB_USE_CUDA - cuda::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad); + cuda::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); #else - cpu::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad); + cpu::rms_normalize_gradient(eps, gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); #endif } diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 548fde4269..67c7d5808b 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -885,7 +885,8 @@ namespace dlib { namespace tt const tensor& src, const tensor& gamma, tensor& src_grad, - tensor& gamma_grad + tensor& gamma_grad, + tensor& dscale ); /*! requires diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index f2e8b0ce5b..43c0e1d379 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -1536,6 +1536,7 @@ namespace dlib gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc()); params.set_size(gamma.size()); gamma(params, 0) = 1; + dscale.copy_size(gamma(params, 0)); } template @@ -1550,7 +1551,7 @@ namespace dlib { auto g = gamma(params, 0); auto g_grad = gamma(params_grad, 0); - tt::rms_normalize_gradient(eps, gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad); + tt::rms_normalize_gradient(eps, gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale); } const tensor& get_layer_params() const { return params; }; @@ -1605,6 +1606,7 @@ namespace dlib resizable_tensor params; alias_tensor gamma; resizable_tensor scale; + resizable_tensor dscale; double learning_rate_multiplier; double weight_decay_multiplier; double eps;