mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Promote some of the sub-network methods into the add_loss_layer interface so users don't have to write .subnet() so often.
This commit is contained in:
parent
c79f64f52d
commit
0057461a62
@ -2461,6 +2461,27 @@ namespace dlib
|
||||
return results;
|
||||
}
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
{
|
||||
subnet().back_propagate_error(x);
|
||||
}
|
||||
|
||||
void back_propagate_error(const tensor& x, const tensor& gradient_input)
|
||||
{
|
||||
subnet().back_propagate_error(x, gradient_input);
|
||||
}
|
||||
|
||||
const tensor& get_final_data_gradient(
|
||||
) const
|
||||
{
|
||||
return subnet().get_final_data_gradient();
|
||||
}
|
||||
|
||||
const tensor& forward(const tensor& x)
|
||||
{
|
||||
return subnet().forward(x);
|
||||
}
|
||||
|
||||
template <typename iterable_type>
|
||||
std::vector<output_label_type> operator() (
|
||||
const iterable_type& data,
|
||||
|
@ -857,6 +857,29 @@ namespace dlib
|
||||
|
||||
// -------------
|
||||
|
||||
const tensor& forward(const tensor& x
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- sample_expansion_factor() != 0
|
||||
(i.e. to_tensor() must have been called to set sample_expansion_factor()
|
||||
to something non-zero.)
|
||||
- x.num_samples()%sample_expansion_factor() == 0
|
||||
- x.num_samples() > 0
|
||||
ensures
|
||||
- Runs x through the network and returns the results as a tensor. In particular,
|
||||
this function just performs:
|
||||
return subnet().forward(x);
|
||||
So if you want to get the outputs as an output_label_type then call one of the
|
||||
methods below instead, like operator().
|
||||
- The return value from this function is also available in #subnet().get_output().
|
||||
i.e. this function returns #subnet().get_output().
|
||||
- have_same_dimensions(#subnet().get_gradient_input(), #subnet().get_output()) == true
|
||||
- All elements of #subnet().get_gradient_input() are set to 0.
|
||||
i.e. calling this function clears out #subnet().get_gradient_input() and ensures
|
||||
it has the same dimensions as the most recent output.
|
||||
!*/
|
||||
|
||||
template <typename output_iterator>
|
||||
void operator() (
|
||||
const tensor& x,
|
||||
@ -996,6 +1019,9 @@ namespace dlib
|
||||
- for all valid k:
|
||||
- the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
|
||||
- This function does not update the network parameters.
|
||||
- For sub-layers that are immediate inputs into the loss layer, we also populate the
|
||||
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
|
||||
to the sub-layer's output.
|
||||
!*/
|
||||
|
||||
template <typename forward_iterator, typename label_iterator>
|
||||
@ -1016,6 +1042,9 @@ namespace dlib
|
||||
- for all valid k:
|
||||
- the expected label of *(ibegin+k) is *(lbegin+k).
|
||||
- This function does not update the network parameters.
|
||||
- For sub-layers that are immediate inputs into the loss layer, we also populate the
|
||||
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
|
||||
to the sub-layer's output.
|
||||
!*/
|
||||
|
||||
// -------------
|
||||
@ -1034,6 +1063,9 @@ namespace dlib
|
||||
ensures
|
||||
- runs x through the network and returns the resulting loss.
|
||||
- This function does not update the network parameters.
|
||||
- For sub-layers that are immediate inputs into the loss layer, we also populate the
|
||||
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
|
||||
to the sub-layer's output.
|
||||
!*/
|
||||
|
||||
template <typename forward_iterator>
|
||||
@ -1049,6 +1081,9 @@ namespace dlib
|
||||
ensures
|
||||
- runs [ibegin,iend) through the network and returns the resulting loss.
|
||||
- This function does not update the network parameters.
|
||||
- For sub-layers that are immediate inputs into the loss layer, we also populate the
|
||||
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
|
||||
to the sub-layer's output.
|
||||
!*/
|
||||
|
||||
// -------------
|
||||
@ -1163,12 +1198,72 @@ namespace dlib
|
||||
!*/
|
||||
|
||||
template <typename solver_type>
|
||||
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
|
||||
{ update_parameters(make_sstack(solvers), learning_rate); }
|
||||
void update_parameters(std::vector<solver_type>& solvers, double learning_rate
|
||||
) { update_parameters(make_sstack(solvers), learning_rate); }
|
||||
/*!
|
||||
Convenience method for calling update_parameters()
|
||||
!*/
|
||||
|
||||
void back_propagate_error(
|
||||
const tensor& x
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- forward(x) was called to forward propagate x though the network.
|
||||
Moreover, this was the most recent call to forward() and x has not been
|
||||
subsequently modified in any way.
|
||||
- subnet().get_gradient_input() has been set equal to the gradient of this network's
|
||||
output with respect to the loss function (generally this will be done by calling
|
||||
compute_loss()).
|
||||
ensures
|
||||
- Back propagates the error gradient, subnet().get_gradient_input(), through this
|
||||
network and computes parameter and data gradients, via backpropagation.
|
||||
Specifically, this function populates get_final_data_gradient() and also,
|
||||
for each layer, the tensor returned by get_parameter_gradient().
|
||||
- All elements of #subnet().get_gradient_input() are set to 0.
|
||||
- have_same_dimensions(#get_final_data_gradient(), x) == true.
|
||||
- #get_final_data_gradient() contains the gradient of the network with
|
||||
respect to x.
|
||||
!*/
|
||||
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- forward(x) was called to forward propagate x though the network.
|
||||
Moreover, this was the most recent call to forward() and x has not been
|
||||
subsequently modified in any way.
|
||||
- have_same_dimensions(gradient_input, subnet().get_output()) == true
|
||||
ensures
|
||||
- This function is identical to the version of back_propagate_error()
|
||||
defined immediately above except that it back-propagates gradient_input
|
||||
through the network instead of subnet().get_gradient_input(). Therefore, this
|
||||
version of back_propagate_error() is equivalent to performing:
|
||||
subnet().get_gradient_input() = gradient_input;
|
||||
back_propagate_error(x);
|
||||
Except that calling back_propagate_error(x,gradient_input) avoids the
|
||||
copy and is therefore slightly more efficient.
|
||||
- All elements of #subnet.get_gradient_input() are set to 0.
|
||||
- have_same_dimensions(#get_final_data_gradient(), x) == true.
|
||||
- #get_final_data_gradient() contains the gradient of the network with
|
||||
respect to x.
|
||||
!*/
|
||||
|
||||
const tensor& get_final_data_gradient(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- if back_propagate_error() has been called to back-propagate a gradient
|
||||
through this network then you can call get_final_data_gradient() to
|
||||
obtain the last data gradient computed. That is, this function returns
|
||||
the gradient of the network with respect to its inputs.
|
||||
- Note that there is only one "final data gradient" for an entire network,
|
||||
not one per layer, since there is only one input to the entire network.
|
||||
!*/
|
||||
|
||||
|
||||
// -------------
|
||||
|
||||
void clean (
|
||||
|
@ -109,10 +109,9 @@ matrix<unsigned char> generate_image(generator_type& net, const noise_t& noise)
|
||||
return image;
|
||||
}
|
||||
|
||||
std::vector<matrix<unsigned char>> get_generated_images(generator_type& net)
|
||||
std::vector<matrix<unsigned char>> get_generated_images(const tensor& out)
|
||||
{
|
||||
std::vector<matrix<unsigned char>> images;
|
||||
const tensor& out = layer<1>(net).get_output();
|
||||
for (size_t n = 0; n < out.num_samples(); ++n)
|
||||
{
|
||||
matrix<float> output = image_plane(out, n);
|
||||
@ -194,8 +193,8 @@ int main(int argc, char** argv) try
|
||||
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
|
||||
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor);
|
||||
double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin());
|
||||
discriminator.subnet().back_propagate_error(real_samples_tensor);
|
||||
discriminator.subnet().update_parameters(d_solvers, learning_rate);
|
||||
discriminator.back_propagate_error(real_samples_tensor);
|
||||
discriminator.update_parameters(d_solvers, learning_rate);
|
||||
|
||||
// Train the discriminator with fake images
|
||||
// 1. generate some random noise
|
||||
@ -204,17 +203,16 @@ int main(int argc, char** argv) try
|
||||
{
|
||||
noises.push_back(make_noise(rnd));
|
||||
}
|
||||
// 2. forward the noise through the generator
|
||||
// 2. convert noises into a tensor
|
||||
generator.to_tensor(noises.begin(), noises.end(), noises_tensor);
|
||||
generator.subnet().forward(noises_tensor);
|
||||
// 3. get the generated images from the generator
|
||||
const auto fake_samples = get_generated_images(generator);
|
||||
// 3. Then forward the noise through the network and convert the outputs into images.
|
||||
const auto fake_samples = get_generated_images(generator.forward(noises_tensor));
|
||||
// 4. finally train the discriminator and wait for the threading to stop. The following
|
||||
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
|
||||
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor);
|
||||
d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin());
|
||||
discriminator.subnet().back_propagate_error(fake_samples_tensor);
|
||||
discriminator.subnet().update_parameters(d_solvers, learning_rate);
|
||||
discriminator.back_propagate_error(fake_samples_tensor);
|
||||
discriminator.update_parameters(d_solvers, learning_rate);
|
||||
|
||||
// Train the generator
|
||||
// This part is the essence of the Generative Adversarial Networks. Until now, we have
|
||||
@ -227,11 +225,11 @@ int main(int argc, char** argv) try
|
||||
// Forward the fake samples and compute the loss with real labels
|
||||
const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin());
|
||||
// Back propagate the error to fill the final data gradient
|
||||
discriminator.subnet().back_propagate_error(fake_samples_tensor);
|
||||
discriminator.back_propagate_error(fake_samples_tensor);
|
||||
// Get the gradient that will tell the generator how to update itself
|
||||
const tensor& d_grad = discriminator.subnet().get_final_data_gradient();
|
||||
generator.subnet().back_propagate_error(noises_tensor, d_grad);
|
||||
generator.subnet().update_parameters(g_solvers, learning_rate);
|
||||
const tensor& d_grad = discriminator.get_final_data_gradient();
|
||||
generator.back_propagate_error(noises_tensor, d_grad);
|
||||
generator.update_parameters(g_solvers, learning_rate);
|
||||
|
||||
// At some point, we should see that the generated images start looking like samples from
|
||||
// the MNIST dataset
|
||||
|
Loading…
Reference in New Issue
Block a user