@ -2085,21 +2085,32 @@ namespace dlib
// ----------------------------------------------------------------------------------------
__global__ void _cuda_layer_normalize(float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t num)
__global__ void _cuda_layer_normalize(
float* out,
const float* s,
float* m,
float* v,
const float* g,
const float* b,
float eps,
size_t ns,
size_t k,
size_t num
)
{
// compute means and sum of squares
for (auto n : grid_stride_range_y(0, ns))
{
auto p = s + n * num;
const auto ps = s + n * k * num;
float means = 0;
float invstds = 0;
for (auto i : grid_stride_range(0, num))
for (auto i : grid_stride_range(0, k * num))
{
means += p[i];
invstds += p[i] * p[i];
means += ps [i];
invstds += ps [i] * ps [i];
}
warp_reduce_atomic_add(m[n], means/num);
warp_reduce_atomic_add(v[n], invstds/num);
warp_reduce_atomic_add(m[n], means / (k * num) );
warp_reduce_atomic_add(v[n], invstds / (k * num) );
}
__syncthreads();
@ -2108,61 +2119,19 @@ namespace dlib
{
for (auto i : grid_stride_range(0, 1))
{
auto var = v[n] - m[n] * m[n];
v[n] = 1.0f / std::sqrt(var + eps);
v[n] = 1.0f / std::sqrt(v[n] - m[n] * m[n] + eps);
}
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, num))
const auto ps = s + n * k * num;
const auto pout = out + n * k * num;
for (auto i : grid_stride_range(0, k * num))
{
const float val = (s[n*num+i]-m[n])*v[n];
out[n*num+i] = val*g[i]+b[i];
}
}
}
__global__ void _cuda_layer_normalize_gradient(float* out, float* gg, float* bg, const float* s, const float* gi, const float* m, const float* v, const float* g, float* dm, float* dv, float eps, size_t ns, size_t num)
{
for (auto n : grid_stride_range_y(0, ns))
{
float temp_dv = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float x_hat = (s[idx] - m[n])*v[n];
bg[i] += gi[idx];
gg[i] += gi[idx]*x_hat;
const float dx = gi[idx] * g[n];
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
}
warp_reduce_atomic_add(dv[n], temp_dv);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
float temp_dm = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[i];
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
}
warp_reduce_atomic_add(dm[n], temp_dm);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[i];
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
pout[i] = (ps[i] - m[n]) * v[n];
pout[i] = pout[i] * g[i / num] + b[i / num];
}
}
}
@ -2177,22 +2146,20 @@ namespace dlib
const tensor& beta
)
{
const long num = src.k() * src. nr() * src.nc();
const long num = src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.k() == gamma .k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
gamma.k() == src .k() &&
gamma.nr() == 1 &&
gamma.nc() == 1 &&
eps > 0,
"\nsrc.k(): " << src.k() <<
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
@ -2201,8 +2168,78 @@ namespace dlib
invstds.set_size(src.num_samples());
means = 0;
invstds = 0;
launch_kernel(_cuda_layer_normalize, max_jobs(num, src.num_samples()), dest.device(), src.device(),
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), num);
launch_kernel(_cuda_layer_normalize, max_jobs(src.k() * num, src.num_samples()), dest.device(), src.device(),
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), src.k(), num);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_layer_normalize_gradient(
float* out,
float* gg,
float* bg,
const float* s,
const float* gi,
const float* m,
const float* v,
const float* g,
float* dm,
float* dv,
float eps,
size_t ns,
size_t ks,
size_t num)
{
for (auto nk : grid_stride_range_y(0, ns * ks))
{
const auto n = nk / ks;
const auto k = nk % ks;
const auto ps = s + (n * ks + k) * num;
const auto pgi = gi + (n * ks + k) * num;
const float invstd_pow = -0.5 * std::pow(v[n], 3.0f);
float temp_bg = 0;
float temp_gg = 0;
float temp_dv = 0;
for (auto i : grid_stride_range(0, num))
{
const float x_hat = (ps[i] - m[n]) * v[n];
const float dx = pgi[i] * g[i / num];
temp_bg += pgi[i];
temp_gg += pgi[i] * x_hat;
temp_dv += dx * (ps[i] - m[n]) * invstd_pow;
}
warp_reduce_atomic_add(bg[k], temp_bg);
warp_reduce_atomic_add(gg[k], temp_gg);
warp_reduce_atomic_add(dv[n], temp_dv);
}
__syncthreads();
const float invnum = 1.0f / (ks * num);
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = s + n * ks * num;
const auto pgi = gi + n * ks * num;
float temp_dm = 0;
for (auto i : grid_stride_range(0, ks * num))
{
const float dx = pgi[i] * g[i / num];
temp_dm += -dx * v[n] + dv[n] * -2 * (ps[i] - m[n]) * invnum;
}
warp_reduce_atomic_add(dm[n], temp_dm);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
const auto ps = s + n * ks * num;
const auto pgi = gi + n * ks * num;
const auto pout = out + n * ks * num;
for (auto i : grid_stride_range(0, ks * num))
{
const float dx = pgi[i] * g[i / num];
pout[i] += dx * v[n] + dv[n] * 2 * (ps[i] - m[n]) * invnum + dm[n] * invnum;
}
}
}
void layer_normalize_gradient (
@ -2214,32 +2251,33 @@ namespace dlib
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
tensor& beta_grad,
resizable_tensor& dmeans,
resizable_tensor& dvars
)
{
const long num = src.k() * src. nr() * src.nc();
const long num = src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma.nr());
DLIB_CASSERT(src.nc() == gamma.nc());
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(gamma.k() == src.k());
DLIB_CASSERT(gamma.nr() == 1);
DLIB_CASSERT(gamma.nc() == 1);
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;
gamma_grad = 0;
resizable_tensor dvars, dmeans;
dvars.copy_size(invstds);
dmeans.copy_size(means);
dvars = 0;
dmeans = 0;
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(num, src.num_samples()),
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(src.k() * num, src.num_samples()),
src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(),
gradient_input.device(), means.device(), invstds.device(), gamma.device(),
dmeans.device(), dvars.device(), eps, src.num_samples(), num);
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
}
// ----------------------------------------------------------------------------------------