From fafdac37f12814981fdc97013aaef4886b6d1bc0 Mon Sep 17 00:00:00 2001 From: Cydral <53169060+Cydral@users.noreply.github.com> Date: Sat, 7 Sep 2024 19:29:40 +0200 Subject: [PATCH] Add RMS Normalization Layer (#2999) * Add RMS Normalization Layer * Update dnn.cpp * Missing entry in visitors.h to take into account the new rms_norm_ layer * Fix test function name * Fix dangling pointer issue in CUDA implementation of rms_normalize_gradient * Fixing the dnn.cpp test program for the new rms_norm_ layer * General update of the rms_norm_ class --- dlib/cuda/cpu_dlib.cpp | 138 ++++++++++++++++++++++++++++ dlib/cuda/cpu_dlib.h | 20 ++++ dlib/cuda/cuda_dlib.cu | 160 ++++++++++++++++++++++++++++++++ dlib/cuda/cuda_dlib.h | 20 ++++ dlib/cuda/tensor_tools.cpp | 34 +++++++ dlib/cuda/tensor_tools.h | 53 ++++++++++- dlib/dnn/layers.h | 125 +++++++++++++++++++++++++ dlib/dnn/layers_abstract.h | 172 +++++++++++++++++++++++++++++++++++ dlib/dnn/visitors.h | 39 ++++++++ dlib/dnn/visitors_abstract.h | 4 +- dlib/test/dnn.cpp | 102 ++++++++++++++++++++- 11 files changed, 863 insertions(+), 4 deletions(-) diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index b8b5a4123..1acb35e00 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1447,6 +1447,144 @@ namespace dlib } } +// ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ) + { + DLIB_CASSERT( + gamma.k() == src.k() && + gamma.nr() == 1 && + gamma.nc() == 1 && + eps > 0, + "\nsrc.k(): " << src.k() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\neps: " << eps + ); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + dest.copy_size(src); + scale.set_size(ns); + + // Compute RMS values + scale = 0; + const float* p_src = src.host(); + float* p_scale = scale.host(); + for (long n = 0; n < ns; ++n) + { + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + p_scale[n] += (*p_src) * (*p_src); + ++p_src; + } + } + p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast(eps)); + } + scale.host(); + + // Apply RMS normalization + p_src = src.host(); + float* p_dest = dest.host(); + const float* p_gamma = gamma.host(); + for (long n = 0; n < ns; ++n) + { + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + *p_dest = (*p_src) * p_scale[n] * p_gamma[k]; + ++p_src; + ++p_dest; + } + } + } + } + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ) + { + DLIB_CASSERT(src.num_samples() == scale.size()); + DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); + DLIB_CASSERT(gamma.k() == src.k()); + DLIB_CASSERT(gamma.nr() == 1); + DLIB_CASSERT(gamma.nc() == 1); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + gamma_grad = 0; + dscale.copy_size(scale); + dscale = 0; + + auto p_grad = gradient_input.host(); + auto p_src = src.host(); + const auto p_gamma = gamma.host(); + const auto p_gamma_grad = gamma_grad.host(); + const auto p_scale = scale.host(); + auto p_dscale = dscale.host(); + + for (long n = 0; n < ns; ++n) + { + const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f); + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + const float x_hat = *p_src * p_scale[n]; + p_gamma_grad[k] += (*p_grad) * x_hat; + + const float dx = *p_grad * p_gamma[k]; + p_dscale[n] += dx * *p_src * scale_pow; + + ++p_grad; + ++p_src; + } + } + } + + p_grad = gradient_input.host(); + p_src = src.host(); + auto p_src_grad = src_grad.host(); + const float invnum = 1.0f / (ks * num); + for (long n = 0; n < ns; ++n) + { + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[k]; + *p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum; + + ++p_grad; + ++p_src; + ++p_src_grad; + } + } + } + } + // ----------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index 79ef9842b..45bc57fa9 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -255,6 +255,26 @@ namespace dlib resizable_tensor& dvars ); + // ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ); + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ); + // ----------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 70b8ccb0e..3484baa7b 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2280,6 +2280,166 @@ namespace dlib dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num); } + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_rms_normalize( + float* dest, + float* scale, + const float* src, + const float* gamma, + float eps, + size_t ns, + size_t ks, + size_t num + ) + { + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = src + n * ks * num; + float sum_squares = 0.0f; + for (auto i : grid_stride_range(0, ks * num)) + { + sum_squares += ps[i] * ps[i]; + } + warp_reduce_atomic_add(scale[n], sum_squares / (ks * num)); + } + __syncthreads(); + + for (auto n : grid_stride_range_y(0, ns)) + { + for (auto i : grid_stride_range(0, 1)) + { + scale[n] = 1.0f / std::sqrt(scale[n] + eps); + } + } + __syncthreads(); + + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = src + n * ks * num; + const auto pd = dest + n * ks * num; + for (auto i : grid_stride_range(0, ks * num)) + { + pd[i] = ps[i] * scale[n] * gamma[i / num]; + } + } + } + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ) + { + DLIB_CASSERT( + gamma.k() == src.k() && + gamma.nr() == 1 && + gamma.nc() == 1 && + eps > 0, + "\nsrc.k(): " << src.k() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\neps: " << eps + ); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + dest.copy_size(src); + scale.set_size(ns); + scale = 0; + + launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns), + dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_rms_normalize_gradient( + float* src_grad, + float* gamma_grad, + float* dscale, + const float* src, + const float* gradient_input, + const float* scale, + const float* gamma, + size_t ns, + size_t ks, + size_t num + ) + { + for (auto nk : grid_stride_range_y(0, ns * ks)) + { + const auto n = nk / ks; + const auto k = nk % ks; + const auto ps = src + (n * ks + k) * num; + const auto pgi = gradient_input + (n * ks + k) * num; + const float scale_pow = -0.5f * std::pow(scale[n], 3.0f); + float temp_gg = 0.0f; + float temp_ds = 0.0f; + for (auto i : grid_stride_range(0, num)) + { + const float x_hat = ps[i] * scale[n]; + const float dx = pgi[i] * gamma[i / num]; + temp_gg += pgi[i] * x_hat; + temp_ds += dx * ps[i] * scale_pow; + } + warp_reduce_atomic_add(gamma_grad[k], temp_gg); + warp_reduce_atomic_add(dscale[n], temp_ds); + } + __syncthreads(); + + const float invnum = 1.0f / (ks * num); + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = src + n * ks * num; + const auto pgi = gradient_input + n * ks * num; + const auto psg = src_grad + n * ks * num; + for (auto i : grid_stride_range(0, ks * num)) + { + const float dx = pgi[i] * gamma[i / num]; + psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum; + } + } + } + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ) + { + DLIB_CASSERT(src.num_samples() == scale.size()); + DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); + DLIB_CASSERT(gamma.k() == src.k()); + DLIB_CASSERT(gamma.nr() == 1); + DLIB_CASSERT(gamma.nc() == 1); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + gamma_grad = 0; + dscale.copy_size(scale); + dscale = 0; + + // Lancement du kernel CUDA + launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns), + src_grad.device(), gamma_grad.device(), dscale.device(), + src.device(), gradient_input.device(), scale.device(), gamma.device(), + ns, ks, num); + } + // ---------------------------------------------------------------------------------------- __global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index d157e1b65..059c6dd44 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -362,6 +362,26 @@ namespace dlib resizable_tensor& dvars ); + // ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ); + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ); + // ----------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index 7e11000ce..f4b684dec 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -696,6 +696,40 @@ namespace dlib { namespace tt #endif } +// ---------------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ) + { +#ifdef DLIB_USE_CUDA + cuda::rms_normalize(eps, dest, scale, src, gamma); +#else + cpu::rms_normalize(eps, dest, scale, src, gamma); +#endif + } + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ) + { +#ifdef DLIB_USE_CUDA + cuda::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); +#else + cpu::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); +#endif + } + // ---------------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 31310a961..245035c56 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -857,7 +857,58 @@ namespace dlib { namespace tt - Assigns the gradient of f() with respect to beta to #beta_grad. !*/ - // ----------------------------------------------------------------------------------- +// ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ); + /*! + requires + - eps > 0 + - gamma.k() == src.k() + - gamma.nr() == 1 + - gamma.nc() == 1 + ensures + - have_same_dimensions(#dest, src) == true + - #scale.size() == src.num_samples() + - #dest == the RMS normalized version of src + - #scale contains the RMS (Root Mean Square) values used to normalize each sample of src. + - Each element of #dest is computed as: + - #dest[n, k, i, j] == src[n, k, i, j] * gamma[k] / scale[n] + where n is the sample index, k is the channel index, and i, j are the spatial indices. + !*/ + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ); + /*! + requires + - scale.size() == src.num_samples() + - have_same_dimensions(gamma, gamma_grad) + - gamma.k() == src.k() + - gamma.nr() == 1 + - gamma.nc() == 1 + - have_same_dimensions(gradient_input, src) + - have_same_dimensions(gradient_input, src_grad) + ensures + - Let f(src, gamma) == dot(gradient_input, dest output of + rms_normalize(eps, dest, scale, src, gamma)) + - Adds the gradient of f() with respect to src to #src_grad + - Assigns the gradient of f() with respect to gamma to #gamma_grad + - #dscale contains the gradients of f() with respect to the RMS values. + !*/ + +// ----------------------------------------------------------------------------------- void threshold ( tensor& data, diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 7dd6b51e4..ef11c1b34 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -1504,6 +1504,131 @@ namespace dlib template using layer_norm = add_layer; +// ---------------------------------------------------------------------------------------- + + const double DEFAULT_RMS_NORM_EPS = 1e-5; + + class rms_norm_ + { + public: + explicit rms_norm_( + double eps_ = DEFAULT_RMS_NORM_EPS + ) : + learning_rate_multiplier(1), + weight_decay_multiplier(0), + bias_learning_rate_multiplier(1), + bias_weight_decay_multiplier(1), + eps(eps_) + { + } + + double get_eps() const { return eps; } + + double get_learning_rate_multiplier() const { return learning_rate_multiplier; } + double get_weight_decay_multiplier() const { return weight_decay_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } + + double get_bias_learning_rate_multiplier() const { return bias_learning_rate_multiplier; } + double get_bias_weight_decay_multiplier() const { return bias_weight_decay_multiplier; } + void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } + void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } + + inline dpoint map_input_to_output(const dpoint& p) const { return p; } + inline dpoint map_output_to_input(const dpoint& p) const { return p; } + + template + void setup(const SUBNET& sub) + { + gamma = alias_tensor(1, sub.get_output().k()); + params.set_size(gamma.size()); + gamma(params, 0) = 1; + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto g = gamma(params, 0); + tt::rms_normalize(eps, output, scale, sub.get_output(), g); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + auto g = gamma(params, 0); + auto g_grad = gamma(params_grad, 0); + tt::rms_normalize_gradient(gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale); + } + + const tensor& get_layer_params() const { return params; }; + tensor& get_layer_params() { return params; }; + + friend void serialize(const rms_norm_& item, std::ostream& out) + { + serialize("rms_norm_", out); + serialize(item.params, out); + serialize(item.gamma, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.weight_decay_multiplier, out); + serialize(item.bias_learning_rate_multiplier, out); + serialize(item.bias_weight_decay_multiplier, out); + serialize(item.eps, out); + } + + friend void deserialize(rms_norm_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "rms_norm_") + throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::rms_norm_."); + deserialize(item.params, in); + deserialize(item.gamma, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.weight_decay_multiplier, in); + deserialize(item.bias_learning_rate_multiplier, in); + deserialize(item.bias_weight_decay_multiplier, in); + deserialize(item.eps, in); + } + + friend std::ostream& operator<<(std::ostream& out, const rms_norm_& item) + { + out << "rms_norm"; + out << " (eps=" << item.eps << ")"; + out << " learning_rate_mult=" << item.learning_rate_multiplier; + out << " weight_decay_mult=" << item.weight_decay_multiplier; + out << " bias_learning_rate_mult=" << item.bias_learning_rate_multiplier; + out << " bias_weight_decay_mult=" << item.bias_weight_decay_multiplier; + return out; + } + + friend void to_xml(const rms_norm_& item, std::ostream& out) + { + out << "\n"; + out << mat(item.params); + out << "\n"; + } + + private: + resizable_tensor params; + alias_tensor gamma; + resizable_tensor scale; + resizable_tensor dscale; + double learning_rate_multiplier; + double weight_decay_multiplier; + double bias_learning_rate_multiplier; + double bias_weight_decay_multiplier; + double eps; + }; + + template + using rms_norm = add_layer; + // ---------------------------------------------------------------------------------------- enum layer_mode { diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 7a29ab134..8c09442e7 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -1468,6 +1468,7 @@ namespace dlib using dropout_rate = add_layer, SUBNET>; template using dropout_10 = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- class multiply_ @@ -1665,6 +1666,177 @@ namespace dlib !*/ }; +// ---------------------------------------------------------------------------------------- + + const float DEFAULT_RMS_NORM_EPS = 1e-5f; + + class rms_norm_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above, specifically defining a root mean square (RMS) normalization layer. + + RMS normalization is a technique that normalizes the input tensor based on the + root mean square (RMS) of its elements. Unlike traditional layer normalization, + which both centers and scales the data, RMS normalization only scales by the RMS + value. This makes it computationally more efficient, as it avoids the need to + compute the mean and subtract it from each element. + + This layer produces output tensors with the same dimensionality as the input tensors. + Specifically, for an input tensor with shape [num_samples, k, nr, nc], the RMS + normalization is applied across the [nr, nc] dimensions independently for each + element in the [k] dimension and for each sample in the [num_samples] dimension. + The scaling factor (RMS) and the learnable scaling parameter (gamma) are both of + size [k]. + + The key characteristics of this layer are: + - The RMS of the elements in each sample is standardized to 1. + - It does not center the data (i.e., it does not subtract the mean). + - A learnable scaling factor (gamma) is applied after normalization, allowing the + model to adapt the scaling dynamically. + + This layer is particularly effective in various natural language processing tasks, + where it has been shown to provide performance similar to or better than traditional + layer normalization, with reduced computational overhead. + !*/ + + public: + rms_norm_( + ); + /*! + ensures + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 0 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 1 + - #get_eps() == DEFAULT_RMS_NORM_EPS + !*/ + + explicit rms_norm_( + float eps_ = DEFAULT_RMS_NORM_EPS + ); + /*! + requires + - eps > 0 + ensures + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 0 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 1 + - #get_eps() == eps_ + !*/ + + float get_eps( + ) const; + /*! + ensures + - When doing RMS normalization, we are dividing by the root mean square. + This epsilon value returned by this function is added to the + mean square to prevent division by zero. + !*/ + + void set_eps( + float val + ); + /*! + requires + - val > 0 + ensures + - #get_eps() == val + !*/ + + double get_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its parameters be + multiplied by get_learning_rate_multiplier(). + !*/ + + double get_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its parameters be + multiplied by get_weight_decay_multiplier(). + !*/ + + void set_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_learning_rate_multiplier() == val + !*/ + + void set_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_weight_decay_multiplier() == val + !*/ + + double get_bias_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its bias parameters be + multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). + !*/ + + double get_bias_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its bias parameters be + multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). + !*/ + + void set_bias_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_learning_rate_multiplier() == val + !*/ + + void set_bias_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_weight_decay_multiplier() == val + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template + using rms_norm = add_layer; + // ---------------------------------------------------------------------------------------- enum layer_mode diff --git a/dlib/dnn/visitors.h b/dlib/dnn/visitors.h index c40bcbd33..a60106679 100644 --- a/dlib/dnn/visitors.h +++ b/dlib/dnn/visitors.h @@ -308,6 +308,14 @@ namespace dlib set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0); } + template + void disable_input_bias(add_layer& l) + { + disable_bias(l.subnet().layer_details()); + set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0); + set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0); + } + template void disable_input_bias(add_layer, U, E>& l) { @@ -333,6 +341,14 @@ namespace dlib set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); } + template class R, typename U, typename E> + void disable_input_bias(add_layer, E>& l) + { + disable_bias(l.subnet().get_repeated_layer(0).layer_details()); + set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); + set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); + } + // handle input repeat layer with tag case template void disable_input_bias(add_layer, add_tag_layer, E>& ) @@ -344,6 +360,11 @@ namespace dlib { } + template + void disable_input_bias(add_layer, E>& ) + { + } + // handle tag layer case template void disable_input_bias(add_layer, add_tag_layer, E>& ) @@ -355,6 +376,11 @@ namespace dlib { } + template + void disable_input_bias(add_layer, E>& ) + { + } + // handle skip layer case template class TAG, typename U, typename E> void disable_input_bias(add_layer, add_skip_layer, E>& ) @@ -366,6 +392,11 @@ namespace dlib { } + template