Fix Layer Normalize (#2489)

* Fix Layer Normalize

* remove unneeded temporary variables
This commit is contained in:
Adrià Arrufat 2022-01-24 01:29:28 +09:00 committed by GitHub
parent aaac87a224
commit 3da3e81181
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 35 deletions

View File

@ -1273,8 +1273,9 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.num_samples() == gamma.size() &&
src.num_samples() == beta.size() &&
src.k() == gamma.k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
eps > 0,
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
@ -1282,9 +1283,9 @@ namespace dlib
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
@ -1329,7 +1330,7 @@ namespace dlib
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src - p_means[n])*p_invstds[n];
*p_dest = (*p_dest)*p_gamma[n] + p_beta[n];
*p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
++p_src;
++p_dest;
}
@ -1351,11 +1352,12 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.num_samples() == gamma.size());
DLIB_CASSERT(src.num_samples() == gamma_grad.size());
DLIB_CASSERT(src.num_samples() == beta_grad.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma_grad.nr());
DLIB_CASSERT(src.nc() == beta_grad.nc());
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;
@ -1381,12 +1383,12 @@ namespace dlib
for (long i = 0; i < num; ++i)
{
const float x_hat = (*p_src - p_means[n])*p_invstds[n];
p_beta_grad[n] += *p_grad;
p_gamma_grad[n] += (*p_grad)*x_hat;
p_beta_grad[i] += *p_grad;
p_gamma_grad[i] += (*p_grad)*x_hat;
const float dx = *p_grad * p_gamma[n];
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*std::pow(p_invstds[n], 3.0f);
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
++p_grad;
++p_src;
@ -1400,7 +1402,7 @@ namespace dlib
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[n];
const float dx = *p_grad * p_gamma[i];
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
@ -1415,7 +1417,7 @@ namespace dlib
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[n];
const float dx = *p_grad * p_gamma[i];
*p_src_grad += dx*p_invstds[n] +
p_dvars[n] *2*(*p_src - p_means[n])*invnum +

View File

@ -1908,7 +1908,7 @@ namespace dlib
for (auto i : grid_stride_range(0, num))
{
const float val = (s[n*num+i]-m[n])*v[n];
out[n*num+i] = val*g[n]+b[n];
out[n*num+i] = val*g[i]+b[i];
}
}
}
@ -1917,21 +1917,17 @@ namespace dlib
{
for (auto n : grid_stride_range_y(0, ns))
{
float temp_bg = 0;
float temp_gg = 0;
float temp_dv = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float x_hat = (s[idx] - m[n])*v[n];
temp_bg += gi[idx];
temp_gg += gi[idx]*x_hat;
bg[i] += gi[idx];
gg[i] += gi[idx]*x_hat;
const float dx = gi[idx] * g[n];
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
}
warp_reduce_atomic_add(bg[n], temp_bg);
warp_reduce_atomic_add(gg[n], temp_gg);
warp_reduce_atomic_add(dv[n], temp_dv);
}
__syncthreads();
@ -1942,7 +1938,7 @@ namespace dlib
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[n];
const float dx = gi[idx]*g[i];
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
}
warp_reduce_atomic_add(dm[n], temp_dm);
@ -1954,7 +1950,7 @@ namespace dlib
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[n];
const float dx = gi[idx]*g[i];
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
}
}
@ -1973,8 +1969,9 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.num_samples() == gamma.size() &&
src.num_samples() == beta.size() &&
src.k() == gamma.k() &&
src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
eps > 0,
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
@ -1982,9 +1979,9 @@ namespace dlib
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
@ -2012,11 +2009,13 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.num_samples() == gamma.size());
DLIB_CASSERT(src.num_samples() == gamma_grad.size());
DLIB_CASSERT(src.num_samples() == beta_grad.size());
DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.nr() == gamma.nr());
DLIB_CASSERT(src.nc() == gamma.nc());
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;

View File

@ -1371,7 +1371,7 @@ namespace dlib
template <typename SUBNET>
void setup (const SUBNET& sub)
{
gamma = alias_tensor(sub.get_output().num_samples());
gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc());
beta = gamma;
params.set_size(gamma.size()+beta.size());

View File

@ -556,7 +556,7 @@ namespace
tt::tensor_rand rnd(0);
rnd.fill_uniform(x);
resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples());
resizable_tensor gamma(x.num_samples()), beta(x.num_samples());
resizable_tensor gamma(1, x.k(), x.nr(), x.nc()), beta(1, x.k(), x.nr(), x.nc());
gamma = 1;
beta = 0;
const float eps = 1e-5;
@ -588,8 +588,8 @@ namespace
DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5);
resizable_tensor gradient_input(x);
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(x.num_samples()), beta_grad_cpu(x.num_samples());
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(x.num_samples()), beta_grad_cuda(x.num_samples());
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), x.nr(), x.nc()), beta_grad_cpu(1, x.k(), x.nr(), x.nc());
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), x.nr(), x.nc()), beta_grad_cuda(1, x.k(), x.nr(), x.nc());
rnd.fill_gaussian(gradient_input);
src_grad_cpu = 0;
src_grad_cuda = 0;