mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add visitor to remove bias from bn_ layer inputs (#closes 2155) (#2156)
* add visitor to remove bias from bn_ inputs (#closes 2155) * remove unused parameter and make documentation more clear * remove bias from bn_ layers too and use better name * let the batch norm keep their bias, use even better name * be more consistent with impl naming * remove default constructor * do not use method to prevent some errors * add disable bias method to pertinent layers * update dcgan example - grammar - print number of network parameters to be able to check bias is not allocated - at the end, give feedback to the user about what the discriminator thinks about each generated sample * fix fc_ logic * add documentation * add bias_is_disabled methods and update to_xml * print use_bias=false when bias is disabled
This commit is contained in:
parent
ed22f0400a
commit
e7ec6b7777
@ -183,6 +183,28 @@ namespace dlib
|
||||
impl::set_bias_weight_decay_multiplier(obj, special_(), bias_weight_decay_multiplier);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace impl
|
||||
{
|
||||
template <typename T, typename int_<decltype(&T::disable_bias)>::type = 0>
|
||||
void disable_bias(
|
||||
T& obj,
|
||||
special_
|
||||
) { obj.disable_bias(); }
|
||||
|
||||
template <typename T>
|
||||
void disable_bias( const T& , general_) { }
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void disable_bias(
|
||||
T& obj
|
||||
)
|
||||
{
|
||||
impl::disable_bias(obj, special_());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace impl
|
||||
|
@ -157,6 +157,20 @@ namespace dlib
|
||||
- does nothing
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
void disable_bias(
|
||||
T& obj
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- if (obj has a disable_bias() member function) then
|
||||
- calls obj.disable_bias()
|
||||
- else
|
||||
- does nothing
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
bool dnn_prefer_fastest_algorithms(
|
||||
|
@ -59,7 +59,8 @@ namespace dlib
|
||||
bias_weight_decay_multiplier(0),
|
||||
num_filters_(o.num_outputs),
|
||||
padding_y_(_padding_y),
|
||||
padding_x_(_padding_x)
|
||||
padding_x_(_padding_x),
|
||||
use_bias(true)
|
||||
{
|
||||
DLIB_CASSERT(num_filters_ > 0);
|
||||
}
|
||||
@ -106,6 +107,8 @@ namespace dlib
|
||||
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
|
||||
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
|
||||
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
|
||||
void disable_bias() { use_bias = false; }
|
||||
bool bias_is_disabled() const { return !use_bias; }
|
||||
|
||||
inline dpoint map_input_to_output (
|
||||
dpoint p
|
||||
@ -137,7 +140,8 @@ namespace dlib
|
||||
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
|
||||
num_filters_(item.num_filters_),
|
||||
padding_y_(item.padding_y_),
|
||||
padding_x_(item.padding_x_)
|
||||
padding_x_(item.padding_x_),
|
||||
use_bias(item.use_bias)
|
||||
{
|
||||
// this->conv is non-copyable and basically stateless, so we have to write our
|
||||
// own copy to avoid trying to copy it and getting an error.
|
||||
@ -162,6 +166,7 @@ namespace dlib
|
||||
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
|
||||
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
|
||||
num_filters_ = item.num_filters_;
|
||||
use_bias = item.use_bias;
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -174,17 +179,19 @@ namespace dlib
|
||||
long num_inputs = filt_nr*filt_nc*sub.get_output().k();
|
||||
long num_outputs = num_filters_;
|
||||
// allocate params for the filters and also for the filter bias values.
|
||||
params.set_size(num_inputs*num_filters_ + num_filters_);
|
||||
params.set_size(num_inputs*num_filters_ + static_cast<int>(use_bias) * num_filters_);
|
||||
|
||||
dlib::rand rnd(std::rand());
|
||||
randomize_parameters(params, num_inputs+num_outputs, rnd);
|
||||
|
||||
filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc);
|
||||
if (use_bias)
|
||||
{
|
||||
biases = alias_tensor(1,num_filters_);
|
||||
|
||||
// set the initial bias values to zero
|
||||
biases(params,filters.size()) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||
@ -198,9 +205,11 @@ namespace dlib
|
||||
conv(false, output,
|
||||
sub.get_output(),
|
||||
filters(params,0));
|
||||
|
||||
if (use_bias)
|
||||
{
|
||||
tt::add(1,output,1,biases(params,filters.size()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
|
||||
@ -211,17 +220,20 @@ namespace dlib
|
||||
{
|
||||
auto filt = filters(params_grad,0);
|
||||
conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt);
|
||||
if (use_bias)
|
||||
{
|
||||
auto b = biases(params_grad, filters.size());
|
||||
tt::assign_conv_bias_gradient(b, gradient_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const con_& item, std::ostream& out)
|
||||
{
|
||||
serialize("con_4", out);
|
||||
serialize("con_5", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item.num_filters_, out);
|
||||
serialize(_nr, out);
|
||||
@ -236,6 +248,7 @@ namespace dlib
|
||||
serialize(item.weight_decay_multiplier, out);
|
||||
serialize(item.bias_learning_rate_multiplier, out);
|
||||
serialize(item.bias_weight_decay_multiplier, out);
|
||||
serialize(item.use_bias, out);
|
||||
}
|
||||
|
||||
friend void deserialize(con_& item, std::istream& in)
|
||||
@ -246,7 +259,7 @@ namespace dlib
|
||||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
if (version == "con_4")
|
||||
if (version == "con_4" || version == "con_5")
|
||||
{
|
||||
deserialize(item.params, in);
|
||||
deserialize(item.num_filters_, in);
|
||||
@ -268,6 +281,10 @@ namespace dlib
|
||||
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
|
||||
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
|
||||
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
|
||||
if (version == "con_5")
|
||||
{
|
||||
deserialize(item.use_bias, in);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -289,8 +306,15 @@ namespace dlib
|
||||
<< ")";
|
||||
out << " learning_rate_mult="<<item.learning_rate_multiplier;
|
||||
out << " weight_decay_mult="<<item.weight_decay_multiplier;
|
||||
if (item.use_bias)
|
||||
{
|
||||
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
|
||||
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
|
||||
}
|
||||
else
|
||||
{
|
||||
out << " use_bias=false";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -307,7 +331,8 @@ namespace dlib
|
||||
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
|
||||
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
|
||||
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
|
||||
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
|
||||
out << mat(item.params);
|
||||
out << "</con>";
|
||||
}
|
||||
@ -328,6 +353,7 @@ namespace dlib
|
||||
// serialized to disk) used different padding settings.
|
||||
int padding_y_;
|
||||
int padding_x_;
|
||||
bool use_bias;
|
||||
|
||||
};
|
||||
|
||||
@ -373,7 +399,8 @@ namespace dlib
|
||||
bias_weight_decay_multiplier(0),
|
||||
num_filters_(o.num_outputs),
|
||||
padding_y_(_padding_y),
|
||||
padding_x_(_padding_x)
|
||||
padding_x_(_padding_x),
|
||||
use_bias(true)
|
||||
{
|
||||
DLIB_CASSERT(num_filters_ > 0);
|
||||
}
|
||||
@ -408,6 +435,8 @@ namespace dlib
|
||||
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
|
||||
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
|
||||
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
|
||||
void disable_bias() { use_bias = false; }
|
||||
bool bias_is_disabled() const { return !use_bias; }
|
||||
|
||||
inline dpoint map_output_to_input (
|
||||
dpoint p
|
||||
@ -439,7 +468,8 @@ namespace dlib
|
||||
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
|
||||
num_filters_(item.num_filters_),
|
||||
padding_y_(item.padding_y_),
|
||||
padding_x_(item.padding_x_)
|
||||
padding_x_(item.padding_x_),
|
||||
use_bias(item.use_bias)
|
||||
{
|
||||
// this->conv is non-copyable and basically stateless, so we have to write our
|
||||
// own copy to avoid trying to copy it and getting an error.
|
||||
@ -464,6 +494,7 @@ namespace dlib
|
||||
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
|
||||
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
|
||||
num_filters_ = item.num_filters_;
|
||||
use_bias = item.use_bias;
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -473,17 +504,19 @@ namespace dlib
|
||||
long num_inputs = _nr*_nc*sub.get_output().k();
|
||||
long num_outputs = num_filters_;
|
||||
// allocate params for the filters and also for the filter bias values.
|
||||
params.set_size(num_inputs*num_filters_ + num_filters_);
|
||||
params.set_size(num_inputs*num_filters_ + num_filters_ * static_cast<int>(use_bias));
|
||||
|
||||
dlib::rand rnd(std::rand());
|
||||
randomize_parameters(params, num_inputs+num_outputs, rnd);
|
||||
|
||||
filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc);
|
||||
if (use_bias)
|
||||
{
|
||||
biases = alias_tensor(1,num_filters_);
|
||||
|
||||
// set the initial bias values to zero
|
||||
biases(params,filters.size()) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||
@ -496,8 +529,11 @@ namespace dlib
|
||||
output.set_size(gnsamps,gk,gnr,gnc);
|
||||
conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_);
|
||||
conv.get_gradient_for_data(false, sub.get_output(),filt,output);
|
||||
if (use_bias)
|
||||
{
|
||||
tt::add(1,output,1,biases(params,filters.size()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
|
||||
@ -509,17 +545,20 @@ namespace dlib
|
||||
{
|
||||
auto filt = filters(params_grad,0);
|
||||
conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt);
|
||||
if (use_bias)
|
||||
{
|
||||
auto b = biases(params_grad, filters.size());
|
||||
tt::assign_conv_bias_gradient(b, gradient_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const cont_& item, std::ostream& out)
|
||||
{
|
||||
serialize("cont_1", out);
|
||||
serialize("cont_2", out);
|
||||
serialize(item.params, out);
|
||||
serialize(item.num_filters_, out);
|
||||
serialize(_nr, out);
|
||||
@ -534,6 +573,7 @@ namespace dlib
|
||||
serialize(item.weight_decay_multiplier, out);
|
||||
serialize(item.bias_learning_rate_multiplier, out);
|
||||
serialize(item.bias_weight_decay_multiplier, out);
|
||||
serialize(item.use_bias, out);
|
||||
}
|
||||
|
||||
friend void deserialize(cont_& item, std::istream& in)
|
||||
@ -544,7 +584,7 @@ namespace dlib
|
||||
long nc;
|
||||
int stride_y;
|
||||
int stride_x;
|
||||
if (version == "cont_1")
|
||||
if (version == "cont_1" || version == "cont_2")
|
||||
{
|
||||
deserialize(item.params, in);
|
||||
deserialize(item.num_filters_, in);
|
||||
@ -566,6 +606,10 @@ namespace dlib
|
||||
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
|
||||
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
|
||||
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
|
||||
if (version == "cont_2")
|
||||
{
|
||||
deserialize(item.use_bias, in);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -587,8 +631,15 @@ namespace dlib
|
||||
<< ")";
|
||||
out << " learning_rate_mult="<<item.learning_rate_multiplier;
|
||||
out << " weight_decay_mult="<<item.weight_decay_multiplier;
|
||||
if (item.use_bias)
|
||||
{
|
||||
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
|
||||
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
|
||||
}
|
||||
else
|
||||
{
|
||||
out << " use_bias=false";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -605,7 +656,8 @@ namespace dlib
|
||||
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
|
||||
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
|
||||
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
|
||||
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
|
||||
out << mat(item.params);
|
||||
out << "</cont>";
|
||||
}
|
||||
@ -625,6 +677,8 @@ namespace dlib
|
||||
int padding_y_;
|
||||
int padding_x_;
|
||||
|
||||
bool use_bias;
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
@ -1522,6 +1576,37 @@ namespace dlib
|
||||
|
||||
unsigned long new_window_size;
|
||||
};
|
||||
|
||||
class visitor_bn_input_no_bias
|
||||
{
|
||||
public:
|
||||
|
||||
template <typename T>
|
||||
void set_input_no_bias(T&) const
|
||||
{
|
||||
// ignore other layer types
|
||||
}
|
||||
|
||||
template <layer_mode mode, typename U, typename E>
|
||||
void set_input_no_bias(add_layer<bn_<mode>, U, E>& l)
|
||||
{
|
||||
disable_bias(l.subnet().layer_details());
|
||||
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
|
||||
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
|
||||
}
|
||||
|
||||
template<typename input_layer_type>
|
||||
void operator()(size_t , input_layer_type& ) const
|
||||
{
|
||||
// ignore other layers
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename E>
|
||||
void operator()(size_t , add_layer<T,U,E>& l)
|
||||
{
|
||||
set_input_no_bias(l);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
template <typename net_type>
|
||||
@ -1533,6 +1618,14 @@ namespace dlib
|
||||
visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
|
||||
}
|
||||
|
||||
template <typename net_type>
|
||||
void set_all_bn_inputs_no_bias (
|
||||
net_type& net
|
||||
)
|
||||
{
|
||||
visit_layers(net, impl::visitor_bn_input_no_bias());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
@ -1561,7 +1654,8 @@ namespace dlib
|
||||
learning_rate_multiplier(1),
|
||||
weight_decay_multiplier(1),
|
||||
bias_learning_rate_multiplier(1),
|
||||
bias_weight_decay_multiplier(0)
|
||||
bias_weight_decay_multiplier(0),
|
||||
use_bias(true)
|
||||
{}
|
||||
|
||||
fc_() : fc_(num_fc_outputs(num_outputs_)) {}
|
||||
@ -1575,6 +1669,8 @@ namespace dlib
|
||||
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
|
||||
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
|
||||
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
|
||||
void disable_bias() { use_bias = false; }
|
||||
bool bias_is_disabled() const { return !use_bias; }
|
||||
|
||||
unsigned long get_num_outputs (
|
||||
) const { return num_outputs; }
|
||||
@ -1597,7 +1693,7 @@ namespace dlib
|
||||
void setup (const SUBNET& sub)
|
||||
{
|
||||
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
|
||||
if (bias_mode == FC_HAS_BIAS)
|
||||
if (bias_mode == FC_HAS_BIAS && use_bias)
|
||||
params.set_size(num_inputs+1, num_outputs);
|
||||
else
|
||||
params.set_size(num_inputs, num_outputs);
|
||||
@ -1607,7 +1703,7 @@ namespace dlib
|
||||
|
||||
weights = alias_tensor(num_inputs, num_outputs);
|
||||
|
||||
if (bias_mode == FC_HAS_BIAS)
|
||||
if (bias_mode == FC_HAS_BIAS && use_bias)
|
||||
{
|
||||
biases = alias_tensor(1,num_outputs);
|
||||
// set the initial bias values to zero
|
||||
@ -1624,7 +1720,7 @@ namespace dlib
|
||||
|
||||
auto w = weights(params, 0);
|
||||
tt::gemm(0,output, 1,sub.get_output(),false, w,false);
|
||||
if (bias_mode == FC_HAS_BIAS)
|
||||
if (bias_mode == FC_HAS_BIAS && use_bias)
|
||||
{
|
||||
auto b = biases(params, weights.size());
|
||||
tt::add(1,output,1,b);
|
||||
@ -1641,7 +1737,7 @@ namespace dlib
|
||||
auto pw = weights(params_grad, 0);
|
||||
tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);
|
||||
|
||||
if (bias_mode == FC_HAS_BIAS)
|
||||
if (bias_mode == FC_HAS_BIAS && use_bias)
|
||||
{
|
||||
// compute the gradient of the bias parameters.
|
||||
auto pb = biases(params_grad, weights.size());
|
||||
@ -1683,7 +1779,7 @@ namespace dlib
|
||||
|
||||
friend void serialize(const fc_& item, std::ostream& out)
|
||||
{
|
||||
serialize("fc_2", out);
|
||||
serialize("fc_3", out);
|
||||
serialize(item.num_outputs, out);
|
||||
serialize(item.num_inputs, out);
|
||||
serialize(item.params, out);
|
||||
@ -1694,15 +1790,15 @@ namespace dlib
|
||||
serialize(item.weight_decay_multiplier, out);
|
||||
serialize(item.bias_learning_rate_multiplier, out);
|
||||
serialize(item.bias_weight_decay_multiplier, out);
|
||||
serialize(item.use_bias, out);
|
||||
}
|
||||
|
||||
friend void deserialize(fc_& item, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "fc_2")
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
|
||||
|
||||
if (version == "fc_2" || version == "fc_3")
|
||||
{
|
||||
deserialize(item.num_outputs, in);
|
||||
deserialize(item.num_inputs, in);
|
||||
deserialize(item.params, in);
|
||||
@ -1715,6 +1811,15 @@ namespace dlib
|
||||
deserialize(item.weight_decay_multiplier, in);
|
||||
deserialize(item.bias_learning_rate_multiplier, in);
|
||||
deserialize(item.bias_weight_decay_multiplier, in);
|
||||
if (version == "fc_3")
|
||||
{
|
||||
deserialize(item.use_bias, in);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const fc_& item)
|
||||
@ -1726,10 +1831,17 @@ namespace dlib
|
||||
<< ")";
|
||||
out << " learning_rate_mult="<<item.learning_rate_multiplier;
|
||||
out << " weight_decay_mult="<<item.weight_decay_multiplier;
|
||||
if (item.use_bias)
|
||||
{
|
||||
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
|
||||
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
|
||||
}
|
||||
else
|
||||
{
|
||||
out << " use_bias=false";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
out << "fc_no_bias ("
|
||||
<< "num_outputs="<<item.num_outputs
|
||||
@ -1749,7 +1861,8 @@ namespace dlib
|
||||
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
|
||||
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
|
||||
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
|
||||
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
|
||||
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
|
||||
out << ">\n";
|
||||
out << mat(item.params);
|
||||
out << "</fc>\n";
|
||||
@ -1776,6 +1889,7 @@ namespace dlib
|
||||
double weight_decay_multiplier;
|
||||
double bias_learning_rate_multiplier;
|
||||
double bias_weight_decay_multiplier;
|
||||
bool use_bias;
|
||||
};
|
||||
|
||||
template <
|
||||
|
@ -573,6 +573,22 @@ namespace dlib
|
||||
- #get_bias_weight_decay_multiplier() == val
|
||||
!*/
|
||||
|
||||
void disable_bias(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- bias_is_disabled() returns true
|
||||
!*/
|
||||
|
||||
bool bias_is_disabled(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns true if bias learning is disabled for this layer. This means the biases will
|
||||
not be learned during the training and they will not be used in the forward or backward
|
||||
methods either.
|
||||
!*/
|
||||
|
||||
alias_tensor_const_instance get_weights(
|
||||
) const;
|
||||
/*!
|
||||
@ -903,6 +919,22 @@ namespace dlib
|
||||
- #get_bias_weight_decay_multiplier() == val
|
||||
!*/
|
||||
|
||||
void disable_bias(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- bias_is_disabled() returns true
|
||||
!*/
|
||||
|
||||
bool bias_is_disabled(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns true if bias learning is disabled for this layer. This means the biases will
|
||||
not be learned during the training and they will not be used in the forward or backward
|
||||
methods either.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
|
||||
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
|
||||
@ -1147,6 +1179,22 @@ namespace dlib
|
||||
- #get_bias_weight_decay_multiplier() == val
|
||||
!*/
|
||||
|
||||
void disable_bias(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- bias_is_disabled() returns true
|
||||
!*/
|
||||
|
||||
bool bias_is_disabled(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns true if bias learning is disabled for this layer. This means the biases will
|
||||
not be learned during the training and they will not be used in the forward or backward
|
||||
methods either.
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
|
||||
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
|
||||
@ -1616,6 +1664,22 @@ namespace dlib
|
||||
new_window_size.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename net_type>
|
||||
void set_all_bn_inputs_no_bias (
|
||||
const net_type& net
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
ensures
|
||||
- Disables bias for all bn_ layer inputs.
|
||||
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
|
||||
to zero of all bn_ layer inputs.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class affine_
|
||||
|
@ -10,7 +10,7 @@
|
||||
by Alec Radford, Luke Metz, Soumith Chintala.
|
||||
|
||||
The main idea is that there are two neural networks training at the same time:
|
||||
- the generator is in charge of generating images that look as close as possible as the
|
||||
- the generator is in charge of generating images that look as close as possible to the
|
||||
ones from the dataset.
|
||||
- the discriminator will decide whether an image is fake (created by the generator) or real
|
||||
(selected from the dataset).
|
||||
@ -35,25 +35,6 @@
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
// We start by defining a simple visitor to disable bias learning in a network. By default,
|
||||
// biases are initialized to 0, so setting the multipliers to 0 disables bias learning.
|
||||
class visitor_no_bias
|
||||
{
|
||||
public:
|
||||
template <typename input_layer_type>
|
||||
void operator()(size_t , input_layer_type& ) const
|
||||
{
|
||||
// ignore other layers
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename E>
|
||||
void operator()(size_t , add_layer<T, U, E>& l) const
|
||||
{
|
||||
set_bias_learning_rate_multiplier(l.layer_details(), 0);
|
||||
set_bias_weight_decay_multiplier(l.layer_details(), 0);
|
||||
}
|
||||
};
|
||||
|
||||
// Some helper definitions for the noise generation
|
||||
const size_t noise_size = 100;
|
||||
using noise_t = std::array<matrix<float, 1, 1>, noise_size>;
|
||||
@ -149,16 +130,15 @@ int main(int argc, char** argv) try
|
||||
|
||||
// Instantiate both generator and discriminator
|
||||
generator_type generator;
|
||||
discriminator_type discriminator(
|
||||
leaky_relu_(0.2), leaky_relu_(0.2), leaky_relu_(0.2));
|
||||
// Remove the bias learning from the networks
|
||||
visit_layers(generator, visitor_no_bias());
|
||||
visit_layers(discriminator, visitor_no_bias());
|
||||
discriminator_type discriminator(leaky_relu_(0.2), leaky_relu_(0.2), leaky_relu_(0.2));
|
||||
// Remove the bias learning from all bn_ inputs in both networks
|
||||
set_all_bn_inputs_no_bias(generator);
|
||||
set_all_bn_inputs_no_bias(discriminator);
|
||||
// Forward random noise so that we see the tensor size at each layer
|
||||
discriminator(generate_image(generator, make_noise(rnd)));
|
||||
cout << "generator" << endl;
|
||||
cout << "generator (" << count_parameters(generator) << " parameters)" << endl;
|
||||
cout << generator << endl;
|
||||
cout << "discriminator" << endl;
|
||||
cout << "discriminator (" << count_parameters(discriminator) << " parameters)" << endl;
|
||||
cout << discriminator << endl;
|
||||
|
||||
// The solvers for the generator and discriminator networks. In this example, we are going to
|
||||
@ -257,8 +237,11 @@ int main(int argc, char** argv) try
|
||||
// output.
|
||||
while (!win.is_closed())
|
||||
{
|
||||
win.set_image(generate_image(generator, make_noise(rnd)));
|
||||
cout << "Hit enter to generate a new image";
|
||||
const auto image = generate_image(generator, make_noise(rnd));
|
||||
const auto real = discriminator(image) > 0;
|
||||
win.set_image(image);
|
||||
cout << "The discriminator thinks it's " << (real ? "real" : "fake");
|
||||
cout << ". Hit enter to generate a new image";
|
||||
cin.get();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user