mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add support for fused convolutions (#2294)
* add helper methods to implement fused convolutions
* fix grammar
* add method to disable affine layer and updated serialization
* add documentation for .disable()
* add fuse_convolutions visitor and documentation
* update docs: net is not constant
* fix xml formatting and use std::boolalpha
* fix warning and updated net requirement for visitor
* fix segfault in fuse_convolutions visitor
* copy unconditionally
* make the visitor class a friend of the con_ class
* setup the biases alias tensor after enabling bias
* simplify visitor a bit
* fix comment
* setup the biases size, somehow this got lost
* copy the parameters before resizing
* remove enable_bias() method, since the visitor is now a friend
* Revert "remove enable_bias() method, since the visitor is now a friend"
This reverts commit 35b92b1631
.
* update the visitor to remove the friend requirement
* improve behavior of enable_bias
* better describe the behavior of enable_bias
* wip: use cudnncudnnConvolutionBiasActivationForward when activation has bias
* wip: fix cpu compilation
* WIP: not working fused ReLU
* WIP: forgot do disable ReLU in visitor (does not change the fact that it does not work)
* WIP: more general set of 4d tensor (still not working)
* fused convolutions seem to be working now, more testing needed
* move visitor to the bottom of the file
* fix CPU-side and code clean up
* Do not try to fuse the activation layers
Fusing the activation layers in one cuDNN call is only supported when using
the cuDNN ones (ReLU, Sigmoid, TanH...) which might lead to suprising
behavior. So, let's just fuse the batch norm and the convolution into one
cuDNN call using the IDENTITY activation function.
* Set the correct forward algorithm for the identity activation
Ref: https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
* move the affine alias template to its original position
* wip
* remove unused param in relu and simplify example (I will delete it before merge)
* simplify conv bias logic and fix deserialization issue
* fix enabling bias on convolutions
* remove test example
* fix typo
* update documentation
* update documentation
* remove ccache leftovers from CMakeLists.txt
* Re-add new line
* fix enable/disable bias on unallocated networks
* update comment to mention cudnnConvolutionBiasActivationForward
* fix typo
Co-authored-by: Davis E. King <davis@dlib.net>
* Apply documentation suggestions from code review
Co-authored-by: Davis E. King <davis@dlib.net>
* update affine docs to talk in terms of gamma and beta
* simplify tensor_conv interface
* fix tensor_conv operator() with biases
* add fuse_layers test
* add an example on how to use the fuse_layers function
* fix typo
Co-authored-by: Davis E. King <davis@dlib.net>
This commit is contained in:
parent
8a2c744207
commit
adca7472df
@ -2465,6 +2465,33 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
void tensor_conv::operator() (
|
||||
const bool add_to_output,
|
||||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(filters.num_samples() == biases.k());
|
||||
(*this)(add_to_output, output,data,filters);
|
||||
tt::add(1, output, 1, biases);
|
||||
}
|
||||
|
||||
void tensor_conv::operator() (
|
||||
const bool add_to_output,
|
||||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(filters.num_samples() == biases.k());
|
||||
(*this)(add_to_output, output, data, filters);
|
||||
tt::add(1, output, 1, biases);
|
||||
}
|
||||
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void tensor_conv::
|
||||
|
@ -554,6 +554,22 @@ namespace dlib
|
||||
const tensor& filters
|
||||
);
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
);
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
);
|
||||
|
||||
void get_gradient_for_data (
|
||||
const bool add_to_output,
|
||||
const tensor& gradient_input,
|
||||
|
@ -152,6 +152,12 @@ namespace dlib
|
||||
cudnnActivationDescriptor_t handle;
|
||||
};
|
||||
|
||||
static cudnnActivationDescriptor_t identity_activation_descriptor()
|
||||
{
|
||||
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_IDENTITY, CUDNN_PROPAGATE_NAN,0);
|
||||
return des.get_handle();
|
||||
}
|
||||
|
||||
static cudnnActivationDescriptor_t relu_activation_descriptor()
|
||||
{
|
||||
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,0);
|
||||
@ -1128,6 +1134,89 @@ namespace dlib
|
||||
&beta,
|
||||
descriptor(output),
|
||||
output.device()));
|
||||
|
||||
}
|
||||
|
||||
void tensor_conv::operator() (
|
||||
const bool add_to_output,
|
||||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function");
|
||||
|
||||
output.set_size(out_num_samples, out_k, out_nr, out_nc);
|
||||
(*this)(add_to_output, static_cast<tensor&>(output), data, filters, biases);
|
||||
}
|
||||
|
||||
void tensor_conv::operator() (
|
||||
const bool add_to_output,
|
||||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(is_same_object(output,data) == false);
|
||||
DLIB_CASSERT(is_same_object(output,filters) == false);
|
||||
DLIB_CASSERT(filters.k() == data.k());
|
||||
DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function");
|
||||
DLIB_CASSERT(filters.nc() <= data.nc() + 2*padding_x,
|
||||
"Filter windows must be small enough to fit into the padded image."
|
||||
<< "\n\t filters.nc(): " << filters.nc()
|
||||
<< "\n\t data.nc(): " << data.nc()
|
||||
<< "\n\t padding_x: " << padding_x
|
||||
);
|
||||
DLIB_CASSERT(filters.nr() <= data.nr() + 2*padding_y,
|
||||
"Filter windows must be small enough to fit into the padded image."
|
||||
<< "\n\t filters.nr(): " << filters.nr()
|
||||
<< "\n\t data.nr(): " << data.nr()
|
||||
<< "\n\t padding_y: " << padding_y
|
||||
);
|
||||
|
||||
|
||||
DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples());
|
||||
DLIB_CASSERT(output.k() == filters.num_samples());
|
||||
DLIB_CASSERT(output.nr() == 1+(data.nr()+2*padding_y-filters.nr())/stride_y);
|
||||
DLIB_CASSERT(output.nc() == 1+(data.nc()+2*padding_x-filters.nc())/stride_x);
|
||||
DLIB_CASSERT(filters.num_samples() == biases.k());
|
||||
|
||||
|
||||
|
||||
const float alpha1 = 1;
|
||||
const float alpha2 = add_to_output ? 1 : 0;
|
||||
|
||||
// Since cudnnConvolutionBiasActivationForward() is an asynchronous call,
|
||||
// we need to hold a reference to the workspace buffer so we can be sure it
|
||||
// isn't reallocated while the function is still executing on the device.
|
||||
// But each time we come here, we make sure to grab the latest workspace
|
||||
// buffer so that, globally, we minimize the number of such buffers.
|
||||
forward_workspace = device_global_buffer(forward_workspace_size_in_bytes);
|
||||
|
||||
float* out = output.device();
|
||||
const cudnnTensorDescriptor_t out_desc = descriptor(output);
|
||||
|
||||
CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
|
||||
context(),
|
||||
&alpha1,
|
||||
descriptor(data),
|
||||
data.device(),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
filters.device(),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(cudnnConvolutionFwdAlgo_t)forward_algo,
|
||||
forward_workspace,
|
||||
forward_workspace_size_in_bytes,
|
||||
&alpha2,
|
||||
out_desc,
|
||||
out,
|
||||
descriptor(biases),
|
||||
biases.device(),
|
||||
identity_activation_descriptor(),
|
||||
out_desc,
|
||||
out));
|
||||
}
|
||||
|
||||
void tensor_conv::get_gradient_for_data (
|
||||
|
@ -171,6 +171,13 @@ namespace dlib
|
||||
~tensor_conv (
|
||||
);
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters
|
||||
);
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
tensor& output,
|
||||
@ -182,7 +189,16 @@ namespace dlib
|
||||
const bool add_to_output,
|
||||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
);
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
);
|
||||
|
||||
void get_gradient_for_data (
|
||||
@ -208,6 +224,16 @@ namespace dlib
|
||||
int padding_x
|
||||
);
|
||||
|
||||
void setup(
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases,
|
||||
int stride_y,
|
||||
int stride_x,
|
||||
int padding_y,
|
||||
int padding_x
|
||||
);
|
||||
|
||||
private:
|
||||
|
||||
// These variables record the type of data given to the last call to setup().
|
||||
|
@ -1042,6 +1042,64 @@ namespace dlib { namespace tt
|
||||
- #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
|
||||
!*/
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
) { impl(add_to_output,output,data,filters,biases); }
|
||||
/*!
|
||||
requires
|
||||
- setup() has been called. Specifically, setup() has been called like this:
|
||||
this->setup(data, filters, stride_y, stride_x, padding_y, padding_x);
|
||||
- is_same_object(output,data) == false
|
||||
- is_same_object(output,filters) == false
|
||||
- filters.k() == data.k()
|
||||
- filters.nr() <= src.nr() + 2*padding_y
|
||||
- filters.nc() <= src.nc() + 2*padding_x
|
||||
- filters.num_samples() == biases.k()
|
||||
- #output.num_samples() == data.num_samples()
|
||||
- #output.k() == filters.num_samples()
|
||||
- #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y
|
||||
- #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
|
||||
ensures
|
||||
- Convolves filters over data. If add_to_output==true then we add the
|
||||
results to output, otherwise we assign to output, overwriting the
|
||||
previous values in output.
|
||||
- Adds biases to the result of the convolved data
|
||||
- filters contains filters.num_samples() filters.
|
||||
!*/
|
||||
|
||||
void operator() (
|
||||
const bool add_to_output,
|
||||
resizable_tensor& output,
|
||||
const tensor& data,
|
||||
const tensor& filters,
|
||||
const tensor& biases
|
||||
) { impl(add_to_output,output,data,filters, biases); }
|
||||
/*!
|
||||
requires
|
||||
- setup() has been called. Specifically, setup() has been called like this:
|
||||
this->setup(data, filters, stride_y, stride_x, padding_y, padding_x);
|
||||
- is_same_object(output,data) == false
|
||||
- is_same_object(output,filters) == false
|
||||
- filters.k() == data.k()
|
||||
- filters.nr() <= src.nr() + 2*padding_y
|
||||
- filters.nc() <= src.nc() + 2*padding_x
|
||||
- filters.num_samples() == biases.k()
|
||||
ensures
|
||||
- Convolves filters over data. If add_to_output==true then we add the
|
||||
results to output, otherwise we assign to output, overwriting the
|
||||
previous values in output.
|
||||
- Adds biases to the result of the convolved data
|
||||
- filters contains filters.num_samples() filters.
|
||||
- #output.num_samples() == data.num_samples()
|
||||
- #output.k() == filters.num_samples()
|
||||
- #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y
|
||||
- #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
|
||||
!*/
|
||||
|
||||
void get_gradient_for_data (
|
||||
const bool add_to_output,
|
||||
const tensor& gradient_input,
|
||||
@ -1141,7 +1199,7 @@ namespace dlib { namespace tt
|
||||
the tensors, or store any kind of references to the data or filter
|
||||
tensors.
|
||||
!*/
|
||||
|
||||
|
||||
private:
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::tensor_conv impl;
|
||||
|
@ -107,8 +107,38 @@ namespace dlib
|
||||
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
|
||||
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
|
||||
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
|
||||
void disable_bias() { use_bias = false; }
|
||||
bool bias_is_disabled() const { return !use_bias; }
|
||||
void disable_bias()
|
||||
{
|
||||
if (use_bias == false)
|
||||
return;
|
||||
|
||||
use_bias = false;
|
||||
if (params.size() == 0)
|
||||
return;
|
||||
|
||||
DLIB_CASSERT(params.size() == filters.size() + num_filters_);
|
||||
auto temp = params;
|
||||
params.set_size(params.size() - num_filters_);
|
||||
std::copy(temp.begin(), temp.end() - num_filters_, params.begin());
|
||||
biases = alias_tensor();
|
||||
}
|
||||
void enable_bias()
|
||||
{
|
||||
if (use_bias == true)
|
||||
return;
|
||||
|
||||
use_bias = true;
|
||||
if (params.size() == 0)
|
||||
return;
|
||||
|
||||
DLIB_CASSERT(params.size() == filters.size());
|
||||
auto temp = params;
|
||||
params.set_size(params.size() + num_filters_);
|
||||
std::copy(temp.begin(), temp.end(), params.begin());
|
||||
biases = alias_tensor(1, num_filters_);
|
||||
biases(params, filters.size()) = 0;
|
||||
}
|
||||
|
||||
inline dpoint map_input_to_output (
|
||||
dpoint p
|
||||
@ -202,12 +232,18 @@ namespace dlib
|
||||
_stride_x,
|
||||
padding_y_,
|
||||
padding_x_);
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0));
|
||||
if (use_bias)
|
||||
{
|
||||
tt::add(1,output,1,biases(params,filters.size()));
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0),
|
||||
biases(params, filters.size()));
|
||||
}
|
||||
else
|
||||
{
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,7 +251,7 @@ namespace dlib
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
|
||||
{
|
||||
conv.get_gradient_for_data (true, gradient_input, filters(params,0), sub.get_gradient_input());
|
||||
// no dpoint computing the parameter gradients if they won't be used.
|
||||
// no point computing the parameter gradients if they won't be used.
|
||||
if (learning_rate_multiplier != 0)
|
||||
{
|
||||
auto filt = filters(params_grad,0);
|
||||
@ -354,7 +390,6 @@ namespace dlib
|
||||
int padding_y_;
|
||||
int padding_x_;
|
||||
bool use_bias;
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
@ -2341,7 +2376,7 @@ namespace dlib
|
||||
|
||||
auto g = gamma(params,0);
|
||||
auto b = beta(params,gamma.size());
|
||||
|
||||
|
||||
resizable_tensor temp(item.params);
|
||||
auto sg = gamma(temp,0);
|
||||
auto sb = beta(temp,gamma.size());
|
||||
@ -2352,12 +2387,23 @@ namespace dlib
|
||||
|
||||
layer_mode get_mode() const { return mode; }
|
||||
|
||||
void disable()
|
||||
{
|
||||
params.clear();
|
||||
disabled = true;
|
||||
}
|
||||
|
||||
bool is_disabled() const { return disabled; }
|
||||
|
||||
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
|
||||
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& sub)
|
||||
{
|
||||
if (disabled)
|
||||
return;
|
||||
|
||||
if (mode == FC_MODE)
|
||||
{
|
||||
gamma = alias_tensor(1,
|
||||
@ -2379,6 +2425,9 @@ namespace dlib
|
||||
|
||||
void forward_inplace(const tensor& input, tensor& output)
|
||||
{
|
||||
if (disabled)
|
||||
return;
|
||||
|
||||
auto g = gamma(params,0);
|
||||
auto b = beta(params,gamma.size());
|
||||
if (mode == FC_MODE)
|
||||
@ -2393,6 +2442,9 @@ namespace dlib
|
||||
tensor& /*params_grad*/
|
||||
)
|
||||
{
|
||||
if (disabled)
|
||||
return;
|
||||
|
||||
auto g = gamma(params,0);
|
||||
auto b = beta(params,gamma.size());
|
||||
|
||||
@ -2413,16 +2465,23 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
alias_tensor_instance get_gamma() { return gamma(params, 0); };
|
||||
alias_tensor_const_instance get_gamma() const { return gamma(params, 0); };
|
||||
|
||||
alias_tensor_instance get_beta() { return beta(params, gamma.size()); };
|
||||
alias_tensor_const_instance get_beta() const { return beta(params, gamma.size()); };
|
||||
|
||||
const tensor& get_layer_params() const { return empty_params; }
|
||||
tensor& get_layer_params() { return empty_params; }
|
||||
|
||||
friend void serialize(const affine_& item, std::ostream& out)
|
||||
{
|
||||
serialize("affine_", out);
|
||||
serialize("affine_2", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item.gamma, out);
|
||||
serialize(item.beta, out);
|
||||
serialize((int)item.mode, out);
|
||||
serialize(item.disabled, out);
|
||||
}
|
||||
|
||||
friend void deserialize(affine_& item, std::istream& in)
|
||||
@ -2450,7 +2509,7 @@ namespace dlib
|
||||
return;
|
||||
}
|
||||
|
||||
if (version != "affine_")
|
||||
if (version != "affine_" && version != "affine_2")
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
|
||||
deserialize(item.params, in);
|
||||
deserialize(item.gamma, in);
|
||||
@ -2458,21 +2517,27 @@ namespace dlib
|
||||
int mode;
|
||||
deserialize(mode, in);
|
||||
item.mode = (layer_mode)mode;
|
||||
if (version == "affine_2")
|
||||
deserialize(item.disabled, in);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const affine_& /*item*/)
|
||||
friend std::ostream& operator<<(std::ostream& out, const affine_& item)
|
||||
{
|
||||
out << "affine";
|
||||
if (item.disabled)
|
||||
out << "\t (disabled)";
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const affine_& item, std::ostream& out)
|
||||
{
|
||||
if (item.mode==CONV_MODE)
|
||||
out << "<affine_con>\n";
|
||||
out << "<affine_con";
|
||||
else
|
||||
out << "<affine_fc>\n";
|
||||
|
||||
out << "<affine_fc";
|
||||
if (item.disabled)
|
||||
out << " disabled='"<< std::boolalpha << item.disabled << "'";
|
||||
out << ">\n";
|
||||
out << mat(item.params);
|
||||
|
||||
if (item.mode==CONV_MODE)
|
||||
@ -2485,6 +2550,7 @@ namespace dlib
|
||||
resizable_tensor params, empty_params;
|
||||
alias_tensor gamma, beta;
|
||||
layer_mode mode;
|
||||
bool disabled = false;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
@ -4261,6 +4327,83 @@ namespace dlib
|
||||
>
|
||||
using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace impl
|
||||
{
|
||||
class visitor_fuse_layers
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
void fuse_convolution(T&) const
|
||||
{
|
||||
// disable other layer types
|
||||
}
|
||||
|
||||
// handle the standard case (convolutional layer followed by affine;
|
||||
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
|
||||
void fuse_convolution(add_layer<affine_, add_layer<con_<nf, nr, nc, sy, sx, py, px>, U>, E>& l)
|
||||
{
|
||||
if (l.layer_details().is_disabled())
|
||||
return;
|
||||
|
||||
// get the convolution below the affine layer
|
||||
auto& conv = l.subnet().layer_details();
|
||||
|
||||
// get the parameters from the affine layer as alias_tensor_instance
|
||||
alias_tensor_instance gamma = l.layer_details().get_gamma();
|
||||
alias_tensor_instance beta = l.layer_details().get_beta();
|
||||
|
||||
if (conv.bias_is_disabled())
|
||||
{
|
||||
conv.enable_bias();
|
||||
}
|
||||
|
||||
tensor& params = conv.get_layer_params();
|
||||
|
||||
// update the biases
|
||||
auto biases = alias_tensor(1, conv.num_filters());
|
||||
biases(params, params.size() - conv.num_filters()) += mat(beta);
|
||||
|
||||
// guess the number of input channels
|
||||
const long k_in = (params.size() - conv.num_filters()) / conv.num_filters() / conv.nr() / conv.nc();
|
||||
|
||||
// rescale the filters
|
||||
DLIB_CASSERT(conv.num_filters() == gamma.k());
|
||||
alias_tensor filter(1, k_in, conv.nr(), conv.nc());
|
||||
const float* g = gamma.host();
|
||||
for (long n = 0; n < conv.num_filters(); ++n)
|
||||
{
|
||||
filter(params, n * filter.size()) *= g[n];
|
||||
}
|
||||
|
||||
// disable the affine layer
|
||||
l.layer_details().disable();
|
||||
}
|
||||
|
||||
template <typename input_layer_type>
|
||||
void operator()(size_t , input_layer_type& ) const
|
||||
{
|
||||
// ignore other layers
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename E>
|
||||
void operator()(size_t , add_layer<T, U, E>& l)
|
||||
{
|
||||
fuse_convolution(l);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
template <typename net_type>
|
||||
void fuse_layers (
|
||||
net_type& net
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(count_parameters(net) > 0, "The network has to be allocated before fusing the layers.");
|
||||
visit_layers(net, impl::visitor_fuse_layers());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
@ -944,6 +944,17 @@ namespace dlib
|
||||
/*!
|
||||
ensures
|
||||
- bias_is_disabled() returns true
|
||||
- if bias was enabled and allocated, it resizes the layer parameters
|
||||
to accommodate the filter parameters only, and free the bias parameters.
|
||||
!*/
|
||||
|
||||
void enable_bias(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- bias_is_disabled() returns false
|
||||
- if bias was disabled and not allocated, it resizes the layer parameters
|
||||
to accommodate the new zero-inizialized biases
|
||||
!*/
|
||||
|
||||
bool bias_is_disabled(
|
||||
@ -1852,24 +1863,24 @@ namespace dlib
|
||||
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
|
||||
defined above. In particular, it applies a simple pointwise linear
|
||||
transformation to an input tensor. You can think of it as having two
|
||||
parameter tensors, A and B. If the input tensor is called INPUT then the
|
||||
output of this layer is:
|
||||
A*INPUT+B
|
||||
parameter tensors, gamma and beta. If the input tensor is called INPUT
|
||||
then the output of this layer is:
|
||||
gamma*INPUT+beta
|
||||
where all operations are performed element wise and each sample in the
|
||||
INPUT tensor is processed separately.
|
||||
|
||||
Moreover, this object has two modes that affect the dimensionalities of A
|
||||
and B and how they are applied to compute A*INPUT+B. If
|
||||
get_mode()==FC_MODE then A and B each have the same dimensionality as the
|
||||
input tensor, except their num_samples() dimensions are 1. If
|
||||
get_mode()==CONV_MODE then A and B have all their dimensions set to 1
|
||||
except for k(), which is equal to INPUT.k().
|
||||
Moreover, this object has two modes that affect the dimensionalities of
|
||||
gamma and beta and how they are applied to compute gamma*INPUT+beta. If
|
||||
get_mode()==FC_MODE then gamma and beta each have the same dimensionality
|
||||
as the input tensor, except their num_samples() dimensions are 1. If
|
||||
get_mode()==CONV_MODE then gamma and beta have all their dimensions set
|
||||
to 1 except for k(), which is equal to INPUT.k().
|
||||
|
||||
In either case, the computation of A*INPUT+B is performed pointwise over all
|
||||
the elements of INPUT using either:
|
||||
OUTPUT(n,k,r,c) == A(1,k,r,c)*INPUT(n,k,r,c)+B(1,k,r,c)
|
||||
In either case, the computation of gamma*INPUT+beta is performed pointwise
|
||||
over all the elements of INPUT using either:
|
||||
OUTPUT(n,k,r,c) == gamma(1,k,r,c)*INPUT(n,k,r,c)+beta(1,k,r,c)
|
||||
or
|
||||
OUTPUT(n,k,r,c) == A(1,k,1,1)*INPUT(n,k,r,c)+B(1,k,1,1)
|
||||
OUTPUT(n,k,r,c) == gamma(1,k,1,1)*INPUT(n,k,r,c)+beta(1,k,1,1)
|
||||
as appropriate.
|
||||
|
||||
|
||||
@ -1919,6 +1930,39 @@ namespace dlib
|
||||
- returns the mode of this layer, either CONV_MODE or FC_MODE.
|
||||
!*/
|
||||
|
||||
void disable(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_layer_params().size() == 0.
|
||||
- when forward_inplace and backward_inplace are called, they return immediately doing nothing.
|
||||
Causing this layer to trivially perform the an identity transform.
|
||||
!*/
|
||||
|
||||
alias_tensor_instance get_gamma();
|
||||
/*!
|
||||
ensures
|
||||
- returns the gamma parameter that defines the behavior of forward().
|
||||
!*/
|
||||
|
||||
alias_tensor_const_instance get_gamma() const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the gamma parameter that defines the behavior of forward().
|
||||
!*/
|
||||
|
||||
alias_tensor_instance get_beta();
|
||||
/*!
|
||||
ensures
|
||||
- returns the beta parameter that defines the behavior of forward().
|
||||
!*/
|
||||
|
||||
alias_tensor_const_instance get_beta() const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the beta parameter that defines the behavior of forward().
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
void forward_inplace(const tensor& input, tensor& output);
|
||||
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||
@ -3256,6 +3300,23 @@ namespace dlib
|
||||
>
|
||||
using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename net_type>
|
||||
void fuse_layers (
|
||||
net_type& net
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
- net has been properly allocated, that is: count_parameters(net) > 0.
|
||||
ensures
|
||||
- Disables all the affine_ layers that have a convolution as an input.
|
||||
- Updates the convolution weights beneath the affine_ layers to produce the same
|
||||
output as with the affine_ layers enabled.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
@ -4157,6 +4157,36 @@ namespace
|
||||
DLIB_TEST(dets.size() < approximate_desired_det_count * 1.05);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_fuse_layers()
|
||||
{
|
||||
print_spinner();
|
||||
using net_type = fc<10, avg_pool_everything<relu<bn_con<con<16, 3, 3, 1, 1, input_rgb_image>>>>>;
|
||||
using net_type_fused = fc<10, avg_pool_everything<relu<affine<con<16, 3, 3, 1, 1, input_rgb_image>>>>>;
|
||||
net_type net_bias, net_nobias;
|
||||
disable_duplicative_biases(net_nobias);
|
||||
resizable_tensor x;
|
||||
matrix<rgb_pixel> image(8, 8);
|
||||
net_bias.to_tensor(&image, &image+1, x);
|
||||
net_nobias.to_tensor(&image, &image+1, x);
|
||||
net_bias.forward(x);
|
||||
net_nobias.forward(x);
|
||||
net_type_fused net_fused_bias(net_bias);
|
||||
net_type_fused net_fused_nobias(net_nobias);
|
||||
const resizable_tensor out_bias = net_bias.get_output();
|
||||
const resizable_tensor out_nobias = net_nobias.get_output();
|
||||
fuse_layers(net_fused_bias);
|
||||
fuse_layers(net_fused_nobias);
|
||||
net_fused_bias.forward(x);
|
||||
net_fused_nobias.forward(x);
|
||||
const resizable_tensor out_bias_fused = net_fused_bias.get_output();
|
||||
const resizable_tensor out_nobias_fused = net_fused_nobias.get_output();
|
||||
|
||||
DLIB_TEST(max(squared(mat(out_bias) - mat(out_bias_fused))) < 1e-10);
|
||||
DLIB_TEST(max(squared(mat(out_nobias) - mat(out_nobias_fused))) < 1e-10);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class dnn_tester : public tester
|
||||
@ -4263,6 +4293,7 @@ namespace
|
||||
test_disable_duplicative_biases();
|
||||
test_set_learning_rate_multipliers();
|
||||
test_input_ouput_mappers();
|
||||
test_fuse_layers();
|
||||
}
|
||||
|
||||
void perform_test()
|
||||
|
@ -388,7 +388,7 @@ try
|
||||
}
|
||||
|
||||
|
||||
// Create some data loaders which will load the data, and perform som data augmentation.
|
||||
// Create some data loaders which will load the data, and perform some data augmentation.
|
||||
dlib::pipe<std::pair<matrix<rgb_pixel>, std::vector<yolo_rect>>> train_data(1000);
|
||||
const auto loader = [&dataset, &data_directory, &train_data, &image_size](time_t seed)
|
||||
{
|
||||
@ -516,7 +516,16 @@ try
|
||||
for (auto& worker : data_loaders)
|
||||
worker.join();
|
||||
|
||||
serialize("yolov3.dnn") << net;
|
||||
// Before saving the network, we can assign it to the "infer" version, so that it won't
|
||||
// perform batch normalization with batch sizes larger than one, as usual. Moreover,
|
||||
// we can also fuse the batch normalization (affine) layers into the convolutional
|
||||
// layers, so that the network can run a bit faster. Notice that, after fusing the
|
||||
// layers, the network can no longer be used for training, so you should save the
|
||||
// yolov3_train_type network if you plan to further train or finetune the network.
|
||||
darknet::yolov3_infer_type inet(net);
|
||||
fuse_layers(inet);
|
||||
|
||||
serialize("yolov3.dnn") << inet;
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
|
Loading…
Reference in New Issue
Block a user