Add dnn_introduction3_ex (#1991)

* Add dnn_introduction3_ex
This commit is contained in:
Adrià Arrufat 2020-02-07 21:59:36 +09:00 committed by GitHub
parent c90cb0bc14
commit 10d7f119ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 577 additions and 2 deletions

View File

@ -47,6 +47,66 @@ namespace dlib
template <typename T>
double get_learning_rate_multiplier(const T& obj) { return impl::get_learning_rate_multiplier(obj, special_()); }
namespace impl
{
template <typename T, typename int_<decltype(&T::set_learning_rate_multiplier)>::type = 0>
void set_learning_rate_multiplier (
T& obj,
special_,
double learning_rate_multiplier
) { obj.set_learning_rate_multiplier(learning_rate_multiplier); }
template <typename T>
void set_learning_rate_multiplier (T& , general_, double) { }
}
template <typename T>
void set_learning_rate_multiplier(
T& obj,
double learning_rate_multiplier
)
{
DLIB_CASSERT(learning_rate_multiplier >= 0);
impl::set_learning_rate_multiplier(obj, special_(), learning_rate_multiplier);
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::get_bias_learning_rate_multiplier)>::type = 0>
double get_bias_learning_rate_multiplier (
const T& obj,
special_
) { return obj.get_bias_learning_rate_multiplier(); }
template <typename T>
double get_bias_learning_rate_multiplier ( const T& , general_) { return 1; }
}
template <typename T>
double get_bias_learning_rate_multiplier(const T& obj) { return impl::get_bias_learning_rate_multiplier(obj, special_()); }
namespace impl
{
template <typename T, typename int_<decltype(&T::set_bias_learning_rate_multiplier)>::type = 0>
void set_bias_learning_rate_multiplier (
T& obj,
special_,
double bias_learning_rate_multiplier
) { obj.set_bias_learning_rate_multiplier(bias_learning_rate_multiplier); }
template <typename T>
void set_bias_learning_rate_multiplier (T& , general_, double) { }
}
template <typename T>
void set_bias_learning_rate_multiplier(
T& obj,
double bias_learning_rate_multiplier
)
{
DLIB_CASSERT(bias_learning_rate_multiplier >= 0);
impl::set_bias_learning_rate_multiplier(obj, special_(), bias_learning_rate_multiplier);
}
// ----------------------------------------------------------------------------------------
namespace impl
@ -63,6 +123,66 @@ namespace dlib
template <typename T>
double get_weight_decay_multiplier(const T& obj) { return impl::get_weight_decay_multiplier(obj, special_()); }
namespace impl
{
template <typename T, typename int_<decltype(&T::set_weight_decay_multiplier)>::type = 0>
void set_weight_decay_multiplier (
T& obj,
special_,
double weight_decay_multiplier
) { obj.set_weight_decay_multiplier(weight_decay_multiplier); }
template <typename T>
void set_weight_decay_multiplier (T& , general_, double) { }
}
template <typename T>
void set_weight_decay_multiplier(
T& obj,
double weight_decay_multiplier
)
{
DLIB_CASSERT(weight_decay_multiplier >= 0);
impl::set_weight_decay_multiplier(obj, special_(), weight_decay_multiplier);
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::get_bias_weight_decay_multiplier)>::type = 0>
double get_bias_weight_decay_multiplier (
const T& obj,
special_
) { return obj.get_bias_weight_decay_multiplier(); }
template <typename T>
double get_bias_weight_decay_multiplier ( const T& , general_) { return 1; }
}
template <typename T>
double get_bias_weight_decay_multiplier(const T& obj) { return impl::get_bias_weight_decay_multiplier(obj, special_()); }
namespace impl
{
template <typename T, typename int_<decltype(&T::set_bias_weight_decay_multiplier)>::type = 0>
void set_bias_weight_decay_multiplier (
T& obj,
special_,
double bias_weight_decay_multiplier
) { obj.set_bias_weight_decay_multiplier(bias_weight_decay_multiplier); }
template <typename T>
void set_bias_weight_decay_multiplier (T& , general_, double) { }
}
template <typename T>
void set_bias_weight_decay_multiplier(
T& obj,
double bias_weight_decay_multiplier
)
{
DLIB_CASSERT(bias_weight_decay_multiplier >= 0);
impl::set_bias_weight_decay_multiplier(obj, special_(), bias_weight_decay_multiplier);
}
// ----------------------------------------------------------------------------------------
namespace impl

