dlib/examples/dnn_semantic_segmentation_ex.h

168 lines
9.3 KiB
C
Raw Permalink Normal View History

Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
Semantic segmentation using the PASCAL VOC2012 dataset.
In segmentation, the task is to assign each pixel of an input image
a label - for example, 'dog'. Then, the idea is that neighboring
pixels having the same label can be connected together to form a
larger region, representing a complete (or partially occluded) dog.
So technically, segmentation can be viewed as classification of
individual pixels (using the relevant context in the input images),
however the goal usually is to identify meaningful regions that
represent complete entities of interest (such as dogs).
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_semantic_segmentation_train_ex example program.
3. Run:
./dnn_semantic_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_semantic_segmentation_ex example program.
6. Run:
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#ifndef DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
#define DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
#include <dlib/dnn.h>
Instance segmentation (#1918) * Add instance segmentation example - first version of training code * Add MMOD options; get rid of the cache approach, and instead load all MMOD rects upfront * Improve console output * Set filter count * Minor tweaking * Inference - first version, at least compiles! * Ignore overlapped boxes * Ignore even small instances * Set overlaps_ignore * Add TODO remarks * Revert "Set overlaps_ignore" This reverts commit 65adeff1f89af62b10c691e7aa86c04fc358d03e. * Set result size * Set label image size * Take ignore-color into account * Fix the cropping rect's aspect ratio; also slightly expand the rect * Draw the largest findings last * Improve masking of the current instance * Add some perturbation to the inputs * Simplify ground-truth reading; fix random cropping * Read even class labels * Tweak default minibatch size * Learn only one class * Really train only instances of the selected class * Remove outdated TODO remark * Automatically skip images with no detections * Print to console what was found * Fix class index problem * Fix indentation * Allow to choose multiple classes * Draw rect in the color of the corresponding class * Write detector window classes to ostream; also group detection windows by class (when ostreaming) * Train a separate instance segmentation network for each classlabel * Use separate synchronization file for each seg net of each class * Allow more overlap * Fix sorting criterion * Fix interpolating the predicted mask * Improve bilinear interpolation: if output type is an integer, round instead of truncating * Add helpful comments * Ignore large aspect ratios; refactor the code; tweak some network parameters * Simplify the segmentation network structure; make the object detection network more complex in turn * Problem: CUDA errors not reported properly to console Solution: stop and join data loader threads even in case of exceptions * Minor parameters tweaking * Loss may have increased, even if prob_loss_increasing_thresh > prob_loss_increasing_thresh_max_value * Add previous_loss_values_dump_amount to previous_loss_values.size() when deciding if loss has been increasing * Improve behaviour when loss actually increased after disk sync * Revert some of the earlier change * Disregard dumped loss values only when deciding if learning rate should be shrunk, but *not* when deciding if loss has been going up since last disk sync * Revert "Revert some of the earlier change" This reverts commit 6c852124efe6473a5c962de0091709129d6fcde3. * Keep enough previous loss values, until the disk sync * Fix maintaining the dumped (now "effectively disregarded") loss values count * Detect cats instead of aeroplanes * Add helpful logging * Clarify the intention and the code * Review fixes * Add operator== for the other pixel types as well; remove the inline * If available, use constexpr if * Revert "If available, use constexpr if" This reverts commit 503d4dd3355ff8ad613116e3ffcc0fa664674f69. * Simplify code as per review comments * Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh * Clarify console output * Revert "Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh" This reverts commit 9191ebc7762d17d81cdfc334a80ca9a667365740. * To keep the changes to a bare minimum, revert the steps_since_last_learning_rate_shrink change after all (at least for now) * Even empty out some of the previous test loss values * Minor review fixes * Can't use C++14 features here * Do not use the struct name as a variable name
2019-11-15 11:53:16 +08:00
#include "pascal_voc_2012.h"
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
// ----------------------------------------------------------------------------------------
// Introduce the building blocks used to define the segmentation network.
// The network first does residual downsampling (similar to the dnn_imagenet_(train_)ex
// example program), and then residual upsampling. In addition, U-Net style skip
2023-06-12 08:21:05 +08:00
// connections are used, so that not every simple detail needs to represented on the low
// levels. (See Ronneberger et al. (2015), U-Net: Convolutional Networks for Biomedical
// Image Segmentation, https://arxiv.org/pdf/1505.04597.pdf)
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
template <int N, template <typename> class BN, int stride, typename SUBNET>
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
template <int N, template <typename> class BN, int stride, typename SUBNET>
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual = dlib::add_prev1<block<N,BN,1,dlib::tag1<SUBNET>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_down = dlib::add_prev2<dlib::avg_pool<2,2,2,2,dlib::skip1<dlib::tag2<block<N,BN,2,dlib::tag1<SUBNET>>>>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_up = dlib::add_prev2<dlib::cont<N,2,2,2,2,dlib::skip1<dlib::tag2<blockt<N,BN,2,dlib::tag1<SUBNET>>>>>>;
template <int N, typename SUBNET> using res = dlib::relu<residual<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares = dlib::relu<residual<block,N,dlib::affine,SUBNET>>;
template <int N, typename SUBNET> using res_down = dlib::relu<residual_down<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares_down = dlib::relu<residual_down<block,N,dlib::affine,SUBNET>>;
template <int N, typename SUBNET> using res_up = dlib::relu<residual_up<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block,N,dlib::affine,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <typename SUBNET> using res64 = res<64,SUBNET>;
template <typename SUBNET> using res128 = res<128,SUBNET>;
template <typename SUBNET> using res256 = res<256,SUBNET>;
template <typename SUBNET> using res512 = res<512,SUBNET>;
template <typename SUBNET> using ares64 = ares<64,SUBNET>;
template <typename SUBNET> using ares128 = ares<128,SUBNET>;
template <typename SUBNET> using ares256 = ares<256,SUBNET>;
template <typename SUBNET> using ares512 = ares<512,SUBNET>;
template <typename SUBNET> using level1 = dlib::repeat<2,res64,res<64,SUBNET>>;
template <typename SUBNET> using level2 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
template <typename SUBNET> using level3 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
template <typename SUBNET> using level4 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
template <typename SUBNET> using level1t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
template <typename SUBNET> using level2t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class TAGGED,
template<typename> class PREV_RESIZED,
typename SUBNET
>
using resize_and_concat = dlib::add_layer<
dlib::concat_<TAGGED,PREV_RESIZED>,
PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
// ----------------------------------------------------------------------------------------
static const char* semantic_segmentation_net_filename = "semantic_segmentation_voc2012net_v2.dnn";
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
// ----------------------------------------------------------------------------------------
// training network type
using bnet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,1,1,1,1,
dlib::relu<dlib::bn_con<dlib::cont<64,7,7,2,2,
concat_utag1<level1t<
concat_utag2<level2t<
concat_utag3<level3t<
concat_utag4<level4t<
level4<utag4<
level3<utag3<
level2<utag2<
level1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>>>>>>>>>>>>;
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
// testing network type (replaced batch normalization with fixed affine transforms)
using anet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,1,1,1,1,
dlib::relu<dlib::affine<dlib::cont<64,7,7,2,2,
concat_utag1<alevel1t<
concat_utag2<alevel2t<
concat_utag3<alevel3t<
concat_utag4<alevel4t<
alevel4<utag4<
alevel3<utag3<
alevel2<utag2<
alevel1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>>>>>>>>>>>>;
Add semantic segmentation example (#943) * Add example of semantic segmentation using the PASCAL VOC2012 dataset * Add note about Debug Information Format when using MSVC * Make the upsampling layers residual as well * Fix declaration order * Use a wider net * trainer.set_iterations_without_progress_threshold(5000); // (was 20000) * Add residual_up * Process entire directories of images (just easier to use) * Simplify network structure so that builds finish even on Visual Studio (faster, or at all) * Remove the training example from CMakeLists, because it's too much for the 32-bit MSVC++ compiler to handle * Remove the probably-now-unnecessary set_dnn_prefer_smallest_algorithms call * Review fix: remove the batch normalization layer from right before the loss * Review fix: point out that only the Visual C++ compiler has problems. Also expand the instructions how to run MSBuild.exe to circumvent the problems. * Review fix: use dlib::match_endings * Review fix: use dlib::join_rows. Also add some comments, and instructions where to download the pre-trained net from. * Review fix: make formatting comply with dlib style conventions. * Review fix: output training parameters. * Review fix: remove #ifndef __INTELLISENSE__ * Review fix: use std::string instead of char* * Review fix: update interpolation_abstract.h to say that extract_image_chips can now take the interpolation method as a parameter * Fix whitespace formatting * Add more comments * Fix finding image files for inference * Resize inference test output to the size of the input; add clarifying remarks * Resize net output even in calculate_accuracy * After all crop the net output instead of resizing it by interpolation * For clarity, add an empty line in the console output
2017-11-15 20:01:52 +08:00
// ----------------------------------------------------------------------------------------
#endif // DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_