Instance segmentation (#1918)

* Add instance segmentation example - first version of training code

* Add MMOD options; get rid of the cache approach, and instead load all MMOD rects upfront

* Improve console output

* Set filter count

* Minor tweaking

* Inference - first version, at least compiles!

* Ignore overlapped boxes

* Ignore even small instances

* Set overlaps_ignore

* Add TODO remarks

* Revert "Set overlaps_ignore"

This reverts commit 65adeff1f8.

* Set result size

* Set label image size

* Take ignore-color into account

* Fix the cropping rect's aspect ratio; also slightly expand the rect

* Draw the largest findings last

* Improve masking of the current instance

* Add some perturbation to the inputs

* Simplify ground-truth reading; fix random cropping

* Read even class labels

* Tweak default minibatch size

* Learn only one class

* Really train only instances of the selected class

* Remove outdated TODO remark

* Automatically skip images with no detections

* Print to console what was found

* Fix class index problem

* Fix indentation

* Allow to choose multiple classes

* Draw rect in the color of the corresponding class

* Write detector window classes to ostream; also group detection windows by class (when ostreaming)

* Train a separate instance segmentation network for each classlabel

* Use separate synchronization file for each seg net of each class

* Allow more overlap

* Fix sorting criterion

* Fix interpolating the predicted mask

* Improve bilinear interpolation: if output type is an integer, round instead of truncating

* Add helpful comments

* Ignore large aspect ratios; refactor the code; tweak some network parameters

* Simplify the segmentation network structure; make the object detection network more complex in turn

* Problem: CUDA errors not reported properly to console
Solution: stop and join data loader threads even in case of exceptions

* Minor parameters tweaking

* Loss may have increased, even if prob_loss_increasing_thresh > prob_loss_increasing_thresh_max_value

* Add previous_loss_values_dump_amount to previous_loss_values.size() when deciding if loss has been increasing

* Improve behaviour when loss actually increased after disk sync

* Revert some of the earlier change

* Disregard dumped loss values only when deciding if learning rate should be shrunk, but *not* when deciding if loss has been going up since last disk sync

* Revert "Revert some of the earlier change"

This reverts commit 6c852124ef.

* Keep enough previous loss values, until the disk sync

* Fix maintaining the dumped (now "effectively disregarded") loss values count

* Detect cats instead of aeroplanes

* Add helpful logging

* Clarify the intention and the code

* Review fixes

* Add operator== for the other pixel types as well; remove the inline

* If available, use constexpr if

* Revert "If available, use constexpr if"

This reverts commit 503d4dd335.

* Simplify code as per review comments

* Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh

* Clarify console output

* Revert "Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh"

This reverts commit 9191ebc776.

* To keep the changes to a bare minimum, revert the steps_since_last_learning_rate_shrink change after all (at least for now)

* Even empty out some of the previous test loss values

* Minor review fixes

* Can't use C++14 features here

* Do not use the struct name as a variable name
This commit is contained in:
Juha Reunanen 2019-11-15 05:53:16 +02:00 committed by Davis E. King
parent 7f2be82e33
commit d175c35074
11 changed files with 1483 additions and 213 deletions

View File

@ -993,6 +993,36 @@ namespace dlib
}
}
inline std::ostream& operator<<(std::ostream& out, const std::vector<mmod_options::detector_window_details>& detector_windows)
{
// write detector windows grouped by label
// example output: aeroplane:74x30,131x30,70x45,54x70,198x30;bicycle:70x57,32x70,70x32,51x70,128x30,30x121;car:70x36,70x60,99x30,52x70,30x83,30x114,30x200
std::map<std::string, std::deque<mmod_options::detector_window_details>> detector_windows_by_label;
for (const auto& detector_window : detector_windows)
detector_windows_by_label[detector_window.label].push_back(detector_window);
size_t label_count = 0;
for (const auto& i : detector_windows_by_label)
{
const auto& label = i.first;
const auto& detector_windows = i.second;
if (label_count++ > 0)
out << ";";
out << label << ":";
for (size_t j = 0; j < detector_windows.size(); ++j)
{
out << detector_windows[j].width << "x" << detector_windows[j].height;
if (j + 1 < detector_windows.size())
out << ",";
}
}
return out;
}
// ----------------------------------------------------------------------------------------
class loss_mmod_
@ -1395,15 +1425,10 @@ namespace dlib
{
out << "loss_mmod\t (";
out << "detector_windows:(";
auto& opts = item.options;
for (size_t i = 0; i < opts.detector_windows.size(); ++i)
{
out << opts.detector_windows[i].width << "x" << opts.detector_windows[i].height;
if (i+1 < opts.detector_windows.size())
out << ",";
}
out << ")";
out << "detector_windows:(" << opts.detector_windows << ")";
out << ", loss per FA:" << opts.loss_per_false_alarm;
out << ", loss per miss:" << opts.loss_per_missed_target;
out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold;

View File

@ -460,6 +460,7 @@ namespace dlib
test_steps_without_progress = 0;
previous_loss_values.clear();
test_previous_loss_values.clear();
previous_loss_values_to_keep_until_disk_sync.clear();
}
learning_rate = lr;
lr_schedule.set_size(0);
@ -602,6 +603,11 @@ namespace dlib
// discard really old loss values.
while (previous_loss_values.size() > iter_without_progress_thresh)
previous_loss_values.pop_front();
// separately keep another loss history until disk sync
// (but only if disk sync is enabled)
if (!sync_filename.empty())
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
}
template <typename T>
@ -700,10 +706,10 @@ namespace dlib
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
test_steps_without_progress = 0;
// Empty out some of the previous loss values so that test_steps_without_progress
// will decrease below test_iter_without_progress_thresh.
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
test_previous_loss_values.pop_front();
drop_some_test_previous_loss_values();
}
}
}
@ -820,10 +826,10 @@ namespace dlib
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
steps_without_progress = 0;
// Empty out some of the previous loss values so that steps_without_progress
// will decrease below iter_without_progress_thresh.
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
previous_loss_values.pop_front();
drop_some_previous_loss_values();
}
}
}
@ -895,7 +901,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 12;
int version = 13;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
@ -924,14 +930,14 @@ namespace dlib
serialize(item.test_previous_loss_values, out);
serialize(item.previous_loss_values_dump_amount, out);
serialize(item.test_previous_loss_values_dump_amount, out);
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 12)
if (version != 13)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
@ -970,6 +976,7 @@ namespace dlib
deserialize(item.test_previous_loss_values, in);
deserialize(item.previous_loss_values_dump_amount, in);
deserialize(item.test_previous_loss_values_dump_amount, in);
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
if (item.devices.size() > 1)
{
@ -987,6 +994,20 @@ namespace dlib
}
}
// Empty out some of the previous loss values so that steps_without_progress will decrease below iter_without_progress_thresh.
void drop_some_previous_loss_values()
{
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount + iter_without_progress_thresh / 10 && previous_loss_values.size() > 0; ++cnt)
previous_loss_values.pop_front();
}
// Empty out some of the previous test loss values so that test_steps_without_progress will decrease below test_iter_without_progress_thresh.
void drop_some_test_previous_loss_values()
{
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10 && test_previous_loss_values.size() > 0; ++cnt)
test_previous_loss_values.pop_front();
}
void sync_to_disk (
bool do_it_now = false
)
@ -1020,6 +1041,20 @@ namespace dlib
sync_file_reloaded = true;
if (verbose)
std::cout << "Loss has been increasing, reloading saved state from " << newest_syncfile() << std::endl;
// Are we repeatedly hitting our head against the wall? If so, then we
// might be better off giving up at this learning rate, and trying a
// lower one instead.
if (prob_loss_increasing_thresh >= prob_loss_increasing_thresh_max_value)
{
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
learning_rate = learning_rate_shrink * learning_rate;
steps_without_progress = 0;
test_steps_without_progress = 0;
drop_some_previous_loss_values();
drop_some_test_previous_loss_values();
}
}
else
{
@ -1057,34 +1092,39 @@ namespace dlib
if (!std::ifstream(newest_syncfile(), std::ios::binary))
return false;
for (auto x : previous_loss_values)
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting better or worse.
while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
previous_loss_values_to_keep_until_disk_sync.pop_front();
running_gradient g;
for (auto x : previous_loss_values_to_keep_until_disk_sync)
{
// If we get a NaN value of loss assume things have gone horribly wrong and
// we should reload the state of the trainer.
if (std::isnan(x))
return true;
g.add(x);
}
// if we haven't seen much data yet then just say false. Or, alternatively, if
// it's been too long since the last sync then don't reload either.
if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync)
// if we haven't seen much data yet then just say false.
if (gradient_updates_since_last_sync < 30)
return false;
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting bettor or worse.
running_gradient g;
for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i)
g.add(previous_loss_values[i]);
// if the loss is very likely to be increasing then return true
const double prob = g.probability_gradient_greater_than(0);
if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value)
if (prob > prob_loss_increasing_thresh)
{
// Exponentially decay the threshold towards 1 so that if we keep finding
// the loss to be increasing over and over we will make the test
// progressively harder and harder until it fails, therefore ensuring we
// can't get stuck reloading from a previous state over and over.
prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1;
prob_loss_increasing_thresh = std::min(
0.1*prob_loss_increasing_thresh + 0.9*1,
prob_loss_increasing_thresh_max_value
);
return true;
}
else
@ -1247,6 +1287,8 @@ namespace dlib
std::atomic<unsigned long> test_steps_without_progress;
std::deque<double> test_previous_loss_values;
std::deque<double> previous_loss_values_to_keep_until_disk_sync;
std::atomic<double> learning_rate_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
std::string sync_filename;