View File

@ -55,10 +55,56 @@ namespace dlib
- returns 1
!*/
template <typename T>
void set_learning_rate_multiplier(
T& obj,
double learning_rate_multiplier
)
/*!
requires
- learning_rate_multiplier >= 0
ensures
- if (obj has a set_learning_rate_multiplier() member function) then
- calls obj.set_learning_rate_multiplier(learning_rate_multiplier)
- else
- does nothing
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
double get_bias_learning_rate_multiplier(
const T& obj
);
/*!
ensures
- if (obj has a get_bias_learning_rate_multiplier() member function) then
- returns obj.get_bias_learning_rate_multiplier()
- else
- returns 1
!*/
template <typename T>
void set_bias_learning_rate_multiplier(
T& obj,
double bias_learning_rate_multiplier
)
/*!
requires
- bias_learning_rate_multiplier >= 0
ensures
- if (obj has a set_bias_learning_rate_multiplier() member function) then
- calls obj.set_bias_learning_rate_multiplier(bias_learning_rate_multiplier)
- else
- does nothing
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
double get_weight_decay_multiplier(
const T& obj
);
);
/*!
ensures
- if (obj has a get_weight_decay_multiplier() member function) then
@ -67,6 +113,50 @@ namespace dlib
- returns 1
!*/
template <typename T>
void set_weight_decay_multiplier(
T& obj,
double weight_decay_multiplier
);
/*!
requires
- weight_decay_multiplier >= 0
ensures
- if (obj has a set_weight_decay_multiplier() member function) then
- calls obj.set_weight_decay_multiplier(weight_decay_multiplier)
- else
- does nothing
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
double get_bias_weight_decay_multiplier(
const T& obj
);
/*!
ensures
- if (obj has a get_bias_weight_decay_multiplier() member function) then
- returns obj.get_bias_weight_decay_multiplier()
- else
- returns 1
!*/
template <typename T>
void set_bias_weight_decay_multiplier(
T& obj,
double bias_weight_decay_multiplier
);
/*!
requires:
- bias_weight_decay_multiplier >= 0
ensures
- if (obj has a set_bias_weight_decay_multiplier() member function) then
- calls obj.set_bias_weight_decay_multiplier(bias_weight_decay_multiplier)
- else
- does nothing
!*/
// ----------------------------------------------------------------------------------------
bool dnn_prefer_fastest_algorithms(

View File

@ -278,7 +278,7 @@ namespace dlib
class visitor_count_parameters
{
public:
visitor_count_parameters(size_t& num_parameters_): num_parameters(num_parameters_) {}
visitor_count_parameters(size_t& num_parameters_) : num_parameters(num_parameters_) {}
void operator()(size_t, const tensor& t)
{
@ -301,6 +301,64 @@ namespace dlib
return num_parameters;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_learning_rate_multiplier
{
public:
visitor_learning_rate_multiplier(double new_learning_rate_multiplier_) :
new_learning_rate_multiplier(new_learning_rate_multiplier_) {}
template <typename T>
void set_new_learning_rate_multiplier(T& l) const
{
set_learning_rate_multiplier(l, new_learning_rate_multiplier);
}
template <typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_new_learning_rate_multiplier(l.layer_details());
}
private:
double new_learning_rate_multiplier;
};
}
template <typename net_type>
void set_all_learning_rate_multipliers(
net_type& net,
double learning_rate_multiplier
)
{
DLIB_CASSERT(learning_rate_multiplier >= 0);
impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier);
visit_layers(net, temp);
}
template <size_t begin, size_t end, typename net_type>
void set_learning_rate_multipliers_range(
net_type& net,
double learning_rate_multiplier
)
{
static_assert(begin <= end, "Invalid range");
static_assert(end <= net_type::num_layers, "Invalid range");
DLIB_CASSERT(learning_rate_multiplier >= 0);
impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier);
visit_layers_range<begin, end>(net, temp);
}
// ----------------------------------------------------------------------------------------
}

