mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add Reorg Layer (#2496)
* Add Reorg Layer * Add Reorg Layer * Fix typo * fix grammar * add missing input <-> output mappings to reorg * Add reorg docs and term index entry * Update dlib/cuda/tensor_tools.h Co-authored-by: Davis E. King <davis@dlib.net>
This commit is contained in:
parent
c91959a73d
commit
ffca3b3a6d
@ -2080,6 +2080,84 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void reorg (
|
||||
tensor& dest,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& src
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(dest, src)==false);
|
||||
DLIB_CASSERT(src.nr() % row_stride == 0);
|
||||
DLIB_CASSERT(src.nc() % col_stride == 0);
|
||||
DLIB_CASSERT(dest.num_samples() == src.num_samples());
|
||||
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride);
|
||||
DLIB_CASSERT(dest.nr() == src.nr() / row_stride);
|
||||
DLIB_CASSERT(dest.nc() == src.nc() / col_stride);
|
||||
const float* s = src.host();
|
||||
float* d = dest.host();
|
||||
|
||||
parallel_for(0, dest.num_samples(), [&](long n)
|
||||
{
|
||||
for (long k = 0; k < dest.k(); ++k)
|
||||
{
|
||||
for (long r = 0; r < dest.nr(); ++r)
|
||||
{
|
||||
for (long c = 0; c < dest.nc(); ++c)
|
||||
{
|
||||
const auto out_idx = tensor_index(dest, n, k, r, c);
|
||||
const auto in_idx = tensor_index(src,
|
||||
n,
|
||||
k % src.k(),
|
||||
r * row_stride + (k / src.k()) / row_stride,
|
||||
c * col_stride + (k / src.k()) % col_stride);
|
||||
d[out_idx] = s[in_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void reorg_gradient (
|
||||
tensor& grad,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& gradient_input
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
|
||||
DLIB_CASSERT(grad.nr() % row_stride == 0);
|
||||
DLIB_CASSERT(grad.nc() % col_stride == 0);
|
||||
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples());
|
||||
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride);
|
||||
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride);
|
||||
DLIB_CASSERT(grad.nc() == gradient_input.nc() * row_stride);
|
||||
const float* gi = gradient_input.host();
|
||||
float* g = grad.host();
|
||||
|
||||
parallel_for(0, gradient_input.num_samples(), [&](long n)
|
||||
{
|
||||
for (long k = 0; k < gradient_input.k(); ++k)
|
||||
{
|
||||
for (long r = 0; r < gradient_input.nr(); ++r)
|
||||
{
|
||||
for (long c = 0; c < gradient_input.nc(); ++c)
|
||||
{
|
||||
const auto in_idx = tensor_index(gradient_input, n, k, r, c);
|
||||
const auto out_idx = tensor_index(grad,
|
||||
n,
|
||||
k % grad.k(),
|
||||
r * row_stride + (k / grad.k()) / row_stride,
|
||||
c * col_stride + (k / grad.k()) % col_stride);
|
||||
g[out_idx] += gi[in_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
// ------------------------------------------------------------------------------------
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
@ -449,6 +449,22 @@ namespace dlib
|
||||
const tensor& gradient_input
|
||||
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
void reorg (
|
||||
tensor& dest,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& src
|
||||
);
|
||||
|
||||
void reorg_gradient (
|
||||
tensor& grad,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& gradient_input
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------------
|
||||
|
||||
class pooling
|
||||
|
@ -1872,6 +1872,92 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_reorg(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d,
|
||||
size_t sk, size_t snr, int snc, const float* s,
|
||||
const size_t row_stride, const size_t col_stride)
|
||||
{
|
||||
const auto out_plane_size = dnr * dnc;
|
||||
const auto sample_size = dk * out_plane_size;
|
||||
for(auto i : grid_stride_range(0, dsize))
|
||||
{
|
||||
const auto n = i / sample_size;
|
||||
const auto idx = i % out_plane_size;
|
||||
const auto out_k = (i / out_plane_size) % dk;
|
||||
const auto out_r = idx / dnc;
|
||||
const auto out_c = idx % dnc;
|
||||
|
||||
const auto in_k = out_k % sk;
|
||||
const auto in_r = out_r * row_stride + (out_k / sk) / row_stride;
|
||||
const auto in_c = out_c * col_stride + (out_k / sk) % col_stride;
|
||||
|
||||
const auto in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c;
|
||||
d[i] = s[in_idx];
|
||||
}
|
||||
}
|
||||
__global__ void _cuda_reorg_gradient(size_t ssize, size_t dk, size_t dnr, size_t dnc, float* d,
|
||||
size_t sk, size_t snr, int snc, const float* s,
|
||||
const size_t row_stride, const size_t col_stride)
|
||||
{
|
||||
const auto in_plane_size = snr * snc;
|
||||
const auto sample_size = sk * in_plane_size;
|
||||
for(auto i : grid_stride_range(0, ssize))
|
||||
{
|
||||
const auto n = i / sample_size;
|
||||
const auto idx = i % in_plane_size;
|
||||
const auto in_k = (i / in_plane_size) % sk;
|
||||
const auto in_r = idx / snc;
|
||||
const auto in_c = idx % snc;
|
||||
|
||||
const auto out_k = in_k % dk;
|
||||
const auto out_r = in_r * row_stride + (in_k / dk) / row_stride;
|
||||
const auto out_c = in_c * col_stride + (in_k / dk) % col_stride;
|
||||
|
||||
const auto out_idx = ((n * dk + out_k) * dnr + out_r) * dnc + out_c;
|
||||
d[out_idx] += s[i];
|
||||
}
|
||||
}
|
||||
|
||||
void reorg (
|
||||
tensor& dest,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& src
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(dest, src)==false);
|
||||
DLIB_CASSERT(src.nr() % row_stride == 0);
|
||||
DLIB_CASSERT(src.nc() % col_stride == 0);
|
||||
DLIB_CASSERT(dest.num_samples() == src.num_samples());
|
||||
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride);
|
||||
DLIB_CASSERT(dest.nr() == src.nr() / row_stride);
|
||||
DLIB_CASSERT(dest.nc() == src.nc() / col_stride);
|
||||
|
||||
launch_kernel(_cuda_reorg, dest.size(), dest.k(), dest.nr(), dest.nc(), dest.device(),
|
||||
src.k(), src.nr(), src.nc(), src.device(), row_stride, col_stride);
|
||||
}
|
||||
|
||||
void reorg_gradient (
|
||||
tensor& grad,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& gradient_input
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
|
||||
DLIB_CASSERT(grad.nr() % row_stride == 0);
|
||||
DLIB_CASSERT(grad.nc() % col_stride == 0);
|
||||
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples());
|
||||
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride);
|
||||
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride);
|
||||
DLIB_CASSERT(grad.nc() == gradient_input.nc() * row_stride);
|
||||
|
||||
launch_kernel(_cuda_reorg_gradient, gradient_input.size(), grad.k(), grad.nr(), grad.nc(), grad.device(),
|
||||
gradient_input.k(), gradient_input.nr(), gradient_input.nc(), gradient_input.device(),
|
||||
row_stride, col_stride);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_layer_normalize(float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t num)
|
||||
|
@ -493,6 +493,22 @@ namespace dlib
|
||||
const tensor& gradient_input
|
||||
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void reorg (
|
||||
tensor& dest,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& src
|
||||
);
|
||||
|
||||
void reorg_gradient (
|
||||
tensor& grad,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& gradient_input
|
||||
);
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void copy_tensor(
|
||||
|
@ -1118,6 +1118,36 @@ namespace dlib { namespace tt
|
||||
#endif
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void reorg (
|
||||
tensor& dest,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& src
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::reorg(dest, row_stride, col_stride, src);
|
||||
#else
|
||||
cpu::reorg(dest, row_stride, col_stride, src);
|
||||
#endif
|
||||
}
|
||||
|
||||
void reorg_gradient (
|
||||
tensor& grad,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& gradient_input
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::reorg_gradient(grad, row_stride, col_stride, gradient_input);
|
||||
#else
|
||||
cpu::reorg_gradient(grad, row_stride, col_stride, gradient_input);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void copy_tensor(
|
||||
@ -1156,4 +1186,3 @@ namespace dlib { namespace tt
|
||||
}}
|
||||
|
||||
#endif // DLIB_TeNSOR_TOOLS_CPP_
|
||||
|
||||
|
@ -1833,6 +1833,59 @@ namespace dlib { namespace tt
|
||||
as DEST.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void reorg (
|
||||
tensor& dest,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& src
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_same_object(dest, src)==false
|
||||
- src.nr() % row_stride == 0
|
||||
- src.nc() % col_stride == 0
|
||||
- dest.num_samples() == src.num_samples()
|
||||
- dest.k() == src.k() * row_stride * col_stride
|
||||
- dest.nr() == src.nr() / row_stride
|
||||
- dest.nc() == src.nc() / col_stride
|
||||
ensures
|
||||
- Converts the spatial resolution into channel information. So all the values in the input tensor
|
||||
appear in the output tensor, just in different positions.
|
||||
- For all n, k, r, c in dest:
|
||||
dest.host[tensor_index(dest, n, k, r, c)] ==
|
||||
src.host[tensor_index(src,
|
||||
n,
|
||||
k % src.k(),
|
||||
r * row_stride + (k / src.k()) / row_stride,
|
||||
c * col_stride + (k / src.k()) % col_stride)]
|
||||
|
||||
|
||||
!*/
|
||||
|
||||
void reorg_gradient (
|
||||
tensor& grad,
|
||||
const int row_stride,
|
||||
const int col_stride,
|
||||
const tensor& gradient_input
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_same_object(dest, src)==false
|
||||
- gradient_input.nr % row_stride == 0
|
||||
- gradient_input.nc % col_stride == 0
|
||||
- dest.num_samples() == src.num_samples()
|
||||
- grad.k() == gradient_input.k() / row_stride / col_stride
|
||||
- grad.nr() == gradient_input.nr() * row_stride
|
||||
- grad.nc() == gradient_input.nc() * col_stride
|
||||
ensures
|
||||
- Suppose that DEST is the output of reog(DEST, row_stride, col_stride, SRC)
|
||||
for some SRC tensor, let f(SRC) == dot(gradient_input,DEST). Then this
|
||||
function computes the gradient of f() with respect to SRC and adds it to grad.
|
||||
- It effectively reverts the reorg operation
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class multi_device_tensor_averager
|
||||
|
@ -3273,8 +3273,8 @@ namespace dlib
|
||||
// layer.
|
||||
const long num_samples = rnd.get_random_32bit_number()%4+3;
|
||||
const long k = rnd.get_random_32bit_number()%4+2;
|
||||
const long nr = rnd.get_random_32bit_number()%4+2;
|
||||
const long nc = rnd.get_random_32bit_number()%4+2;
|
||||
const long nr = ((rnd.get_random_32bit_number()%4)/2)*2+2;
|
||||
const long nc = ((rnd.get_random_32bit_number()%4)/2)*2+2;
|
||||
|
||||
output.set_size(num_samples, k, nr, nc);
|
||||
gradient_input.set_size(num_samples, k, nr, nc);
|
||||
|
@ -4327,6 +4327,107 @@ namespace dlib
|
||||
>
|
||||
using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <long long row_stride = 2, long long col_stride = 2>
|
||||
class reorg_
|
||||
{
|
||||
static_assert(row_stride >= 1, "The row_stride must be >= 1");
|
||||
static_assert(row_stride >= 1, "The col_stride must be >= 1");
|
||||
|
||||
public:
|
||||
reorg_(
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& sub)
|
||||
{
|
||||
DLIB_CASSERT(sub.get_output().nr() % row_stride == 0);
|
||||
DLIB_CASSERT(sub.get_output().nc() % col_stride == 0);
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||
{
|
||||
output.set_size(
|
||||
sub.get_output().num_samples(),
|
||||
sub.get_output().k() * col_stride * row_stride,
|
||||
sub.get_output().nr() / row_stride,
|
||||
sub.get_output().nc() / col_stride
|
||||
);
|
||||
tt::reorg(output, row_stride, col_stride, sub.get_output());
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
|
||||
{
|
||||
tt::reorg_gradient(sub.get_gradient_input(), row_stride, col_stride, gradient_input);
|
||||
}
|
||||
|
||||
inline dpoint map_input_to_output (dpoint p) const
|
||||
{
|
||||
p.x() = p.x() / col_stride;
|
||||
p.y() = p.y() / row_stride;
|
||||
return p;
|
||||
}
|
||||
inline dpoint map_output_to_input (dpoint p) const
|
||||
{
|
||||
p.x() = p.x() * col_stride;
|
||||
p.y() = p.y() * row_stride;
|
||||
return p;
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const reorg_& /*item*/, std::ostream& out)
|
||||
{
|
||||
serialize("reorg_", out);
|
||||
serialize(row_stride, out);
|
||||
serialize(col_stride, out);
|
||||
}
|
||||
|
||||
friend void deserialize(reorg_& /*item*/, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "reorg_")
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::reorg_.");
|
||||
long long rs;
|
||||
long long cs;
|
||||
deserialize(rs, in);
|
||||
deserialize(cs, in);
|
||||
if (rs != row_stride) throw serialization_error("Wrong row_stride found while deserializing dlib::reorg_");
|
||||
if (cs != col_stride) throw serialization_error("Wrong col_stride found while deserializing dlib::reorg_");
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const reorg_& /*item*/)
|
||||
{
|
||||
out << "reorg\t ("
|
||||
<< "row_stride=" << row_stride
|
||||
<< ", col_stride=" << col_stride
|
||||
<< ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const reorg_ /*item*/, std::ostream& out)
|
||||
{
|
||||
out << "<reorg";
|
||||
out << " row_stride='" << row_stride << "'";
|
||||
out << " col_stride='" << col_stride << "'";
|
||||
out << "/>\n";
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params; // unused
|
||||
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using reorg = add_layer<reorg_<2, 2>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace impl
|
||||
|
@ -3300,6 +3300,59 @@ namespace dlib
|
||||
>
|
||||
using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <long long row_stride = 2, long long col_stride = 2>
|
||||
class reorg_
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON TEMPLATE ARGUMENTS
|
||||
- row_stride >= 1
|
||||
- col_stride >= 1
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
|
||||
defined above. In particular, the output of this layer is simply a copy of
|
||||
the input tensor. However, it rearranges spatial information along the
|
||||
channel dimension. The dimensions of the tensor output by this layer are as
|
||||
follows (letting IN be the input tensor and OUT the output tensor):
|
||||
- OUT.num_samples() == IN.num_samples()
|
||||
- OUT.k() == IN.k() * row_stride * col_stride
|
||||
- OUT.nr() == IN.nr() / row_stride
|
||||
- OUT.nc() == IN.nc() / col_stride
|
||||
|
||||
So the output will always have the same number of samples as the input, but
|
||||
within each sample (the k,nr,nc part) we will reorganize the values. To be
|
||||
very precise, we will have, for all n, k, r, c in OUT:
|
||||
OUT.host[tensor_index(OUT, n, k, r, c)] ==
|
||||
IN.host[tensor_index(IN,
|
||||
n,
|
||||
k % IN.k(),
|
||||
r * row_stride + (k / IN.k()) / row_stride,
|
||||
c * col_stride + (k / IN.k()) % col_stride)]
|
||||
|
||||
|
||||
Finally, you can think of this layer as an alternative to a strided convolutonal
|
||||
layer to downsample a tensor.
|
||||
!*/
|
||||
|
||||
public:
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
|
||||
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
|
||||
dpoint map_input_to_output (dpoint p) const;
|
||||
dpoint map_output_to_input (dpoint p) const;
|
||||
const tensor& get_layer_params() const;
|
||||
tensor& get_layer_params();
|
||||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
|
||||
!*/
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using reorg = add_layer<reorg_<2, 2>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename net_type>
|
||||
|
@ -1839,6 +1839,12 @@ namespace
|
||||
|
||||
void test_layers()
|
||||
{
|
||||
{
|
||||
print_spinner();
|
||||
reorg_<2,2> l;
|
||||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
extract_<0,2,2,2> l;
|
||||
@ -4187,6 +4193,26 @@ namespace
|
||||
DLIB_TEST(max(squared(mat(out_nobias) - mat(out_nobias_fused))) < 1e-10);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_reorg()
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
print_spinner();
|
||||
resizable_tensor x(2, 4, 8, 16);
|
||||
resizable_tensor out_cpu(2, 16, 4, 8), out_cuda(2, 16, 4, 8);
|
||||
resizable_tensor grad_cpu(x), grad_cuda(x);
|
||||
tt::tensor_rand rnd;
|
||||
rnd.fill_gaussian(x);
|
||||
cpu::reorg(out_cpu, 2, 2, x);
|
||||
cuda::reorg(out_cuda, 2, 2, x);
|
||||
DLIB_TEST(max(squared(mat(out_cuda) - mat(out_cpu))) == 0);
|
||||
cpu::reorg_gradient(grad_cpu, 2, 2, out_cpu);
|
||||
cuda::reorg_gradient(grad_cuda, 2, 2, out_cuda);
|
||||
DLIB_TEST(max(squared(mat(out_cuda) - mat(out_cpu))) == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class dnn_tester : public tester
|
||||
@ -4294,6 +4320,7 @@ namespace
|
||||
test_set_learning_rate_multipliers();
|
||||
test_input_ouput_mappers();
|
||||
test_fuse_layers();
|
||||
test_reorg();
|
||||
}
|
||||
|
||||
void perform_test()
|
||||
|
@ -165,6 +165,10 @@ Davis E. King. <a href="http://jmlr.csail.mit.edu/papers/volume10/king09a/king09
|
||||
<name>extract</name>
|
||||
<link>dlib/dnn/layers_abstract.h.html#extract_</link>
|
||||
</item>
|
||||
<item>
|
||||
<name>reorg</name>
|
||||
<link>dlib/dnn/layers_abstract.h.html#reorg_</link>
|
||||
</item>
|
||||
<item>
|
||||
<name>mult_prev</name>
|
||||
<link>dlib/dnn/layers_abstract.h.html#mult_prev_</link>
|
||||
|
@ -170,6 +170,7 @@
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="resize_prev_to_tagged_" include="dlib/dnn.h"/>
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="mult_prev_" include="dlib/dnn.h"/>
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="extract_" include="dlib/dnn.h"/>
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="reorg_" include="dlib/dnn.h"/>
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="upsample_" include="dlib/dnn.h"/>
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="cont_" include="dlib/dnn.h"/>
|
||||
<term file="dlib/dnn/layers_abstract.h.html" name="scale_" include="dlib/dnn.h"/>
|
||||
|
Loading…
Reference in New Issue
Block a user