Add RMS Normalization Layer (#2999)

* Add RMS Normalization Layer

* Update dnn.cpp

* Missing entry in visitors.h to take into account the new rms_norm_ layer

* Fix test function name

* Fix dangling pointer issue in CUDA implementation of rms_normalize_gradient

* Fixing the dnn.cpp test program for the new rms_norm_ layer

* General update of the rms_norm_ class
This commit is contained in:
Cydral 2024-09-07 19:29:40 +02:00 committed by GitHub
parent 253098eb1b
commit fafdac37f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 863 additions and 4 deletions

View File

@ -1447,6 +1447,144 @@ namespace dlib
} }
} }
// -----------------------------------------------------------------------------------
void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
)
{
DLIB_CASSERT(
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\neps: " << eps
);
const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();
dest.copy_size(src);
scale.set_size(ns);
// Compute RMS values
scale = 0;
const float* p_src = src.host();
float* p_scale = scale.host();
for (long n = 0; n < ns; ++n)
{
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
p_scale[n] += (*p_src) * (*p_src);
++p_src;
}
}
p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast<float>(eps));
}
scale.host();
// Apply RMS normalization
p_src = src.host();
float* p_dest = dest.host();
const float* p_gamma = gamma.host();
for (long n = 0; n < ns; ++n)
{
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src) * p_scale[n] * p_gamma[k];
++p_src;
++p_dest;
}
}
}
}
void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
)
{
DLIB_CASSERT(src.num_samples() == scale.size());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();
gamma_grad = 0;
dscale.copy_size(scale);
dscale = 0;
auto p_grad = gradient_input.host();
auto p_src = src.host();
const auto p_gamma = gamma.host();
const auto p_gamma_grad = gamma_grad.host();
const auto p_scale = scale.host();
auto p_dscale = dscale.host();
for (long n = 0; n < ns; ++n)
{
const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f);
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
const float x_hat = *p_src * p_scale[n];
p_gamma_grad[k] += (*p_grad) * x_hat;
const float dx = *p_grad * p_gamma[k];
p_dscale[n] += dx * *p_src * scale_pow;
++p_grad;
++p_src;
}
}
}
p_grad = gradient_input.host();
p_src = src.host();
auto p_src_grad = src_grad.host();
const float invnum = 1.0f / (ks * num);
for (long n = 0; n < ns; ++n)
{
for (long k = 0; k < ks; ++k)
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[k];
*p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum;
++p_grad;
++p_src;
++p_src_grad;
}
}
}
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void threshold ( void threshold (

View File

@ -255,6 +255,26 @@ namespace dlib
resizable_tensor& dvars resizable_tensor& dvars
); );
// -----------------------------------------------------------------------------------
void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
);
void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void threshold ( void threshold (

View File

@ -2280,6 +2280,166 @@ namespace dlib
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num); dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
} }
// ----------------------------------------------------------------------------------------
__global__ void _cuda_rms_normalize(
float* dest,
float* scale,
const float* src,
const float* gamma,
float eps,
size_t ns,
size_t ks,
size_t num
)
{
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = src + n * ks * num;
float sum_squares = 0.0f;
for (auto i : grid_stride_range(0, ks * num))
{
sum_squares += ps[i] * ps[i];
}
warp_reduce_atomic_add(scale[n], sum_squares / (ks * num));
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, 1))
{
scale[n] = 1.0f / std::sqrt(scale[n] + eps);
}
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = src + n * ks * num;
const auto pd = dest + n * ks * num;
for (auto i : grid_stride_range(0, ks * num))
{
pd[i] = ps[i] * scale[n] * gamma[i / num];
}
}
}
void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
)
{
DLIB_CASSERT(
gamma.k() == src.k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\neps: " << eps
);
const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();
dest.copy_size(src);
scale.set_size(ns);
scale = 0;
launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns),
dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_rms_normalize_gradient(
float* src_grad,
float* gamma_grad,
float* dscale,
const float* src,
const float* gradient_input,
const float* scale,
const float* gamma,
size_t ns,
size_t ks,
size_t num
)
{
for (auto nk : grid_stride_range_y(0, ns * ks))
{
const auto n = nk / ks;
const auto k = nk % ks;
const auto ps = src + (n * ks + k) * num;
const auto pgi = gradient_input + (n * ks + k) * num;
const float scale_pow = -0.5f * std::pow(scale[n], 3.0f);
float temp_gg = 0.0f;
float temp_ds = 0.0f;
for (auto i : grid_stride_range(0, num))
{
const float x_hat = ps[i] * scale[n];
const float dx = pgi[i] * gamma[i / num];
temp_gg += pgi[i] * x_hat;
temp_ds += dx * ps[i] * scale_pow;
}
warp_reduce_atomic_add(gamma_grad[k], temp_gg);
warp_reduce_atomic_add(dscale[n], temp_ds);
}
__syncthreads();
const float invnum = 1.0f / (ks * num);
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = src + n * ks * num;
const auto pgi = gradient_input + n * ks * num;
const auto psg = src_grad + n * ks * num;
for (auto i : grid_stride_range(0, ks * num))
{
const float dx = pgi[i] * gamma[i / num];
psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum;
}
}
}
void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
)
{
DLIB_CASSERT(src.num_samples() == scale.size());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
const long ns = src.num_samples();
const long ks = src.k();
const long num = src.nr() * src.nc();
gamma_grad = 0;
dscale.copy_size(scale);
dscale = 0;
// Lancement du kernel CUDA
launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns),
src_grad.device(), gamma_grad.device(), dscale.device(),
src.device(), gradient_input.device(), scale.device(), gamma.device(),
ns, ks, num);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) __global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)

