mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Upgraded resize_bilinear() to let the user specify independent row and channel
stride values. This lets you run the tensor resizing routine on subwindows in a tensor.
This commit is contained in:
parent
f5a68ded86
commit
9ba4b45ffd
@ -1478,7 +1478,11 @@ namespace dlib
|
||||
|
||||
void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
long dest_row_stride,
|
||||
long dest_channel_stride,
|
||||
const tensor& src,
|
||||
long src_row_stride,
|
||||
long src_channel_stride
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(dest, src)==false);
|
||||
@ -1509,27 +1513,31 @@ namespace dlib
|
||||
const long right = std::min(left+1, src.nc()-1);
|
||||
const float lr_frac = x - left;
|
||||
|
||||
float tl = s[top*src.nc()+left];
|
||||
float tr = s[top*src.nc()+right];
|
||||
float bl = s[bottom*src.nc()+left];
|
||||
float br = s[bottom*src.nc()+right];
|
||||
float tl = s[top*src_row_stride+left];
|
||||
float tr = s[top*src_row_stride+right];
|
||||
float bl = s[bottom*src_row_stride+left];
|
||||
float br = s[bottom*src_row_stride+right];
|
||||
|
||||
float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
|
||||
tb_frac*((1-lr_frac)*bl + lr_frac*br);
|
||||
|
||||
d[r*dest.nc()+c] = temp;
|
||||
d[r*dest_row_stride+c] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
d += dest.nr()*dest.nc();
|
||||
s += src.nr()*src.nc();
|
||||
d += dest_channel_stride;
|
||||
s += src_channel_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
long grad_row_stride,
|
||||
long grad_channel_stride,
|
||||
const tensor& gradient_input,
|
||||
long gradient_input_row_stride,
|
||||
long gradient_input_channel_stride
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
|
||||
@ -1560,17 +1568,17 @@ namespace dlib
|
||||
const long right = std::min(left+1, grad.nc()-1);
|
||||
const float lr_frac = x - left;
|
||||
|
||||
const float tmp = gi[r*gradient_input.nc()+c];
|
||||
const float tmp = gi[r*gradient_input_row_stride+c];
|
||||
|
||||
g[top*grad.nc()+left] += tmp*(1-tb_frac)*(1-lr_frac);
|
||||
g[top*grad.nc()+right] += tmp*(1-tb_frac)*(lr_frac);
|
||||
g[bottom*grad.nc()+left] += tmp*(tb_frac)*(1-lr_frac);
|
||||
g[bottom*grad.nc()+right] += tmp*(tb_frac)*(lr_frac);
|
||||
g[top*grad_row_stride+left] += tmp*(1-tb_frac)*(1-lr_frac);
|
||||
g[top*grad_row_stride+right] += tmp*(1-tb_frac)*(lr_frac);
|
||||
g[bottom*grad_row_stride+left] += tmp*(tb_frac)*(1-lr_frac);
|
||||
g[bottom*grad_row_stride+right] += tmp*(tb_frac)*(lr_frac);
|
||||
}
|
||||
}
|
||||
|
||||
g += grad.nr()*grad.nc();
|
||||
gi += gradient_input.nr()*gradient_input.nc();
|
||||
g += grad_channel_stride;
|
||||
gi += gradient_input_channel_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -307,14 +307,32 @@ namespace dlib
|
||||
|
||||
void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
long dest_row_stride,
|
||||
long dest_channel_stride,
|
||||
const tensor& src,
|
||||
long src_row_stride,
|
||||
long src_channel_stride
|
||||
);
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
long grad_row_stride,
|
||||
long grad_channel_stride,
|
||||
const tensor& gradient_input,
|
||||
long gradient_input_row_stride,
|
||||
long gradient_input_channel_stride
|
||||
);
|
||||
|
||||
inline void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
|
||||
|
||||
inline void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
class pooling
|
||||
|
@ -1301,9 +1301,50 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _cuda_resize_bilinear_strided(size_t dsize, size_t dchan_size, size_t dnc, float* d,
|
||||
size_t schan_size, int snr, int snc, const float* s,
|
||||
const float x_scale, const float y_scale,
|
||||
size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided
|
||||
)
|
||||
{
|
||||
for(auto i : grid_stride_range(0, dsize))
|
||||
{
|
||||
const int idx = i%dchan_size;
|
||||
const int channel = i/dchan_size;
|
||||
const int sidx = channel*schan_size;
|
||||
const int r = idx/dnc;
|
||||
const int c = idx%dnc;
|
||||
const int didx = channel*dest_chan_size_strided + r*dest_row_stride+c;
|
||||
|
||||
const float y = r*y_scale;
|
||||
const int top = static_cast<int>(::floor(y));
|
||||
const int bottom = ::min(top+1, snr-1);
|
||||
const float tb_frac = y - top;
|
||||
|
||||
const float x = c*x_scale;
|
||||
const int left = static_cast<int>(::floor(x));
|
||||
const int right = ::min(left+1, snc-1);
|
||||
const float lr_frac = x - left;
|
||||
|
||||
float tl = s[sidx+top*src_row_stride+left];
|
||||
float tr = s[sidx+top*src_row_stride+right];
|
||||
float bl = s[sidx+bottom*src_row_stride+left];
|
||||
float br = s[sidx+bottom*src_row_stride+right];
|
||||
|
||||
float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
|
||||
tb_frac*((1-lr_frac)*bl + lr_frac*br);
|
||||
|
||||
d[didx] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
long dest_row_stride,
|
||||
long dest_channel_stride,
|
||||
const tensor& src,
|
||||
long src_row_stride,
|
||||
long src_channel_stride
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(dest, src)==false);
|
||||
@ -1316,12 +1357,25 @@ namespace dlib
|
||||
const float x_scale = (src.nc()-1)/(float)std::max<long>((dest.nc()-1),1);
|
||||
const float y_scale = (src.nr()-1)/(float)std::max<long>((dest.nr()-1),1);
|
||||
|
||||
launch_kernel(_cuda_resize_bilinear,
|
||||
dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(),
|
||||
src.nr()*src.nc(), src.nr(), src.nc(), src.device(),
|
||||
x_scale, y_scale);
|
||||
if (dest.nc() == dest_row_stride && dest.nr()*dest.nc()==dest_channel_stride &&
|
||||
src.nc() == src_row_stride && src.nr()*src.nc()==src_channel_stride)
|
||||
{
|
||||
launch_kernel(_cuda_resize_bilinear,
|
||||
dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(),
|
||||
src.nr()*src.nc(), src.nr(), src.nc(), src.device(),
|
||||
x_scale, y_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
launch_kernel(_cuda_resize_bilinear_strided,
|
||||
dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(),
|
||||
src_channel_stride, src.nr(), src.nc(), src.device(),
|
||||
x_scale, y_scale, dest_row_stride, src_row_stride, dest_channel_stride);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_resize_bilinear_gradient(size_t dsize, size_t dchan_size, size_t dnc, const float* d,
|
||||
size_t schan_size, int snr, int snc, float* s,
|
||||
const float x_scale, const float y_scale)
|
||||
@ -1354,9 +1408,49 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _cuda_resize_bilinear_gradient_strided(size_t dsize, size_t dchan_size, size_t dnc, const float* d,
|
||||
size_t schan_size, int snr, int snc, float* s,
|
||||
const float x_scale, const float y_scale,
|
||||
size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided
|
||||
)
|
||||
{
|
||||
for(auto i : grid_stride_range(0, dsize))
|
||||
{
|
||||
|
||||
const int idx = i%dchan_size;
|
||||
const int channel = i/dchan_size;
|
||||
const int didx = channel*dest_chan_size_strided;
|
||||
const int sidx = channel*schan_size;
|
||||
const int r = idx/dnc;
|
||||
const int c = idx%dnc;
|
||||
|
||||
const float tmp = d[didx + r*dest_row_stride+c];
|
||||
|
||||
const float y = r*y_scale;
|
||||
const int top = static_cast<int>(::floor(y));
|
||||
const int bottom = ::min(top+1, snr-1);
|
||||
const float tb_frac = y - top;
|
||||
|
||||
const float x = c*x_scale;
|
||||
const int left = static_cast<int>(::floor(x));
|
||||
const int right = ::min(left+1, snc-1);
|
||||
const float lr_frac = x - left;
|
||||
|
||||
|
||||
atomicAdd(s+sidx+top*src_row_stride+left, tmp*(1-tb_frac)*(1-lr_frac));
|
||||
atomicAdd(s+sidx+top*src_row_stride+right, tmp*(1-tb_frac)*(lr_frac));
|
||||
atomicAdd(s+sidx+bottom*src_row_stride+left, tmp*(tb_frac)*(1-lr_frac));
|
||||
atomicAdd(s+sidx+bottom*src_row_stride+right, tmp*(tb_frac)*(lr_frac));
|
||||
}
|
||||
}
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
long grad_row_stride,
|
||||
long grad_channel_stride,
|
||||
const tensor& gradient_input,
|
||||
long gradient_input_row_stride,
|
||||
long gradient_input_channel_stride
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
|
||||
@ -1369,10 +1463,21 @@ namespace dlib
|
||||
const float x_scale = (grad.nc()-1)/(float)std::max<long>((gradient_input.nc()-1),1);
|
||||
const float y_scale = (grad.nr()-1)/(float)std::max<long>((gradient_input.nr()-1),1);
|
||||
|
||||
launch_kernel(_cuda_resize_bilinear_gradient,
|
||||
gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(),
|
||||
grad.nr()*grad.nc(), grad.nr(), grad.nc(), grad.device(),
|
||||
x_scale, y_scale);
|
||||
if (grad.nc() == grad_row_stride && grad.nr()*grad.nc()==grad_channel_stride &&
|
||||
gradient_input.nc() == gradient_input_row_stride && gradient_input.nr()*gradient_input.nc()==gradient_input_channel_stride)
|
||||
{
|
||||
launch_kernel(_cuda_resize_bilinear_gradient,
|
||||
gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(),
|
||||
grad.nr()*grad.nc(), grad.nr(), grad.nc(), grad.device(),
|
||||
x_scale, y_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
launch_kernel(_cuda_resize_bilinear_gradient_strided,
|
||||
gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(),
|
||||
grad_channel_stride, grad.nr(), grad.nc(), grad.device(),
|
||||
x_scale, y_scale, gradient_input_row_stride, grad_row_stride, gradient_input_channel_stride);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -358,14 +358,32 @@ namespace dlib
|
||||
|
||||
void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
long dest_row_stride,
|
||||
long dest_channel_stride,
|
||||
const tensor& src,
|
||||
long src_row_stride,
|
||||
long src_channel_stride
|
||||
);
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
long grad_row_stride,
|
||||
long grad_channel_stride,
|
||||
const tensor& gradient_input,
|
||||
long gradient_input_row_stride,
|
||||
long gradient_input_channel_stride
|
||||
);
|
||||
|
||||
inline void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
|
||||
|
||||
inline void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void copy_tensor(
|
||||
|
@ -856,25 +856,33 @@ namespace dlib { namespace tt
|
||||
|
||||
void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
long dest_row_stride,
|
||||
long dest_channel_stride,
|
||||
const tensor& src,
|
||||
long src_row_stride,
|
||||
long src_channel_stride
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::resize_bilinear(dest,src);
|
||||
cuda::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride);
|
||||
#else
|
||||
cpu::resize_bilinear(dest,src);
|
||||
cpu::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride);
|
||||
#endif
|
||||
}
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
long grad_row_stride,
|
||||
long grad_channel_stride,
|
||||
const tensor& gradient_input,
|
||||
long gradient_input_row_stride,
|
||||
long gradient_input_channel_stride
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::resize_bilinear_gradient(grad,gradient_input);
|
||||
cuda::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride);
|
||||
#else
|
||||
cpu::resize_bilinear_gradient(grad,gradient_input);
|
||||
cpu::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1371,8 +1371,59 @@ namespace dlib { namespace tt
|
||||
|
||||
void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
long dest_row_stride,
|
||||
long dest_channel_stride,
|
||||
const tensor& src,
|
||||
long src_row_stride,
|
||||
long src_channel_stride
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_same_object(dest, src)==false
|
||||
- dest.num_samples() == src.num_samples()
|
||||
- dest.k() == src.k()
|
||||
ensures
|
||||
- for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k)
|
||||
that has been bilinearly interpolated to fit into the shape of
|
||||
image_plane(dest,i,k).
|
||||
- Instead of supposing the row stride and channel stride in the tensors is
|
||||
given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the
|
||||
provided stride values to transition from one row and channel to the next.
|
||||
This is useful in combination with alias_tensor objects since it allows you
|
||||
to operate on subwindows in an image.
|
||||
!*/
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
long grad_row_stride,
|
||||
long grad_channel_stride,
|
||||
const tensor& gradient_input,
|
||||
long gradient_input_row_stride,
|
||||
long gradient_input_channel_stride
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_same_object(grad, gradient_input)==false
|
||||
- gradient_input.num_samples() == grad.num_samples()
|
||||
- gradient_input.k() == grad.k()
|
||||
ensures
|
||||
- Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC
|
||||
tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes
|
||||
the gradient of f() with respect to SRC and adds it to grad. It should be
|
||||
noted that we don't need to know the contents of DEST to compute this
|
||||
gradient. All that matters is that gradient_input have the same dimensions
|
||||
as DEST.
|
||||
- Instead of supposing the row stride and channel stride in the tensors is
|
||||
given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the
|
||||
provided stride values to transition from one row and channel to the next.
|
||||
This is useful in combination with alias_tensor objects since it allows you
|
||||
to operate on subwindows in an image.
|
||||
!*/
|
||||
|
||||
inline void resize_bilinear (
|
||||
tensor& dest,
|
||||
const tensor& src
|
||||
) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
|
||||
/*!
|
||||
requires
|
||||
- is_same_object(dest, src)==false
|
||||
@ -1384,10 +1435,10 @@ namespace dlib { namespace tt
|
||||
image_plane(dest,i,k).
|
||||
!*/
|
||||
|
||||
void resize_bilinear_gradient (
|
||||
inline void resize_bilinear_gradient (
|
||||
tensor& grad,
|
||||
const tensor& gradient_input
|
||||
);
|
||||
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
|
||||
/*!
|
||||
requires
|
||||
- is_same_object(grad, gradient_input)==false
|
||||
|
@ -1774,8 +1774,8 @@ namespace
|
||||
|
||||
net_type2 pnet;
|
||||
|
||||
DLIB_CASSERT(pnet.num_layers == 131, pnet.num_layers);
|
||||
DLIB_CASSERT(pnet.num_computational_layers == 109, pnet.num_computational_layers);
|
||||
DLIB_TEST_MSG(pnet.num_layers == 131, pnet.num_layers);
|
||||
DLIB_TEST_MSG(pnet.num_computational_layers == 109, pnet.num_computational_layers);
|
||||
|
||||
std::vector<bool> hit(pnet.num_computational_layers, false);
|
||||
size_t count = 0;
|
||||
@ -2322,7 +2322,7 @@ namespace
|
||||
for (int is_bias = 0; is_bias <= 1; ++is_bias) {
|
||||
for (uint16_t k = 0; k < num_classes; ++k) {
|
||||
size_t index = k + is_bias * num_classes;
|
||||
DLIB_CASSERT(index < learned_params.size());
|
||||
DLIB_TEST(index < learned_params.size());
|
||||
if (k == true_label) {
|
||||
DLIB_TEST(learned_params_data[index] > 1e5);
|
||||
}
|
||||
@ -2419,13 +2419,13 @@ namespace
|
||||
|
||||
for (long k = 0; k < num_classes; ++k) {
|
||||
const size_t index = ((ii * output_tensor.k() + k) * output_tensor.nr() + jj) * output_tensor.nc() + kk;
|
||||
DLIB_CASSERT(index < output_tensor.size());
|
||||
DLIB_TEST(index < output_tensor.size());
|
||||
|
||||
if (k == true_label) {
|
||||
DLIB_TEST_MSG(out_data[index] > 1e4, "");
|
||||
DLIB_TEST(out_data[index] > 1e4);
|
||||
}
|
||||
else {
|
||||
DLIB_TEST_MSG(out_data[index] < -1e4, "");
|
||||
DLIB_TEST(out_data[index] < -1e4);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2745,7 +2745,7 @@ namespace
|
||||
cpu::resize_bilinear(out, img);
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::resize_bilinear(out2, img);
|
||||
DLIB_CASSERT(max(abs(mat(out)-mat(out2))) < 1e-5);
|
||||
DLIB_TEST(max(abs(mat(out)-mat(out2))) < 1e-5);
|
||||
#endif
|
||||
|
||||
resizable_tensor gradient_input;
|
||||
@ -2775,16 +2775,87 @@ namespace
|
||||
|
||||
cpu::resize_bilinear_gradient(grad2, gradient_input);
|
||||
dlog << LINFO << "analytic grad: "<< grad2.host()[idx]-0.1;
|
||||
DLIB_CASSERT(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
|
||||
DLIB_TEST_MSG(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::resize_bilinear_gradient(grad, gradient_input);
|
||||
dlog << LINFO << "analytic grad: "<< grad.host()[idx]-0.1;
|
||||
DLIB_CASSERT(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
|
||||
DLIB_CASSERT(max(abs(mat(grad)-mat(grad2))) < 1e-5);
|
||||
DLIB_TEST_MSG(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
|
||||
DLIB_TEST(max(abs(mat(grad)-mat(grad2))) < 1e-5);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
|
||||
// now test with strided/sub-window calls
|
||||
alias_tensor aimg(samps, k, nr-2,nc-2);
|
||||
alias_tensor aout(samps, k, onr-2,onc-2);
|
||||
for (int iter = 0; iter < 10; ++iter)
|
||||
{
|
||||
print_spinner();
|
||||
|
||||
const size_t idx = rnd.get_random_64bit_number()%img.size();
|
||||
|
||||
img = 1;
|
||||
img.host()[idx] = 2;
|
||||
out = 9;
|
||||
out2 = 9;
|
||||
auto wout = aout(out, out.nc()*1+1);
|
||||
auto wimg = aimg(img, img.nc()*1+1);
|
||||
cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc());
|
||||
#ifdef DLIB_USE_CUDA
|
||||
auto wout2 = aout(out2, out2.nc()*1+1);
|
||||
cuda::resize_bilinear(wout2,out2.nc(),out2.nr()*out2.nc(), wimg,img.nc(),img.nr()*img.nc());
|
||||
DLIB_TEST(max(abs(mat(out)-mat(out2))) < 1e-5);
|
||||
#endif
|
||||
|
||||
|
||||
resizable_tensor gradient_input;
|
||||
gradient_input.copy_size(out);
|
||||
tt::tensor_rand rnd;
|
||||
rnd.fill_uniform(gradient_input);
|
||||
|
||||
const float h = 1e-2;
|
||||
|
||||
img.host()[idx] = 2;
|
||||
out = 0;
|
||||
wout = aout(out, out.nc()*1+1);
|
||||
wimg = aimg(img, img.nc()*1+1);
|
||||
cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc());
|
||||
float f1 = dot(out, gradient_input);
|
||||
|
||||
img.host()[idx] = 2+h;
|
||||
out = 0;
|
||||
cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc());
|
||||
float f2 = dot(out, gradient_input);
|
||||
|
||||
const float numerical_grad = (f2-f1)/h;
|
||||
dlog << LINFO << "numerical grad: " << numerical_grad;
|
||||
|
||||
|
||||
resizable_tensor grad, grad2;
|
||||
grad.copy_size(img);
|
||||
grad = 0.1;
|
||||
grad2.copy_size(img);
|
||||
grad2 = 0.1;
|
||||
|
||||
auto wgrad2 = aimg(grad2, grad2.nc()*1+1);
|
||||
auto wgradient_input = aout(gradient_input, gradient_input.nc()*1+1);
|
||||
cpu::resize_bilinear_gradient(wgrad2,grad2.nc(),grad2.nr()*grad2.nc(), wgradient_input,gradient_input.nc(),gradient_input.nr()*gradient_input.nc());
|
||||
dlog << LINFO << "analytic grad: "<< grad2.host()[idx]-0.1;
|
||||
DLIB_TEST_MSG(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
wgrad2 = aimg(grad, grad.nc()*1+1);
|
||||
wgradient_input = aout(gradient_input, gradient_input.nc()*1+1);
|
||||
cuda::resize_bilinear_gradient(wgrad2,grad.nc(),grad.nr()*grad.nc(), wgradient_input,gradient_input.nc(),gradient_input.nr()*gradient_input.nc());
|
||||
dlog << LINFO << "analytic grad: "<< grad.host()[idx]-0.1;
|
||||
DLIB_TEST_MSG(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
|
||||
DLIB_TEST_MSG(max(abs(mat(grad)-mat(grad2))) < 1e-5, max(abs(mat(grad)-mat(grad2))));
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user