mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Changed conv layer to use cross-correlation rather than convolution.
This commit is contained in:
parent
595f01289b
commit
8e6d8ae01a
@ -1631,9 +1631,11 @@ namespace dlib
|
||||
|
||||
// now fill in the Toeplitz output matrix for the n-th sample in data.
|
||||
size_t cnt = 0;
|
||||
for (long r = filter_nr-1-padding_y; r-padding_y < data.nr(); r+=stride_y)
|
||||
const long max_r = data.nr() + padding_y-(filter_nr-1);
|
||||
const long max_c = data.nc() + padding_x-(filter_nc-1);
|
||||
for (long r = -padding_y; r < max_r; r+=stride_y)
|
||||
{
|
||||
for (long c = filter_nc-1-padding_x; c-padding_x < data.nc(); c+=stride_x)
|
||||
for (long c = -padding_x; c < max_c; c+=stride_x)
|
||||
{
|
||||
for (long k = 0; k < data.k(); ++k)
|
||||
{
|
||||
@ -1642,8 +1644,8 @@ namespace dlib
|
||||
for (long x = 0; x < filter_nc; ++x)
|
||||
{
|
||||
DLIB_ASSERT(cnt < output.size(),"");
|
||||
long xx = c-x;
|
||||
long yy = r-y;
|
||||
long xx = c+x;
|
||||
long yy = r+y;
|
||||
if (boundary.contains(xx,yy))
|
||||
*t = d[(k*data.nr() + yy)*data.nc() + xx];
|
||||
else
|
||||
@ -1676,9 +1678,11 @@ namespace dlib
|
||||
const float* t = &output(0,0);
|
||||
|
||||
// now fill in the Toeplitz output matrix for the n-th sample in data.
|
||||
for (long r = filter_nr-1-padding_y; r-padding_y < data.nr(); r+=stride_y)
|
||||
const long max_r = data.nr() + padding_y-(filter_nr-1);
|
||||
const long max_c = data.nc() + padding_x-(filter_nc-1);
|
||||
for (long r = -padding_y; r < max_r; r+=stride_y)
|
||||
{
|
||||
for (long c = filter_nc-1-padding_x; c-padding_x < data.nc(); c+=stride_x)
|
||||
for (long c = -padding_x; c < max_c; c+=stride_x)
|
||||
{
|
||||
for (long k = 0; k < data.k(); ++k)
|
||||
{
|
||||
@ -1686,8 +1690,8 @@ namespace dlib
|
||||
{
|
||||
for (long x = 0; x < filter_nc; ++x)
|
||||
{
|
||||
long xx = c-x;
|
||||
long yy = r-y;
|
||||
long xx = c+x;
|
||||
long yy = r+y;
|
||||
if (boundary.contains(xx,yy))
|
||||
d[(k*data.nr() + yy)*data.nc() + xx] += *t;
|
||||
++t;
|
||||
|
@ -827,7 +827,7 @@ namespace dlib
|
||||
stride_y,
|
||||
stride_x,
|
||||
1, 1, // must be 1,1
|
||||
CUDNN_CONVOLUTION)); // could also be CUDNN_CROSS_CORRELATION
|
||||
CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION
|
||||
|
||||
CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
|
@ -160,7 +160,7 @@ namespace dlib
|
||||
|
||||
friend void serialize(const con_& item, std::ostream& out)
|
||||
{
|
||||
serialize("con_3", out);
|
||||
serialize("con_4", out);
|
||||
serialize(item.params, out);
|
||||
serialize(_num_filters, out);
|
||||
serialize(_nr, out);
|
||||
@ -186,6 +186,33 @@ namespace dlib
|
||||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
if (version == "con_4")
|
||||
{
|
||||
deserialize(item.params, in);
|
||||
deserialize(num_filters, in);
|
||||
deserialize(nr, in);
|
||||
deserialize(nc, in);
|
||||
deserialize(stride_y, in);
|
||||
deserialize(stride_x, in);
|
||||
deserialize(item.padding_y_, in);
|
||||
deserialize(item.padding_x_, in);
|
||||
deserialize(item.filters, in);
|
||||
deserialize(item.biases, in);
|
||||
deserialize(item.learning_rate_multiplier, in);
|
||||
deserialize(item.weight_decay_multiplier, in);
|
||||
deserialize(item.bias_learning_rate_multiplier, in);
|
||||
deserialize(item.bias_weight_decay_multiplier, in);
|
||||
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
|
||||
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
|
||||
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
|
||||
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
|
||||
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
|
||||
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
|
||||
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
if (version == "con_")
|
||||
{
|
||||
deserialize(item.params, in);
|
||||
@ -237,6 +264,20 @@ namespace dlib
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
|
||||
}
|
||||
|
||||
|
||||
// now flip all the filters
|
||||
alias_tensor at(_nr, _nc);
|
||||
size_t off = 0;
|
||||
for (int i = 0; i < item.filters.num_samples(); ++i)
|
||||
{
|
||||
for (int j = 0; j < item.filters.k(); ++j)
|
||||
{
|
||||
auto temp = at(item.params,off);
|
||||
off += _nr*_nc;
|
||||
temp = flipud(fliplr(mat(temp)));
|
||||
}
|
||||
}
|
||||
|
||||
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
|
||||
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
|
||||
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
|
||||
|
Loading…
Reference in New Issue
Block a user