View File

@ -865,10 +865,18 @@ namespace dlib
float fout[4];
out.store(fout);
out_img[r][c] = static_cast<T>(fout[0]);
out_img[r][c+1] = static_cast<T>(fout[1]);
out_img[r][c+2] = static_cast<T>(fout[2]);
out_img[r][c+3] = static_cast<T>(fout[3]);
const auto convert_to_output_type = [](float value)
{
if (std::is_integral<T>::value)
return static_cast<T>(value + 0.5);
else
return static_cast<T>(value);
};
out_img[r][c] = convert_to_output_type(fout[0]);
out_img[r][c+1] = convert_to_output_type(fout[1]);
out_img[r][c+2] = convert_to_output_type(fout[2]);
out_img[r][c+3] = convert_to_output_type(fout[3]);
}
x = -x_scale + c*x_scale;
for (; c < out_img.nc(); ++c)

View File

@ -125,6 +125,13 @@ namespace dlib
unsigned char red;
unsigned char green;
unsigned char blue;
bool operator == (const rgb_pixel& that) const
{
return this->red == that.red
&& this->green == that.green
&& this->blue == that.blue;
}
};
// ----------------------------------------------------------------------------------------
@ -151,6 +158,13 @@ namespace dlib
unsigned char blue;
unsigned char green;
unsigned char red;
bool operator == (const bgr_pixel& that) const
{
return this->blue == that.blue
&& this->green == that.green
&& this->red == that.red;
}
};
// ----------------------------------------------------------------------------------------
@ -177,6 +191,14 @@ namespace dlib
unsigned char green;
unsigned char blue;
unsigned char alpha;
bool operator == (const rgb_alpha_pixel& that) const
{
return this->red == that.red
&& this->green == that.green
&& this->blue == that.blue
&& this->alpha == that.alpha;
}
};
// ----------------------------------------------------------------------------------------
@ -200,6 +222,13 @@ namespace dlib
unsigned char h;
unsigned char s;
unsigned char i;
bool operator == (const hsi_pixel& that) const
{
return this->h == that.h
&& this->s == that.s
&& this->i == that.i;
}
};
// ----------------------------------------------------------------------------------------
@ -222,6 +251,13 @@ namespace dlib
unsigned char l;
unsigned char a;
unsigned char b;
bool operator == (const lab_pixel& that) const
{
return this->l == that.l
&& this->a == that.a
&& this->b == that.b;
}
};
// ----------------------------------------------------------------------------------------

View File

@ -148,8 +148,10 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
add_gui_example(dnn_mmod_find_cars2_ex)
add_example(dnn_mmod_train_find_cars_ex)
add_gui_example(dnn_semantic_segmentation_ex)
add_gui_example(dnn_instance_segmentation_ex)
add_example(dnn_imagenet_train_ex)
add_example(dnn_semantic_segmentation_train_ex)
add_example(dnn_instance_segmentation_train_ex)
add_example(dnn_metric_learning_on_images_ex)
endif()

View File