View File

@ -362,6 +362,26 @@ namespace dlib
resizable_tensor& dvars resizable_tensor& dvars
); );
// -----------------------------------------------------------------------------------
void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
);
void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void threshold ( void threshold (

View File

@ -696,6 +696,40 @@ namespace dlib { namespace tt
#endif #endif
} }
// ----------------------------------------------------------------------------------------
void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
)
{
#ifdef DLIB_USE_CUDA
cuda::rms_normalize(eps, dest, scale, src, gamma);
#else
cpu::rms_normalize(eps, dest, scale, src, gamma);
#endif
}
void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
)
{
#ifdef DLIB_USE_CUDA
cuda::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale);
#else
cpu::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale);
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void threshold ( void threshold (

View File

@ -857,7 +857,58 @@ namespace dlib { namespace tt
- Assigns the gradient of f() with respect to beta to #beta_grad. - Assigns the gradient of f() with respect to beta to #beta_grad.
!*/ !*/
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void rms_normalize(
const double eps,
resizable_tensor& dest,
resizable_tensor& scale,
const tensor& src,
const tensor& gamma
);
/*!
requires
- eps > 0
- gamma.k() == src.k()
- gamma.nr() == 1
- gamma.nc() == 1
ensures
- have_same_dimensions(#dest, src) == true
- #scale.size() == src.num_samples()
- #dest == the RMS normalized version of src
- #scale contains the RMS (Root Mean Square) values used to normalize each sample of src.
- Each element of #dest is computed as:
- #dest[n, k, i, j] == src[n, k, i, j] * gamma[k] / scale[n]
where n is the sample index, k is the channel index, and i, j are the spatial indices.
!*/
void rms_normalize_gradient(
const tensor& gradient_input,
const tensor& scale,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
resizable_tensor& dscale
);
/*!
requires
- scale.size() == src.num_samples()
- have_same_dimensions(gamma, gamma_grad)
- gamma.k() == src.k()
- gamma.nr() == 1
- gamma.nc() == 1
- have_same_dimensions(gradient_input, src)
- have_same_dimensions(gradient_input, src_grad)
ensures
- Let f(src, gamma) == dot(gradient_input, dest output of
rms_normalize(eps, dest, scale, src, gamma))
- Adds the gradient of f() with respect to src to #src_grad
- Assigns the gradient of f() with respect to gamma to #gamma_grad
- #dscale contains the gradients of f() with respect to the RMS values.
!*/
// -----------------------------------------------------------------------------------
void threshold ( void threshold (
tensor& data, tensor& data,

View File

@ -1504,6 +1504,131 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using layer_norm = add_layer<layer_norm_, SUBNET>; using layer_norm = add_layer<layer_norm_, SUBNET>;
// ----------------------------------------------------------------------------------------
const double DEFAULT_RMS_NORM_EPS = 1e-5;
class rms_norm_
{
public:
explicit rms_norm_(
double eps_ = DEFAULT_RMS_NORM_EPS
) :
learning_rate_multiplier(1),
weight_decay_multiplier(0),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(1),
eps(eps_)
{
}
double get_eps() const { return eps; }
double get_learning_rate_multiplier() const { return learning_rate_multiplier; }
double get_weight_decay_multiplier() const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
double get_bias_learning_rate_multiplier() const { return bias_learning_rate_multiplier; }
double get_bias_weight_decay_multiplier() const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
inline dpoint map_input_to_output(const dpoint& p) const { return p; }
inline dpoint map_output_to_input(const dpoint& p) const { return p; }
template <typename SUBNET>
void setup(const SUBNET& sub)
{
gamma = alias_tensor(1, sub.get_output().k());
params.set_size(gamma.size());
gamma(params, 0) = 1;
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto g = gamma(params, 0);
tt::rms_normalize(eps, output, scale, sub.get_output(), g);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{
auto g = gamma(params, 0);
auto g_grad = gamma(params_grad, 0);
tt::rms_normalize_gradient(gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale);
}
const tensor& get_layer_params() const { return params; };
tensor& get_layer_params() { return params; };
friend void serialize(const rms_norm_& item, std::ostream& out)
{
serialize("rms_norm_", out);
serialize(item.params, out);
serialize(item.gamma, out);
serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.eps, out);
}
friend void deserialize(rms_norm_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "rms_norm_")
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::rms_norm_.");
deserialize(item.params, in);
deserialize(item.gamma, in);
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
deserialize(item.eps, in);
}
friend std::ostream& operator<<(std::ostream& out, const rms_norm_& item)
{
out << "rms_norm";
out << " (eps=" << item.eps << ")";
out << " learning_rate_mult=" << item.learning_rate_multiplier;
out << " weight_decay_mult=" << item.weight_decay_multiplier;
out << " bias_learning_rate_mult=" << item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult=" << item.bias_weight_decay_multiplier;
return out;
}
friend void to_xml(const rms_norm_& item, std::ostream& out)
{
out << "<rms_norm";
out << " eps='" << item.eps << "'";
out << " learning_rate_mult='" << item.learning_rate_multiplier << "'";
out << " weight_decay_mult='" << item.weight_decay_multiplier << "'";
out << " bias_learning_rate_mult='" << item.bias_learning_rate_multiplier << "'";
out << " bias_weight_decay_mult='" << item.bias_weight_decay_multiplier << "'";
out << ">\n";
out << mat(item.params);
out << "</rms_norm>\n";
}
private:
resizable_tensor params;
alias_tensor gamma;
resizable_tensor scale;
resizable_tensor dscale;
double learning_rate_multiplier;
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
double eps;
};
template <typename SUBNET>
using rms_norm = add_layer<rms_norm_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum layer_mode enum layer_mode
{ {

View File

@ -1468,6 +1468,7 @@ namespace dlib
using dropout_rate = add_layer<dropout_rate_<DROP_RATE>, SUBNET>; using dropout_rate = add_layer<dropout_rate_<DROP_RATE>, SUBNET>;
template <typename SUBNET> template <typename SUBNET>
using dropout_10 = add_layer<dropout_rate_<10>, SUBNET>; using dropout_10 = add_layer<dropout_rate_<10>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class multiply_ class multiply_
@ -1665,6 +1666,177 @@ namespace dlib
!*/ !*/
}; };
// ----------------------------------------------------------------------------------------
const float DEFAULT_RMS_NORM_EPS = 1e-5f;
class rms_norm_
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above, specifically defining a root mean square (RMS) normalization layer.
RMS normalization is a technique that normalizes the input tensor based on the
root mean square (RMS) of its elements. Unlike traditional layer normalization,
which both centers and scales the data, RMS normalization only scales by the RMS
value. This makes it computationally more efficient, as it avoids the need to
compute the mean and subtract it from each element.
This layer produces output tensors with the same dimensionality as the input tensors.
Specifically, for an input tensor with shape [num_samples, k, nr, nc], the RMS
normalization is applied across the [nr, nc] dimensions independently for each
element in the [k] dimension and for each sample in the [num_samples] dimension.
The scaling factor (RMS) and the learnable scaling parameter (gamma) are both of
size [k].
The key characteristics of this layer are:
- The RMS of the elements in each sample is standardized to 1.
- It does not center the data (i.e., it does not subtract the mean).
- A learnable scaling factor (gamma) is applied after normalization, allowing the
model to adapt the scaling dynamically.
This layer is particularly effective in various natural language processing tasks,
where it has been shown to provide performance similar to or better than traditional
layer normalization, with reduced computational overhead.
!*/
public:
rms_norm_(
);
/*!
ensures
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == DEFAULT_RMS_NORM_EPS
!*/
explicit rms_norm_(
float eps_ = DEFAULT_RMS_NORM_EPS
);
/*!
requires
- eps > 0
ensures
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == eps_
!*/
float get_eps(
) const;
/*!
ensures
- When doing RMS normalization, we are dividing by the root mean square.
This epsilon value returned by this function is added to the
mean square to prevent division by zero.
!*/
void set_eps(
float val
);
/*!
requires
- val > 0
ensures
- #get_eps() == val
!*/
double get_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its parameters be
multiplied by get_learning_rate_multiplier().
!*/
double get_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its parameters be
multiplied by get_weight_decay_multiplier().
!*/
void set_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_learning_rate_multiplier() == val
!*/
void set_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_weight_decay_multiplier() == val
!*/
double get_bias_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its bias parameters be
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
!*/
double get_bias_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its bias parameters be
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
!*/
void set_bias_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_learning_rate_multiplier() == val
!*/
void set_bias_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
template <typename SUBNET>
using rms_norm = add_layer<rms_norm_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum layer_mode enum layer_mode

View File

@ -308,6 +308,14 @@ namespace dlib
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0); set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
} }
template <typename U, typename E>
void disable_input_bias(add_layer<rms_norm_, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template <layer_mode mode, typename U, typename E> template <layer_mode mode, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, U, E>& l) void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
{ {
@ -333,6 +341,14 @@ namespace dlib
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
} }
template <size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<rms_norm_, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
// handle input repeat layer with tag case // handle input repeat layer with tag case
template <layer_mode mode, unsigned long ID, typename E> template <layer_mode mode, unsigned long ID, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer>, E>& ) void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
@ -344,6 +360,11 @@ namespace dlib
{ {
} }
template <unsigned long ID, typename E>
void disable_input_bias(add_layer<rms_norm_, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
// handle tag layer case // handle tag layer case
template <layer_mode mode, unsigned long ID, typename U, typename E> template <layer_mode mode, unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, U>, E>& ) void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, U>, E>& )
@ -355,6 +376,11 @@ namespace dlib
{ {
} }
template <unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<rms_norm_, add_tag_layer<ID, U>, E>& )
{
}
// handle skip layer case // handle skip layer case
template <layer_mode mode, template <typename> class TAG, typename U, typename E> template <layer_mode mode, template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_skip_layer<TAG, U>, E>& ) void disable_input_bias(add_layer<bn_<mode>, add_skip_layer<TAG, U>, E>& )
@ -366,6 +392,11 @@ namespace dlib
{ {
} }
template <template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<rms_norm_, add_skip_layer<TAG, U>, E>& )
{
}
template<typename input_layer_type> template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const void operator()(size_t , input_layer_type& ) const
{ {
@ -741,6 +772,14 @@ namespace dlib
update(i); update(i);
} }
template <typename U, typename E>
void operator()(size_t i, const add_layer<rms_norm_, U, E>&)
{
start_node(i, "rms_norm");
end_node();
update(i);
}
template <layer_mode MODE, typename U, typename E> template <layer_mode MODE, typename U, typename E>
void operator()(size_t i, const add_layer<bn_<MODE>, U, E>&) void operator()(size_t i, const add_layer<bn_<MODE>, U, E>&)
{ {

View File

@ -38,9 +38,9 @@ namespace dlib
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer. add_tag_layer.
ensures ensures
- Disables bias for all bn_ and layer_norm_ inputs. - Disables bias for all bn_, layer_norm_ and rms_norms_ inputs.
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier() - Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
to zero of all bn_ and layer_norm_ inputs. to zero of all bn_, layer_norm_ and rms_norm_ inputs.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------

View File

@ -631,7 +631,7 @@ namespace
DLIB_TEST(::std::abs(rs.stddev() - 1.0f) < 0.01); DLIB_TEST(::std::abs(rs.stddev() - 1.0f) < 0.01);
} }
// check that the CPU and the CUDA implementation are equivalent // check that the CPU and the CUDA implementation are equivalent
#if DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
resizable_tensor y_cuda(x); resizable_tensor y_cuda(x);
resizable_tensor means_cuda(x.num_samples()), invstds_cuda(x.num_samples()); resizable_tensor means_cuda(x.num_samples()), invstds_cuda(x.num_samples());
cuda::layer_normalize(eps, y_cuda, means_cuda, invstds_cuda, x, gamma, beta); cuda::layer_normalize(eps, y_cuda, means_cuda, invstds_cuda, x, gamma, beta);
@ -655,6 +655,99 @@ namespace
#endif #endif
} }
// ----------------------------------------------------------------------------------------
void test_rms_normalize()
{
resizable_tensor x(2, 3, 4, 5);
resizable_tensor y_cpu(x);
tt::tensor_rand rnd(0);
rnd.fill_uniform(x);
resizable_tensor scale_cpu;
resizable_tensor gamma(1, x.k());
gamma = 1;
const float eps = 1e-5;
cpu::rms_normalize(eps, y_cpu, scale_cpu, x, gamma);
// check that the output is correctly normalized
const float* p_x = x.host();
const float* p_y = y_cpu.host();
const float* p_scale = scale_cpu.host();
bool error_found = false;
for (long n = 0; n < x.num_samples(); ++n)
{
for (long k = 0; k < x.k(); ++k)
{
for (long r = 0; r < x.nr(); ++r)
{
for (long c = 0; c < x.nc(); ++c)
{
float x_val = p_x[tensor_index(x, n, k, r, c)];
float y_val = p_y[tensor_index(y_cpu, n, k, r, c)];
float rms_val = p_scale[n];
if (std::abs(y_val - x_val * rms_val) >= 1e-5) error_found = true;
}
}
}
}
DLIB_TEST(!error_found);
// check the backward pass
resizable_tensor gradient_input(x);
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k());
resizable_tensor dscale_cpu(x.num_samples());
rnd.fill_gaussian(gradient_input);
src_grad_cpu = 0;
cpu::rms_normalize_gradient(gradient_input, scale_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, dscale_cpu);
const float* p_gradient_input = gradient_input.host();
const float* p_src = x.host();
const float* p_src_grad_cpu = src_grad_cpu.host();
const float* p_gamma = gamma.host();
const float* p_scale_cpu = scale_cpu.host();
const float* p_dscale_cpu = dscale_cpu.host();
bool backward_error_found = false;
for (long n = 0; n < x.num_samples(); ++n)
{
const float scale_pow = -0.5 * std::pow(p_scale_cpu[n], 3.0f);
for (long k = 0; k < x.k(); ++k)
{
for (long r = 0; r < x.nr(); ++r)
{
for (long c = 0; c < x.nc(); ++c)
{
float gradient_input_val = p_gradient_input[tensor_index(gradient_input, n, k, r, c)];
float src_val = p_src[tensor_index(x, n, k, r, c)];
float rms_val = p_scale_cpu[n];
float expected_src_grad = gradient_input_val * p_gamma[k] * rms_val + p_dscale_cpu[n] * 2 * src_val * 1.0f / (x.k() * x.nr() * x.nc());
float src_grad_val = p_src_grad_cpu[tensor_index(src_grad_cpu, n, k, r, c)];
if (std::abs(src_grad_val - expected_src_grad) >= 1e-4)
backward_error_found = true;
}
}
}
}
DLIB_TEST(!backward_error_found);
// check that the CPU and the CUDA implementation are equivalent
#ifdef DLIB_USE_CUDA
resizable_tensor y_cuda(x);
resizable_tensor scale_cuda;
cuda::rms_normalize(eps, y_cuda, scale_cuda, x, gamma);
DLIB_TEST(max(abs(mat(y_cpu) - mat(y_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(scale_cpu) - mat(scale_cuda))) < 1e-5);
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k());
resizable_tensor dscale_cuda(x.num_samples());
src_grad_cuda = 0;
cuda::rms_normalize_gradient(gradient_input, scale_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, dscale_cuda);
DLIB_TEST(max(abs(mat(src_grad_cpu) - mat(src_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(gamma_grad_cpu) - mat(gamma_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(dscale_cpu) - mat(dscale_cuda))) < 1e-5);
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void test_basic_tensor_ops() void test_basic_tensor_ops()
@ -2013,6 +2106,12 @@ namespace
auto res = test_layer(l); auto res = test_layer(l);
DLIB_TEST_MSG(res, res); DLIB_TEST_MSG(res, res);
} }
{
print_spinner();
rms_norm_ l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{ {
print_spinner(); print_spinner();
cont_<3,3,3,2,2,0,0> l; cont_<3,3,3,2,2,0,0> l;
@ -4389,6 +4488,7 @@ namespace
test_batch_normalize(); test_batch_normalize();
test_batch_normalize_conv(); test_batch_normalize_conv();
test_layer_normalize(); test_layer_normalize();
test_rms_normalize();
test_basic_tensor_ops(); test_basic_tensor_ops();
test_layers(); test_layers();
test_visit_functions(); test_visit_functions();