mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Fix Layer Normalize (#2489)
* Fix Layer Normalize * remove unneeded temporary variables
This commit is contained in:
parent
aaac87a224
commit
3da3e81181
@ -1273,8 +1273,9 @@ namespace dlib
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
DLIB_CASSERT(
|
||||
have_same_dimensions(gamma, beta) &&
|
||||
src.num_samples() == gamma.size() &&
|
||||
src.num_samples() == beta.size() &&
|
||||
src.k() == gamma.k() &&
|
||||
src.nr() == gamma.nr() &&
|
||||
src.nc() == gamma.nc() &&
|
||||
eps > 0,
|
||||
"\ngamma.k(): " << gamma.k() <<
|
||||
"\ngamma.nr(): " << gamma.nr() <<
|
||||
@ -1282,9 +1283,9 @@ namespace dlib
|
||||
"\nbeta.k(): " << beta.k() <<
|
||||
"\nbeta.nr(): " << beta.nr() <<
|
||||
"\nbeta.nc(): " << beta.nc() <<
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\nsrc.nr(): " << src.nr() <<
|
||||
"\nsrc.nc(): " << src.nc() <<
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\nsrc.nr(): " << src.nr() <<
|
||||
"\nsrc.nc(): " << src.nc() <<
|
||||
"\neps: " << eps
|
||||
);
|
||||
|
||||
@ -1329,7 +1330,7 @@ namespace dlib
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
*p_dest = (*p_src - p_means[n])*p_invstds[n];
|
||||
*p_dest = (*p_dest)*p_gamma[n] + p_beta[n];
|
||||
*p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
|
||||
++p_src;
|
||||
++p_dest;
|
||||
}
|
||||
@ -1351,11 +1352,12 @@ namespace dlib
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
DLIB_CASSERT(src.num_samples() == means.size());
|
||||
DLIB_CASSERT(src.num_samples() == invstds.size());
|
||||
DLIB_CASSERT(src.num_samples() == gamma.size());
|
||||
DLIB_CASSERT(src.num_samples() == gamma_grad.size());
|
||||
DLIB_CASSERT(src.num_samples() == beta_grad.size());
|
||||
DLIB_CASSERT(src.k() == gamma.k());
|
||||
DLIB_CASSERT(src.nr() == gamma_grad.nr());
|
||||
DLIB_CASSERT(src.nc() == beta_grad.nc());
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
|
||||
DLIB_CASSERT(eps > 0);
|
||||
|
||||
beta_grad = 0;
|
||||
@ -1381,12 +1383,12 @@ namespace dlib
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float x_hat = (*p_src - p_means[n])*p_invstds[n];
|
||||
p_beta_grad[n] += *p_grad;
|
||||
p_gamma_grad[n] += (*p_grad)*x_hat;
|
||||
p_beta_grad[i] += *p_grad;
|
||||
p_gamma_grad[i] += (*p_grad)*x_hat;
|
||||
|
||||
const float dx = *p_grad * p_gamma[n];
|
||||
|
||||
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*std::pow(p_invstds[n], 3.0f);
|
||||
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
|
||||
|
||||
++p_grad;
|
||||
++p_src;
|
||||
@ -1400,7 +1402,7 @@ namespace dlib
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float dx = *p_grad * p_gamma[n];
|
||||
const float dx = *p_grad * p_gamma[i];
|
||||
|
||||
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
|
||||
|
||||
@ -1415,7 +1417,7 @@ namespace dlib
|
||||
{
|
||||
for (long i = 0; i < num; ++i)
|
||||
{
|
||||
const float dx = *p_grad * p_gamma[n];
|
||||
const float dx = *p_grad * p_gamma[i];
|
||||
|
||||
*p_src_grad += dx*p_invstds[n] +
|
||||
p_dvars[n] *2*(*p_src - p_means[n])*invnum +
|
||||
|
@ -1908,7 +1908,7 @@ namespace dlib
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
const float val = (s[n*num+i]-m[n])*v[n];
|
||||
out[n*num+i] = val*g[n]+b[n];
|
||||
out[n*num+i] = val*g[i]+b[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1917,21 +1917,17 @@ namespace dlib
|
||||
{
|
||||
for (auto n : grid_stride_range_y(0, ns))
|
||||
{
|
||||
float temp_bg = 0;
|
||||
float temp_gg = 0;
|
||||
float temp_dv = 0;
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
auto idx = n*num+i;
|
||||
const float x_hat = (s[idx] - m[n])*v[n];
|
||||
temp_bg += gi[idx];
|
||||
temp_gg += gi[idx]*x_hat;
|
||||
bg[i] += gi[idx];
|
||||
gg[i] += gi[idx]*x_hat;
|
||||
|
||||
const float dx = gi[idx] * g[n];
|
||||
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
|
||||
}
|
||||
warp_reduce_atomic_add(bg[n], temp_bg);
|
||||
warp_reduce_atomic_add(gg[n], temp_gg);
|
||||
warp_reduce_atomic_add(dv[n], temp_dv);
|
||||
}
|
||||
__syncthreads();
|
||||
@ -1942,7 +1938,7 @@ namespace dlib
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
auto idx = n*num+i;
|
||||
const float dx = gi[idx]*g[n];
|
||||
const float dx = gi[idx]*g[i];
|
||||
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
|
||||
}
|
||||
warp_reduce_atomic_add(dm[n], temp_dm);
|
||||
@ -1954,7 +1950,7 @@ namespace dlib
|
||||
for (auto i : grid_stride_range(0, num))
|
||||
{
|
||||
auto idx = n*num+i;
|
||||
const float dx = gi[idx]*g[n];
|
||||
const float dx = gi[idx]*g[i];
|
||||
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
|
||||
}
|
||||
}
|
||||
@ -1973,8 +1969,9 @@ namespace dlib
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
DLIB_CASSERT(
|
||||
have_same_dimensions(gamma, beta) &&
|
||||
src.num_samples() == gamma.size() &&
|
||||
src.num_samples() == beta.size() &&
|
||||
src.k() == gamma.k() &&
|
||||
src.nr() == gamma.nr() &&
|
||||
src.nc() == gamma.nc() &&
|
||||
eps > 0,
|
||||
"\ngamma.k(): " << gamma.k() <<
|
||||
"\ngamma.nr(): " << gamma.nr() <<
|
||||
@ -1982,9 +1979,9 @@ namespace dlib
|
||||
"\nbeta.k(): " << beta.k() <<
|
||||
"\nbeta.nr(): " << beta.nr() <<
|
||||
"\nbeta.nc(): " << beta.nc() <<
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\nsrc.nr(): " << src.nr() <<
|
||||
"\nsrc.nc(): " << src.nc() <<
|
||||
"\nsrc.k(): " << src.k() <<
|
||||
"\nsrc.nr(): " << src.nr() <<
|
||||
"\nsrc.nc(): " << src.nc() <<
|
||||
"\neps: " << eps
|
||||
);
|
||||
|
||||
@ -2012,11 +2009,13 @@ namespace dlib
|
||||
const long num = src.k() * src.nr() * src.nc();
|
||||
DLIB_CASSERT(src.num_samples() == means.size());
|
||||
DLIB_CASSERT(src.num_samples() == invstds.size());
|
||||
DLIB_CASSERT(src.num_samples() == gamma.size());
|
||||
DLIB_CASSERT(src.num_samples() == gamma_grad.size());
|
||||
DLIB_CASSERT(src.num_samples() == beta_grad.size());
|
||||
DLIB_CASSERT(src.k() == gamma.k());
|
||||
DLIB_CASSERT(src.nr() == gamma.nr());
|
||||
DLIB_CASSERT(src.nc() == gamma.nc());
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
|
||||
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
|
||||
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
|
||||
DLIB_CASSERT(eps > 0);
|
||||
|
||||
beta_grad = 0;
|
||||
|
@ -1371,7 +1371,7 @@ namespace dlib
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& sub)
|
||||
{
|
||||
gamma = alias_tensor(sub.get_output().num_samples());
|
||||
gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc());
|
||||
beta = gamma;
|
||||
|
||||
params.set_size(gamma.size()+beta.size());
|
||||
|
@ -556,7 +556,7 @@ namespace
|
||||
tt::tensor_rand rnd(0);
|
||||
rnd.fill_uniform(x);
|
||||
resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples());
|
||||
resizable_tensor gamma(x.num_samples()), beta(x.num_samples());
|
||||
resizable_tensor gamma(1, x.k(), x.nr(), x.nc()), beta(1, x.k(), x.nr(), x.nc());
|
||||
gamma = 1;
|
||||
beta = 0;
|
||||
const float eps = 1e-5;
|
||||
@ -588,8 +588,8 @@ namespace
|
||||
DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5);
|
||||
resizable_tensor gradient_input(x);
|
||||
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(x.num_samples()), beta_grad_cpu(x.num_samples());
|
||||
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(x.num_samples()), beta_grad_cuda(x.num_samples());
|
||||
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), x.nr(), x.nc()), beta_grad_cpu(1, x.k(), x.nr(), x.nc());
|
||||
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), x.nr(), x.nc()), beta_grad_cuda(1, x.k(), x.nr(), x.nc());
|
||||
rnd.fill_gaussian(gradient_input);
|
||||
src_grad_cpu = 0;
|
||||
src_grad_cuda = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user