@ -0,0 +1,178 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This example shows how to do instance segmentation on an image using net pretrained
on the PASCAL VOC2012 dataset. For an introduction to what instance segmentation is,
see the accompanying header file dnn_instance_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_instance_segmentation_train_ex example program.
3. Run:
./dnn_instance_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_instance_segmentation_ex example program.
6. Run:
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#include "dnn_instance_segmentation_ex.h"
#include "pascal_voc_2012.h"
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/gui_widgets.h>
using namespace std;
using namespace dlib;
// ----------------------------------------------------------------------------------------
int main(int argc, char** argv) try
{
if (argc != 2)
{
cout << "You call this program like this: " << endl;
cout << "./dnn_instance_segmentation_train_ex /path/to/images" << endl;
cout << endl;
cout << "You will also need a trained '" << instance_segmentation_net_filename << "' file." << endl;
cout << "You can either train it yourself (see example program" << endl;
cout << "dnn_instance_segmentation_train_ex), or download a" << endl;
cout << "copy from here: http://dlib.net/files/" << instance_segmentation_net_filename << endl;
return 1;
}
// Read the file containing the trained networks from the working directory.
det_anet_type det_net;
std::map<std::string, seg_bnet_type> seg_nets_by_class;
deserialize(instance_segmentation_net_filename) >> det_net >> seg_nets_by_class;
// Show inference results in a window.
image_window win;
matrix<rgb_pixel> input_image;
// Find supported image files.
const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1],
dlib::match_endings(".jpeg .jpg .png"));
dlib::rand rnd;
cout << "Found " << files.size() << " images, processing..." << endl;
for (const file& file : files)
{
// Load the input image.
load_image(input_image, file.full_name());
// Draw largest objects last
const auto sort_instances = [](const std::vector<mmod_rect>& input) {
auto output = input;
const auto compare_area = [](const mmod_rect& lhs, const mmod_rect& rhs) {
return lhs.rect.area() < rhs.rect.area();
};
std::sort(output.begin(), output.end(), compare_area);
return output;
};
// Find instances in the input image
const auto instances = sort_instances(det_net(input_image));
matrix<rgb_pixel> rgb_label_image;
matrix<rgb_pixel> input_chip;
rgb_label_image.set_size(input_image.nr(), input_image.nc());
rgb_label_image = rgb_pixel(0, 0, 0);
bool found_something = false;
for (const auto& instance : instances)
{
if (!found_something)
{
cout << "Found ";
found_something = true;
}
else
{
cout << ", ";
}
cout << instance.label;
const auto cropping_rect = get_cropping_rect(instance.rect);
const chip_details chip_details(cropping_rect, chip_dims(seg_dim, seg_dim));
extract_image_chip(input_image, chip_details, input_chip, interpolate_bilinear());
const auto i = seg_nets_by_class.find(instance.label);
if (i == seg_nets_by_class.end())
{
// per-class segmentation net not found, so we must be using the same net for all classes
// (see bool separate_seg_net_for_each_class in dnn_instance_segmentation_train_ex.cpp)
DLIB_CASSERT(seg_nets_by_class.size() == 1);
DLIB_CASSERT(seg_nets_by_class.begin()->first == "");
}
auto& seg_net = i != seg_nets_by_class.end()
? i->second // use the segmentation net trained for this class
: seg_nets_by_class.begin()->second; // use the same segmentation net for all classes
const auto mask = seg_net(input_chip);
const rgb_pixel random_color(
rnd.get_random_8bit_number(),
rnd.get_random_8bit_number(),
rnd.get_random_8bit_number()
);
dlib::matrix<uint16_t> resized_mask(
static_cast<int>(chip_details.rect.height()),
static_cast<int>(chip_details.rect.width())
);
dlib::resize_image(mask, resized_mask);
for (int r = 0; r < resized_mask.nr(); ++r)
{
for (int c = 0; c < resized_mask.nc(); ++c)
{
if (resized_mask(r, c))
{
const auto y = chip_details.rect.top() + r;
const auto x = chip_details.rect.left() + c;
if (y >= 0 && y < rgb_label_image.nr() && x >= 0 && x < rgb_label_image.nc())
rgb_label_image(y, x) = random_color;
}
}
}
const Voc2012class& voc2012_class = find_voc2012_class(
[&instance](const Voc2012class& candidate) {
return candidate.classlabel == instance.label;
}
);
dlib::draw_rectangle(rgb_label_image, instance.rect, voc2012_class.rgb_label, 1);
}
// Show the input image on the left, and the predicted RGB labels on the right.
win.set_image(join_rows(input_image, rgb_label_image));
if (!instances.empty())
{
cout << " in " << file.name() << " - hit enter to process the next image";
cin.get();
}
}
}
catch(std::exception& e)
{
cout << e.what() << endl;
}

View File

@ -0,0 +1,200 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
Instance segmentation using the PASCAL VOC2012 dataset.
Instance segmentation sort-of combines object detection with semantic
segmentation. While each dog, for example, is detected separately,
the output is not only a bounding-box but a more accurate, per-pixel
mask.
For introductions to object detection and semantic segmentation, you
can have a look at dnn_mmod_ex.cpp and dnn_semantic_segmentation.h,
respectively.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_instance_segmentation_train_ex example program.
3. Run:
./dnn_instance_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_instance_segmentation_ex example program.
6. Run:
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#ifndef DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
#define DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
#include <dlib/dnn.h>
// ----------------------------------------------------------------------------------------
namespace {
// Segmentation will be performed using patches having this size.
constexpr int seg_dim = 227;
}
dlib::rectangle get_cropping_rect(const dlib::rectangle& rectangle)
{
DLIB_ASSERT(!rectangle.is_empty());
const auto center_point = dlib::center(rectangle);
const auto max_dim = std::max(rectangle.width(), rectangle.height());
const auto d = static_cast<long>(std::round(max_dim / 2.0 * 1.5)); // add +50%
return dlib::rectangle(
center_point.x() - d,
center_point.y() - d,
center_point.x() + d,
center_point.y() + d
);
}
// ----------------------------------------------------------------------------------------
// The object detection network.
// Adapted from dnn_mmod_train_find_cars_ex.cpp and friends.
template <long num_filters, typename SUBNET> using con5d = dlib::con<num_filters,5,5,2,2,SUBNET>;
template <long num_filters, typename SUBNET> using con5 = dlib::con<num_filters,5,5,1,1,SUBNET>;
template <typename SUBNET> using bdownsampler = dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<32,SUBNET>>>>>>>>>;
template <typename SUBNET> using adownsampler = dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<32,SUBNET>>>>>>>>>;
template <typename SUBNET> using brcon5 = dlib::relu<dlib::bn_con<con5<256,SUBNET>>>;
template <typename SUBNET> using arcon5 = dlib::relu<dlib::affine<con5<256,SUBNET>>>;
using det_bnet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,brcon5<brcon5<brcon5<bdownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
using det_anet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,arcon5<arcon5<arcon5<adownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
// The segmentation network.
// For the time being, this is very much copy-paste from dnn_semantic_segmentation.h, although the network is made narrower (smaller feature maps).
template <int N, template <typename> class BN, int stride, typename SUBNET>
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
template <int N, template <typename> class BN, int stride, typename SUBNET>
using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual = dlib::add_prev1<block<N,BN,1,dlib::tag1<SUBNET>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_down = dlib::add_prev2<dlib::avg_pool<2,2,2,2,dlib::skip1<dlib::tag2<block<N,BN,2,dlib::tag1<SUBNET>>>>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_up = dlib::add_prev2<dlib::cont<N,2,2,2,2,dlib::skip1<dlib::tag2<blockt<N,BN,2,dlib::tag1<SUBNET>>>>>>;
template <int N, typename SUBNET> using res = dlib::relu<residual<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares = dlib::relu<residual<block,N,dlib::affine,SUBNET>>;
template <int N, typename SUBNET> using res_down = dlib::relu<residual_down<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares_down = dlib::relu<residual_down<block,N,dlib::affine,SUBNET>>;
template <int N, typename SUBNET> using res_up = dlib::relu<residual_up<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block,N,dlib::affine,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <typename SUBNET> using res16 = res<16,SUBNET>;
template <typename SUBNET> using res24 = res<24,SUBNET>;
template <typename SUBNET> using res32 = res<32,SUBNET>;
template <typename SUBNET> using res48 = res<48,SUBNET>;
template <typename SUBNET> using ares16 = ares<16,SUBNET>;
template <typename SUBNET> using ares24 = ares<24,SUBNET>;
template <typename SUBNET> using ares32 = ares<32,SUBNET>;
template <typename SUBNET> using ares48 = ares<48,SUBNET>;
template <typename SUBNET> using level1 = dlib::repeat<2,res16,res<16,SUBNET>>;
template <typename SUBNET> using level2 = dlib::repeat<2,res24,res_down<24,SUBNET>>;
template <typename SUBNET> using level3 = dlib::repeat<2,res32,res_down<32,SUBNET>>;
template <typename SUBNET> using level4 = dlib::repeat<2,res48,res_down<48,SUBNET>>;
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares16,ares<16,SUBNET>>;
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares24,ares_down<24,SUBNET>>;
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares32,ares_down<32,SUBNET>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares48,ares_down<48,SUBNET>>;
template <typename SUBNET> using level1t = dlib::repeat<2,res16,res_up<16,SUBNET>>;
template <typename SUBNET> using level2t = dlib::repeat<2,res24,res_up<24,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res32,res_up<32,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res48,res_up<48,SUBNET>>;
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares16,ares_up<16,SUBNET>>;
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares24,ares_up<24,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares32,ares_up<32,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares48,ares_up<48,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class TAGGED,
template<typename> class PREV_RESIZED,
typename SUBNET
>
using resize_and_concat = dlib::add_layer<
dlib::concat_<TAGGED,PREV_RESIZED>,
PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
// ----------------------------------------------------------------------------------------
static const char* instance_segmentation_net_filename = "instance_segmentation_voc2012net.dnn";
// ----------------------------------------------------------------------------------------
// training network type
using seg_bnet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<2,1,1,1,1,
dlib::relu<dlib::bn_con<dlib::cont<16,7,7,2,2,
concat_utag1<level1t<
concat_utag2<level2t<
concat_utag3<level3t<
concat_utag4<level4t<
level4<utag4<
level3<utag3<
level2<utag2<
level1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::bn_con<dlib::con<16,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>>>>>>>>>>>>;
// testing network type (replaced batch normalization with fixed affine transforms)
using seg_anet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<2,1,1,1,1,
dlib::relu<dlib::affine<dlib::cont<16,7,7,2,2,
concat_utag1<alevel1t<
concat_utag2<alevel2t<
concat_utag3<alevel3t<
concat_utag4<alevel4t<
alevel4<utag4<
alevel3<utag3<
alevel2<utag2<
alevel1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::affine<dlib::con<16,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>>>>>>>>>>>>;
// ----------------------------------------------------------------------------------------
#endif // DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_

