Add U-net style skip connections to the semantic-segmentation example (#1600)

* Add concat_prev layer, and U-net example for semantic segmentation

* Allow to supply mini-batch size as command-line parameter

* Decrease default mini-batch size from 30 to 24

* Resize t1, if needed

* Use DenseNet-style blocks instead of residual learning

* Increase default mini-batch size to 50

* Increase default mini-batch size from 50 to 60

* Resize even during the backward step, if needed

* Use resize_bilinear_gradient for the backward step

* Fix function call ambiguity problem

* Clear destination before adding gradient

* Works OK-ish

* Add more U-tags

* Tweak default mini-batch size

* Define a simpler network when using Microsoft Visual C++ compiler; clean up the DenseNet stuff (leaving it for a later PR)

* Decrease default mini-batch size from 24 to 23

* Define separate dnn filename for MSVC++ and not

* Add documentation for the resize_to_prev layer; move the implementation so that it comes after mult_prev

* Fix previous typo

* Minor formatting changes

* Reverse the ordering of levels

* Increase the learning-rate stopping criterion back to 1e-4 (was 1e-8)

* Use more U-tags even on Windows

* Minor formatting

* Latest MSVC 2017 builds fast, so there's no need to limit the depth any longer

* Tweak default mini-batch size again

* Even though latest MSVC can now build the extra layers, it does not mean we should add them!

* Fix naming
pull/1614/head
Juha Reunanen 6 years ago committed by Davis E. King
parent fb4c62cc67
commit f685cb4249

@ -2386,6 +2386,106 @@ namespace dlib
using mult_prev9_ = mult_prev_<tag9>;
using mult_prev10_ = mult_prev_<tag10>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class tag
>
class resize_prev_to_tagged_
{
public:
const static unsigned long id = tag_id<tag>::id;
resize_prev_to_tagged_()
{
}
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto& prev = sub.get_output();
auto& tagged = layer<tag>(sub).get_output();
DLIB_CASSERT(prev.num_samples() == tagged.num_samples());
output.set_size(prev.num_samples(),
prev.k(),
tagged.nr(),
tagged.nc());
if (prev.nr() == tagged.nr() && prev.nc() == tagged.nc())
{
tt::copy_tensor(false, output, 0, prev, 0, prev.k());
}
else
{
tt::resize_bilinear(output, prev);
}
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto& prev = sub.get_gradient_input();
DLIB_CASSERT(prev.k() == gradient_input.k());
DLIB_CASSERT(prev.num_samples() == gradient_input.num_samples());
if (prev.nr() == gradient_input.nr() && prev.nc() == gradient_input.nc())
{
tt::copy_tensor(true, prev, 0, gradient_input, 0, prev.k());
}
else
{
tt::resize_bilinear_gradient(prev, gradient_input);
}
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
friend void serialize(const resize_prev_to_tagged_& , std::ostream& out)
{
serialize("resize_prev_to_tagged_", out);
}
friend void deserialize(resize_prev_to_tagged_& , std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "resize_prev_to_tagged_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_prev_to_tagged_.");
}
friend std::ostream& operator<<(std::ostream& out, const resize_prev_to_tagged_& item)
{
out << "resize_prev_to_tagged"<<id;
return out;
}
friend void to_xml(const resize_prev_to_tagged_& item, std::ostream& out)
{
out << "<resize_prev_to_tagged tag='"<<id<<"'/>\n";
}
private:
resizable_tensor params;
};
template <
template<typename> class tag,
typename SUBNET
>
using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
// ----------------------------------------------------------------------------------------
template <

