mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add RMS Normalization Layer (#2999)
* Add RMS Normalization Layer * Update dnn.cpp * Missing entry in visitors.h to take into account the new rms_norm_ layer * Fix test function name * Fix dangling pointer issue in CUDA implementation of rms_normalize_gradient * Fixing the dnn.cpp test program for the new rms_norm_ layer * General update of the rms_norm_ class
This commit is contained in:
parent
253098eb1b
commit
fafdac37f1
@ -1447,6 +1447,144 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void rms_normalize(
|
||||
const double eps,
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(
|
||||
gamma.k() == src.k() &&
|
||||
gamma.nr() == 1 &&
|
||||
gamma.nc() == 1 &&
|
||||
eps > 0,
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\ngamma.k(): " << gamma.k() <<
|
||||
"\ngamma.nr(): " << gamma.nr() <<
|
||||
"\ngamma.nc(): " << gamma.nc() <<
|
||||
"\neps: " << eps
|
||||
);
|
||||
|
||||
const long ns = src.num_samples();
|
||||
const long ks = src.k();
|
||||
const long num = src.nr() * src.nc();
|
||||
|
||||
dest.copy_size(src);
|
||||
scale.set_size(ns);
|
||||
|
||||
// Compute RMS values
|
||||
scale = 0;
|
||||
const float* p_src = src.host();
|
||||
float* p_scale = scale.host();
|
||||
for (long n = 0; n < ns; ++n)
|
||||
{
|
||||
for (long k = 0; k < ks; ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
p_scale[n] += (*p_src) * (*p_src);
|
||||
++p_src;
|
||||
}
|
||||
}
|
||||
p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast<float>(eps));
|
||||
}
|
||||
scale.host();
|
||||
|
||||
// Apply RMS normalization
|
||||
p_src = src.host();
|
||||
float* p_dest = dest.host();
|
||||
const float* p_gamma = gamma.host();
|
||||
for (long n = 0; n < ns; ++n)
|
||||
{
|
||||
for (long k = 0; k < ks; ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
*p_dest = (*p_src) * p_scale[n] * p_gamma[k];
|
||||
++p_src;
|
||||
++p_dest;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rms_normalize_gradient(
|
||||
const tensor& gradient_input,
|
||||
const tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
resizable_tensor& dscale
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(src.num_samples() == scale.size());
|
||||
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
|
||||
DLIB_CASSERT(gamma.k() == src.k());
|
||||
DLIB_CASSERT(gamma.nr() == 1);
|
||||
DLIB_CASSERT(gamma.nc() == 1);
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
|
||||
|
||||
const long ns = src.num_samples();
|
||||
const long ks = src.k();
|
||||
const long num = src.nr() * src.nc();
|
||||
|
||||
gamma_grad = 0;
|
||||
dscale.copy_size(scale);
|
||||
dscale = 0;
|
||||
|
||||
auto p_grad = gradient_input.host();
|
||||
auto p_src = src.host();
|
||||
const auto p_gamma = gamma.host();
|
||||
const auto p_gamma_grad = gamma_grad.host();
|
||||
const auto p_scale = scale.host();
|
||||
auto p_dscale = dscale.host();
|
||||
|
||||
for (long n = 0; n < ns; ++n)
|
||||
{
|
||||
const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f);
|
||||
for (long k = 0; k < ks; ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float x_hat = *p_src * p_scale[n];
|
||||
p_gamma_grad[k] += (*p_grad) * x_hat;
|
||||
|
||||
const float dx = *p_grad * p_gamma[k];
|
||||
p_dscale[n] += dx * *p_src * scale_pow;
|
||||
|
||||
++p_grad;
|
||||
++p_src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p_grad = gradient_input.host();
|
||||
p_src = src.host();
|
||||
auto p_src_grad = src_grad.host();
|
||||
const float invnum = 1.0f / (ks * num);
|
||||
for (long n = 0; n < ns; ++n)
|
||||
{
|
||||
for (long k = 0; k < ks; ++k)
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float dx = *p_grad * p_gamma[k];
|
||||
*p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum;
|
||||
|
||||
++p_grad;
|
||||
++p_src;
|
||||
++p_src_grad;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void threshold (
|
||||
|
@ -255,6 +255,26 @@ namespace dlib
|
||||
resizable_tensor& dvars
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void rms_normalize(
|
||||
const double eps,
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma
|
||||
);
|
||||
|
||||
void rms_normalize_gradient(
|
||||
const tensor& gradient_input,
|
||||
const tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
resizable_tensor& dscale
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void threshold (
|
||||
|
@ -2280,6 +2280,166 @@ namespace dlib
|
||||
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_rms_normalize(
|
||||
float* dest,
|
||||
float* scale,
|
||||
const float* src,
|
||||
const float* gamma,
|
||||
float eps,
|
||||
size_t ns,
|
||||
size_t ks,
|
||||
size_t num
|
||||
)
|
||||
{
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
const auto ps = src + n * ks * num;
|
||||
float sum_squares = 0.0f;
|
||||
for (auto i : grid_stride_range(0, ks * num))
|
||||
{
|
||||
sum_squares += ps[i] * ps[i];
|
||||
}
|
||||
warp_reduce_atomic_add(scale[n], sum_squares / (ks * num));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
for (auto i : grid_stride_range(0, 1))
|
||||
{
|
||||
scale[n] = 1.0f / std::sqrt(scale[n] + eps);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
const auto ps = src + n * ks * num;
|
||||
const auto pd = dest + n * ks * num;
|
||||
for (auto i : grid_stride_range(0, ks * num))
|
||||
{
|
||||
pd[i] = ps[i] * scale[n] * gamma[i / num];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rms_normalize(
|
||||
const double eps,
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(
|
||||
gamma.k() == src.k() &&
|
||||
gamma.nr() == 1 &&
|
||||
gamma.nc() == 1 &&
|
||||
eps > 0,
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\ngamma.k(): " << gamma.k() <<
|
||||
"\ngamma.nr(): " << gamma.nr() <<
|
||||
"\ngamma.nc(): " << gamma.nc() <<
|
||||
"\neps: " << eps
|
||||
);
|
||||
|
||||
const long ns = src.num_samples();
|
||||
const long ks = src.k();
|
||||
const long num = src.nr() * src.nc();
|
||||
|
||||
dest.copy_size(src);
|
||||
scale.set_size(ns);
|
||||
scale = 0;
|
||||
|
||||
launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns),
|
||||
dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_rms_normalize_gradient(
|
||||
float* src_grad,
|
||||
float* gamma_grad,
|
||||
float* dscale,
|
||||
const float* src,
|
||||
const float* gradient_input,
|
||||
const float* scale,
|
||||
const float* gamma,
|
||||
size_t ns,
|
||||
size_t ks,
|
||||
size_t num
|
||||
)
|
||||
{
|
||||
for (auto nk : grid_stride_range_y(0, ns * ks))
|
||||
{
|
||||
const auto n = nk / ks;
|
||||
const auto k = nk % ks;
|
||||
const auto ps = src + (n * ks + k) * num;
|
||||
const auto pgi = gradient_input + (n * ks + k) * num;
|
||||
const float scale_pow = -0.5f * std::pow(scale[n], 3.0f);
|
||||
float temp_gg = 0.0f;
|
||||
float temp_ds = 0.0f;
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
const float x_hat = ps[i] * scale[n];
|
||||
const float dx = pgi[i] * gamma[i / num];
|
||||
temp_gg += pgi[i] * x_hat;
|
||||
temp_ds += dx * ps[i] * scale_pow;
|
||||
}
|
||||
warp_reduce_atomic_add(gamma_grad[k], temp_gg);
|
||||
warp_reduce_atomic_add(dscale[n], temp_ds);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float invnum = 1.0f / (ks * num);
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
const auto ps = src + n * ks * num;
|
||||
const auto pgi = gradient_input + n * ks * num;
|
||||
const auto psg = src_grad + n * ks * num;
|
||||
for (auto i : grid_stride_range(0, ks * num))
|
||||
{
|
||||
const float dx = pgi[i] * gamma[i / num];
|
||||
psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rms_normalize_gradient(
|
||||
const tensor& gradient_input,
|
||||
const tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
resizable_tensor& dscale
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(src.num_samples() == scale.size());
|
||||
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
|
||||
DLIB_CASSERT(gamma.k() == src.k());
|
||||
DLIB_CASSERT(gamma.nr() == 1);
|
||||
DLIB_CASSERT(gamma.nc() == 1);
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
|
||||
|
||||
const long ns = src.num_samples();
|
||||
const long ks = src.k();
|
||||
const long num = src.nr() * src.nc();
|
||||
|
||||
gamma_grad = 0;
|
||||
dscale.copy_size(scale);
|
||||
dscale = 0;
|
||||
|
||||
// Lancement du kernel CUDA
|
||||
launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns),
|
||||
src_grad.device(), gamma_grad.device(), dscale.device(),
|
||||
src.device(), gradient_input.device(), scale.device(), gamma.device(),
|
||||
ns, ks, num);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)
|
||||
|
@ -362,6 +362,26 @@ namespace dlib
|
||||
resizable_tensor& dvars
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void rms_normalize(
|
||||
const double eps,
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma
|
||||
);
|
||||
|
||||
void rms_normalize_gradient(
|
||||
const tensor& gradient_input,
|
||||
const tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
resizable_tensor& dscale
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void threshold (
|
||||
|
@ -696,6 +696,40 @@ namespace dlib { namespace tt
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void rms_normalize(
|
||||
const double eps,
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::rms_normalize(eps, dest, scale, src, gamma);
|
||||
#else
|
||||
cpu::rms_normalize(eps, dest, scale, src, gamma);
|
||||
#endif
|
||||
}
|
||||
|
||||
void rms_normalize_gradient(
|
||||
const tensor& gradient_input,
|
||||
const tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
resizable_tensor& dscale
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale);
|
||||
#else
|
||||
cpu::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void threshold (
|
||||
|
@ -857,7 +857,58 @@ namespace dlib { namespace tt
|
||||
- Assigns the gradient of f() with respect to beta to #beta_grad.
|
||||
!*/
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void rms_normalize(
|
||||
const double eps,
|
||||
resizable_tensor& dest,
|
||||
resizable_tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- eps > 0
|
||||
- gamma.k() == src.k()
|
||||
- gamma.nr() == 1
|
||||
- gamma.nc() == 1
|
||||
ensures
|
||||
- have_same_dimensions(#dest, src) == true
|
||||
- #scale.size() == src.num_samples()
|
||||
- #dest == the RMS normalized version of src
|
||||
- #scale contains the RMS (Root Mean Square) values used to normalize each sample of src.
|
||||
- Each element of #dest is computed as:
|
||||
- #dest[n, k, i, j] == src[n, k, i, j] * gamma[k] / scale[n]
|
||||
where n is the sample index, k is the channel index, and i, j are the spatial indices.
|
||||
!*/
|
||||
|
||||
void rms_normalize_gradient(
|
||||
const tensor& gradient_input,
|
||||
const tensor& scale,
|
||||
const tensor& src,
|
||||
const tensor& gamma,
|
||||
tensor& src_grad,
|
||||
tensor& gamma_grad,
|
||||
resizable_tensor& dscale
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- scale.size() == src.num_samples()
|
||||
- have_same_dimensions(gamma, gamma_grad)
|
||||
- gamma.k() == src.k()
|
||||
- gamma.nr() == 1
|
||||
- gamma.nc() == 1
|
||||
- have_same_dimensions(gradient_input, src)
|
||||
- have_same_dimensions(gradient_input, src_grad)
|
||||
ensures
|
||||
- Let f(src, gamma) == dot(gradient_input, dest output of
|
||||
rms_normalize(eps, dest, scale, src, gamma))
|
||||
- Adds the gradient of f() with respect to src to #src_grad
|
||||
- Assigns the gradient of f() with respect to gamma to #gamma_grad
|
||||
- #dscale contains the gradients of f() with respect to the RMS values.
|
||||
!*/
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void threshold (
|
||||
tensor& data,
|
||||
|
@ -1504,6 +1504,131 @@ namespace dlib
|
||||
template <typename SUBNET>
|
||||
using layer_norm = add_layer<layer_norm_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
const double DEFAULT_RMS_NORM_EPS = 1e-5;
|
||||
|
||||
class rms_norm_
|
||||
{
|
||||
public:
|
||||
explicit rms_norm_(
|
||||
double eps_ = DEFAULT_RMS_NORM_EPS
|
||||
) :
|
||||
learning_rate_multiplier(1),
|
||||
weight_decay_multiplier(0),
|
||||
bias_learning_rate_multiplier(1),
|
||||
bias_weight_decay_multiplier(1),
|
||||
eps(eps_)
|
||||
{
|
||||
}
|
||||
|
||||
double get_eps() const { return eps; }
|
||||
|
||||
double get_learning_rate_multiplier() const { return learning_rate_multiplier; }
|
||||
double get_weight_decay_multiplier() const { return weight_decay_multiplier; }
|
||||
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
|
||||
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
|
||||
|
||||
double get_bias_learning_rate_multiplier() const { return bias_learning_rate_multiplier; }
|
||||
double get_bias_weight_decay_multiplier() const { return bias_weight_decay_multiplier; }
|
||||
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
|
||||
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
|
||||
|
||||
inline dpoint map_input_to_output(const dpoint& p) const { return p; }
|
||||
inline dpoint map_output_to_input(const dpoint& p) const { return p; }
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup(const SUBNET& sub)
|
||||
{
|
||||
gamma = alias_tensor(1, sub.get_output().k());
|
||||
params.set_size(gamma.size());
|
||||
gamma(params, 0) = 1;
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||
{
|
||||
auto g = gamma(params, 0);
|
||||
tt::rms_normalize(eps, output, scale, sub.get_output(), g);
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
|
||||
{
|
||||
auto g = gamma(params, 0);
|
||||
auto g_grad = gamma(params_grad, 0);
|
||||
tt::rms_normalize_gradient(gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale);
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; };
|
||||
tensor& get_layer_params() { return params; };
|
||||
|
||||
friend void serialize(const rms_norm_& item, std::ostream& out)
|
||||
{
|
||||
serialize("rms_norm_", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item.gamma, out);
|
||||
serialize(item.learning_rate_multiplier, out);
|
||||
serialize(item.weight_decay_multiplier, out);
|
||||
serialize(item.bias_learning_rate_multiplier, out);
|
||||
serialize(item.bias_weight_decay_multiplier, out);
|
||||
serialize(item.eps, out);
|
||||
}
|
||||
|
||||
friend void deserialize(rms_norm_& item, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "rms_norm_")
|
||||
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::rms_norm_.");
|
||||
deserialize(item.params, in);
|
||||
deserialize(item.gamma, in);
|
||||
deserialize(item.learning_rate_multiplier, in);
|
||||
deserialize(item.weight_decay_multiplier, in);
|
||||
deserialize(item.bias_learning_rate_multiplier, in);
|
||||
deserialize(item.bias_weight_decay_multiplier, in);
|
||||
deserialize(item.eps, in);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const rms_norm_& item)
|
||||
{
|
||||
out << "rms_norm";
|
||||
out << " (eps=" << item.eps << ")";
|
||||
out << " learning_rate_mult=" << item.learning_rate_multiplier;
|
||||
out << " weight_decay_mult=" << item.weight_decay_multiplier;
|
||||
out << " bias_learning_rate_mult=" << item.bias_learning_rate_multiplier;
|
||||
out << " bias_weight_decay_mult=" << item.bias_weight_decay_multiplier;
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const rms_norm_& item, std::ostream& out)
|
||||
{
|
||||
out << "<rms_norm";
|
||||
out << " eps='" << item.eps << "'";
|
||||
out << " learning_rate_mult='" << item.learning_rate_multiplier << "'";
|
||||
out << " weight_decay_mult='" << item.weight_decay_multiplier << "'";
|
||||
out << " bias_learning_rate_mult='" << item.bias_learning_rate_multiplier << "'";
|
||||
out << " bias_weight_decay_mult='" << item.bias_weight_decay_multiplier << "'";
|
||||
out << ">\n";
|
||||
out << mat(item.params);
|
||||
out << "</rms_norm>\n";
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params;
|
||||
alias_tensor gamma;
|
||||
resizable_tensor scale;
|
||||
resizable_tensor dscale;
|
||||
double learning_rate_multiplier;
|
||||
double weight_decay_multiplier;
|
||||
double bias_learning_rate_multiplier;
|
||||
double bias_weight_decay_multiplier;
|
||||
double eps;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using rms_norm = add_layer<rms_norm_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
enum layer_mode
|
||||
{
|
||||
|
@ -1468,6 +1468,7 @@ namespace dlib
|
||||
using dropout_rate = add_layer<dropout_rate_<DROP_RATE>, SUBNET>;
|
||||
template <typename SUBNET>
|
||||
using dropout_10 = add_layer<dropout_rate_<10>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class multiply_
|
||||
@ -1665,6 +1666,177 @@ namespace dlib
|
||||
!*/
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
const float DEFAULT_RMS_NORM_EPS = 1e-5f;
|
||||
|
||||
class rms_norm_
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface
|
||||
defined above, specifically defining a root mean square (RMS) normalization layer.
|
||||
|
||||
RMS normalization is a technique that normalizes the input tensor based on the
|
||||
root mean square (RMS) of its elements. Unlike traditional layer normalization,
|
||||
which both centers and scales the data, RMS normalization only scales by the RMS
|
||||
value. This makes it computationally more efficient, as it avoids the need to
|
||||
compute the mean and subtract it from each element.
|
||||
|
||||
This layer produces output tensors with the same dimensionality as the input tensors.
|
||||
Specifically, for an input tensor with shape [num_samples, k, nr, nc], the RMS
|
||||
normalization is applied across the [nr, nc] dimensions independently for each
|
||||
element in the [k] dimension and for each sample in the [num_samples] dimension.
|
||||
The scaling factor (RMS) and the learnable scaling parameter (gamma) are both of
|
||||
size [k].
|
||||
|
||||
The key characteristics of this layer are:
|
||||
- The RMS of the elements in each sample is standardized to 1.
|
||||
- It does not center the data (i.e., it does not subtract the mean).
|
||||
- A learnable scaling factor (gamma) is applied after normalization, allowing the
|
||||
model to adapt the scaling dynamically.
|
||||
|
||||
This layer is particularly effective in various natural language processing tasks,
|
||||
where it has been shown to provide performance similar to or better than traditional
|
||||
layer normalization, with reduced computational overhead.
|
||||
!*/
|
||||
|
||||
public:
|
||||
rms_norm_(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_learning_rate_multiplier() == 1
|
||||
- #get_weight_decay_multiplier() == 0
|
||||
- #get_bias_learning_rate_multiplier() == 1
|
||||
- #get_bias_weight_decay_multiplier() == 1
|
||||
- #get_eps() == DEFAULT_RMS_NORM_EPS
|
||||
!*/
|
||||
|
||||
explicit rms_norm_(
|
||||
float eps_ = DEFAULT_RMS_NORM_EPS
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- eps > 0
|
||||
ensures
|
||||
- #get_learning_rate_multiplier() == 1
|
||||
- #get_weight_decay_multiplier() == 0
|
||||
- #get_bias_learning_rate_multiplier() == 1
|
||||
- #get_bias_weight_decay_multiplier() == 1
|
||||
- #get_eps() == eps_
|
||||
!*/
|
||||
|
||||
float get_eps(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- When doing RMS normalization, we are dividing by the root mean square.
|
||||
This epsilon value returned by this function is added to the
|
||||
mean square to prevent division by zero.
|
||||
!*/
|
||||
|
||||
void set_eps(
|
||||
float val
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- val > 0
|
||||
ensures
|
||||
- #get_eps() == val
|
||||
!*/
|
||||
|
||||
double get_learning_rate_multiplier(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns a multiplier number. The interpretation is that this object is
|
||||
requesting that the learning rate used to optimize its parameters be
|
||||
multiplied by get_learning_rate_multiplier().
|
||||
!*/
|
||||
|
||||
double get_weight_decay_multiplier(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns a multiplier number. The interpretation is that this object is
|
||||
requesting that the weight decay used to optimize its parameters be
|
||||
multiplied by get_weight_decay_multiplier().
|
||||
!*/
|
||||
|
||||
void set_learning_rate_multiplier(
|
||||
double val
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- val >= 0
|
||||
ensures
|
||||
- #get_learning_rate_multiplier() == val
|
||||
!*/
|
||||
|
||||
void set_weight_decay_multiplier(
|
||||
double val
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- val >= 0
|
||||
ensures
|
||||
- #get_weight_decay_multiplier() == val
|
||||
!*/
|
||||
|
||||
double get_bias_learning_rate_multiplier(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns a multiplier number. The interpretation is that this object is
|
||||
requesting that the learning rate used to optimize its bias parameters be
|
||||
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
|
||||
!*/
|
||||
|
||||
double get_bias_weight_decay_multiplier(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns a multiplier number. The interpretation is that this object is
|
||||
requesting that the weight decay used to optimize its bias parameters be
|
||||
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
|
||||
!*/
|
||||
|
||||
void set_bias_learning_rate_multiplier(
|
||||
double val
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- val >= 0
|
||||
ensures
|
||||
- #get_bias_learning_rate_multiplier() == val
|
||||
!*/
|
||||
|
||||
void set_bias_weight_decay_multiplier(
|
||||
double val
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- val >= 0
|
||||
ensures
|
||||
- #get_bias_weight_decay_multiplier() == val
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
|
||||
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
|
||||
dpoint map_input_to_output(dpoint p) const;
|
||||
dpoint map_output_to_input(dpoint p) const;
|
||||
const tensor& get_layer_params() const;
|
||||
tensor& get_layer_params();
|
||||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
|
||||
!*/
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using rms_norm = add_layer<rms_norm_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
enum layer_mode
|
||||
|
@ -308,6 +308,14 @@ namespace dlib
|
||||
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
|
||||
}
|
||||
|
||||
template <typename U, typename E>
|
||||
void disable_input_bias(add_layer<rms_norm_, U, E>& l)
|
||||
{
|
||||
disable_bias(l.subnet().layer_details());
|
||||
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
|
||||
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
|
||||
}
|
||||
|
||||
template <layer_mode mode, typename U, typename E>
|
||||
void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
|
||||
{
|
||||
@ -333,6 +341,14 @@ namespace dlib
|
||||
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
|
||||
}
|
||||
|
||||
template <size_t N, template <typename> class R, typename U, typename E>
|
||||
void disable_input_bias(add_layer<rms_norm_, repeat<N, R, U>, E>& l)
|
||||
{
|
||||
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
|
||||
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
|
||||
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
|
||||
}
|
||||
|
||||
// handle input repeat layer with tag case
|
||||
template <layer_mode mode, unsigned long ID, typename E>
|
||||
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
|
||||
@ -344,6 +360,11 @@ namespace dlib
|
||||
{
|
||||
}
|
||||
|
||||
template <unsigned long ID, typename E>
|
||||
void disable_input_bias(add_layer<rms_norm_, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
|
||||
{
|
||||
}
|
||||
|
||||
// handle tag layer case
|
||||
template <layer_mode mode, unsigned long ID, typename U, typename E>
|
||||
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, U>, E>& )
|
||||
@ -355,6 +376,11 @@ namespace dlib
|
||||
{
|
||||
}
|
||||
|
||||
template <unsigned long ID, typename U, typename E>
|
||||
void disable_input_bias(add_layer<rms_norm_, add_tag_layer<ID, U>, E>& )
|
||||
{
|
||||
}
|
||||
|
||||
// handle skip layer case
|
||||
template <layer_mode mode, template <typename> class TAG, typename U, typename E>
|
||||
void disable_input_bias(add_layer<bn_<mode>, add_skip_layer<TAG, U>, E>& )
|
||||
@ -366,6 +392,11 @@ namespace dlib
|
||||
{
|
||||
}
|
||||
|
||||
template <template <typename> class TAG, typename U, typename E>
|
||||
void disable_input_bias(add_layer<rms_norm_, add_skip_layer<TAG, U>, E>& )
|
||||
{
|
||||
}
|
||||
|
||||
template<typename input_layer_type>
|
||||
void operator()(size_t , input_layer_type& ) const
|
||||
{
|
||||
@ -741,6 +772,14 @@ namespace dlib
|
||||
update(i);
|
||||
}
|
||||
|
||||
template <typename U, typename E>
|
||||
void operator()(size_t i, const add_layer<rms_norm_, U, E>&)
|
||||
{
|
||||
start_node(i, "rms_norm");
|
||||
end_node();
|
||||
update(i);
|
||||
}
|
||||
|
||||
template <layer_mode MODE, typename U, typename E>
|
||||
void operator()(size_t i, const add_layer<bn_<MODE>, U, E>&)
|
||||
{
|
||||
|
@ -38,9 +38,9 @@ namespace dlib
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
ensures
|
||||
- Disables bias for all bn_ and layer_norm_ inputs.
|
||||
- Disables bias for all bn_, layer_norm_ and rms_norms_ inputs.
|
||||
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
|
||||
to zero of all bn_ and layer_norm_ inputs.
|
||||
to zero of all bn_, layer_norm_ and rms_norm_ inputs.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -631,7 +631,7 @@ namespace
|
||||
DLIB_TEST(::std::abs(rs.stddev() - 1.0f) < 0.01);
|
||||
}
|
||||
// check that the CPU and the CUDA implementation are equivalent
|
||||
#if DLIB_USE_CUDA
|
||||
#ifdef DLIB_USE_CUDA
|
||||
resizable_tensor y_cuda(x);
|
||||
resizable_tensor means_cuda(x.num_samples()), invstds_cuda(x.num_samples());
|
||||
cuda::layer_normalize(eps, y_cuda, means_cuda, invstds_cuda, x, gamma, beta);
|
||||
@ -655,6 +655,99 @@ namespace
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_rms_normalize()
|
||||
{
|
||||
resizable_tensor x(2, 3, 4, 5);
|
||||
resizable_tensor y_cpu(x);
|
||||
tt::tensor_rand rnd(0);
|
||||
rnd.fill_uniform(x);
|
||||
resizable_tensor scale_cpu;
|
||||
resizable_tensor gamma(1, x.k());
|
||||
gamma = 1;
|
||||
const float eps = 1e-5;
|
||||
cpu::rms_normalize(eps, y_cpu, scale_cpu, x, gamma);
|
||||
|
||||
// check that the output is correctly normalized
|
||||
const float* p_x = x.host();
|
||||
const float* p_y = y_cpu.host();
|
||||
const float* p_scale = scale_cpu.host();
|
||||
bool error_found = false;
|
||||
for (long n = 0; n < x.num_samples(); ++n)
|
||||
{
|
||||
for (long k = 0; k < x.k(); ++k)
|
||||
{
|
||||
for (long r = 0; r < x.nr(); ++r)
|
||||
{
|
||||
for (long c = 0; c < x.nc(); ++c)
|
||||
{
|
||||
float x_val = p_x[tensor_index(x, n, k, r, c)];
|
||||
float y_val = p_y[tensor_index(y_cpu, n, k, r, c)];
|
||||
float rms_val = p_scale[n];
|
||||
if (std::abs(y_val - x_val * rms_val) >= 1e-5) error_found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
DLIB_TEST(!error_found);
|
||||
|
||||
// check the backward pass
|
||||
resizable_tensor gradient_input(x);
|
||||
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k());
|
||||
resizable_tensor dscale_cpu(x.num_samples());
|
||||
rnd.fill_gaussian(gradient_input);
|
||||
src_grad_cpu = 0;
|
||||
cpu::rms_normalize_gradient(gradient_input, scale_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, dscale_cpu);
|
||||
|
||||
const float* p_gradient_input = gradient_input.host();
|
||||
const float* p_src = x.host();
|
||||
const float* p_src_grad_cpu = src_grad_cpu.host();
|
||||
const float* p_gamma = gamma.host();
|
||||
const float* p_scale_cpu = scale_cpu.host();
|
||||
const float* p_dscale_cpu = dscale_cpu.host();
|
||||
|
||||
bool backward_error_found = false;
|
||||
for (long n = 0; n < x.num_samples(); ++n)
|
||||
{
|
||||
const float scale_pow = -0.5 * std::pow(p_scale_cpu[n], 3.0f);
|
||||
for (long k = 0; k < x.k(); ++k)
|
||||
{
|
||||
for (long r = 0; r < x.nr(); ++r)
|
||||
{
|
||||
for (long c = 0; c < x.nc(); ++c)
|
||||
{
|
||||
float gradient_input_val = p_gradient_input[tensor_index(gradient_input, n, k, r, c)];
|
||||
float src_val = p_src[tensor_index(x, n, k, r, c)];
|
||||
float rms_val = p_scale_cpu[n];
|
||||
float expected_src_grad = gradient_input_val * p_gamma[k] * rms_val + p_dscale_cpu[n] * 2 * src_val * 1.0f / (x.k() * x.nr() * x.nc());
|
||||
float src_grad_val = p_src_grad_cpu[tensor_index(src_grad_cpu, n, k, r, c)];
|
||||
if (std::abs(src_grad_val - expected_src_grad) >= 1e-4)
|
||||
backward_error_found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
DLIB_TEST(!backward_error_found);
|
||||
|
||||
// check that the CPU and the CUDA implementation are equivalent
|
||||
#ifdef DLIB_USE_CUDA
|
||||
resizable_tensor y_cuda(x);
|
||||
resizable_tensor scale_cuda;
|
||||
cuda::rms_normalize(eps, y_cuda, scale_cuda, x, gamma);
|
||||
DLIB_TEST(max(abs(mat(y_cpu) - mat(y_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(scale_cpu) - mat(scale_cuda))) < 1e-5);
|
||||
|
||||
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k());
|
||||
resizable_tensor dscale_cuda(x.num_samples());
|
||||
src_grad_cuda = 0;
|
||||
cuda::rms_normalize_gradient(gradient_input, scale_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, dscale_cuda);
|
||||
DLIB_TEST(max(abs(mat(src_grad_cpu) - mat(src_grad_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(gamma_grad_cpu) - mat(gamma_grad_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(dscale_cpu) - mat(dscale_cuda))) < 1e-5);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_basic_tensor_ops()
|
||||
@ -2013,6 +2106,12 @@ namespace
|
||||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
rms_norm_ l;
|
||||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
cont_<3,3,3,2,2,0,0> l;
|
||||
@ -4389,6 +4488,7 @@ namespace
|
||||
test_batch_normalize();
|
||||
test_batch_normalize_conv();
|
||||
test_layer_normalize();
|
||||
test_rms_normalize();
|
||||
test_basic_tensor_ops();
|
||||
test_layers();
|
||||
test_visit_functions();
|
||||
|
Loading…
Reference in New Issue
Block a user