View File

@ -0,0 +1,776 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This example shows how to train a instance segmentation net using the PASCAL VOC2012
dataset. For an introduction to what segmentation is, see the accompanying header file
dnn_instance_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_instance_segmentation_train_ex example program.
3. Run:
./dnn_instance_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_instance_segmentation_ex example program.
6. Run:
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp, dnn_introduction2_ex.cpp,
and dnn_semantic_segmentation_train_ex.cpp before reading this example program.
*/
#include "dnn_instance_segmentation_ex.h"
#include "pascal_voc_2012.h"
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/image_transforms.h>
#include <dlib/dir_nav.h>
#include <iterator>
#include <thread>
#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
#include <execution>
#endif // __cplusplus >= 201703L
using namespace std;
using namespace dlib;
// ----------------------------------------------------------------------------------------
// A single training sample for detection. A mini-batch comprises many of these.
struct det_training_sample
{
matrix<rgb_pixel> input_image;
std::vector<dlib::mmod_rect> mmod_rects;
};
// A single training sample for segmentation. A mini-batch comprises many of these.
struct seg_training_sample
{
matrix<rgb_pixel> input_image;
matrix<uint16_t> label_image; // The ground-truth label of each pixel.
};
// ----------------------------------------------------------------------------------------
bool is_instance_pixel(const dlib::rgb_pixel& rgb_label)
{
if (rgb_label == dlib::rgb_pixel(0, 0, 0))
return false; // Background
if (rgb_label == dlib::rgb_pixel(224, 224, 192))
return false; // The cream-colored `void' label is used in border regions and to mask difficult objects
return true;
}
// Provide hash function for dlib::rgb_pixel
namespace std {
template <>
struct hash<dlib::rgb_pixel>
{
std::size_t operator()(const dlib::rgb_pixel& p) const
{
return (static_cast<uint32_t>(p.red) << 16)
| (static_cast<uint32_t>(p.green) << 8)
| (static_cast<uint32_t>(p.blue));
}
};
}
struct truth_instance
{
dlib::rgb_pixel rgb_label;
dlib::mmod_rect mmod_rect;
};
std::vector<truth_instance> rgb_label_images_to_truth_instances(
const dlib::matrix<dlib::rgb_pixel>& instance_label_image,
const dlib::matrix<dlib::rgb_pixel>& class_label_image
)
{
std::unordered_map<dlib::rgb_pixel, mmod_rect> result_map;
DLIB_CASSERT(instance_label_image.nr() == class_label_image.nr());
DLIB_CASSERT(instance_label_image.nc() == class_label_image.nc());
const auto nr = instance_label_image.nr();
const auto nc = instance_label_image.nc();
for (int r = 0; r < nr; ++r)
{
for (int c = 0; c < nc; ++c)
{
const auto rgb_instance_label = instance_label_image(r, c);
if (!is_instance_pixel(rgb_instance_label))
continue;
const auto rgb_class_label = class_label_image(r, c);
const Voc2012class& voc2012_class = find_voc2012_class(rgb_class_label);
const auto i = result_map.find(rgb_instance_label);
if (i == result_map.end())
{
// Encountered a new instance
result_map[rgb_instance_label] = rectangle(c, r, c, r);
result_map[rgb_instance_label].label = voc2012_class.classlabel;
}
else
{
// Not the first occurrence - update the rect
auto& rect = i->second.rect;
if (c < rect.left())
rect.set_left(c);
else if (c > rect.right())
rect.set_right(c);
if (r > rect.bottom())
rect.set_bottom(r);
DLIB_CASSERT(i->second.label == voc2012_class.classlabel);
}
}
}
std::vector<truth_instance> flat_result;
flat_result.reserve(result_map.size());
for (const auto& i : result_map) {
flat_result.push_back(truth_instance{
i.first, i.second
});
}
return flat_result;
}
// ----------------------------------------------------------------------------------------
struct truth_image
{
image_info info;
std::vector<truth_instance> truth_instances;
};
std::vector<mmod_rect> extract_mmod_rects(
const std::vector<truth_instance>& truth_instances
)
{
std::vector<mmod_rect> mmod_rects(truth_instances.size());
std::transform(
truth_instances.begin(),
truth_instances.end(),
mmod_rects.begin(),
[](const truth_instance& truth) { return truth.mmod_rect; }
);
return mmod_rects;
};
std::vector<std::vector<mmod_rect>> extract_mmod_rect_vectors(
const std::vector<truth_image>& truth_images
)
{
std::vector<std::vector<mmod_rect>> mmod_rects(truth_images.size());
const auto extract_mmod_rects_from_truth_image = [](const truth_image& truth_image)
{
return extract_mmod_rects(truth_image.truth_instances);
};
std::transform(
truth_images.begin(),
truth_images.end(),
mmod_rects.begin(),
extract_mmod_rects_from_truth_image
);
return mmod_rects;
}
det_bnet_type train_detection_network(
const std::vector<truth_image>& truth_images,
unsigned int det_minibatch_size
)
{
const double initial_learning_rate = 0.1;
const double weight_decay = 0.0001;
const double momentum = 0.9;
const double min_detector_window_overlap_iou = 0.65;
const int target_size = 70;
const int min_target_size = 30;
mmod_options options(
extract_mmod_rect_vectors(truth_images),
target_size, min_target_size,
min_detector_window_overlap_iou
);
options.overlaps_ignore = test_box_overlap(0.5, 0.9);
det_bnet_type det_net(options);
det_net.subnet().layer_details().set_num_filters(options.detector_windows.size());
dlib::pipe<det_training_sample> data(200);
auto f = [&data, &truth_images, target_size, min_target_size](time_t seed)
{
dlib::rand rnd(time(0) + seed);
matrix<rgb_pixel> input_image;
random_cropper cropper;
cropper.set_seed(time(0));
cropper.set_chip_dims(350, 350);
// Usually you want to give the cropper whatever min sizes you passed to the
// mmod_options constructor, or very slightly smaller sizes, which is what we do here.
cropper.set_min_object_size(target_size - 2, min_target_size - 2);
cropper.set_max_rotation_degrees(2);
det_training_sample temp;
while (data.is_enabled())
{
// Pick a random input image.
const auto random_index = rnd.get_random_32bit_number() % truth_images.size();
const auto& truth_image = truth_images[random_index];
// Load the input image.
load_image(input_image, truth_image.info.image_filename);
// Get a random crop of the input.
const auto mmod_rects = extract_mmod_rects(truth_image.truth_instances);
cropper(input_image, mmod_rects, temp.input_image, temp.mmod_rects);
disturb_colors(temp.input_image, rnd);
// Push the result to be used by the trainer.
data.enqueue(temp);
}
};
std::thread data_loader1([f]() { f(1); });
std::thread data_loader2([f]() { f(2); });
std::thread data_loader3([f]() { f(3); });
std::thread data_loader4([f]() { f(4); });
const auto stop_data_loaders = [&]()
{
data.disable();
data_loader1.join();
data_loader2.join();
data_loader3.join();
data_loader4.join();
};
dnn_trainer<det_bnet_type> det_trainer(det_net, sgd(weight_decay, momentum));
try
{
det_trainer.be_verbose();
det_trainer.set_learning_rate(initial_learning_rate);
det_trainer.set_synchronization_file("pascal_voc2012_det_trainer_state_file.dat", std::chrono::minutes(10));
det_trainer.set_iterations_without_progress_threshold(5000);
// Output training parameters.
cout << det_trainer << endl;
std::vector<matrix<rgb_pixel>> samples;
std::vector<std::vector<mmod_rect>> labels;
// The main training loop. Keep making mini-batches and giving them to the trainer.
// We will run until the learning rate becomes small enough.
while (det_trainer.get_learning_rate() >= 1e-4)
{
samples.clear();
labels.clear();
// make a mini-batch
det_training_sample temp;
while (samples.size() < det_minibatch_size)
{
data.dequeue(temp);
samples.push_back(std::move(temp.input_image));
labels.push_back(std::move(temp.mmod_rects));
}
det_trainer.train_one_step(samples, labels);
}
}
catch (std::exception&)
{
stop_data_loaders();
throw;
}
// Training done, tell threads to stop and make sure to wait for them to finish before
// moving on.
stop_data_loaders();
// also wait for threaded processing to stop in the trainer.
det_trainer.get_net();
det_net.clean();
return det_net;
}
// ----------------------------------------------------------------------------------------
matrix<uint16_t> keep_only_current_instance(const matrix<rgb_pixel>& rgb_label_image, const rgb_pixel rgb_label)
{
const auto nr = rgb_label_image.nr();
const auto nc = rgb_label_image.nc();
matrix<uint16_t> result(nr, nc);
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
const auto& index = rgb_label_image(r, c);
if (index == rgb_label)
result(r, c) = 1;
else if (index == dlib::rgb_pixel(224, 224, 192))
result(r, c) = dlib::loss_multiclass_log_per_pixel_::label_to_ignore;
else
result(r, c) = 0;
}
}
return result;
}
seg_bnet_type train_segmentation_network(
const std::vector<truth_image>& truth_images,
unsigned int seg_minibatch_size,
const std::string& classlabel
)
{
seg_bnet_type seg_net;
const double initial_learning_rate = 0.1;
const double weight_decay = 0.0001;
const double momentum = 0.9;
const std::string synchronization_file_name
= "pascal_voc2012_seg_trainer_state_file"
+ (classlabel.empty() ? "" : ("_" + classlabel))
+ ".dat";
dnn_trainer<seg_bnet_type> seg_trainer(seg_net, sgd(weight_decay, momentum));
seg_trainer.be_verbose();
seg_trainer.set_learning_rate(initial_learning_rate);
seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10));
seg_trainer.set_iterations_without_progress_threshold(2000);
set_all_bn_running_stats_window_sizes(seg_net, 1000);
// Output training parameters.
cout << seg_trainer << endl;
std::vector<matrix<rgb_pixel>> samples;
std::vector<matrix<uint16_t>> labels;
// Start a bunch of threads that read images from disk and pull out random crops. It's
// important to be sure to feed the GPU fast enough to keep it busy. Using multiple
// thread for this kind of data preparation helps us do that. Each thread puts the
// crops into the data queue.
dlib::pipe<seg_training_sample> data(200);
auto f = [&data, &truth_images](time_t seed)
{
dlib::rand rnd(time(0) + seed);
matrix<rgb_pixel> input_image;
matrix<rgb_pixel> rgb_label_image;
matrix<rgb_pixel> rgb_label_chip;
seg_training_sample temp;
while (data.is_enabled())
{
// Pick a random input image.
const auto random_index = rnd.get_random_32bit_number() % truth_images.size();
const auto& truth_image = truth_images[random_index];
const auto image_truths = truth_image.truth_instances;
if (!image_truths.empty())
{
const image_info& info = truth_image.info;
// Load the input image.
load_image(input_image, info.image_filename);
// Load the ground-truth (RGB) instance labels.
load_image(rgb_label_image, info.instance_label_filename);
// Pick a random training instance.
const auto& truth_instance = image_truths[rnd.get_random_32bit_number() % image_truths.size()];
const auto& truth_rect = truth_instance.mmod_rect.rect;
const auto cropping_rect = get_cropping_rect(truth_rect);
// Pick a random crop around the instance.
const auto max_x_translate_amount = static_cast<long>(truth_rect.width() / 10.0);
const auto max_y_translate_amount = static_cast<long>(truth_rect.height() / 10.0);
const auto random_translate = point(
rnd.get_integer_in_range(-max_x_translate_amount, max_x_translate_amount + 1),
rnd.get_integer_in_range(-max_y_translate_amount, max_y_translate_amount + 1)
);
const rectangle random_rect(
cropping_rect.left() + random_translate.x(),
cropping_rect.top() + random_translate.y(),
cropping_rect.right() + random_translate.x(),
cropping_rect.bottom() + random_translate.y()
);
const chip_details chip_details(random_rect, chip_dims(seg_dim, seg_dim));
// Crop the input image.
extract_image_chip(input_image, chip_details, temp.input_image, interpolate_bilinear());
disturb_colors(temp.input_image, rnd);
// Crop the labels correspondingly. However, note that here bilinear
// interpolation would make absolutely no sense - you wouldn't say that
// a bicycle is half-way between an aeroplane and a bird, would you?
extract_image_chip(rgb_label_image, chip_details, rgb_label_chip, interpolate_nearest_neighbor());
// Clear pixels not related to the current instance.
temp.label_image = keep_only_current_instance(rgb_label_chip, truth_instance.rgb_label);
// Push the result to be used by the trainer.
data.enqueue(temp);
}
else
{
// TODO: use background samples as well
}
}
};
std::thread data_loader1([f]() { f(1); });
std::thread data_loader2([f]() { f(2); });
std::thread data_loader3([f]() { f(3); });
std::thread data_loader4([f]() { f(4); });
const auto stop_data_loaders = [&]()
{
data.disable();
data_loader1.join();
data_loader2.join();
data_loader3.join();
data_loader4.join();
};
try
{
// The main training loop. Keep making mini-batches and giving them to the trainer.
// We will run until the learning rate has dropped by a factor of 1e-4.
while (seg_trainer.get_learning_rate() >= 1e-4)
{
samples.clear();
labels.clear();
// make a mini-batch
seg_training_sample temp;
while (samples.size() < seg_minibatch_size)
{
data.dequeue(temp);
samples.push_back(std::move(temp.input_image));
labels.push_back(std::move(temp.label_image));
}
seg_trainer.train_one_step(samples, labels);
}
}
catch (std::exception&)
{
stop_data_loaders();
throw;
}
// Training done, tell threads to stop and make sure to wait for them to finish before
// moving on.
stop_data_loaders();
// also wait for threaded processing to stop in the trainer.
seg_trainer.get_net();
seg_net.clean();
return seg_net;
}
// ----------------------------------------------------------------------------------------
int ignore_overlapped_boxes(
std::vector<truth_instance>& truth_instances,
const test_box_overlap& overlaps
)
/*!
ensures
- Whenever two rectangles in boxes overlap, according to overlaps(), we set the
smallest box to ignore.
- returns the number of newly ignored boxes.
!*/
{
int num_ignored = 0;
for (size_t i = 0, end = truth_instances.size(); i < end; ++i)
{
auto& box_i = truth_instances[i].mmod_rect;
if (box_i.ignore)
continue;
for (size_t j = i+1; j < end; ++j)
{
auto& box_j = truth_instances[j].mmod_rect;
if (box_j.ignore)
continue;
if (overlaps(box_i, box_j))
{
++num_ignored;
if(box_i.rect.area() < box_j.rect.area())
box_i.ignore = true;
else
box_j.ignore = true;
}
}
}
return num_ignored;
}
std::vector<truth_instance> load_truth_instances(const image_info& info)
{
matrix<rgb_pixel> instance_label_image;
matrix<rgb_pixel> class_label_image;
load_image(instance_label_image, info.instance_label_filename);
load_image(class_label_image, info.class_label_filename);
return rgb_label_images_to_truth_instances(instance_label_image, class_label_image);
};
std::vector<std::vector<truth_instance>> load_all_truth_instances(const std::vector<image_info>& listing)
{
std::vector<std::vector<truth_instance>> truth_instances(listing.size());
std::transform(
#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
std::execution::par,
#endif // __cplusplus >= 201703L
listing.begin(),
listing.end(),
truth_instances.begin(),
load_truth_instances
);
return truth_instances;
}
// ----------------------------------------------------------------------------------------
std::vector<truth_image> filter_based_on_classlabel(
const std::vector<truth_image>& truth_images,
const std::vector<std::string>& desired_classlabels
)
{
std::vector<truth_image> result;
const auto represents_desired_class = [&desired_classlabels](const truth_instance& truth_instance) {
return std::find(
desired_classlabels.begin(),
desired_classlabels.end(),
truth_instance.mmod_rect.label
) != desired_classlabels.end();
};
for (const auto& input : truth_images)
{
const auto has_desired_class = std::any_of(
input.truth_instances.begin(),
input.truth_instances.end(),
represents_desired_class
);
if (has_desired_class) {
// NB: This keeps only MMOD rects belonging to any of the desired classes.
// A reasonable alternative could be to keep all rects, but mark those
// belonging in other classes to be ignored during training.
std::vector<truth_instance> temp;
std::copy_if(
input.truth_instances.begin(),
input.truth_instances.end(),
std::back_inserter(temp),
represents_desired_class
);
result.push_back(truth_image{ input.info, temp });
}
}
return result;
}
// Ignore truth boxes that overlap too much, are too small, or have a large aspect ratio.
void ignore_some_truth_boxes(std::vector<truth_image>& truth_images)
{
for (auto& i : truth_images)
{
auto& truth_instances = i.truth_instances;
ignore_overlapped_boxes(truth_instances, test_box_overlap(0.90, 0.95));
for (auto& truth : truth_instances)
{
if (truth.mmod_rect.ignore)
continue;
const auto& rect = truth.mmod_rect.rect;
constexpr unsigned long min_width = 35;
constexpr unsigned long min_height = 35;
if (rect.width() < min_width && rect.height() < min_height)
{
truth.mmod_rect.ignore = true;
continue;
}
constexpr double max_aspect_ratio_width_to_height = 3.0;
constexpr double max_aspect_ratio_height_to_width = 1.5;
const double aspect_ratio_width_to_height = rect.width() / static_cast<double>(rect.height());
const double aspect_ratio_height_to_width = 1.0 / aspect_ratio_width_to_height;
const bool is_aspect_ratio_too_large
= aspect_ratio_width_to_height > max_aspect_ratio_width_to_height
|| aspect_ratio_height_to_width > max_aspect_ratio_height_to_width;
if (is_aspect_ratio_too_large)
truth.mmod_rect.ignore = true;
}
}
}
// Filter images that have no (non-ignored) truth
std::vector<truth_image> filter_images_with_no_truth(const std::vector<truth_image>& truth_images)
{
std::vector<truth_image> result;
for (const auto& truth_image : truth_images)
{
const auto ignored = [](const truth_instance& truth) { return truth.mmod_rect.ignore; };
const auto& truth_instances = truth_image.truth_instances;
if (!std::all_of(truth_instances.begin(), truth_instances.end(), ignored))
result.push_back(truth_image);
}
return result;
}
int main(int argc, char** argv) try
{
if (argc < 2)
{
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
cout << endl;
cout << "You call this program like this: " << endl;
cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-1] [class-2] [class-3] ..." << endl;
return 1;
}
cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl;
const auto listing = get_pascal_voc2012_train_listing(argv[1]);
cout << "images in entire dataset: " << listing.size() << endl;
if (listing.size() == 0)
{
cout << "Didn't find the VOC2012 dataset. " << endl;
return 1;
}
// mini-batches smaller than the default can be used with GPUs having less memory
const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 35;
const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 100;
cout << "det mini-batch size: " << det_minibatch_size << endl;
cout << "seg mini-batch size: " << seg_minibatch_size << endl;
std::vector<std::string> desired_classlabels;
for (int arg = 4; arg < argc; ++arg)
desired_classlabels.push_back(argv[arg]);
if (desired_classlabels.empty())
{
desired_classlabels.push_back("bicycle");
desired_classlabels.push_back("car");
desired_classlabels.push_back("cat");
}
cout << "desired classlabels:";
for (const auto& desired_classlabel : desired_classlabels)
cout << " " << desired_classlabel;
cout << endl;
// extract the MMOD rects
cout << endl << "Extracting all truth instances...";
const auto truth_instances = load_all_truth_instances(listing);
cout << " Done!" << endl << endl;
DLIB_CASSERT(listing.size() == truth_instances.size());
std::vector<truth_image> original_truth_images;
for (size_t i = 0, end = listing.size(); i < end; ++i)
{
original_truth_images.push_back(truth_image{
listing[i], truth_instances[i]
});
}
auto truth_images_filtered_by_class = filter_based_on_classlabel(original_truth_images, desired_classlabels);
cout << "images in dataset filtered by class: " << truth_images_filtered_by_class.size() << endl;
ignore_some_truth_boxes(truth_images_filtered_by_class);
const auto truth_images = filter_images_with_no_truth(truth_images_filtered_by_class);
cout << "images in dataset after ignoring some truth boxes: " << truth_images.size() << endl;
// First train an object detector network (loss_mmod).
cout << endl << "Training detector network:" << endl;
const auto det_net = train_detection_network(truth_images, det_minibatch_size);
// Then train mask predictors (segmentation).
std::map<std::string, seg_bnet_type> seg_nets_by_class;
// This flag controls if a separate mask predictor is trained for each class.
// Note that it would also be possible to train a separate mask predictor for
// class groups, each containing somehow similar classes -- for example, one
// mask predictor for cars and buses, another for cats and dogs, and so on.
constexpr bool separate_seg_net_for_each_class = true;
if (separate_seg_net_for_each_class)
{
for (const auto& classlabel : desired_classlabels)
{
// Consider only the truth images belonging to this class.
const auto class_images = filter_based_on_classlabel(truth_images, { classlabel });
cout << endl << "Training segmentation network for class " << classlabel << ":" << endl;
seg_nets_by_class[classlabel] = train_segmentation_network(class_images, seg_minibatch_size, classlabel);
}
}
else
{
cout << "Training a single segmentation network:" << endl;
seg_nets_by_class[""] = train_segmentation_network(truth_images, seg_minibatch_size, "");
}
cout << "Saving networks" << endl;
serialize(instance_segmentation_net_filename) << det_net << seg_nets_by_class;
}
catch(std::exception& e)
{
cout << e.what() << endl;
}

