mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Instance segmentation (#1918)
* Add instance segmentation example - first version of training code * Add MMOD options; get rid of the cache approach, and instead load all MMOD rects upfront * Improve console output * Set filter count * Minor tweaking * Inference - first version, at least compiles! * Ignore overlapped boxes * Ignore even small instances * Set overlaps_ignore * Add TODO remarks * Revert "Set overlaps_ignore" This reverts commit65adeff1f8
. * Set result size * Set label image size * Take ignore-color into account * Fix the cropping rect's aspect ratio; also slightly expand the rect * Draw the largest findings last * Improve masking of the current instance * Add some perturbation to the inputs * Simplify ground-truth reading; fix random cropping * Read even class labels * Tweak default minibatch size * Learn only one class * Really train only instances of the selected class * Remove outdated TODO remark * Automatically skip images with no detections * Print to console what was found * Fix class index problem * Fix indentation * Allow to choose multiple classes * Draw rect in the color of the corresponding class * Write detector window classes to ostream; also group detection windows by class (when ostreaming) * Train a separate instance segmentation network for each classlabel * Use separate synchronization file for each seg net of each class * Allow more overlap * Fix sorting criterion * Fix interpolating the predicted mask * Improve bilinear interpolation: if output type is an integer, round instead of truncating * Add helpful comments * Ignore large aspect ratios; refactor the code; tweak some network parameters * Simplify the segmentation network structure; make the object detection network more complex in turn * Problem: CUDA errors not reported properly to console Solution: stop and join data loader threads even in case of exceptions * Minor parameters tweaking * Loss may have increased, even if prob_loss_increasing_thresh > prob_loss_increasing_thresh_max_value * Add previous_loss_values_dump_amount to previous_loss_values.size() when deciding if loss has been increasing * Improve behaviour when loss actually increased after disk sync * Revert some of the earlier change * Disregard dumped loss values only when deciding if learning rate should be shrunk, but *not* when deciding if loss has been going up since last disk sync * Revert "Revert some of the earlier change" This reverts commit6c852124ef
. * Keep enough previous loss values, until the disk sync * Fix maintaining the dumped (now "effectively disregarded") loss values count * Detect cats instead of aeroplanes * Add helpful logging * Clarify the intention and the code * Review fixes * Add operator== for the other pixel types as well; remove the inline * If available, use constexpr if * Revert "If available, use constexpr if" This reverts commit503d4dd335
. * Simplify code as per review comments * Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh * Clarify console output * Revert "Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh" This reverts commit9191ebc776
. * To keep the changes to a bare minimum, revert the steps_since_last_learning_rate_shrink change after all (at least for now) * Even empty out some of the previous test loss values * Minor review fixes * Can't use C++14 features here * Do not use the struct name as a variable name
This commit is contained in:
parent
7f2be82e33
commit
d175c35074
@ -993,6 +993,36 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const std::vector<mmod_options::detector_window_details>& detector_windows)
|
||||
{
|
||||
// write detector windows grouped by label
|
||||
// example output: aeroplane:74x30,131x30,70x45,54x70,198x30;bicycle:70x57,32x70,70x32,51x70,128x30,30x121;car:70x36,70x60,99x30,52x70,30x83,30x114,30x200
|
||||
|
||||
std::map<std::string, std::deque<mmod_options::detector_window_details>> detector_windows_by_label;
|
||||
for (const auto& detector_window : detector_windows)
|
||||
detector_windows_by_label[detector_window.label].push_back(detector_window);
|
||||
|
||||
size_t label_count = 0;
|
||||
for (const auto& i : detector_windows_by_label)
|
||||
{
|
||||
const auto& label = i.first;
|
||||
const auto& detector_windows = i.second;
|
||||
|
||||
if (label_count++ > 0)
|
||||
out << ";";
|
||||
out << label << ":";
|
||||
|
||||
for (size_t j = 0; j < detector_windows.size(); ++j)
|
||||
{
|
||||
out << detector_windows[j].width << "x" << detector_windows[j].height;
|
||||
if (j + 1 < detector_windows.size())
|
||||
out << ",";
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class loss_mmod_
|
||||
@ -1395,15 +1425,10 @@ namespace dlib
|
||||
{
|
||||
out << "loss_mmod\t (";
|
||||
|
||||
out << "detector_windows:(";
|
||||
auto& opts = item.options;
|
||||
for (size_t i = 0; i < opts.detector_windows.size(); ++i)
|
||||
{
|
||||
out << opts.detector_windows[i].width << "x" << opts.detector_windows[i].height;
|
||||
if (i+1 < opts.detector_windows.size())
|
||||
out << ",";
|
||||
}
|
||||
out << ")";
|
||||
|
||||
out << "detector_windows:(" << opts.detector_windows << ")";
|
||||
|
||||
out << ", loss per FA:" << opts.loss_per_false_alarm;
|
||||
out << ", loss per miss:" << opts.loss_per_missed_target;
|
||||
out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold;
|
||||
|
@ -460,6 +460,7 @@ namespace dlib
|
||||
test_steps_without_progress = 0;
|
||||
previous_loss_values.clear();
|
||||
test_previous_loss_values.clear();
|
||||
previous_loss_values_to_keep_until_disk_sync.clear();
|
||||
}
|
||||
learning_rate = lr;
|
||||
lr_schedule.set_size(0);
|
||||
@ -602,6 +603,11 @@ namespace dlib
|
||||
// discard really old loss values.
|
||||
while (previous_loss_values.size() > iter_without_progress_thresh)
|
||||
previous_loss_values.pop_front();
|
||||
|
||||
// separately keep another loss history until disk sync
|
||||
// (but only if disk sync is enabled)
|
||||
if (!sync_filename.empty())
|
||||
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -700,10 +706,10 @@ namespace dlib
|
||||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
test_steps_without_progress = 0;
|
||||
|
||||
// Empty out some of the previous loss values so that test_steps_without_progress
|
||||
// will decrease below test_iter_without_progress_thresh.
|
||||
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
|
||||
test_previous_loss_values.pop_front();
|
||||
drop_some_test_previous_loss_values();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -820,10 +826,10 @@ namespace dlib
|
||||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
steps_without_progress = 0;
|
||||
|
||||
// Empty out some of the previous loss values so that steps_without_progress
|
||||
// will decrease below iter_without_progress_thresh.
|
||||
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
|
||||
previous_loss_values.pop_front();
|
||||
drop_some_previous_loss_values();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -895,7 +901,7 @@ namespace dlib
|
||||
friend void serialize(const dnn_trainer& item, std::ostream& out)
|
||||
{
|
||||
item.wait_for_thread_to_pause();
|
||||
int version = 12;
|
||||
int version = 13;
|
||||
serialize(version, out);
|
||||
|
||||
size_t nl = dnn_trainer::num_layers;
|
||||
@ -924,14 +930,14 @@ namespace dlib
|
||||
serialize(item.test_previous_loss_values, out);
|
||||
serialize(item.previous_loss_values_dump_amount, out);
|
||||
serialize(item.test_previous_loss_values_dump_amount, out);
|
||||
|
||||
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
|
||||
}
|
||||
friend void deserialize(dnn_trainer& item, std::istream& in)
|
||||
{
|
||||
item.wait_for_thread_to_pause();
|
||||
int version = 0;
|
||||
deserialize(version, in);
|
||||
if (version != 12)
|
||||
if (version != 13)
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
|
||||
|
||||
size_t num_layers = 0;
|
||||
@ -970,6 +976,7 @@ namespace dlib
|
||||
deserialize(item.test_previous_loss_values, in);
|
||||
deserialize(item.previous_loss_values_dump_amount, in);
|
||||
deserialize(item.test_previous_loss_values_dump_amount, in);
|
||||
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
|
||||
|
||||
if (item.devices.size() > 1)
|
||||
{
|
||||
@ -987,6 +994,20 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
// Empty out some of the previous loss values so that steps_without_progress will decrease below iter_without_progress_thresh.
|
||||
void drop_some_previous_loss_values()
|
||||
{
|
||||
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount + iter_without_progress_thresh / 10 && previous_loss_values.size() > 0; ++cnt)
|
||||
previous_loss_values.pop_front();
|
||||
}
|
||||
|
||||
// Empty out some of the previous test loss values so that test_steps_without_progress will decrease below test_iter_without_progress_thresh.
|
||||
void drop_some_test_previous_loss_values()
|
||||
{
|
||||
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10 && test_previous_loss_values.size() > 0; ++cnt)
|
||||
test_previous_loss_values.pop_front();
|
||||
}
|
||||
|
||||
void sync_to_disk (
|
||||
bool do_it_now = false
|
||||
)
|
||||
@ -1020,6 +1041,20 @@ namespace dlib
|
||||
sync_file_reloaded = true;
|
||||
if (verbose)
|
||||
std::cout << "Loss has been increasing, reloading saved state from " << newest_syncfile() << std::endl;
|
||||
|
||||
// Are we repeatedly hitting our head against the wall? If so, then we
|
||||
// might be better off giving up at this learning rate, and trying a
|
||||
// lower one instead.
|
||||
if (prob_loss_increasing_thresh >= prob_loss_increasing_thresh_max_value)
|
||||
{
|
||||
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
|
||||
learning_rate = learning_rate_shrink * learning_rate;
|
||||
steps_without_progress = 0;
|
||||
test_steps_without_progress = 0;
|
||||
|
||||
drop_some_previous_loss_values();
|
||||
drop_some_test_previous_loss_values();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1057,34 +1092,39 @@ namespace dlib
|
||||
if (!std::ifstream(newest_syncfile(), std::ios::binary))
|
||||
return false;
|
||||
|
||||
for (auto x : previous_loss_values)
|
||||
// Now look at the data since a little before the last disk sync. We will
|
||||
// check if the loss is getting better or worse.
|
||||
while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
|
||||
previous_loss_values_to_keep_until_disk_sync.pop_front();
|
||||
|
||||
running_gradient g;
|
||||
|
||||
for (auto x : previous_loss_values_to_keep_until_disk_sync)
|
||||
{
|
||||
// If we get a NaN value of loss assume things have gone horribly wrong and
|
||||
// we should reload the state of the trainer.
|
||||
if (std::isnan(x))
|
||||
return true;
|
||||
|
||||
g.add(x);
|
||||
}
|
||||
|
||||
// if we haven't seen much data yet then just say false. Or, alternatively, if
|
||||
// it's been too long since the last sync then don't reload either.
|
||||
if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync)
|
||||
// if we haven't seen much data yet then just say false.
|
||||
if (gradient_updates_since_last_sync < 30)
|
||||
return false;
|
||||
|
||||
// Now look at the data since a little before the last disk sync. We will
|
||||
// check if the loss is getting bettor or worse.
|
||||
running_gradient g;
|
||||
for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i)
|
||||
g.add(previous_loss_values[i]);
|
||||
|
||||
// if the loss is very likely to be increasing then return true
|
||||
const double prob = g.probability_gradient_greater_than(0);
|
||||
if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value)
|
||||
if (prob > prob_loss_increasing_thresh)
|
||||
{
|
||||
// Exponentially decay the threshold towards 1 so that if we keep finding
|
||||
// the loss to be increasing over and over we will make the test
|
||||
// progressively harder and harder until it fails, therefore ensuring we
|
||||
// can't get stuck reloading from a previous state over and over.
|
||||
prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1;
|
||||
prob_loss_increasing_thresh = std::min(
|
||||
0.1*prob_loss_increasing_thresh + 0.9*1,
|
||||
prob_loss_increasing_thresh_max_value
|
||||
);
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@ -1247,6 +1287,8 @@ namespace dlib
|
||||
std::atomic<unsigned long> test_steps_without_progress;
|
||||
std::deque<double> test_previous_loss_values;
|
||||
|
||||
std::deque<double> previous_loss_values_to_keep_until_disk_sync;
|
||||
|
||||
std::atomic<double> learning_rate_shrink;
|
||||
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
||||
std::string sync_filename;
|
||||
|
@ -865,10 +865,18 @@ namespace dlib
|
||||
float fout[4];
|
||||
out.store(fout);
|
||||
|
||||
out_img[r][c] = static_cast<T>(fout[0]);
|
||||
out_img[r][c+1] = static_cast<T>(fout[1]);
|
||||
out_img[r][c+2] = static_cast<T>(fout[2]);
|
||||
out_img[r][c+3] = static_cast<T>(fout[3]);
|
||||
const auto convert_to_output_type = [](float value)
|
||||
{
|
||||
if (std::is_integral<T>::value)
|
||||
return static_cast<T>(value + 0.5);
|
||||
else
|
||||
return static_cast<T>(value);
|
||||
};
|
||||
|
||||
out_img[r][c] = convert_to_output_type(fout[0]);
|
||||
out_img[r][c+1] = convert_to_output_type(fout[1]);
|
||||
out_img[r][c+2] = convert_to_output_type(fout[2]);
|
||||
out_img[r][c+3] = convert_to_output_type(fout[3]);
|
||||
}
|
||||
x = -x_scale + c*x_scale;
|
||||
for (; c < out_img.nc(); ++c)
|
||||
|
36
dlib/pixel.h
36
dlib/pixel.h
@ -125,6 +125,13 @@ namespace dlib
|
||||
unsigned char red;
|
||||
unsigned char green;
|
||||
unsigned char blue;
|
||||
|
||||
bool operator == (const rgb_pixel& that) const
|
||||
{
|
||||
return this->red == that.red
|
||||
&& this->green == that.green
|
||||
&& this->blue == that.blue;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -151,6 +158,13 @@ namespace dlib
|
||||
unsigned char blue;
|
||||
unsigned char green;
|
||||
unsigned char red;
|
||||
|
||||
bool operator == (const bgr_pixel& that) const
|
||||
{
|
||||
return this->blue == that.blue
|
||||
&& this->green == that.green
|
||||
&& this->red == that.red;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -177,6 +191,14 @@ namespace dlib
|
||||
unsigned char green;
|
||||
unsigned char blue;
|
||||
unsigned char alpha;
|
||||
|
||||
bool operator == (const rgb_alpha_pixel& that) const
|
||||
{
|
||||
return this->red == that.red
|
||||
&& this->green == that.green
|
||||
&& this->blue == that.blue
|
||||
&& this->alpha == that.alpha;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -200,6 +222,13 @@ namespace dlib
|
||||
unsigned char h;
|
||||
unsigned char s;
|
||||
unsigned char i;
|
||||
|
||||
bool operator == (const hsi_pixel& that) const
|
||||
{
|
||||
return this->h == that.h
|
||||
&& this->s == that.s
|
||||
&& this->i == that.i;
|
||||
}
|
||||
};
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
@ -222,6 +251,13 @@ namespace dlib
|
||||
unsigned char l;
|
||||
unsigned char a;
|
||||
unsigned char b;
|
||||
|
||||
bool operator == (const lab_pixel& that) const
|
||||
{
|
||||
return this->l == that.l
|
||||
&& this->a == that.a
|
||||
&& this->b == that.b;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -148,8 +148,10 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
|
||||
add_gui_example(dnn_mmod_find_cars2_ex)
|
||||
add_example(dnn_mmod_train_find_cars_ex)
|
||||
add_gui_example(dnn_semantic_segmentation_ex)
|
||||
add_gui_example(dnn_instance_segmentation_ex)
|
||||
add_example(dnn_imagenet_train_ex)
|
||||
add_example(dnn_semantic_segmentation_train_ex)
|
||||
add_example(dnn_instance_segmentation_train_ex)
|
||||
add_example(dnn_metric_learning_on_images_ex)
|
||||
endif()
|
||||
|
||||
|
178
examples/dnn_instance_segmentation_ex.cpp
Normal file
178
examples/dnn_instance_segmentation_ex.cpp
Normal file
@ -0,0 +1,178 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
This example shows how to do instance segmentation on an image using net pretrained
|
||||
on the PASCAL VOC2012 dataset. For an introduction to what instance segmentation is,
|
||||
see the accompanying header file dnn_instance_segmentation_ex.h.
|
||||
|
||||
Instructions how to run the example:
|
||||
1. Download the PASCAL VOC2012 data, and untar it somewhere.
|
||||
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
|
||||
2. Build the dnn_instance_segmentation_train_ex example program.
|
||||
3. Run:
|
||||
./dnn_instance_segmentation_train_ex /path/to/VOC2012
|
||||
4. Wait while the network is being trained.
|
||||
5. Build the dnn_instance_segmentation_ex example program.
|
||||
6. Run:
|
||||
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
|
||||
|
||||
An alternative to steps 2-4 above is to download a pre-trained network
|
||||
from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn
|
||||
|
||||
It would be a good idea to become familiar with dlib's DNN tooling before reading this
|
||||
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
|
||||
before reading this example program.
|
||||
*/
|
||||
|
||||
#include "dnn_instance_segmentation_ex.h"
|
||||
#include "pascal_voc_2012.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/gui_widgets.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main(int argc, char** argv) try
|
||||
{
|
||||
if (argc != 2)
|
||||
{
|
||||
cout << "You call this program like this: " << endl;
|
||||
cout << "./dnn_instance_segmentation_train_ex /path/to/images" << endl;
|
||||
cout << endl;
|
||||
cout << "You will also need a trained '" << instance_segmentation_net_filename << "' file." << endl;
|
||||
cout << "You can either train it yourself (see example program" << endl;
|
||||
cout << "dnn_instance_segmentation_train_ex), or download a" << endl;
|
||||
cout << "copy from here: http://dlib.net/files/" << instance_segmentation_net_filename << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Read the file containing the trained networks from the working directory.
|
||||
det_anet_type det_net;
|
||||
std::map<std::string, seg_bnet_type> seg_nets_by_class;
|
||||
deserialize(instance_segmentation_net_filename) >> det_net >> seg_nets_by_class;
|
||||
|
||||
// Show inference results in a window.
|
||||
image_window win;
|
||||
|
||||
matrix<rgb_pixel> input_image;
|
||||
|
||||
// Find supported image files.
|
||||
const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1],
|
||||
dlib::match_endings(".jpeg .jpg .png"));
|
||||
|
||||
dlib::rand rnd;
|
||||
|
||||
cout << "Found " << files.size() << " images, processing..." << endl;
|
||||
|
||||
for (const file& file : files)
|
||||
{
|
||||
// Load the input image.
|
||||
load_image(input_image, file.full_name());
|
||||
|
||||
// Draw largest objects last
|
||||
const auto sort_instances = [](const std::vector<mmod_rect>& input) {
|
||||
auto output = input;
|
||||
const auto compare_area = [](const mmod_rect& lhs, const mmod_rect& rhs) {
|
||||
return lhs.rect.area() < rhs.rect.area();
|
||||
};
|
||||
std::sort(output.begin(), output.end(), compare_area);
|
||||
return output;
|
||||
};
|
||||
|
||||
// Find instances in the input image
|
||||
const auto instances = sort_instances(det_net(input_image));
|
||||
|
||||
matrix<rgb_pixel> rgb_label_image;
|
||||
matrix<rgb_pixel> input_chip;
|
||||
|
||||
rgb_label_image.set_size(input_image.nr(), input_image.nc());
|
||||
rgb_label_image = rgb_pixel(0, 0, 0);
|
||||
|
||||
bool found_something = false;
|
||||
|
||||
for (const auto& instance : instances)
|
||||
{
|
||||
if (!found_something)
|
||||
{
|
||||
cout << "Found ";
|
||||
found_something = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
cout << ", ";
|
||||
}
|
||||
cout << instance.label;
|
||||
|
||||
const auto cropping_rect = get_cropping_rect(instance.rect);
|
||||
const chip_details chip_details(cropping_rect, chip_dims(seg_dim, seg_dim));
|
||||
extract_image_chip(input_image, chip_details, input_chip, interpolate_bilinear());
|
||||
|
||||
const auto i = seg_nets_by_class.find(instance.label);
|
||||
if (i == seg_nets_by_class.end())
|
||||
{
|
||||
// per-class segmentation net not found, so we must be using the same net for all classes
|
||||
// (see bool separate_seg_net_for_each_class in dnn_instance_segmentation_train_ex.cpp)
|
||||
DLIB_CASSERT(seg_nets_by_class.size() == 1);
|
||||
DLIB_CASSERT(seg_nets_by_class.begin()->first == "");
|
||||
}
|
||||
|
||||
auto& seg_net = i != seg_nets_by_class.end()
|
||||
? i->second // use the segmentation net trained for this class
|
||||
: seg_nets_by_class.begin()->second; // use the same segmentation net for all classes
|
||||
|
||||
const auto mask = seg_net(input_chip);
|
||||
|
||||
const rgb_pixel random_color(
|
||||
rnd.get_random_8bit_number(),
|
||||
rnd.get_random_8bit_number(),
|
||||
rnd.get_random_8bit_number()
|
||||
);
|
||||
|
||||
dlib::matrix<uint16_t> resized_mask(
|
||||
static_cast<int>(chip_details.rect.height()),
|
||||
static_cast<int>(chip_details.rect.width())
|
||||
);
|
||||
|
||||
dlib::resize_image(mask, resized_mask);
|
||||
|
||||
for (int r = 0; r < resized_mask.nr(); ++r)
|
||||
{
|
||||
for (int c = 0; c < resized_mask.nc(); ++c)
|
||||
{
|
||||
if (resized_mask(r, c))
|
||||
{
|
||||
const auto y = chip_details.rect.top() + r;
|
||||
const auto x = chip_details.rect.left() + c;
|
||||
if (y >= 0 && y < rgb_label_image.nr() && x >= 0 && x < rgb_label_image.nc())
|
||||
rgb_label_image(y, x) = random_color;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Voc2012class& voc2012_class = find_voc2012_class(
|
||||
[&instance](const Voc2012class& candidate) {
|
||||
return candidate.classlabel == instance.label;
|
||||
}
|
||||
);
|
||||
|
||||
dlib::draw_rectangle(rgb_label_image, instance.rect, voc2012_class.rgb_label, 1);
|
||||
}
|
||||
|
||||
// Show the input image on the left, and the predicted RGB labels on the right.
|
||||
win.set_image(join_rows(input_image, rgb_label_image));
|
||||
|
||||
if (!instances.empty())
|
||||
{
|
||||
cout << " in " << file.name() << " - hit enter to process the next image";
|
||||
cin.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
catch(std::exception& e)
|
||||
{
|
||||
cout << e.what() << endl;
|
||||
}
|
||||
|
200
examples/dnn_instance_segmentation_ex.h
Normal file
200
examples/dnn_instance_segmentation_ex.h
Normal file
@ -0,0 +1,200 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
Instance segmentation using the PASCAL VOC2012 dataset.
|
||||
|
||||
Instance segmentation sort-of combines object detection with semantic
|
||||
segmentation. While each dog, for example, is detected separately,
|
||||
the output is not only a bounding-box but a more accurate, per-pixel
|
||||
mask.
|
||||
|
||||
For introductions to object detection and semantic segmentation, you
|
||||
can have a look at dnn_mmod_ex.cpp and dnn_semantic_segmentation.h,
|
||||
respectively.
|
||||
|
||||
Instructions how to run the example:
|
||||
1. Download the PASCAL VOC2012 data, and untar it somewhere.
|
||||
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
|
||||
2. Build the dnn_instance_segmentation_train_ex example program.
|
||||
3. Run:
|
||||
./dnn_instance_segmentation_train_ex /path/to/VOC2012
|
||||
4. Wait while the network is being trained.
|
||||
5. Build the dnn_instance_segmentation_ex example program.
|
||||
6. Run:
|
||||
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
|
||||
|
||||
An alternative to steps 2-4 above is to download a pre-trained network
|
||||
from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn
|
||||
|
||||
It would be a good idea to become familiar with dlib's DNN tooling before reading this
|
||||
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
|
||||
before reading this example program.
|
||||
*/
|
||||
|
||||
#ifndef DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
|
||||
#define DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
|
||||
|
||||
#include <dlib/dnn.h>
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
// Segmentation will be performed using patches having this size.
|
||||
constexpr int seg_dim = 227;
|
||||
}
|
||||
|
||||
dlib::rectangle get_cropping_rect(const dlib::rectangle& rectangle)
|
||||
{
|
||||
DLIB_ASSERT(!rectangle.is_empty());
|
||||
|
||||
const auto center_point = dlib::center(rectangle);
|
||||
const auto max_dim = std::max(rectangle.width(), rectangle.height());
|
||||
const auto d = static_cast<long>(std::round(max_dim / 2.0 * 1.5)); // add +50%
|
||||
|
||||
return dlib::rectangle(
|
||||
center_point.x() - d,
|
||||
center_point.y() - d,
|
||||
center_point.x() + d,
|
||||
center_point.y() + d
|
||||
);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// The object detection network.
|
||||
// Adapted from dnn_mmod_train_find_cars_ex.cpp and friends.
|
||||
|
||||
template <long num_filters, typename SUBNET> using con5d = dlib::con<num_filters,5,5,2,2,SUBNET>;
|
||||
template <long num_filters, typename SUBNET> using con5 = dlib::con<num_filters,5,5,1,1,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using bdownsampler = dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<32,SUBNET>>>>>>>>>;
|
||||
template <typename SUBNET> using adownsampler = dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<32,SUBNET>>>>>>>>>;
|
||||
|
||||
template <typename SUBNET> using brcon5 = dlib::relu<dlib::bn_con<con5<256,SUBNET>>>;
|
||||
template <typename SUBNET> using arcon5 = dlib::relu<dlib::affine<con5<256,SUBNET>>>;
|
||||
|
||||
using det_bnet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,brcon5<brcon5<brcon5<bdownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
|
||||
using det_anet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,arcon5<arcon5<arcon5<adownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
|
||||
|
||||
// The segmentation network.
|
||||
// For the time being, this is very much copy-paste from dnn_semantic_segmentation.h, although the network is made narrower (smaller feature maps).
|
||||
|
||||
template <int N, template <typename> class BN, int stride, typename SUBNET>
|
||||
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
|
||||
|
||||
template <int N, template <typename> class BN, int stride, typename SUBNET>
|
||||
using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>;
|
||||
|
||||
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
|
||||
using residual = dlib::add_prev1<block<N,BN,1,dlib::tag1<SUBNET>>>;
|
||||
|
||||
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
|
||||
using residual_down = dlib::add_prev2<dlib::avg_pool<2,2,2,2,dlib::skip1<dlib::tag2<block<N,BN,2,dlib::tag1<SUBNET>>>>>>;
|
||||
|
||||
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
|
||||
using residual_up = dlib::add_prev2<dlib::cont<N,2,2,2,2,dlib::skip1<dlib::tag2<blockt<N,BN,2,dlib::tag1<SUBNET>>>>>>;
|
||||
|
||||
template <int N, typename SUBNET> using res = dlib::relu<residual<block,N,dlib::bn_con,SUBNET>>;
|
||||
template <int N, typename SUBNET> using ares = dlib::relu<residual<block,N,dlib::affine,SUBNET>>;
|
||||
template <int N, typename SUBNET> using res_down = dlib::relu<residual_down<block,N,dlib::bn_con,SUBNET>>;
|
||||
template <int N, typename SUBNET> using ares_down = dlib::relu<residual_down<block,N,dlib::affine,SUBNET>>;
|
||||
template <int N, typename SUBNET> using res_up = dlib::relu<residual_up<block,N,dlib::bn_con,SUBNET>>;
|
||||
template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block,N,dlib::affine,SUBNET>>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename SUBNET> using res16 = res<16,SUBNET>;
|
||||
template <typename SUBNET> using res24 = res<24,SUBNET>;
|
||||
template <typename SUBNET> using res32 = res<32,SUBNET>;
|
||||
template <typename SUBNET> using res48 = res<48,SUBNET>;
|
||||
template <typename SUBNET> using ares16 = ares<16,SUBNET>;
|
||||
template <typename SUBNET> using ares24 = ares<24,SUBNET>;
|
||||
template <typename SUBNET> using ares32 = ares<32,SUBNET>;
|
||||
template <typename SUBNET> using ares48 = ares<48,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using level1 = dlib::repeat<2,res16,res<16,SUBNET>>;
|
||||
template <typename SUBNET> using level2 = dlib::repeat<2,res24,res_down<24,SUBNET>>;
|
||||
template <typename SUBNET> using level3 = dlib::repeat<2,res32,res_down<32,SUBNET>>;
|
||||
template <typename SUBNET> using level4 = dlib::repeat<2,res48,res_down<48,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares16,ares<16,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares24,ares_down<24,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares32,ares_down<32,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares48,ares_down<48,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using level1t = dlib::repeat<2,res16,res_up<16,SUBNET>>;
|
||||
template <typename SUBNET> using level2t = dlib::repeat<2,res24,res_up<24,SUBNET>>;
|
||||
template <typename SUBNET> using level3t = dlib::repeat<2,res32,res_up<32,SUBNET>>;
|
||||
template <typename SUBNET> using level4t = dlib::repeat<2,res48,res_up<48,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares16,ares_up<16,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares24,ares_up<24,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares32,ares_up<32,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares48,ares_up<48,SUBNET>>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
template<typename> class TAGGED,
|
||||
template<typename> class PREV_RESIZED,
|
||||
typename SUBNET
|
||||
>
|
||||
using resize_and_concat = dlib::add_layer<
|
||||
dlib::concat_<TAGGED,PREV_RESIZED>,
|
||||
PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
|
||||
|
||||
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
|
||||
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
|
||||
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
|
||||
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
|
||||
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
|
||||
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
|
||||
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
|
||||
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
|
||||
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
|
||||
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
static const char* instance_segmentation_net_filename = "instance_segmentation_voc2012net.dnn";
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// training network type
|
||||
using seg_bnet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<2,1,1,1,1,
|
||||
dlib::relu<dlib::bn_con<dlib::cont<16,7,7,2,2,
|
||||
concat_utag1<level1t<
|
||||
concat_utag2<level2t<
|
||||
concat_utag3<level3t<
|
||||
concat_utag4<level4t<
|
||||
level4<utag4<
|
||||
level3<utag3<
|
||||
level2<utag2<
|
||||
level1<dlib::max_pool<3,3,2,2,utag1<
|
||||
dlib::relu<dlib::bn_con<dlib::con<16,7,7,2,2,
|
||||
dlib::input<dlib::matrix<dlib::rgb_pixel>>
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>>;
|
||||
|
||||
// testing network type (replaced batch normalization with fixed affine transforms)
|
||||
using seg_anet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<2,1,1,1,1,
|
||||
dlib::relu<dlib::affine<dlib::cont<16,7,7,2,2,
|
||||
concat_utag1<alevel1t<
|
||||
concat_utag2<alevel2t<
|
||||
concat_utag3<alevel3t<
|
||||
concat_utag4<alevel4t<
|
||||
alevel4<utag4<
|
||||
alevel3<utag3<
|
||||
alevel2<utag2<
|
||||
alevel1<dlib::max_pool<3,3,2,2,utag1<
|
||||
dlib::relu<dlib::affine<dlib::con<16,7,7,2,2,
|
||||
dlib::input<dlib::matrix<dlib::rgb_pixel>>
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
#endif // DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
|
776
examples/dnn_instance_segmentation_train_ex.cpp
Normal file
776
examples/dnn_instance_segmentation_train_ex.cpp
Normal file
@ -0,0 +1,776 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
This example shows how to train a instance segmentation net using the PASCAL VOC2012
|
||||
dataset. For an introduction to what segmentation is, see the accompanying header file
|
||||
dnn_instance_segmentation_ex.h.
|
||||
|
||||
Instructions how to run the example:
|
||||
1. Download the PASCAL VOC2012 data, and untar it somewhere.
|
||||
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
|
||||
2. Build the dnn_instance_segmentation_train_ex example program.
|
||||
3. Run:
|
||||
./dnn_instance_segmentation_train_ex /path/to/VOC2012
|
||||
4. Wait while the network is being trained.
|
||||
5. Build the dnn_instance_segmentation_ex example program.
|
||||
6. Run:
|
||||
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
|
||||
|
||||
It would be a good idea to become familiar with dlib's DNN tooling before reading this
|
||||
example. So you should read dnn_introduction_ex.cpp, dnn_introduction2_ex.cpp,
|
||||
and dnn_semantic_segmentation_train_ex.cpp before reading this example program.
|
||||
*/
|
||||
|
||||
#include "dnn_instance_segmentation_ex.h"
|
||||
#include "pascal_voc_2012.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/image_transforms.h>
|
||||
#include <dlib/dir_nav.h>
|
||||
#include <iterator>
|
||||
#include <thread>
|
||||
#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
|
||||
#include <execution>
|
||||
#endif // __cplusplus >= 201703L
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// A single training sample for detection. A mini-batch comprises many of these.
|
||||
struct det_training_sample
|
||||
{
|
||||
matrix<rgb_pixel> input_image;
|
||||
std::vector<dlib::mmod_rect> mmod_rects;
|
||||
};
|
||||
|
||||
// A single training sample for segmentation. A mini-batch comprises many of these.
|
||||
struct seg_training_sample
|
||||
{
|
||||
matrix<rgb_pixel> input_image;
|
||||
matrix<uint16_t> label_image; // The ground-truth label of each pixel.
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
bool is_instance_pixel(const dlib::rgb_pixel& rgb_label)
|
||||
{
|
||||
if (rgb_label == dlib::rgb_pixel(0, 0, 0))
|
||||
return false; // Background
|
||||
if (rgb_label == dlib::rgb_pixel(224, 224, 192))
|
||||
return false; // The cream-colored `void' label is used in border regions and to mask difficult objects
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Provide hash function for dlib::rgb_pixel
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<dlib::rgb_pixel>
|
||||
{
|
||||
std::size_t operator()(const dlib::rgb_pixel& p) const
|
||||
{
|
||||
return (static_cast<uint32_t>(p.red) << 16)
|
||||
| (static_cast<uint32_t>(p.green) << 8)
|
||||
| (static_cast<uint32_t>(p.blue));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
struct truth_instance
|
||||
{
|
||||
dlib::rgb_pixel rgb_label;
|
||||
dlib::mmod_rect mmod_rect;
|
||||
};
|
||||
|
||||
std::vector<truth_instance> rgb_label_images_to_truth_instances(
|
||||
const dlib::matrix<dlib::rgb_pixel>& instance_label_image,
|
||||
const dlib::matrix<dlib::rgb_pixel>& class_label_image
|
||||
)
|
||||
{
|
||||
std::unordered_map<dlib::rgb_pixel, mmod_rect> result_map;
|
||||
|
||||
DLIB_CASSERT(instance_label_image.nr() == class_label_image.nr());
|
||||
DLIB_CASSERT(instance_label_image.nc() == class_label_image.nc());
|
||||
|
||||
const auto nr = instance_label_image.nr();
|
||||
const auto nc = instance_label_image.nc();
|
||||
|
||||
for (int r = 0; r < nr; ++r)
|
||||
{
|
||||
for (int c = 0; c < nc; ++c)
|
||||
{
|
||||
const auto rgb_instance_label = instance_label_image(r, c);
|
||||
|
||||
if (!is_instance_pixel(rgb_instance_label))
|
||||
continue;
|
||||
|
||||
const auto rgb_class_label = class_label_image(r, c);
|
||||
const Voc2012class& voc2012_class = find_voc2012_class(rgb_class_label);
|
||||
|
||||
const auto i = result_map.find(rgb_instance_label);
|
||||
if (i == result_map.end())
|
||||
{
|
||||
// Encountered a new instance
|
||||
result_map[rgb_instance_label] = rectangle(c, r, c, r);
|
||||
result_map[rgb_instance_label].label = voc2012_class.classlabel;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Not the first occurrence - update the rect
|
||||
auto& rect = i->second.rect;
|
||||
|
||||
if (c < rect.left())
|
||||
rect.set_left(c);
|
||||
else if (c > rect.right())
|
||||
rect.set_right(c);
|
||||
|
||||
if (r > rect.bottom())
|
||||
rect.set_bottom(r);
|
||||
|
||||
DLIB_CASSERT(i->second.label == voc2012_class.classlabel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<truth_instance> flat_result;
|
||||
flat_result.reserve(result_map.size());
|
||||
|
||||
for (const auto& i : result_map) {
|
||||
flat_result.push_back(truth_instance{
|
||||
i.first, i.second
|
||||
});
|
||||
}
|
||||
|
||||
return flat_result;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct truth_image
|
||||
{
|
||||
image_info info;
|
||||
std::vector<truth_instance> truth_instances;
|
||||
};
|
||||
|
||||
std::vector<mmod_rect> extract_mmod_rects(
|
||||
const std::vector<truth_instance>& truth_instances
|
||||
)
|
||||
{
|
||||
std::vector<mmod_rect> mmod_rects(truth_instances.size());
|
||||
|
||||
std::transform(
|
||||
truth_instances.begin(),
|
||||
truth_instances.end(),
|
||||
mmod_rects.begin(),
|
||||
[](const truth_instance& truth) { return truth.mmod_rect; }
|
||||
);
|
||||
|
||||
return mmod_rects;
|
||||
};
|
||||
|
||||
std::vector<std::vector<mmod_rect>> extract_mmod_rect_vectors(
|
||||
const std::vector<truth_image>& truth_images
|
||||
)
|
||||
{
|
||||
std::vector<std::vector<mmod_rect>> mmod_rects(truth_images.size());
|
||||
|
||||
const auto extract_mmod_rects_from_truth_image = [](const truth_image& truth_image)
|
||||
{
|
||||
return extract_mmod_rects(truth_image.truth_instances);
|
||||
};
|
||||
|
||||
std::transform(
|
||||
truth_images.begin(),
|
||||
truth_images.end(),
|
||||
mmod_rects.begin(),
|
||||
extract_mmod_rects_from_truth_image
|
||||
);
|
||||
|
||||
return mmod_rects;
|
||||
}
|
||||
|
||||
det_bnet_type train_detection_network(
|
||||
const std::vector<truth_image>& truth_images,
|
||||
unsigned int det_minibatch_size
|
||||
)
|
||||
{
|
||||
const double initial_learning_rate = 0.1;
|
||||
const double weight_decay = 0.0001;
|
||||
const double momentum = 0.9;
|
||||
const double min_detector_window_overlap_iou = 0.65;
|
||||
|
||||
const int target_size = 70;
|
||||
const int min_target_size = 30;
|
||||
|
||||
mmod_options options(
|
||||
extract_mmod_rect_vectors(truth_images),
|
||||
target_size, min_target_size,
|
||||
min_detector_window_overlap_iou
|
||||
);
|
||||
|
||||
options.overlaps_ignore = test_box_overlap(0.5, 0.9);
|
||||
|
||||
det_bnet_type det_net(options);
|
||||
|
||||
det_net.subnet().layer_details().set_num_filters(options.detector_windows.size());
|
||||
|
||||
dlib::pipe<det_training_sample> data(200);
|
||||
auto f = [&data, &truth_images, target_size, min_target_size](time_t seed)
|
||||
{
|
||||
dlib::rand rnd(time(0) + seed);
|
||||
matrix<rgb_pixel> input_image;
|
||||
|
||||
random_cropper cropper;
|
||||
cropper.set_seed(time(0));
|
||||
cropper.set_chip_dims(350, 350);
|
||||
|
||||
// Usually you want to give the cropper whatever min sizes you passed to the
|
||||
// mmod_options constructor, or very slightly smaller sizes, which is what we do here.
|
||||
cropper.set_min_object_size(target_size - 2, min_target_size - 2);
|
||||
cropper.set_max_rotation_degrees(2);
|
||||
|
||||
det_training_sample temp;
|
||||
|
||||
while (data.is_enabled())
|
||||
{
|
||||
// Pick a random input image.
|
||||
const auto random_index = rnd.get_random_32bit_number() % truth_images.size();
|
||||
const auto& truth_image = truth_images[random_index];
|
||||
|
||||
// Load the input image.
|
||||
load_image(input_image, truth_image.info.image_filename);
|
||||
|
||||
// Get a random crop of the input.
|
||||
const auto mmod_rects = extract_mmod_rects(truth_image.truth_instances);
|
||||
cropper(input_image, mmod_rects, temp.input_image, temp.mmod_rects);
|
||||
|
||||
disturb_colors(temp.input_image, rnd);
|
||||
|
||||
// Push the result to be used by the trainer.
|
||||
data.enqueue(temp);
|
||||
}
|
||||
};
|
||||
std::thread data_loader1([f]() { f(1); });
|
||||
std::thread data_loader2([f]() { f(2); });
|
||||
std::thread data_loader3([f]() { f(3); });
|
||||
std::thread data_loader4([f]() { f(4); });
|
||||
|
||||
const auto stop_data_loaders = [&]()
|
||||
{
|
||||
data.disable();
|
||||
data_loader1.join();
|
||||
data_loader2.join();
|
||||
data_loader3.join();
|
||||
data_loader4.join();
|
||||
};
|
||||
|
||||
dnn_trainer<det_bnet_type> det_trainer(det_net, sgd(weight_decay, momentum));
|
||||
|
||||
try
|
||||
{
|
||||
det_trainer.be_verbose();
|
||||
det_trainer.set_learning_rate(initial_learning_rate);
|
||||
det_trainer.set_synchronization_file("pascal_voc2012_det_trainer_state_file.dat", std::chrono::minutes(10));
|
||||
det_trainer.set_iterations_without_progress_threshold(5000);
|
||||
|
||||
// Output training parameters.
|
||||
cout << det_trainer << endl;
|
||||
|
||||
std::vector<matrix<rgb_pixel>> samples;
|
||||
std::vector<std::vector<mmod_rect>> labels;
|
||||
|
||||
// The main training loop. Keep making mini-batches and giving them to the trainer.
|
||||
// We will run until the learning rate becomes small enough.
|
||||
while (det_trainer.get_learning_rate() >= 1e-4)
|
||||
{
|
||||
samples.clear();
|
||||
labels.clear();
|
||||
|
||||
// make a mini-batch
|
||||
det_training_sample temp;
|
||||
while (samples.size() < det_minibatch_size)
|
||||
{
|
||||
data.dequeue(temp);
|
||||
|
||||
samples.push_back(std::move(temp.input_image));
|
||||
labels.push_back(std::move(temp.mmod_rects));
|
||||
}
|
||||
|
||||
det_trainer.train_one_step(samples, labels);
|
||||
}
|
||||
}
|
||||
catch (std::exception&)
|
||||
{
|
||||
stop_data_loaders();
|
||||
throw;
|
||||
}
|
||||
|
||||
// Training done, tell threads to stop and make sure to wait for them to finish before
|
||||
// moving on.
|
||||
stop_data_loaders();
|
||||
|
||||
// also wait for threaded processing to stop in the trainer.
|
||||
det_trainer.get_net();
|
||||
|
||||
det_net.clean();
|
||||
|
||||
return det_net;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
matrix<uint16_t> keep_only_current_instance(const matrix<rgb_pixel>& rgb_label_image, const rgb_pixel rgb_label)
|
||||
{
|
||||
const auto nr = rgb_label_image.nr();
|
||||
const auto nc = rgb_label_image.nc();
|
||||
|
||||
matrix<uint16_t> result(nr, nc);
|
||||
|
||||
for (long r = 0; r < nr; ++r)
|
||||
{
|
||||
for (long c = 0; c < nc; ++c)
|
||||
{
|
||||
const auto& index = rgb_label_image(r, c);
|
||||
if (index == rgb_label)
|
||||
result(r, c) = 1;
|
||||
else if (index == dlib::rgb_pixel(224, 224, 192))
|
||||
result(r, c) = dlib::loss_multiclass_log_per_pixel_::label_to_ignore;
|
||||
else
|
||||
result(r, c) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
seg_bnet_type train_segmentation_network(
|
||||
const std::vector<truth_image>& truth_images,
|
||||
unsigned int seg_minibatch_size,
|
||||
const std::string& classlabel
|
||||
)
|
||||
{
|
||||
seg_bnet_type seg_net;
|
||||
|
||||
const double initial_learning_rate = 0.1;
|
||||
const double weight_decay = 0.0001;
|
||||
const double momentum = 0.9;
|
||||
|
||||
const std::string synchronization_file_name
|
||||
= "pascal_voc2012_seg_trainer_state_file"
|
||||
+ (classlabel.empty() ? "" : ("_" + classlabel))
|
||||
+ ".dat";
|
||||
|
||||
dnn_trainer<seg_bnet_type> seg_trainer(seg_net, sgd(weight_decay, momentum));
|
||||
seg_trainer.be_verbose();
|
||||
seg_trainer.set_learning_rate(initial_learning_rate);
|
||||
seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10));
|
||||
seg_trainer.set_iterations_without_progress_threshold(2000);
|
||||
set_all_bn_running_stats_window_sizes(seg_net, 1000);
|
||||
|
||||
// Output training parameters.
|
||||
cout << seg_trainer << endl;
|
||||
|
||||
std::vector<matrix<rgb_pixel>> samples;
|
||||
std::vector<matrix<uint16_t>> labels;
|
||||
|
||||
// Start a bunch of threads that read images from disk and pull out random crops. It's
|
||||
// important to be sure to feed the GPU fast enough to keep it busy. Using multiple
|
||||
// thread for this kind of data preparation helps us do that. Each thread puts the
|
||||
// crops into the data queue.
|
||||
dlib::pipe<seg_training_sample> data(200);
|
||||
auto f = [&data, &truth_images](time_t seed)
|
||||
{
|
||||
dlib::rand rnd(time(0) + seed);
|
||||
matrix<rgb_pixel> input_image;
|
||||
matrix<rgb_pixel> rgb_label_image;
|
||||
matrix<rgb_pixel> rgb_label_chip;
|
||||
seg_training_sample temp;
|
||||
while (data.is_enabled())
|
||||
{
|
||||
// Pick a random input image.
|
||||
const auto random_index = rnd.get_random_32bit_number() % truth_images.size();
|
||||
const auto& truth_image = truth_images[random_index];
|
||||
const auto image_truths = truth_image.truth_instances;
|
||||
|
||||
if (!image_truths.empty())
|
||||
{
|
||||
const image_info& info = truth_image.info;
|
||||
|
||||
// Load the input image.
|
||||
load_image(input_image, info.image_filename);
|
||||
|
||||
// Load the ground-truth (RGB) instance labels.
|
||||
load_image(rgb_label_image, info.instance_label_filename);
|
||||
|
||||
// Pick a random training instance.
|
||||
const auto& truth_instance = image_truths[rnd.get_random_32bit_number() % image_truths.size()];
|
||||
const auto& truth_rect = truth_instance.mmod_rect.rect;
|
||||
const auto cropping_rect = get_cropping_rect(truth_rect);
|
||||
|
||||
// Pick a random crop around the instance.
|
||||
const auto max_x_translate_amount = static_cast<long>(truth_rect.width() / 10.0);
|
||||
const auto max_y_translate_amount = static_cast<long>(truth_rect.height() / 10.0);
|
||||
|
||||
const auto random_translate = point(
|
||||
rnd.get_integer_in_range(-max_x_translate_amount, max_x_translate_amount + 1),
|
||||
rnd.get_integer_in_range(-max_y_translate_amount, max_y_translate_amount + 1)
|
||||
);
|
||||
|
||||
const rectangle random_rect(
|
||||
cropping_rect.left() + random_translate.x(),
|
||||
cropping_rect.top() + random_translate.y(),
|
||||
cropping_rect.right() + random_translate.x(),
|
||||
cropping_rect.bottom() + random_translate.y()
|
||||
);
|
||||
|
||||
const chip_details chip_details(random_rect, chip_dims(seg_dim, seg_dim));
|
||||
|
||||
// Crop the input image.
|
||||
extract_image_chip(input_image, chip_details, temp.input_image, interpolate_bilinear());
|
||||
|
||||
disturb_colors(temp.input_image, rnd);
|
||||
|
||||
// Crop the labels correspondingly. However, note that here bilinear
|
||||
// interpolation would make absolutely no sense - you wouldn't say that
|
||||
// a bicycle is half-way between an aeroplane and a bird, would you?
|
||||
extract_image_chip(rgb_label_image, chip_details, rgb_label_chip, interpolate_nearest_neighbor());
|
||||
|
||||
// Clear pixels not related to the current instance.
|
||||
temp.label_image = keep_only_current_instance(rgb_label_chip, truth_instance.rgb_label);
|
||||
|
||||
// Push the result to be used by the trainer.
|
||||
data.enqueue(temp);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use background samples as well
|
||||
}
|
||||
}
|
||||
};
|
||||
std::thread data_loader1([f]() { f(1); });
|
||||
std::thread data_loader2([f]() { f(2); });
|
||||
std::thread data_loader3([f]() { f(3); });
|
||||
std::thread data_loader4([f]() { f(4); });
|
||||
|
||||
const auto stop_data_loaders = [&]()
|
||||
{
|
||||
data.disable();
|
||||
data_loader1.join();
|
||||
data_loader2.join();
|
||||
data_loader3.join();
|
||||
data_loader4.join();
|
||||
};
|
||||
|
||||
try
|
||||
{
|
||||
// The main training loop. Keep making mini-batches and giving them to the trainer.
|
||||
// We will run until the learning rate has dropped by a factor of 1e-4.
|
||||
while (seg_trainer.get_learning_rate() >= 1e-4)
|
||||
{
|
||||
samples.clear();
|
||||
labels.clear();
|
||||
|
||||
// make a mini-batch
|
||||
seg_training_sample temp;
|
||||
while (samples.size() < seg_minibatch_size)
|
||||
{
|
||||
data.dequeue(temp);
|
||||
|
||||
samples.push_back(std::move(temp.input_image));
|
||||
labels.push_back(std::move(temp.label_image));
|
||||
}
|
||||
|
||||
seg_trainer.train_one_step(samples, labels);
|
||||
}
|
||||
}
|
||||
catch (std::exception&)
|
||||
{
|
||||
stop_data_loaders();
|
||||
throw;
|
||||
}
|
||||
|
||||
// Training done, tell threads to stop and make sure to wait for them to finish before
|
||||
// moving on.
|
||||
stop_data_loaders();
|
||||
|
||||
// also wait for threaded processing to stop in the trainer.
|
||||
seg_trainer.get_net();
|
||||
|
||||
seg_net.clean();
|
||||
|
||||
return seg_net;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int ignore_overlapped_boxes(
|
||||
std::vector<truth_instance>& truth_instances,
|
||||
const test_box_overlap& overlaps
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- Whenever two rectangles in boxes overlap, according to overlaps(), we set the
|
||||
smallest box to ignore.
|
||||
- returns the number of newly ignored boxes.
|
||||
!*/
|
||||
{
|
||||
int num_ignored = 0;
|
||||
for (size_t i = 0, end = truth_instances.size(); i < end; ++i)
|
||||
{
|
||||
auto& box_i = truth_instances[i].mmod_rect;
|
||||
if (box_i.ignore)
|
||||
continue;
|
||||
for (size_t j = i+1; j < end; ++j)
|
||||
{
|
||||
auto& box_j = truth_instances[j].mmod_rect;
|
||||
if (box_j.ignore)
|
||||
continue;
|
||||
if (overlaps(box_i, box_j))
|
||||
{
|
||||
++num_ignored;
|
||||
if(box_i.rect.area() < box_j.rect.area())
|
||||
box_i.ignore = true;
|
||||
else
|
||||
box_j.ignore = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_ignored;
|
||||
}
|
||||
|
||||
std::vector<truth_instance> load_truth_instances(const image_info& info)
|
||||
{
|
||||
matrix<rgb_pixel> instance_label_image;
|
||||
matrix<rgb_pixel> class_label_image;
|
||||
|
||||
load_image(instance_label_image, info.instance_label_filename);
|
||||
load_image(class_label_image, info.class_label_filename);
|
||||
|
||||
return rgb_label_images_to_truth_instances(instance_label_image, class_label_image);
|
||||
};
|
||||
|
||||
std::vector<std::vector<truth_instance>> load_all_truth_instances(const std::vector<image_info>& listing)
|
||||
{
|
||||
std::vector<std::vector<truth_instance>> truth_instances(listing.size());
|
||||
|
||||
std::transform(
|
||||
#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
|
||||
std::execution::par,
|
||||
#endif // __cplusplus >= 201703L
|
||||
listing.begin(),
|
||||
listing.end(),
|
||||
truth_instances.begin(),
|
||||
load_truth_instances
|
||||
);
|
||||
|
||||
return truth_instances;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
std::vector<truth_image> filter_based_on_classlabel(
|
||||
const std::vector<truth_image>& truth_images,
|
||||
const std::vector<std::string>& desired_classlabels
|
||||
)
|
||||
{
|
||||
std::vector<truth_image> result;
|
||||
|
||||
const auto represents_desired_class = [&desired_classlabels](const truth_instance& truth_instance) {
|
||||
return std::find(
|
||||
desired_classlabels.begin(),
|
||||
desired_classlabels.end(),
|
||||
truth_instance.mmod_rect.label
|
||||
) != desired_classlabels.end();
|
||||
};
|
||||
|
||||
for (const auto& input : truth_images)
|
||||
{
|
||||
const auto has_desired_class = std::any_of(
|
||||
input.truth_instances.begin(),
|
||||
input.truth_instances.end(),
|
||||
represents_desired_class
|
||||
);
|
||||
|
||||
if (has_desired_class) {
|
||||
|
||||
// NB: This keeps only MMOD rects belonging to any of the desired classes.
|
||||
// A reasonable alternative could be to keep all rects, but mark those
|
||||
// belonging in other classes to be ignored during training.
|
||||
std::vector<truth_instance> temp;
|
||||
std::copy_if(
|
||||
input.truth_instances.begin(),
|
||||
input.truth_instances.end(),
|
||||
std::back_inserter(temp),
|
||||
represents_desired_class
|
||||
);
|
||||
|
||||
result.push_back(truth_image{ input.info, temp });
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Ignore truth boxes that overlap too much, are too small, or have a large aspect ratio.
|
||||
void ignore_some_truth_boxes(std::vector<truth_image>& truth_images)
|
||||
{
|
||||
for (auto& i : truth_images)
|
||||
{
|
||||
auto& truth_instances = i.truth_instances;
|
||||
|
||||
ignore_overlapped_boxes(truth_instances, test_box_overlap(0.90, 0.95));
|
||||
|
||||
for (auto& truth : truth_instances)
|
||||
{
|
||||
if (truth.mmod_rect.ignore)
|
||||
continue;
|
||||
|
||||
const auto& rect = truth.mmod_rect.rect;
|
||||
|
||||
constexpr unsigned long min_width = 35;
|
||||
constexpr unsigned long min_height = 35;
|
||||
if (rect.width() < min_width && rect.height() < min_height)
|
||||
{
|
||||
truth.mmod_rect.ignore = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
constexpr double max_aspect_ratio_width_to_height = 3.0;
|
||||
constexpr double max_aspect_ratio_height_to_width = 1.5;
|
||||
const double aspect_ratio_width_to_height = rect.width() / static_cast<double>(rect.height());
|
||||
const double aspect_ratio_height_to_width = 1.0 / aspect_ratio_width_to_height;
|
||||
const bool is_aspect_ratio_too_large
|
||||
= aspect_ratio_width_to_height > max_aspect_ratio_width_to_height
|
||||
|| aspect_ratio_height_to_width > max_aspect_ratio_height_to_width;
|
||||
|
||||
if (is_aspect_ratio_too_large)
|
||||
truth.mmod_rect.ignore = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter images that have no (non-ignored) truth
|
||||
std::vector<truth_image> filter_images_with_no_truth(const std::vector<truth_image>& truth_images)
|
||||
{
|
||||
std::vector<truth_image> result;
|
||||
|
||||
for (const auto& truth_image : truth_images)
|
||||
{
|
||||
const auto ignored = [](const truth_instance& truth) { return truth.mmod_rect.ignore; };
|
||||
const auto& truth_instances = truth_image.truth_instances;
|
||||
if (!std::all_of(truth_instances.begin(), truth_instances.end(), ignored))
|
||||
result.push_back(truth_image);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) try
|
||||
{
|
||||
if (argc < 2)
|
||||
{
|
||||
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
|
||||
cout << endl;
|
||||
cout << "You call this program like this: " << endl;
|
||||
cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-1] [class-2] [class-3] ..." << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl;
|
||||
|
||||
const auto listing = get_pascal_voc2012_train_listing(argv[1]);
|
||||
cout << "images in entire dataset: " << listing.size() << endl;
|
||||
if (listing.size() == 0)
|
||||
{
|
||||
cout << "Didn't find the VOC2012 dataset. " << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// mini-batches smaller than the default can be used with GPUs having less memory
|
||||
const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 35;
|
||||
const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 100;
|
||||
cout << "det mini-batch size: " << det_minibatch_size << endl;
|
||||
cout << "seg mini-batch size: " << seg_minibatch_size << endl;
|
||||
|
||||
std::vector<std::string> desired_classlabels;
|
||||
|
||||
for (int arg = 4; arg < argc; ++arg)
|
||||
desired_classlabels.push_back(argv[arg]);
|
||||
|
||||
if (desired_classlabels.empty())
|
||||
{
|
||||
desired_classlabels.push_back("bicycle");
|
||||
desired_classlabels.push_back("car");
|
||||
desired_classlabels.push_back("cat");
|
||||
}
|
||||
|
||||
cout << "desired classlabels:";
|
||||
for (const auto& desired_classlabel : desired_classlabels)
|
||||
cout << " " << desired_classlabel;
|
||||
cout << endl;
|
||||
|
||||
// extract the MMOD rects
|
||||
cout << endl << "Extracting all truth instances...";
|
||||
const auto truth_instances = load_all_truth_instances(listing);
|
||||
cout << " Done!" << endl << endl;
|
||||
|
||||
DLIB_CASSERT(listing.size() == truth_instances.size());
|
||||
|
||||
std::vector<truth_image> original_truth_images;
|
||||
for (size_t i = 0, end = listing.size(); i < end; ++i)
|
||||
{
|
||||
original_truth_images.push_back(truth_image{
|
||||
listing[i], truth_instances[i]
|
||||
});
|
||||
}
|
||||
|
||||
auto truth_images_filtered_by_class = filter_based_on_classlabel(original_truth_images, desired_classlabels);
|
||||
|
||||
cout << "images in dataset filtered by class: " << truth_images_filtered_by_class.size() << endl;
|
||||
|
||||
ignore_some_truth_boxes(truth_images_filtered_by_class);
|
||||
const auto truth_images = filter_images_with_no_truth(truth_images_filtered_by_class);
|
||||
|
||||
cout << "images in dataset after ignoring some truth boxes: " << truth_images.size() << endl;
|
||||
|
||||
// First train an object detector network (loss_mmod).
|
||||
cout << endl << "Training detector network:" << endl;
|
||||
const auto det_net = train_detection_network(truth_images, det_minibatch_size);
|
||||
|
||||
// Then train mask predictors (segmentation).
|
||||
std::map<std::string, seg_bnet_type> seg_nets_by_class;
|
||||
|
||||
// This flag controls if a separate mask predictor is trained for each class.
|
||||
// Note that it would also be possible to train a separate mask predictor for
|
||||
// class groups, each containing somehow similar classes -- for example, one
|
||||
// mask predictor for cars and buses, another for cats and dogs, and so on.
|
||||
constexpr bool separate_seg_net_for_each_class = true;
|
||||
|
||||
if (separate_seg_net_for_each_class)
|
||||
{
|
||||
for (const auto& classlabel : desired_classlabels)
|
||||
{
|
||||
// Consider only the truth images belonging to this class.
|
||||
const auto class_images = filter_based_on_classlabel(truth_images, { classlabel });
|
||||
|
||||
cout << endl << "Training segmentation network for class " << classlabel << ":" << endl;
|
||||
seg_nets_by_class[classlabel] = train_segmentation_network(class_images, seg_minibatch_size, classlabel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cout << "Training a single segmentation network:" << endl;
|
||||
seg_nets_by_class[""] = train_segmentation_network(truth_images, seg_minibatch_size, "");
|
||||
}
|
||||
|
||||
cout << "Saving networks" << endl;
|
||||
serialize(instance_segmentation_net_filename) << det_net << seg_nets_by_class;
|
||||
}
|
||||
|
||||
catch(std::exception& e)
|
||||
{
|
||||
cout << e.what() << endl;
|
||||
}
|
||||
|
@ -34,83 +34,7 @@
|
||||
#define DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
|
||||
|
||||
#include <dlib/dnn.h>
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
inline bool operator == (const dlib::rgb_pixel& a, const dlib::rgb_pixel& b)
|
||||
{
|
||||
return a.red == b.red && a.green == b.green && a.blue == b.blue;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
|
||||
// is represented using an RGB color value. We associate each class also to an index in the
|
||||
// range [0, 20], used internally by the network.
|
||||
|
||||
struct Voc2012class {
|
||||
Voc2012class(uint16_t index, const dlib::rgb_pixel& rgb_label, const std::string& classlabel)
|
||||
: index(index), rgb_label(rgb_label), classlabel(classlabel)
|
||||
{}
|
||||
|
||||
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
|
||||
const uint16_t index = 0;
|
||||
|
||||
// The corresponding RGB representation of the class.
|
||||
const dlib::rgb_pixel rgb_label;
|
||||
|
||||
// The label of the class in plain text.
|
||||
const std::string classlabel;
|
||||
};
|
||||
|
||||
namespace {
|
||||
constexpr int class_count = 21; // background + 20 classes
|
||||
|
||||
const std::vector<Voc2012class> classes = {
|
||||
Voc2012class(0, dlib::rgb_pixel(0, 0, 0), ""), // background
|
||||
|
||||
// The cream-colored `void' label is used in border regions and to mask difficult objects
|
||||
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
|
||||
Voc2012class(dlib::loss_multiclass_log_per_pixel_::label_to_ignore,
|
||||
dlib::rgb_pixel(224, 224, 192), "border"),
|
||||
|
||||
Voc2012class(1, dlib::rgb_pixel(128, 0, 0), "aeroplane"),
|
||||
Voc2012class(2, dlib::rgb_pixel( 0, 128, 0), "bicycle"),
|
||||
Voc2012class(3, dlib::rgb_pixel(128, 128, 0), "bird"),
|
||||
Voc2012class(4, dlib::rgb_pixel( 0, 0, 128), "boat"),
|
||||
Voc2012class(5, dlib::rgb_pixel(128, 0, 128), "bottle"),
|
||||
Voc2012class(6, dlib::rgb_pixel( 0, 128, 128), "bus"),
|
||||
Voc2012class(7, dlib::rgb_pixel(128, 128, 128), "car"),
|
||||
Voc2012class(8, dlib::rgb_pixel( 64, 0, 0), "cat"),
|
||||
Voc2012class(9, dlib::rgb_pixel(192, 0, 0), "chair"),
|
||||
Voc2012class(10, dlib::rgb_pixel( 64, 128, 0), "cow"),
|
||||
Voc2012class(11, dlib::rgb_pixel(192, 128, 0), "diningtable"),
|
||||
Voc2012class(12, dlib::rgb_pixel( 64, 0, 128), "dog"),
|
||||
Voc2012class(13, dlib::rgb_pixel(192, 0, 128), "horse"),
|
||||
Voc2012class(14, dlib::rgb_pixel( 64, 128, 128), "motorbike"),
|
||||
Voc2012class(15, dlib::rgb_pixel(192, 128, 128), "person"),
|
||||
Voc2012class(16, dlib::rgb_pixel( 0, 64, 0), "pottedplant"),
|
||||
Voc2012class(17, dlib::rgb_pixel(128, 64, 0), "sheep"),
|
||||
Voc2012class(18, dlib::rgb_pixel( 0, 192, 0), "sofa"),
|
||||
Voc2012class(19, dlib::rgb_pixel(128, 192, 0), "train"),
|
||||
Voc2012class(20, dlib::rgb_pixel( 0, 64, 128), "tvmonitor"),
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Predicate>
|
||||
const Voc2012class& find_voc2012_class(Predicate predicate)
|
||||
{
|
||||
const auto i = std::find_if(classes.begin(), classes.end(), predicate);
|
||||
|
||||
if (i != classes.end())
|
||||
{
|
||||
return *i;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unable to find a matching VOC2012 class");
|
||||
}
|
||||
}
|
||||
#include "pascal_voc_2012.h"
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -91,107 +91,6 @@ void randomly_crop_image (
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
|
||||
// data set.
|
||||
struct image_info
|
||||
{
|
||||
string image_filename;
|
||||
string label_filename;
|
||||
};
|
||||
|
||||
// Read the list of image files belonging to either the "train", "trainval", or "val" set
|
||||
// of the PASCAL VOC2012 data.
|
||||
std::vector<image_info> get_pascal_voc2012_listing(
|
||||
const std::string& voc2012_folder,
|
||||
const std::string& file = "train" // "train", "trainval", or "val"
|
||||
)
|
||||
{
|
||||
std::ifstream in(voc2012_folder + "/ImageSets/Segmentation/" + file + ".txt");
|
||||
|
||||
std::vector<image_info> results;
|
||||
|
||||
while (in)
|
||||
{
|
||||
std::string basename;
|
||||
in >> basename;
|
||||
|
||||
if (!basename.empty())
|
||||
{
|
||||
image_info image_info;
|
||||
image_info.image_filename = voc2012_folder + "/JPEGImages/" + basename + ".jpg";
|
||||
image_info.label_filename = voc2012_folder + "/SegmentationClass/" + basename + ".png";
|
||||
results.push_back(image_info);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
|
||||
std::vector<image_info> get_pascal_voc2012_train_listing(
|
||||
const std::string& voc2012_folder
|
||||
)
|
||||
{
|
||||
return get_pascal_voc2012_listing(voc2012_folder, "train");
|
||||
}
|
||||
|
||||
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
|
||||
std::vector<image_info> get_pascal_voc2012_val_listing(
|
||||
const std::string& voc2012_folder
|
||||
)
|
||||
{
|
||||
return get_pascal_voc2012_listing(voc2012_folder, "val");
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
|
||||
// is represented using an RGB color value. We associate each class also to an index in the
|
||||
// range [0, 20], used internally by the network. To convert the ground-truth data to
|
||||
// something that the network can efficiently digest, we need to be able to map the RGB
|
||||
// values to the corresponding indexes.
|
||||
|
||||
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
|
||||
// (e.g., 'dog').
|
||||
const Voc2012class& find_voc2012_class(const dlib::rgb_pixel& rgb_label)
|
||||
{
|
||||
return find_voc2012_class(
|
||||
[&rgb_label](const Voc2012class& voc2012class)
|
||||
{
|
||||
return rgb_label == voc2012class.rgb_label;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Convert an RGB class label to an index in the range [0, 20].
|
||||
inline uint16_t rgb_label_to_index_label(const dlib::rgb_pixel& rgb_label)
|
||||
{
|
||||
return find_voc2012_class(rgb_label).index;
|
||||
}
|
||||
|
||||
// Convert an image containing RGB class labels to a corresponding
|
||||
// image containing indexes in the range [0, 20].
|
||||
void rgb_label_image_to_index_label_image(
|
||||
const dlib::matrix<dlib::rgb_pixel>& rgb_label_image,
|
||||
dlib::matrix<uint16_t>& index_label_image
|
||||
)
|
||||
{
|
||||
const long nr = rgb_label_image.nr();
|
||||
const long nc = rgb_label_image.nc();
|
||||
|
||||
index_label_image.set_size(nr, nc);
|
||||
|
||||
for (long r = 0; r < nr; ++r)
|
||||
{
|
||||
for (long c = 0; c < nc; ++c)
|
||||
{
|
||||
index_label_image(r, c) = rgb_label_to_index_label(rgb_label_image(r, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
|
||||
double calculate_accuracy(anet_type& anet, const std::vector<image_info>& dataset)
|
||||
{
|
||||
@ -209,14 +108,14 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase
|
||||
load_image(input_image, image_info.image_filename);
|
||||
|
||||
// Load the ground-truth (RGB) labels.
|
||||
load_image(rgb_label_image, image_info.label_filename);
|
||||
load_image(rgb_label_image, image_info.class_label_filename);
|
||||
|
||||
// Create predictions for each pixel. At this point, the type of each prediction
|
||||
// is an index (a value between 0 and 20). Note that the net may return an image
|
||||
// that is not exactly the same size as the input.
|
||||
const matrix<uint16_t> temp = anet(input_image);
|
||||
|
||||
// Convert the indexes to RGB values.
|
||||
// Convert the RGB values to indexes.
|
||||
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
|
||||
|
||||
// Crop the net output to be exactly the same size as the input.
|
||||
@ -324,9 +223,9 @@ int main(int argc, char** argv) try
|
||||
load_image(input_image, image_info.image_filename);
|
||||
|
||||
// Load the ground-truth (RGB) labels.
|
||||
load_image(rgb_label_image, image_info.label_filename);
|
||||
load_image(rgb_label_image, image_info.class_label_filename);
|
||||
|
||||
// Convert the indexes to RGB values.
|
||||
// Convert the RGB values to indexes.
|
||||
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
|
||||
|
||||
// Randomly pick a part of the image.
|
||||
|
180
examples/pascal_voc_2012.h
Normal file
180
examples/pascal_voc_2012.h
Normal file
@ -0,0 +1,180 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
Helper definitions for working with the PASCAL VOC2012 dataset.
|
||||
*/
|
||||
|
||||
#ifndef PASCAL_VOC_2012_H_
|
||||
#define PASCAL_VOC_2012_H_
|
||||
|
||||
#include <dlib/pixel.h>
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
|
||||
// is represented using an RGB color value. We associate each class also to an index in the
|
||||
// range [0, 20], used internally by the network. To convert the ground-truth data to
|
||||
// something that the network can efficiently digest, we need to be able to map the RGB
|
||||
// values to the corresponding indexes.
|
||||
|
||||
struct Voc2012class {
|
||||
Voc2012class(uint16_t index, const dlib::rgb_pixel& rgb_label, const std::string& classlabel)
|
||||
: index(index), rgb_label(rgb_label), classlabel(classlabel)
|
||||
{}
|
||||
|
||||
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
|
||||
const uint16_t index = 0;
|
||||
|
||||
// The corresponding RGB representation of the class.
|
||||
const dlib::rgb_pixel rgb_label;
|
||||
|
||||
// The label of the class in plain text.
|
||||
const std::string classlabel;
|
||||
};
|
||||
|
||||
namespace {
|
||||
constexpr int class_count = 21; // background + 20 classes
|
||||
|
||||
const std::vector<Voc2012class> classes = {
|
||||
Voc2012class(0, dlib::rgb_pixel(0, 0, 0), ""), // background
|
||||
|
||||
// The cream-colored `void' label is used in border regions and to mask difficult objects
|
||||
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
|
||||
Voc2012class(dlib::loss_multiclass_log_per_pixel_::label_to_ignore,
|
||||
dlib::rgb_pixel(224, 224, 192), "border"),
|
||||
|
||||
Voc2012class(1, dlib::rgb_pixel(128, 0, 0), "aeroplane"),
|
||||
Voc2012class(2, dlib::rgb_pixel( 0, 128, 0), "bicycle"),
|
||||
Voc2012class(3, dlib::rgb_pixel(128, 128, 0), "bird"),
|
||||
Voc2012class(4, dlib::rgb_pixel( 0, 0, 128), "boat"),
|
||||
Voc2012class(5, dlib::rgb_pixel(128, 0, 128), "bottle"),
|
||||
Voc2012class(6, dlib::rgb_pixel( 0, 128, 128), "bus"),
|
||||
Voc2012class(7, dlib::rgb_pixel(128, 128, 128), "car"),
|
||||
Voc2012class(8, dlib::rgb_pixel( 64, 0, 0), "cat"),
|
||||
Voc2012class(9, dlib::rgb_pixel(192, 0, 0), "chair"),
|
||||
Voc2012class(10, dlib::rgb_pixel( 64, 128, 0), "cow"),
|
||||
Voc2012class(11, dlib::rgb_pixel(192, 128, 0), "diningtable"),
|
||||
Voc2012class(12, dlib::rgb_pixel( 64, 0, 128), "dog"),
|
||||
Voc2012class(13, dlib::rgb_pixel(192, 0, 128), "horse"),
|
||||
Voc2012class(14, dlib::rgb_pixel( 64, 128, 128), "motorbike"),
|
||||
Voc2012class(15, dlib::rgb_pixel(192, 128, 128), "person"),
|
||||
Voc2012class(16, dlib::rgb_pixel( 0, 64, 0), "pottedplant"),
|
||||
Voc2012class(17, dlib::rgb_pixel(128, 64, 0), "sheep"),
|
||||
Voc2012class(18, dlib::rgb_pixel( 0, 192, 0), "sofa"),
|
||||
Voc2012class(19, dlib::rgb_pixel(128, 192, 0), "train"),
|
||||
Voc2012class(20, dlib::rgb_pixel( 0, 64, 128), "tvmonitor"),
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Predicate>
|
||||
const Voc2012class& find_voc2012_class(Predicate predicate)
|
||||
{
|
||||
const auto i = std::find_if(classes.begin(), classes.end(), predicate);
|
||||
|
||||
if (i != classes.end())
|
||||
{
|
||||
return *i;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unable to find a matching VOC2012 class");
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
|
||||
// data set.
|
||||
struct image_info
|
||||
{
|
||||
std::string image_filename;
|
||||
std::string class_label_filename;
|
||||
std::string instance_label_filename;
|
||||
};
|
||||
|
||||
// Read the list of image files belonging to either the "train", "trainval", or "val" set
|
||||
// of the PASCAL VOC2012 data.
|
||||
std::vector<image_info> get_pascal_voc2012_listing(
|
||||
const std::string& voc2012_folder,
|
||||
const std::string& file = "train" // "train", "trainval", or "val"
|
||||
)
|
||||
{
|
||||
std::ifstream in(voc2012_folder + "/ImageSets/Segmentation/" + file + ".txt");
|
||||
|
||||
std::vector<image_info> results;
|
||||
|
||||
while (in)
|
||||
{
|
||||
std::string basename;
|
||||
in >> basename;
|
||||
|
||||
if (!basename.empty())
|
||||
{
|
||||
image_info info;
|
||||
info.image_filename = voc2012_folder + "/JPEGImages/" + basename + ".jpg";
|
||||
info.class_label_filename = voc2012_folder + "/SegmentationClass/" + basename + ".png";
|
||||
info.instance_label_filename = voc2012_folder + "/SegmentationObject/" + basename + ".png";
|
||||
results.push_back(info);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
|
||||
std::vector<image_info> get_pascal_voc2012_train_listing(
|
||||
const std::string& voc2012_folder
|
||||
)
|
||||
{
|
||||
return get_pascal_voc2012_listing(voc2012_folder, "train");
|
||||
}
|
||||
|
||||
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
|
||||
std::vector<image_info> get_pascal_voc2012_val_listing(
|
||||
const std::string& voc2012_folder
|
||||
)
|
||||
{
|
||||
return get_pascal_voc2012_listing(voc2012_folder, "val");
|
||||
}
|
||||
|
||||
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
|
||||
// (e.g., 'dog').
|
||||
const Voc2012class& find_voc2012_class(const dlib::rgb_pixel& rgb_label)
|
||||
{
|
||||
return find_voc2012_class(
|
||||
[&rgb_label](const Voc2012class& voc2012class)
|
||||
{
|
||||
return rgb_label == voc2012class.rgb_label;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// Convert an RGB class label to an index in the range [0, 20].
|
||||
inline uint16_t rgb_label_to_index_label(const dlib::rgb_pixel& rgb_label)
|
||||
{
|
||||
return find_voc2012_class(rgb_label).index;
|
||||
}
|
||||
|
||||
// Convert an image containing RGB class labels to a corresponding
|
||||
// image containing indexes in the range [0, 20].
|
||||
void rgb_label_image_to_index_label_image(
|
||||
const dlib::matrix<dlib::rgb_pixel>& rgb_label_image,
|
||||
dlib::matrix<uint16_t>& index_label_image
|
||||
)
|
||||
{
|
||||
const long nr = rgb_label_image.nr();
|
||||
const long nc = rgb_label_image.nc();
|
||||
|
||||
index_label_image.set_size(nr, nc);
|
||||
|
||||
for (long r = 0; r < nr; ++r)
|
||||
{
|
||||
for (long c = 0; c < nc; ++c)
|
||||
{
|
||||
index_label_image(r, c) = rgb_label_to_index_label(rgb_label_image(r, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // PASCAL_VOC_2012_H_
|
Loading…
Reference in New Issue
Block a user