Add visitor to remove bias from bn_ layer inputs (#closes 2155) (#2156)

* add visitor to remove bias from bn_ inputs (#closes 2155)

* remove unused parameter and make documentation more clear

* remove bias from bn_ layers too and use better name

* let the batch norm keep their bias, use even better name

* be more consistent with impl naming

* remove default constructor

* do not use method to prevent some errors

* add disable bias method to pertinent layers

* update dcgan example

- grammar
- print number of network parameters to be able to check bias is not allocated
- at the end, give feedback to the user about what the discriminator thinks about each generated sample

* fix fc_ logic

* add documentation

* add bias_is_disabled methods and update to_xml

* print use_bias=false when bias is disabled
pull/2162/head
Adrià Arrufat 4 years ago committed by GitHub
parent ed22f0400a
commit e7ec6b7777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -183,6 +183,28 @@ namespace dlib
impl::set_bias_weight_decay_multiplier(obj, special_(), bias_weight_decay_multiplier);
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::disable_bias)>::type = 0>
void disable_bias(
T& obj,
special_
) { obj.disable_bias(); }
template <typename T>
void disable_bias( const T& , general_) { }
}
template <typename T>
void disable_bias(
T& obj
)
{
impl::disable_bias(obj, special_());
}
// ----------------------------------------------------------------------------------------
namespace impl