View File

@ -34,83 +34,7 @@
#define DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
#include <dlib/dnn.h>
// ----------------------------------------------------------------------------------------
inline bool operator == (const dlib::rgb_pixel& a, const dlib::rgb_pixel& b)
{
return a.red == b.red && a.green == b.green && a.blue == b.blue;
}
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network.
struct Voc2012class {
Voc2012class(uint16_t index, const dlib::rgb_pixel& rgb_label, const std::string& classlabel)
: index(index), rgb_label(rgb_label), classlabel(classlabel)
{}
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
const uint16_t index = 0;
// The corresponding RGB representation of the class.
const dlib::rgb_pixel rgb_label;
// The label of the class in plain text.
const std::string classlabel;
};
namespace {
constexpr int class_count = 21; // background + 20 classes
const std::vector<Voc2012class> classes = {
Voc2012class(0, dlib::rgb_pixel(0, 0, 0), ""), // background
// The cream-colored `void' label is used in border regions and to mask difficult objects
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
Voc2012class(dlib::loss_multiclass_log_per_pixel_::label_to_ignore,
dlib::rgb_pixel(224, 224, 192), "border"),
Voc2012class(1, dlib::rgb_pixel(128, 0, 0), "aeroplane"),
Voc2012class(2, dlib::rgb_pixel( 0, 128, 0), "bicycle"),
Voc2012class(3, dlib::rgb_pixel(128, 128, 0), "bird"),
Voc2012class(4, dlib::rgb_pixel( 0, 0, 128), "boat"),
Voc2012class(5, dlib::rgb_pixel(128, 0, 128), "bottle"),
Voc2012class(6, dlib::rgb_pixel( 0, 128, 128), "bus"),
Voc2012class(7, dlib::rgb_pixel(128, 128, 128), "car"),
Voc2012class(8, dlib::rgb_pixel( 64, 0, 0), "cat"),
Voc2012class(9, dlib::rgb_pixel(192, 0, 0), "chair"),
Voc2012class(10, dlib::rgb_pixel( 64, 128, 0), "cow"),
Voc2012class(11, dlib::rgb_pixel(192, 128, 0), "diningtable"),
Voc2012class(12, dlib::rgb_pixel( 64, 0, 128), "dog"),
Voc2012class(13, dlib::rgb_pixel(192, 0, 128), "horse"),
Voc2012class(14, dlib::rgb_pixel( 64, 128, 128), "motorbike"),
Voc2012class(15, dlib::rgb_pixel(192, 128, 128), "person"),
Voc2012class(16, dlib::rgb_pixel( 0, 64, 0), "pottedplant"),
Voc2012class(17, dlib::rgb_pixel(128, 64, 0), "sheep"),
Voc2012class(18, dlib::rgb_pixel( 0, 192, 0), "sofa"),
Voc2012class(19, dlib::rgb_pixel(128, 192, 0), "train"),
Voc2012class(20, dlib::rgb_pixel( 0, 64, 128), "tvmonitor"),
};
}
template <typename Predicate>
const Voc2012class& find_voc2012_class(Predicate predicate)
{
const auto i = std::find_if(classes.begin(), classes.end(), predicate);
if (i != classes.end())
{
return *i;
}
else
{
throw std::runtime_error("Unable to find a matching VOC2012 class");
}
}
#include "pascal_voc_2012.h"
// ----------------------------------------------------------------------------------------