View File

@ -133,6 +133,42 @@ namespace dlib
been trained then, since nothing has been allocated yet, it will return 0.
!*/
// ----------------------------------------------------------------------------------------
template<typename net_type>
void set_all_learning_rate_multipliers(
net_type& net,
double learning_rate_multiplier
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
- learning_rate_multiplier >= 0
ensures
- Sets all learning_rate_multipliers and bias_learning_rate_multipliers in net
to learning_rate_multiplier.
!*/
// ----------------------------------------------------------------------------------------
template <size_t begin, size_t end, typename net_type>
void set_learning_rate_multipliers_range(
net_type& net,
double learning_rate_multiplier
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
- learning_rate_multiplier >= 0
- begin <= end <= net_type::num_layers
ensures
- Loops over the layers in the range [begin,end) in net and calls
set_learning_rate_multiplier on them with the value of
learning_rate_multiplier.
!*/
// ----------------------------------------------------------------------------------------
}

View File

@ -138,6 +138,7 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
add_gui_example(dnn_face_recognition_ex)
add_example(dnn_introduction_ex)
add_example(dnn_introduction2_ex)
add_example(dnn_introduction3_ex)
add_example(dnn_inception_ex)
add_gui_example(dnn_mmod_ex)
add_gui_example(dnn_mmod_face_detection_ex)

View File

@ -0,0 +1,169 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This is an example illustrating the use of the deep learning tools from the
dlib C++ Library. I'm assuming you have already read the dnn_introduction_ex.cpp and
the dnn_introduction2_ex.cpp examples. So in this example program I'm going to go
over a transfer learning example, which includes:
- Defining a layer visitor to modify the some network parameters for fine-tuning
- Using pretrained layers of a network for another task
*/
#include <dlib/dnn.h>
#include <iostream>
// This header file includes a generic definition of the most common ResNet architectures
#include "resnet.h"
using namespace std;
using namespace dlib;
// In this simple example we will show how to load a pretrained network and use it for a
// different task. In particular, we will load a ResNet50 trained on ImageNet, adjust
// some of its parameters and use it as a pretrained backbone for some metric learning
// task.
// Let's start by defining a network that will use the ResNet50 backbone from resnet.h
namespace model
{
template<template<typename> class BN>
using net_type = loss_metric<
fc_no_bias<128,
avg_pool_everything<
typename resnet<BN>::template backbone_50<
input_rgb_image
>>>>;
using train = net_type<bn_con>;
using infer = net_type<affine>;
}
// Next, we define a layer visitor that will modify the weight decay of a network. The
// main interest of this class is to show how one can define custom visitors that modify
// some network parameters.
class visitor_weight_decay_multiplier
{
public:
visitor_weight_decay_multiplier(double new_weight_decay_multiplier_) :
new_weight_decay_multiplier(new_weight_decay_multiplier_) {}
template <typename T>
void set_new_weight_decay_mulitplier(T& l) const
{
set_weight_decay_multiplier(l, new_weight_decay_multiplier);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_new_weight_decay_mulitplier(l.layer_details());
}
private:
double new_weight_decay_multiplier;
};
int main() try
{
// Let's instantiate our network in train mode.
model::train net;
// We create a new scope so that resources from the loaded network are freed
// automatically when leaving the scope.
{
// Now, let's define the classic ResNet50 network and load the pretrained model on
// ImageNet.
resnet<bn_con>::n50 resnet50;
std::vector<string> labels;
deserialize("resnet50_1000_imagenet_classifier.dnn") >> resnet50 >> labels;
// For transfer learning, we are only interested in the ResNet50's backbone, which
// lays below the loss and the fc layers, so we can extract it as:
auto backbone = std::move(resnet50.subnet().subnet());
// We can now assign ResNet50's backbone to our network skipping the different
// layers, in our case, the loss layer and the fc layer:
net.subnet().subnet() = backbone;
// An alternative way to use the pretrained network on a different
// network is to extract the relevant part of the network (we remove
// loss and fc layers), stack the new layers on top of it and assign
// the network.
using net_type = loss_metric<fc_no_bias<128, decltype(backbone)>>;
net_type net2;
net2.subnet().subnet() = backbone;
}
// We can use the visit_layers function to modify the weight decay of the entire
// network:
visit_layers(net, visitor_weight_decay_multiplier(0.001));
// We can also use predefined visitors to affect the learning rate of the whole
// network.
set_all_learning_rate_multipliers(net, 0.5);
// Modifying the learning rates of a network is a common practice for fine tuning, for
// this reason it is already provided. However, it is implemented internally using a
// visitor that is very similar to the one defined in this example.
// Usually, we want to freeze the network, except for the top layers:
visit_layers(net.subnet().subnet(), visitor_weight_decay_multiplier(0));
set_all_learning_rate_multipliers(net.subnet().subnet(), 0);
// Alternatively, we can use the visit_layers_range to modify only a specific set of
// layers:
visit_layers_range<0, 2>(net, visitor_weight_decay_multiplier(1));
// Sometimes we might want to set the learning rate differently thoughout the network.
// Here we show how to use adjust the learning rate at the different ResNet50's
// convolutional blocks:
set_learning_rate_multipliers_range< 0, 2>(net, 1);
set_learning_rate_multipliers_range< 2, 38>(net, 0.1);
set_learning_rate_multipliers_range< 38, 107>(net, 0.01);
set_learning_rate_multipliers_range<107, 154>(net, 0.001);
set_learning_rate_multipliers_range<154, 193>(net, 0.0001);
// Finally, we can check the results by printing the network. But before, if we
// forward an image through the network, we will see tensors shape at every layer.
matrix<rgb_pixel> image(224, 224);
assign_all_pixels(image, rgb_pixel(0, 0, 0));
std::vector<matrix<rgb_pixel>> minibatch(1, image);
resizable_tensor input;
net.to_tensor(minibatch.begin(), minibatch.end(), input);
net.subnet().forward(input);
cout << net << endl;
cout << "input size=(" <<
"num:" << input.num_samples() << ", " <<
"k:" << input.k() << ", " <<
"nr:" << input.nr() << ", "
"nc:" << input.nc() << ")" << endl;
// We can also print the number of parameters of the network:
cout << "number of network parameters: " << count_parameters(net) << endl;
// From this point on, we can finetune the new network using this pretrained backbone
// on another task, such as the one showed in dnn_metric_learning_on_images_ex.cpp.
return EXIT_SUCCESS;
}
catch (const serialization_error& e)
{
cout << e.what() << endl;
cout << "You need to download a copy of the file resnet50_1000_imagenet_classifier.dnn" << endl;
cout << "available at http://dlib.net/files/resnet50_1000_imagenet_classifier.dnn.bz2" << endl;
cout << endl;
return EXIT_FAILURE;
}
catch (const exception& e)
{
cout << e.what() << endl;
return EXIT_FAILURE;
}

101
examples/resnet.h Normal file
View File

@ -0,0 +1,101 @@
#ifndef ResNet_H
#define ResNet_H
#include <dlib/dnn.h>
// BATCHNORM must be bn_con or affine layer
template<template<typename> class BATCHNORM>
struct resnet
{
// the resnet basic block, where BN is bn_con or affine
template<long num_filters, template<typename> class BN, int stride, typename SUBNET>
using basicblock = BN<dlib::con<num_filters, 3, 3, 1, 1,
dlib::relu<BN<dlib::con<num_filters, 3, 3, stride, stride, SUBNET>>>>>;
// the resnet bottleneck block
template<long num_filters, template<typename> class BN, int stride, typename SUBNET>
using bottleneck = BN<dlib::con<4 * num_filters, 1, 1, 1, 1,
dlib::relu<BN<dlib::con<num_filters, 3, 3, stride, stride,
dlib::relu<BN<dlib::con<num_filters, 1, 1, 1, 1, SUBNET>>>>>>>>;
// the resnet residual
template<
template<long, template<typename> class, int, typename> class BLOCK, // basicblock or bottleneck
long num_filters,
template<typename> class BN, // bn_con or affine
typename SUBNET
> // adds the block to the result of tag1 (the subnet)
using residual = dlib::add_prev1<BLOCK<num_filters, BN, 1, dlib::tag1<SUBNET>>>;
// a resnet residual that does subsampling on both paths
template<
template<long, template<typename> class, int, typename> class BLOCK, // basicblock or bottleneck
long num_filters,
template<typename> class BN, // bn_con or affine
typename SUBNET
>
using residual_down = dlib::add_prev2<dlib::avg_pool<2, 2, 2, 2,
dlib::skip1<dlib::tag2<BLOCK<num_filters, BN, 2,
dlib::tag1<SUBNET>>>>>>;
// residual block with optional downsampling and custom regularization (bn_con or affine)
template<
template<template<long, template<typename> class, int, typename> class, long, template<typename>class, typename> class RESIDUAL,
template<long, template<typename> class, int, typename> class BLOCK,
long num_filters,
template<typename> class BN, // bn_con or affine
typename SUBNET
>
using residual_block = dlib::relu<RESIDUAL<BLOCK, num_filters, BN, SUBNET>>;
template<long num_filters, typename SUBNET>
using resbasicblock_down = residual_block<residual_down, basicblock, num_filters, BATCHNORM, SUBNET>;
template<long num_filters, typename SUBNET>
using resbottleneck_down = residual_block<residual_down, bottleneck, num_filters, BATCHNORM, SUBNET>;
// some definitions to allow the use of the repeat layer
template<typename SUBNET> using resbasicblock_512 = residual_block<residual, basicblock, 512, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbasicblock_256 = residual_block<residual, basicblock, 256, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbasicblock_128 = residual_block<residual, basicblock, 128, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbasicblock_64 = residual_block<residual, basicblock, 64, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbottleneck_512 = residual_block<residual, bottleneck, 512, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbottleneck_256 = residual_block<residual, bottleneck, 256, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbottleneck_128 = residual_block<residual, bottleneck, 128, BATCHNORM, SUBNET>;
template<typename SUBNET> using resbottleneck_64 = residual_block<residual, bottleneck, 64, BATCHNORM, SUBNET>;
// common processing for standard resnet inputs
template<template<typename> class BN, typename INPUT>
using input_processing = dlib::max_pool<3, 3, 2, 2, dlib::relu<BN<dlib::con<64, 7, 7, 2, 2, INPUT>>>>;
// the resnet backbone with basicblocks
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
using backbone_basicblock =
dlib::repeat<nb_512, resbasicblock_512, resbasicblock_down<512,
dlib::repeat<nb_256, resbasicblock_256, resbasicblock_down<256,
dlib::repeat<nb_128, resbasicblock_128, resbasicblock_down<128,
dlib::repeat<nb_64, resbasicblock_64, input_processing<BATCHNORM, INPUT>>>>>>>>;
// the resnet backbone with bottlenecks
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
using backbone_bottleneck =
dlib::repeat<nb_512, resbottleneck_512, resbottleneck_down<512,
dlib::repeat<nb_256, resbottleneck_256, resbottleneck_down<256,
dlib::repeat<nb_128, resbottleneck_128, resbottleneck_down<128,
dlib::repeat<nb_64, resbottleneck_64, input_processing<BATCHNORM, INPUT>>>>>>>>;
// the backbones for the classic architectures
template<typename INPUT> using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>;
template<typename INPUT> using backbone_34 = backbone_basicblock<2, 5, 3, 3, INPUT>;
template<typename INPUT> using backbone_50 = backbone_bottleneck<2, 5, 3, 3, INPUT>;
template<typename INPUT> using backbone_101 = backbone_bottleneck<2, 22, 3, 3, INPUT>;
template<typename INPUT> using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>;
// the typical classifier models
using n18 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_18<dlib::input_rgb_image>>>>;
using n34 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_34<dlib::input_rgb_image>>>>;
using n50 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_50<dlib::input_rgb_image>>>>;
using n101 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_101<dlib::input_rgb_image>>>>;
using n152 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_152<dlib::input_rgb_image>>>>;
};
#endif // ResNet_H