mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add U-net style skip connections to the semantic-segmentation example (#1600)
* Add concat_prev layer, and U-net example for semantic segmentation * Allow to supply mini-batch size as command-line parameter * Decrease default mini-batch size from 30 to 24 * Resize t1, if needed * Use DenseNet-style blocks instead of residual learning * Increase default mini-batch size to 50 * Increase default mini-batch size from 50 to 60 * Resize even during the backward step, if needed * Use resize_bilinear_gradient for the backward step * Fix function call ambiguity problem * Clear destination before adding gradient * Works OK-ish * Add more U-tags * Tweak default mini-batch size * Define a simpler network when using Microsoft Visual C++ compiler; clean up the DenseNet stuff (leaving it for a later PR) * Decrease default mini-batch size from 24 to 23 * Define separate dnn filename for MSVC++ and not * Add documentation for the resize_to_prev layer; move the implementation so that it comes after mult_prev * Fix previous typo * Minor formatting changes * Reverse the ordering of levels * Increase the learning-rate stopping criterion back to 1e-4 (was 1e-8) * Use more U-tags even on Windows * Minor formatting * Latest MSVC 2017 builds fast, so there's no need to limit the depth any longer * Tweak default mini-batch size again * Even though latest MSVC can now build the extra layers, it does not mean we should add them! * Fix naming
This commit is contained in:
parent
fb4c62cc67
commit
f685cb4249
@ -2386,6 +2386,106 @@ namespace dlib
|
||||
using mult_prev9_ = mult_prev_<tag9>;
|
||||
using mult_prev10_ = mult_prev_<tag10>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
template<typename> class tag
|
||||
>
|
||||
class resize_prev_to_tagged_
|
||||
{
|
||||
public:
|
||||
const static unsigned long id = tag_id<tag>::id;
|
||||
|
||||
resize_prev_to_tagged_()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup (const SUBNET& /*sub*/)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||
{
|
||||
auto& prev = sub.get_output();
|
||||
auto& tagged = layer<tag>(sub).get_output();
|
||||
|
||||
DLIB_CASSERT(prev.num_samples() == tagged.num_samples());
|
||||
|
||||
output.set_size(prev.num_samples(),
|
||||
prev.k(),
|
||||
tagged.nr(),
|
||||
tagged.nc());
|
||||
|
||||
if (prev.nr() == tagged.nr() && prev.nc() == tagged.nc())
|
||||
{
|
||||
tt::copy_tensor(false, output, 0, prev, 0, prev.k());
|
||||
}
|
||||
else
|
||||
{
|
||||
tt::resize_bilinear(output, prev);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
|
||||
{
|
||||
auto& prev = sub.get_gradient_input();
|
||||
|
||||
DLIB_CASSERT(prev.k() == gradient_input.k());
|
||||
DLIB_CASSERT(prev.num_samples() == gradient_input.num_samples());
|
||||
|
||||
if (prev.nr() == gradient_input.nr() && prev.nc() == gradient_input.nc())
|
||||
{
|
||||
tt::copy_tensor(true, prev, 0, gradient_input, 0, prev.k());
|
||||
}
|
||||
else
|
||||
{
|
||||
tt::resize_bilinear_gradient(prev, gradient_input);
|
||||
}
|
||||
}
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
|
||||
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
|
||||
|
||||
friend void serialize(const resize_prev_to_tagged_& , std::ostream& out)
|
||||
{
|
||||
serialize("resize_prev_to_tagged_", out);
|
||||
}
|
||||
|
||||
friend void deserialize(resize_prev_to_tagged_& , std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "resize_prev_to_tagged_")
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_prev_to_tagged_.");
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const resize_prev_to_tagged_& item)
|
||||
{
|
||||
out << "resize_prev_to_tagged"<<id;
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const resize_prev_to_tagged_& item, std::ostream& out)
|
||||
{
|
||||
out << "<resize_prev_to_tagged tag='"<<id<<"'/>\n";
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params;
|
||||
};
|
||||
|
||||
template <
|
||||
template<typename> class tag,
|
||||
typename SUBNET
|
||||
>
|
||||
using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
|
@ -2382,6 +2382,56 @@ namespace dlib
|
||||
using mult_prev9_ = mult_prev_<tag9>;
|
||||
using mult_prev10_ = mult_prev_<tag10>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
template<typename> class tag
|
||||
>
|
||||
class resize_prev_to_tagged_
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
|
||||
defined above. This layer resizes the output channels of the previous layer
|
||||
to have the same number of rows and columns as the output of the tagged layer.
|
||||
|
||||
This layer uses bilinear interpolation. If the sizes match already, then it
|
||||
simply copies the data.
|
||||
|
||||
Therefore, you supply a tag via resize_prev_to_tagged's template argument that
|
||||
tells it what layer to use for the target size.
|
||||
|
||||
If tensor PREV is resized to size of tensor TAGGED, then a tensor OUT is
|
||||
produced such that:
|
||||
- OUT.num_samples() == PREV.num_samples()
|
||||
- OUT.k() == PREV.k()
|
||||
- OUT.nr() == TAGGED.nr()
|
||||
- OUT.nc() == TAGGED.nc()
|
||||
!*/
|
||||
|
||||
public:
|
||||
resize_prev_to_tagged_(
|
||||
);
|
||||
|
||||
template <typename SUBNET> void setup(const SUBNET& sub);
|
||||
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
|
||||
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
|
||||
dpoint map_input_to_output(dpoint p) const;
|
||||
dpoint map_output_to_input(dpoint p) const;
|
||||
const tensor& get_layer_params() const;
|
||||
tensor& get_layer_params();
|
||||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
|
||||
!*/
|
||||
};
|
||||
|
||||
|
||||
template <
|
||||
template<typename> class tag,
|
||||
typename SUBNET
|
||||
>
|
||||
using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
|
@ -1910,7 +1910,7 @@ namespace
|
||||
template <typename SUBNET>
|
||||
using pres = prelu<add_prev1<bn_con<con<8,3,3,1,1,prelu<bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>;
|
||||
|
||||
void test_visit_funcions()
|
||||
void test_visit_functions()
|
||||
{
|
||||
using net_type2 = loss_multiclass_log<fc<10,
|
||||
avg_pool_everything<
|
||||
@ -3243,7 +3243,7 @@ namespace
|
||||
test_batch_normalize_conv();
|
||||
test_basic_tensor_ops();
|
||||
test_layers();
|
||||
test_visit_funcions();
|
||||
test_visit_functions();
|
||||
test_copy_tensor_cpu();
|
||||
test_copy_tensor_add_to_cpu();
|
||||
test_concat();
|
||||
|
@ -16,7 +16,7 @@
|
||||
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
|
||||
|
||||
An alternative to steps 2-4 above is to download a pre-trained network
|
||||
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn
|
||||
from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
|
||||
|
||||
It would be a good idea to become familiar with dlib's DNN tooling before reading this
|
||||
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
|
||||
@ -111,16 +111,16 @@ int main(int argc, char** argv) try
|
||||
cout << "You call this program like this: " << endl;
|
||||
cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl;
|
||||
cout << endl;
|
||||
cout << "You will also need a trained 'semantic_segmentation_voc2012net.dnn' file." << endl;
|
||||
cout << "You will also need a trained '" << semantic_segmentation_net_filename << "' file." << endl;
|
||||
cout << "You can either train it yourself (see example program" << endl;
|
||||
cout << "dnn_semantic_segmentation_train_ex), or download a" << endl;
|
||||
cout << "copy from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn" << endl;
|
||||
cout << "copy from here: http://dlib.net/files/" << semantic_segmentation_net_filename << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Read the file containing the trained network from the working directory.
|
||||
anet_type net;
|
||||
deserialize("semantic_segmentation_voc2012net.dnn") >> net;
|
||||
deserialize(semantic_segmentation_net_filename) >> net;
|
||||
|
||||
// Show inference results in a window.
|
||||
image_window win;
|
||||
|
@ -23,7 +23,7 @@
|
||||
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
|
||||
|
||||
An alternative to steps 2-4 above is to download a pre-trained network
|
||||
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn
|
||||
from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
|
||||
|
||||
It would be a good idea to become familiar with dlib's DNN tooling before reading this
|
||||
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
|
||||
@ -116,13 +116,13 @@ const Voc2012class& find_voc2012_class(Predicate predicate)
|
||||
|
||||
// Introduce the building blocks used to define the segmentation network.
|
||||
// The network first does residual downsampling (similar to the dnn_imagenet_(train_)ex
|
||||
// example program), and then residual upsampling. The network could be improved e.g.
|
||||
// by introducing skip connections from the input image, and/or the first layers, to the
|
||||
// last layer(s). (See Long et al., Fully Convolutional Networks for Semantic Segmentation,
|
||||
// https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf)
|
||||
// example program), and then residual upsampling. In addition, U-Net style skip
|
||||
// connections are used, so that not every simple detail needs to reprented on the low
|
||||
// levels. (See Ronneberger et al. (2015), U-Net: Convolutional Networks for Biomedical
|
||||
// Image Segmentation, https://arxiv.org/pdf/1505.04597.pdf)
|
||||
|
||||
template <int N, template <typename> class BN, int stride, typename SUBNET>
|
||||
using block = BN<dlib::con<N,3,3,1,1, dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
|
||||
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
|
||||
|
||||
template <int N, template <typename> class BN, int stride, typename SUBNET>
|
||||
using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>;
|
||||
@ -145,55 +145,98 @@ template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename SUBNET> using res512 = res<512, SUBNET>;
|
||||
template <typename SUBNET> using res256 = res<256, SUBNET>;
|
||||
template <typename SUBNET> using res128 = res<128, SUBNET>;
|
||||
template <typename SUBNET> using res64 = res<64, SUBNET>;
|
||||
template <typename SUBNET> using ares512 = ares<512, SUBNET>;
|
||||
template <typename SUBNET> using ares256 = ares<256, SUBNET>;
|
||||
template <typename SUBNET> using ares128 = ares<128, SUBNET>;
|
||||
template <typename SUBNET> using ares64 = ares<64, SUBNET>;
|
||||
template <typename SUBNET> using res64 = res<64,SUBNET>;
|
||||
template <typename SUBNET> using res128 = res<128,SUBNET>;
|
||||
template <typename SUBNET> using res256 = res<256,SUBNET>;
|
||||
template <typename SUBNET> using res512 = res<512,SUBNET>;
|
||||
template <typename SUBNET> using ares64 = ares<64,SUBNET>;
|
||||
template <typename SUBNET> using ares128 = ares<128,SUBNET>;
|
||||
template <typename SUBNET> using ares256 = ares<256,SUBNET>;
|
||||
template <typename SUBNET> using ares512 = ares<512,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using level1 = dlib::repeat<2,res64,res<64,SUBNET>>;
|
||||
template <typename SUBNET> using level2 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
|
||||
template <typename SUBNET> using level3 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
|
||||
template <typename SUBNET> using level4 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using level1 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
|
||||
template <typename SUBNET> using level2 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
|
||||
template <typename SUBNET> using level3 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
|
||||
template <typename SUBNET> using level4 = dlib::repeat<2,res64,res<64,SUBNET>>;
|
||||
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
|
||||
template <typename SUBNET> using level1t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
|
||||
template <typename SUBNET> using level2t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
|
||||
template <typename SUBNET> using level3t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
|
||||
template <typename SUBNET> using level4t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using level1t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
|
||||
template <typename SUBNET> using level2t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
|
||||
template <typename SUBNET> using level3t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
|
||||
template <typename SUBNET> using level4t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
|
||||
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
template<typename> class TAGGED,
|
||||
template<typename> class PREV_RESIZED,
|
||||
typename SUBNET
|
||||
>
|
||||
using resize_and_concat = dlib::add_layer<
|
||||
dlib::concat_<TAGGED,PREV_RESIZED>,
|
||||
PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
|
||||
|
||||
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
|
||||
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
|
||||
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
|
||||
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
|
||||
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
|
||||
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
|
||||
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
|
||||
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
|
||||
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
|
||||
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
static const char* semantic_segmentation_net_filename = "semantic_segmentation_voc2012net_v2.dnn";
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// training network type
|
||||
using net_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<class_count,7,7,2,2,
|
||||
level4t<level3t<level2t<level1t<
|
||||
level1<level2<level3<level4<
|
||||
dlib::max_pool<3,3,2,2,dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
|
||||
using bnet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<class_count,1,1,1,1,
|
||||
dlib::relu<dlib::bn_con<dlib::cont<64,7,7,2,2,
|
||||
concat_utag1<level1t<
|
||||
concat_utag2<level2t<
|
||||
concat_utag3<level3t<
|
||||
concat_utag4<level4t<
|
||||
level4<utag4<
|
||||
level3<utag3<
|
||||
level2<utag2<
|
||||
level1<dlib::max_pool<3,3,2,2,utag1<
|
||||
dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
|
||||
dlib::input<dlib::matrix<dlib::rgb_pixel>>
|
||||
>>>>>>>>>>>>>>;
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>>;
|
||||
|
||||
// testing network type (replaced batch normalization with fixed affine transforms)
|
||||
using anet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<class_count,7,7,2,2,
|
||||
alevel4t<alevel3t<alevel2t<alevel1t<
|
||||
alevel1<alevel2<alevel3<alevel4<
|
||||
dlib::max_pool<3,3,2,2,dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
|
||||
dlib::cont<class_count,1,1,1,1,
|
||||
dlib::relu<dlib::affine<dlib::cont<64,7,7,2,2,
|
||||
concat_utag1<alevel1t<
|
||||
concat_utag2<alevel2t<
|
||||
concat_utag3<alevel3t<
|
||||
concat_utag4<alevel4t<
|
||||
alevel4<utag4<
|
||||
alevel3<utag3<
|
||||
alevel2<utag2<
|
||||
alevel1<dlib::max_pool<3,3,2,2,utag1<
|
||||
dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
|
||||
dlib::input<dlib::matrix<dlib::rgb_pixel>>
|
||||
>>>>>>>>>>>>>>;
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -41,7 +41,7 @@ struct training_sample
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
rectangle make_random_cropping_rect_resnet(
|
||||
rectangle make_random_cropping_rect(
|
||||
const matrix<rgb_pixel>& img,
|
||||
dlib::rand& rnd
|
||||
)
|
||||
@ -66,7 +66,7 @@ void randomly_crop_image (
|
||||
dlib::rand& rnd
|
||||
)
|
||||
{
|
||||
const auto rect = make_random_cropping_rect_resnet(input_image, rnd);
|
||||
const auto rect = make_random_cropping_rect(input_image, rnd);
|
||||
|
||||
const chip_details chip_details(rect, chip_dims(227, 227));
|
||||
|
||||
@ -259,12 +259,12 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase
|
||||
|
||||
int main(int argc, char** argv) try
|
||||
{
|
||||
if (argc != 2)
|
||||
if (argc < 2 || argc > 3)
|
||||
{
|
||||
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
|
||||
cout << endl;
|
||||
cout << "You call this program like this: " << endl;
|
||||
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012" << endl;
|
||||
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]" << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -278,13 +278,16 @@ int main(int argc, char** argv) try
|
||||
return 1;
|
||||
}
|
||||
|
||||
// a mini-batch smaller than the default can be used with GPUs having less memory
|
||||
const int minibatch_size = argc == 3 ? std::stoi(argv[2]) : 23;
|
||||
cout << "mini-batch size: " << minibatch_size << endl;
|
||||
|
||||
const double initial_learning_rate = 0.1;
|
||||
const double weight_decay = 0.0001;
|
||||
const double momentum = 0.9;
|
||||
|
||||
net_type net;
|
||||
dnn_trainer<net_type> trainer(net,sgd(weight_decay, momentum));
|
||||
bnet_type bnet;
|
||||
dnn_trainer<bnet_type> trainer(bnet,sgd(weight_decay, momentum));
|
||||
trainer.be_verbose();
|
||||
trainer.set_learning_rate(initial_learning_rate);
|
||||
trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10));
|
||||
@ -292,7 +295,7 @@ int main(int argc, char** argv) try
|
||||
trainer.set_iterations_without_progress_threshold(5000);
|
||||
// Since the progress threshold is so large might as well set the batch normalization
|
||||
// stats window to something big too.
|
||||
set_all_bn_running_stats_window_sizes(net, 1000);
|
||||
set_all_bn_running_stats_window_sizes(bnet, 1000);
|
||||
|
||||
// Output training parameters.
|
||||
cout << endl << trainer << endl;
|
||||
@ -345,9 +348,9 @@ int main(int argc, char** argv) try
|
||||
samples.clear();
|
||||
labels.clear();
|
||||
|
||||
// make a 30-image mini-batch
|
||||
// make a mini-batch
|
||||
training_sample temp;
|
||||
while(samples.size() < 30)
|
||||
while(samples.size() < minibatch_size)
|
||||
{
|
||||
data.dequeue(temp);
|
||||
|
||||
@ -369,13 +372,13 @@ int main(int argc, char** argv) try
|
||||
// also wait for threaded processing to stop in the trainer.
|
||||
trainer.get_net();
|
||||
|
||||
net.clean();
|
||||
bnet.clean();
|
||||
cout << "saving network" << endl;
|
||||
serialize("semantic_segmentation_voc2012net.dnn") << net;
|
||||
serialize(semantic_segmentation_net_filename) << bnet;
|
||||
|
||||
|
||||
// Make a copy of the network to use it for inference.
|
||||
anet_type anet = net;
|
||||
anet_type anet = bnet;
|
||||
|
||||
cout << "Testing the network..." << endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user