View File

@ -91,107 +91,6 @@ void randomly_crop_image (
// ----------------------------------------------------------------------------------------
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
// data set.
struct image_info
{
string image_filename;
string label_filename;
};
// Read the list of image files belonging to either the "train", "trainval", or "val" set
// of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_listing(
const std::string& voc2012_folder,
const std::string& file = "train" // "train", "trainval", or "val"
)
{
std::ifstream in(voc2012_folder + "/ImageSets/Segmentation/" + file + ".txt");
std::vector<image_info> results;
while (in)
{
std::string basename;
in >> basename;
if (!basename.empty())
{
image_info image_info;
image_info.image_filename = voc2012_folder + "/JPEGImages/" + basename + ".jpg";
image_info.label_filename = voc2012_folder + "/SegmentationClass/" + basename + ".png";
results.push_back(image_info);
}
}
return results;
}
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_train_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "train");
}
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_val_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "val");
}
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network. To convert the ground-truth data to
// something that the network can efficiently digest, we need to be able to map the RGB
// values to the corresponding indexes.
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const Voc2012class& find_voc2012_class(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(
[&rgb_label](const Voc2012class& voc2012class)
{
return rgb_label == voc2012class.rgb_label;
}
);
}
// Convert an RGB class label to an index in the range [0, 20].
inline uint16_t rgb_label_to_index_label(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(rgb_label).index;
}
// Convert an image containing RGB class labels to a corresponding
// image containing indexes in the range [0, 20].
void rgb_label_image_to_index_label_image(
const dlib::matrix<dlib::rgb_pixel>& rgb_label_image,
dlib::matrix<uint16_t>& index_label_image
)
{
const long nr = rgb_label_image.nr();
const long nc = rgb_label_image.nc();
index_label_image.set_size(nr, nc);
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
index_label_image(r, c) = rgb_label_to_index_label(rgb_label_image(r, c));
}
}
}
// ----------------------------------------------------------------------------------------
// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
double calculate_accuracy(anet_type& anet, const std::vector<image_info>& dataset)
{
@ -209,14 +108,14 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase
load_image(input_image, image_info.image_filename);
// Load the ground-truth (RGB) labels.
load_image(rgb_label_image, image_info.label_filename);
load_image(rgb_label_image, image_info.class_label_filename);
// Create predictions for each pixel. At this point, the type of each prediction
// is an index (a value between 0 and 20). Note that the net may return an image
// that is not exactly the same size as the input.
const matrix<uint16_t> temp = anet(input_image);
// Convert the indexes to RGB values.
// Convert the RGB values to indexes.
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
// Crop the net output to be exactly the same size as the input.
@ -324,9 +223,9 @@ int main(int argc, char** argv) try
load_image(input_image, image_info.image_filename);
// Load the ground-truth (RGB) labels.
load_image(rgb_label_image, image_info.label_filename);
load_image(rgb_label_image, image_info.class_label_filename);
// Convert the indexes to RGB values.
// Convert the RGB values to indexes.
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
// Randomly pick a part of the image.

180
examples/pascal_voc_2012.h Normal file
View File

@ -0,0 +1,180 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
Helper definitions for working with the PASCAL VOC2012 dataset.
*/
#ifndef PASCAL_VOC_2012_H_
#define PASCAL_VOC_2012_H_
#include <dlib/pixel.h>
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network. To convert the ground-truth data to
// something that the network can efficiently digest, we need to be able to map the RGB
// values to the corresponding indexes.
struct Voc2012class {
Voc2012class(uint16_t index, const dlib::rgb_pixel& rgb_label, const std::string& classlabel)
: index(index), rgb_label(rgb_label), classlabel(classlabel)
{}
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
const uint16_t index = 0;
// The corresponding RGB representation of the class.
const dlib::rgb_pixel rgb_label;
// The label of the class in plain text.
const std::string classlabel;
};
namespace {
constexpr int class_count = 21; // background + 20 classes
const std::vector<Voc2012class> classes = {
Voc2012class(0, dlib::rgb_pixel(0, 0, 0), ""), // background
// The cream-colored `void' label is used in border regions and to mask difficult objects
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
Voc2012class(dlib::loss_multiclass_log_per_pixel_::label_to_ignore,
dlib::rgb_pixel(224, 224, 192), "border"),
Voc2012class(1, dlib::rgb_pixel(128, 0, 0), "aeroplane"),
Voc2012class(2, dlib::rgb_pixel( 0, 128, 0), "bicycle"),
Voc2012class(3, dlib::rgb_pixel(128, 128, 0), "bird"),
Voc2012class(4, dlib::rgb_pixel( 0, 0, 128), "boat"),
Voc2012class(5, dlib::rgb_pixel(128, 0, 128), "bottle"),
Voc2012class(6, dlib::rgb_pixel( 0, 128, 128), "bus"),
Voc2012class(7, dlib::rgb_pixel(128, 128, 128), "car"),
Voc2012class(8, dlib::rgb_pixel( 64, 0, 0), "cat"),
Voc2012class(9, dlib::rgb_pixel(192, 0, 0), "chair"),
Voc2012class(10, dlib::rgb_pixel( 64, 128, 0), "cow"),
Voc2012class(11, dlib::rgb_pixel(192, 128, 0), "diningtable"),
Voc2012class(12, dlib::rgb_pixel( 64, 0, 128), "dog"),
Voc2012class(13, dlib::rgb_pixel(192, 0, 128), "horse"),
Voc2012class(14, dlib::rgb_pixel( 64, 128, 128), "motorbike"),
Voc2012class(15, dlib::rgb_pixel(192, 128, 128), "person"),
Voc2012class(16, dlib::rgb_pixel( 0, 64, 0), "pottedplant"),
Voc2012class(17, dlib::rgb_pixel(128, 64, 0), "sheep"),
Voc2012class(18, dlib::rgb_pixel( 0, 192, 0), "sofa"),
Voc2012class(19, dlib::rgb_pixel(128, 192, 0), "train"),
Voc2012class(20, dlib::rgb_pixel( 0, 64, 128), "tvmonitor"),
};
}
template <typename Predicate>
const Voc2012class& find_voc2012_class(Predicate predicate)
{
const auto i = std::find_if(classes.begin(), classes.end(), predicate);
if (i != classes.end())
{
return *i;
}
else
{
throw std::runtime_error("Unable to find a matching VOC2012 class");
}
}
// ----------------------------------------------------------------------------------------
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
// data set.
struct image_info
{
std::string image_filename;
std::string class_label_filename;
std::string instance_label_filename;
};
// Read the list of image files belonging to either the "train", "trainval", or "val" set
// of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_listing(
const std::string& voc2012_folder,
const std::string& file = "train" // "train", "trainval", or "val"
)
{
std::ifstream in(voc2012_folder + "/ImageSets/Segmentation/" + file + ".txt");
std::vector<image_info> results;
while (in)
{
std::string basename;
in >> basename;
if (!basename.empty())
{
image_info info;
info.image_filename = voc2012_folder + "/JPEGImages/" + basename + ".jpg";
info.class_label_filename = voc2012_folder + "/SegmentationClass/" + basename + ".png";
info.instance_label_filename = voc2012_folder + "/SegmentationObject/" + basename + ".png";
results.push_back(info);
}
}
return results;
}
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_train_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "train");
}
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_val_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "val");
}
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const Voc2012class& find_voc2012_class(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(
[&rgb_label](const Voc2012class& voc2012class)
{
return rgb_label == voc2012class.rgb_label;
}
);
}
// ----------------------------------------------------------------------------------------
// Convert an RGB class label to an index in the range [0, 20].
inline uint16_t rgb_label_to_index_label(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(rgb_label).index;
}
// Convert an image containing RGB class labels to a corresponding
// image containing indexes in the range [0, 20].
void rgb_label_image_to_index_label_image(
const dlib::matrix<dlib::rgb_pixel>& rgb_label_image,
dlib::matrix<uint16_t>& index_label_image
)
{
const long nr = rgb_label_image.nr();
const long nc = rgb_label_image.nc();
index_label_image.set_size(nr, nc);
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
index_label_image(r, c) = rgb_label_to_index_label(rgb_label_image(r, c));
}
}
}
#endif // PASCAL_VOC_2012_H_