@ -2382,6 +2382,56 @@ namespace dlib
using mult_prev9_ = mult_prev_<tag9>;
using mult_prev10_ = mult_prev_<tag10>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class tag
>
class resize_prev_to_tagged_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. This layer resizes the output channels of the previous layer
to have the same number of rows and columns as the output of the tagged layer.
This layer uses bilinear interpolation. If the sizes match already, then it
simply copies the data.
Therefore, you supply a tag via resize_prev_to_tagged's template argument that
tells it what layer to use for the target size.
If tensor PREV is resized to size of tensor TAGGED, then a tensor OUT is
produced such that:
- OUT.num_samples() == PREV.num_samples()
- OUT.k() == PREV.k()
- OUT.nr() == TAGGED.nr()
- OUT.nc() == TAGGED.nc()
!*/
public:
resize_prev_to_tagged_(
);
template <typename SUBNET> void setup(const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
template <
template<typename> class tag,
typename SUBNET
>
using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
// ----------------------------------------------------------------------------------------
template <

@ -1910,7 +1910,7 @@ namespace
template <typename SUBNET>
using pres = prelu<add_prev1<bn_con<con<8,3,3,1,1,prelu<bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>;
void test_visit_funcions()
void test_visit_functions()
{
using net_type2 = loss_multiclass_log<fc<10,
avg_pool_everything<
@ -3243,7 +3243,7 @@ namespace
test_batch_normalize_conv();
test_basic_tensor_ops();
test_layers();
test_visit_funcions();
test_visit_functions();
test_copy_tensor_cpu();
test_copy_tensor_add_to_cpu();
test_concat();

@ -16,7 +16,7 @@
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn
from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
@ -111,16 +111,16 @@ int main(int argc, char** argv) try
cout << "You call this program like this: " << endl;
cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl;
cout << endl;
cout << "You will also need a trained 'semantic_segmentation_voc2012net.dnn' file." << endl;
cout << "You will also need a trained '" << semantic_segmentation_net_filename << "' file." << endl;
cout << "You can either train it yourself (see example program" << endl;
cout << "dnn_semantic_segmentation_train_ex), or download a" << endl;
cout << "copy from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn" << endl;
cout << "copy from here: http://dlib.net/files/" << semantic_segmentation_net_filename << endl;
return 1;
}
// Read the file containing the trained network from the working directory.
anet_type net;
deserialize("semantic_segmentation_voc2012net.dnn") >> net;
deserialize(semantic_segmentation_net_filename) >> net;
// Show inference results in a window.
image_window win;

@ -23,7 +23,7 @@
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn
from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
@ -116,10 +116,10 @@ const Voc2012class& find_voc2012_class(Predicate predicate)
// Introduce the building blocks used to define the segmentation network.
// The network first does residual downsampling (similar to the dnn_imagenet_(train_)ex
// example program), and then residual upsampling. The network could be improved e.g.
// by introducing skip connections from the input image, and/or the first layers, to the
// last layer(s). (See Long et al., Fully Convolutional Networks for Semantic Segmentation,
// https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf)
// example program), and then residual upsampling. In addition, U-Net style skip
// connections are used, so that not every simple detail needs to reprented on the low
// levels. (See Ronneberger et al. (2015), U-Net: Convolutional Networks for Biomedical
// Image Segmentation, https://arxiv.org/pdf/1505.04597.pdf)
template <int N, template <typename> class BN, int stride, typename SUBNET>
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
@ -145,55 +145,98 @@ template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block
// ----------------------------------------------------------------------------------------
template <typename SUBNET> using res512 = res<512, SUBNET>;
template <typename SUBNET> using res256 = res<256, SUBNET>;
template <typename SUBNET> using res128 = res<128, SUBNET>;
template <typename SUBNET> using res64 = res<64,SUBNET>;
template <typename SUBNET> using ares512 = ares<512, SUBNET>;
template <typename SUBNET> using ares256 = ares<256, SUBNET>;
template <typename SUBNET> using ares128 = ares<128, SUBNET>;
template <typename SUBNET> using res128 = res<128,SUBNET>;
template <typename SUBNET> using res256 = res<256,SUBNET>;
template <typename SUBNET> using res512 = res<512,SUBNET>;
template <typename SUBNET> using ares64 = ares<64,SUBNET>;
template <typename SUBNET> using ares128 = ares<128,SUBNET>;
template <typename SUBNET> using ares256 = ares<256,SUBNET>;
template <typename SUBNET> using ares512 = ares<512,SUBNET>;
template <typename SUBNET> using level1 = dlib::repeat<2,res64,res<64,SUBNET>>;
template <typename SUBNET> using level2 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
template <typename SUBNET> using level3 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
template <typename SUBNET> using level4 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
template <typename SUBNET> using level1 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
template <typename SUBNET> using level2 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
template <typename SUBNET> using level3 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
template <typename SUBNET> using level4 = dlib::repeat<2,res64,res<64,SUBNET>>;
template <typename SUBNET> using level1t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
template <typename SUBNET> using level2t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
template <typename SUBNET> using level1t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
template <typename SUBNET> using level2t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class TAGGED,
template<typename> class PREV_RESIZED,
typename SUBNET
>
using resize_and_concat = dlib::add_layer<
dlib::concat_<TAGGED,PREV_RESIZED>,
PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
static const char* semantic_segmentation_net_filename = "semantic_segmentation_voc2012net_v2.dnn";
// ----------------------------------------------------------------------------------------
// training network type
using net_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,7,7,2,2,
level4t<level3t<level2t<level1t<
level1<level2<level3<level4<
dlib::max_pool<3,3,2,2,dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
using bnet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,1,1,1,1,
dlib::relu<dlib::bn_con<dlib::cont<64,7,7,2,2,
concat_utag1<level1t<
concat_utag2<level2t<
concat_utag3<level3t<
concat_utag4<level4t<
level4<utag4<
level3<utag3<
level2<utag2<
level1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>;
>>>>>>>>>>>>>>>>>>>>>>>>>;
// testing network type (replaced batch normalization with fixed affine transforms)
using anet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,7,7,2,2,
alevel4t<alevel3t<alevel2t<alevel1t<
alevel1<alevel2<alevel3<alevel4<
dlib::max_pool<3,3,2,2,dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
dlib::cont<class_count,1,1,1,1,
dlib::relu<dlib::affine<dlib::cont<64,7,7,2,2,
concat_utag1<alevel1t<
concat_utag2<alevel2t<
concat_utag3<alevel3t<
concat_utag4<alevel4t<
alevel4<utag4<
alevel3<utag3<
alevel2<utag2<
alevel1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>;
>>>>>>>>>>>>>>>>>>>>>>>>>;
// ----------------------------------------------------------------------------------------

@ -41,7 +41,7 @@ struct training_sample
// ----------------------------------------------------------------------------------------
rectangle make_random_cropping_rect_resnet(
rectangle make_random_cropping_rect(
const matrix<rgb_pixel>& img,
dlib::rand& rnd
)
@ -66,7 +66,7 @@ void randomly_crop_image (
dlib::rand& rnd
)
{
const auto rect = make_random_cropping_rect_resnet(input_image, rnd);
const auto rect = make_random_cropping_rect(input_image, rnd);
const chip_details chip_details(rect, chip_dims(227, 227));
@ -259,12 +259,12 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase
int main(int argc, char** argv) try
{
if (argc != 2)
if (argc < 2 || argc > 3)
{
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
cout << endl;
cout << "You call this program like this: " << endl;
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012" << endl;
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]" << endl;
return 1;
}
@ -278,13 +278,16 @@ int main(int argc, char** argv) try
return 1;
}
// a mini-batch smaller than the default can be used with GPUs having less memory
const int minibatch_size = argc == 3 ? std::stoi(argv[2]) : 23;
cout << "mini-batch size: " << minibatch_size << endl;
const double initial_learning_rate = 0.1;
const double weight_decay = 0.0001;
const double momentum = 0.9;
net_type net;
dnn_trainer<net_type> trainer(net,sgd(weight_decay, momentum));
bnet_type bnet;
dnn_trainer<bnet_type> trainer(bnet,sgd(weight_decay, momentum));
trainer.be_verbose();
trainer.set_learning_rate(initial_learning_rate);
trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10));
@ -292,7 +295,7 @@ int main(int argc, char** argv) try
trainer.set_iterations_without_progress_threshold(5000);
// Since the progress threshold is so large might as well set the batch normalization
// stats window to something big too.
set_all_bn_running_stats_window_sizes(net, 1000);
set_all_bn_running_stats_window_sizes(bnet, 1000);
// Output training parameters.
cout << endl << trainer << endl;
@ -345,9 +348,9 @@ int main(int argc, char** argv) try
samples.clear();
labels.clear();
// make a 30-image mini-batch
// make a mini-batch
training_sample temp;
while(samples.size() < 30)
while(samples.size() < minibatch_size)
{
data.dequeue(temp);
@ -369,13 +372,13 @@ int main(int argc, char** argv) try
// also wait for threaded processing to stop in the trainer.
trainer.get_net();
net.clean();
bnet.clean();
cout << "saving network" << endl;
serialize("semantic_segmentation_voc2012net.dnn") << net;
serialize(semantic_segmentation_net_filename) << bnet;
// Make a copy of the network to use it for inference.
anet_type anet = net;
anet_type anet = bnet;
cout << "Testing the network..." << endl;

Loading…
Cancel
Save