@ -157,6 +157,20 @@ namespace dlib
- does nothing
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
void disable_bias(
T& obj
);
/*!
ensures
- if (obj has a disable_bias() member function) then
- calls obj.disable_bias()
- else
- does nothing
!*/
// ----------------------------------------------------------------------------------------
bool dnn_prefer_fastest_algorithms(

@ -59,7 +59,8 @@ namespace dlib
bias_weight_decay_multiplier(0),
num_filters_(o.num_outputs),
padding_y_(_padding_y),
padding_x_(_padding_x)
padding_x_(_padding_x),
use_bias(true)
{
DLIB_CASSERT(num_filters_ > 0);
}
@ -106,6 +107,8 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
void disable_bias() { use_bias = false; }
bool bias_is_disabled() const { return !use_bias; }
inline dpoint map_input_to_output (
dpoint p
@ -137,7 +140,8 @@ namespace dlib
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
num_filters_(item.num_filters_),
padding_y_(item.padding_y_),
padding_x_(item.padding_x_)
padding_x_(item.padding_x_),
use_bias(item.use_bias)
{
// this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error.
@ -162,6 +166,7 @@ namespace dlib
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
num_filters_ = item.num_filters_;
use_bias = item.use_bias;
return *this;
}
@ -174,16 +179,18 @@ namespace dlib
long num_inputs = filt_nr*filt_nc*sub.get_output().k();
long num_outputs = num_filters_;
// allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*num_filters_ + num_filters_);
params.set_size(num_inputs*num_filters_ + static_cast<int>(use_bias) * num_filters_);
dlib::rand rnd(std::rand());
randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc);
biases = alias_tensor(1,num_filters_);
// set the initial bias values to zero
biases(params,filters.size()) = 0;
if (use_bias)
{
biases = alias_tensor(1,num_filters_);
// set the initial bias values to zero
biases(params,filters.size()) = 0;
}
}
template <typename SUBNET>
@ -198,9 +205,11 @@ namespace dlib
conv(false, output,
sub.get_output(),
filters(params,0));
tt::add(1,output,1,biases(params,filters.size()));
}
if (use_bias)
{
tt::add(1,output,1,biases(params,filters.size()));
}
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
@ -211,8 +220,11 @@ namespace dlib
{
auto filt = filters(params_grad,0);
conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt);
auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input);
if (use_bias)
{
auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input);
}
}
}
@ -221,7 +233,7 @@ namespace dlib
friend void serialize(const con_& item, std::ostream& out)
{
serialize("con_4", out);
serialize("con_5", out);
serialize(item.params, out);
serialize(item.num_filters_, out);
serialize(_nr, out);
@ -236,6 +248,7 @@ namespace dlib
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
}
friend void deserialize(con_& item, std::istream& in)
@ -246,7 +259,7 @@ namespace dlib
long nc;
int stride_y;
int stride_x;
if (version == "con_4")
if (version == "con_4" || version == "con_5")
{
deserialize(item.params, in);
deserialize(item.num_filters_, in);
@ -268,6 +281,10 @@ namespace dlib
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
if (version == "con_5")
{
deserialize(item.use_bias, in);
}
}
else
{
@ -289,8 +306,15 @@ namespace dlib
<< ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
if (item.use_bias)
{
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
}
else
{
out << " use_bias=false";
}
return out;
}
@ -307,7 +331,8 @@ namespace dlib
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
out << mat(item.params);
out << "</con>";
}
@ -328,6 +353,7 @@ namespace dlib
// serialized to disk) used different padding settings.
int padding_y_;
int padding_x_;
bool use_bias;
};
@ -373,7 +399,8 @@ namespace dlib
bias_weight_decay_multiplier(0),
num_filters_(o.num_outputs),
padding_y_(_padding_y),
padding_x_(_padding_x)
padding_x_(_padding_x),
use_bias(true)
{
DLIB_CASSERT(num_filters_ > 0);
}
@ -408,6 +435,8 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
void disable_bias() { use_bias = false; }
bool bias_is_disabled() const { return !use_bias; }
inline dpoint map_output_to_input (
dpoint p
@ -439,7 +468,8 @@ namespace dlib
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
num_filters_(item.num_filters_),
padding_y_(item.padding_y_),
padding_x_(item.padding_x_)
padding_x_(item.padding_x_),
use_bias(item.use_bias)
{
// this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error.
@ -464,6 +494,7 @@ namespace dlib
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
num_filters_ = item.num_filters_;
use_bias = item.use_bias;
return *this;
}
@ -473,16 +504,18 @@ namespace dlib
long num_inputs = _nr*_nc*sub.get_output().k();
long num_outputs = num_filters_;
// allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*num_filters_ + num_filters_);
params.set_size(num_inputs*num_filters_ + num_filters_ * static_cast<int>(use_bias));
dlib::rand rnd(std::rand());
randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc);
biases = alias_tensor(1,num_filters_);
// set the initial bias values to zero
biases(params,filters.size()) = 0;
if (use_bias)
{
biases = alias_tensor(1,num_filters_);
// set the initial bias values to zero
biases(params,filters.size()) = 0;
}
}
template <typename SUBNET>
@ -496,7 +529,10 @@ namespace dlib
output.set_size(gnsamps,gk,gnr,gnc);
conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_);
conv.get_gradient_for_data(false, sub.get_output(),filt,output);
tt::add(1,output,1,biases(params,filters.size()));
if (use_bias)
{
tt::add(1,output,1,biases(params,filters.size()));
}
}
template <typename SUBNET>
@ -509,8 +545,11 @@ namespace dlib
{
auto filt = filters(params_grad,0);
conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt);
auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input);
if (use_bias)
{
auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input);
}
}
}
@ -519,7 +558,7 @@ namespace dlib
friend void serialize(const cont_& item, std::ostream& out)
{
serialize("cont_1", out);
serialize("cont_2", out);
serialize(item.params, out);
serialize(item.num_filters_, out);
serialize(_nr, out);
@ -534,6 +573,7 @@ namespace dlib
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
}
friend void deserialize(cont_& item, std::istream& in)
@ -544,7 +584,7 @@ namespace dlib
long nc;
int stride_y;
int stride_x;
if (version == "cont_1")
if (version == "cont_1" || version == "cont_2")
{
deserialize(item.params, in);
deserialize(item.num_filters_, in);
@ -566,6 +606,10 @@ namespace dlib
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
if (version == "cont_2")
{
deserialize(item.use_bias, in);
}
}
else
{
@ -587,8 +631,15 @@ namespace dlib
<< ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
if (item.use_bias)
{
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
}
else
{
out << " use_bias=false";
}
return out;
}
@ -605,7 +656,8 @@ namespace dlib
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
out << mat(item.params);
out << "</cont>";
}
@ -625,6 +677,8 @@ namespace dlib
int padding_y_;
int padding_x_;
bool use_bias;
};
template <
@ -1522,6 +1576,37 @@ namespace dlib
unsigned long new_window_size;
};
class visitor_bn_input_no_bias
{
public:
template <typename T>
void set_input_no_bias(T&) const
{
// ignore other layer types
}
template <layer_mode mode, typename U, typename E>
void set_input_no_bias(add_layer<bn_<mode>, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l)
{
set_input_no_bias(l);
}
};
}
template <typename net_type>
@ -1533,6 +1618,14 @@ namespace dlib
visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
}
template <typename net_type>
void set_all_bn_inputs_no_bias (
net_type& net
)
{
visit_layers(net, impl::visitor_bn_input_no_bias());
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
@ -1561,7 +1654,8 @@ namespace dlib
learning_rate_multiplier(1),
weight_decay_multiplier(1),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(0)
bias_weight_decay_multiplier(0),
use_bias(true)
{}
fc_() : fc_(num_fc_outputs(num_outputs_)) {}
@ -1575,6 +1669,8 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
void disable_bias() { use_bias = false; }
bool bias_is_disabled() const { return !use_bias; }
unsigned long get_num_outputs (
) const { return num_outputs; }
@ -1597,7 +1693,7 @@ namespace dlib
void setup (const SUBNET& sub)
{
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
if (bias_mode == FC_HAS_BIAS)
if (bias_mode == FC_HAS_BIAS && use_bias)
params.set_size(num_inputs+1, num_outputs);
else
params.set_size(num_inputs, num_outputs);
@ -1607,7 +1703,7 @@ namespace dlib
weights = alias_tensor(num_inputs, num_outputs);
if (bias_mode == FC_HAS_BIAS)
if (bias_mode == FC_HAS_BIAS && use_bias)
{
biases = alias_tensor(1,num_outputs);
// set the initial bias values to zero
@ -1624,7 +1720,7 @@ namespace dlib
auto w = weights(params, 0);
tt::gemm(0,output, 1,sub.get_output(),false, w,false);
if (bias_mode == FC_HAS_BIAS)
if (bias_mode == FC_HAS_BIAS && use_bias)
{
auto b = biases(params, weights.size());
tt::add(1,output,1,b);
@ -1641,7 +1737,7 @@ namespace dlib
auto pw = weights(params_grad, 0);
tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);
if (bias_mode == FC_HAS_BIAS)
if (bias_mode == FC_HAS_BIAS && use_bias)
{
// compute the gradient of the bias parameters.
auto pb = biases(params_grad, weights.size());
@ -1683,7 +1779,7 @@ namespace dlib
friend void serialize(const fc_& item, std::ostream& out)
{
serialize("fc_2", out);
serialize("fc_3", out);
serialize(item.num_outputs, out);
serialize(item.num_inputs, out);
serialize(item.params, out);
@ -1694,27 +1790,36 @@ namespace dlib
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
}
friend void deserialize(fc_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "fc_2")
if (version == "fc_2" || version == "fc_3")
{
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
deserialize(item.params, in);
deserialize(item.weights, in);
deserialize(item.biases, in);
int bmode = 0;
deserialize(bmode, in);
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
if (version == "fc_3")
{
deserialize(item.use_bias, in);
}
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
deserialize(item.params, in);
deserialize(item.weights, in);
deserialize(item.biases, in);
int bmode = 0;
deserialize(bmode, in);
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
}
}
friend std::ostream& operator<<(std::ostream& out, const fc_& item)
@ -1726,8 +1831,15 @@ namespace dlib
<< ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
if (item.use_bias)
{
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
}
else
{
out << " use_bias=false";
}
}
else
{
@ -1749,7 +1861,8 @@ namespace dlib
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
out << ">\n";
out << mat(item.params);
out << "</fc>\n";
@ -1776,6 +1889,7 @@ namespace dlib
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
bool use_bias;
};
template <

@ -573,6 +573,22 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val
!*/
void disable_bias(
);
/*!
ensures
- bias_is_disabled() returns true
!*/
bool bias_is_disabled(
) const;
/*!
ensures
- returns true if bias learning is disabled for this layer. This means the biases will
not be learned during the training and they will not be used in the forward or backward
methods either.
!*/
alias_tensor_const_instance get_weights(
) const;
/*!
@ -903,6 +919,22 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val
!*/
void disable_bias(
);
/*!
ensures
- bias_is_disabled() returns true
!*/
bool bias_is_disabled(
) const;
/*!
ensures
- returns true if bias learning is disabled for this layer. This means the biases will
not be learned during the training and they will not be used in the forward or backward
methods either.
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
@ -1147,6 +1179,22 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val
!*/
void disable_bias(
);
/*!
ensures
- bias_is_disabled() returns true
!*/
bool bias_is_disabled(
) const;
/*!
ensures
- returns true if bias learning is disabled for this layer. This means the biases will
not be learned during the training and they will not be used in the forward or backward
methods either.
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
@ -1616,6 +1664,22 @@ namespace dlib
new_window_size.
!*/
// ----------------------------------------------------------------------------------------
template <typename net_type>
void set_all_bn_inputs_no_bias (
const net_type& net
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Disables bias for all bn_ layer inputs.
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
to zero of all bn_ layer inputs.
!*/
// ----------------------------------------------------------------------------------------
class affine_

@ -10,7 +10,7 @@
by Alec Radford, Luke Metz, Soumith Chintala.
The main idea is that there are two neural networks training at the same time:
- the generator is in charge of generating images that look as close as possible as the
- the generator is in charge of generating images that look as close as possible to the
ones from the dataset.
- the discriminator will decide whether an image is fake (created by the generator) or real
(selected from the dataset).
@ -35,25 +35,6 @@
using namespace std;
using namespace dlib;
// We start by defining a simple visitor to disable bias learning in a network. By default,
// biases are initialized to 0, so setting the multipliers to 0 disables bias learning.
class visitor_no_bias
{
public:
template <typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T, U, E>& l) const
{
set_bias_learning_rate_multiplier(l.layer_details(), 0);
set_bias_weight_decay_multiplier(l.layer_details(), 0);
}
};
// Some helper definitions for the noise generation
const size_t noise_size = 100;
using noise_t = std::array<matrix<float, 1, 1>, noise_size>;
@ -149,16 +130,15 @@ int main(int argc, char** argv) try
// Instantiate both generator and discriminator
generator_type generator;
discriminator_type discriminator(
leaky_relu_(0.2), leaky_relu_(0.2), leaky_relu_(0.2));
// Remove the bias learning from the networks
visit_layers(generator, visitor_no_bias());
visit_layers(discriminator, visitor_no_bias());
discriminator_type discriminator(leaky_relu_(0.2), leaky_relu_(0.2), leaky_relu_(0.2));
// Remove the bias learning from all bn_ inputs in both networks
set_all_bn_inputs_no_bias(generator);
set_all_bn_inputs_no_bias(discriminator);
// Forward random noise so that we see the tensor size at each layer
discriminator(generate_image(generator, make_noise(rnd)));
cout << "generator" << endl;
cout << "generator (" << count_parameters(generator) << " parameters)" << endl;
cout << generator << endl;
cout << "discriminator" << endl;
cout << "discriminator (" << count_parameters(discriminator) << " parameters)" << endl;
cout << discriminator << endl;
// The solvers for the generator and discriminator networks. In this example, we are going to
@ -204,7 +184,7 @@ int main(int argc, char** argv) try
{
noises.push_back(make_noise(rnd));
}
// 2. Convert noises into a tensor
// 2. Convert noises into a tensor
generator.to_tensor(noises.begin(), noises.end(), noises_tensor);
// 3. Forward the noise through the network and convert the outputs into images.
const auto fake_samples = get_generated_images(generator.forward(noises_tensor));
@ -257,8 +237,11 @@ int main(int argc, char** argv) try
// output.
while (!win.is_closed())
{
win.set_image(generate_image(generator, make_noise(rnd)));
cout << "Hit enter to generate a new image";
const auto image = generate_image(generator, make_noise(rnd));
const auto real = discriminator(image) > 0;
win.set_image(image);
cout << "The discriminator thinks it's " << (real ? "real" : "fake");
cout << ". Hit enter to generate a new image";
cin.get();
}

Loading…
Cancel
Save