diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 4099057cd..ca4c226e5 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -993,6 +993,36 @@ namespace dlib } } + inline std::ostream& operator<<(std::ostream& out, const std::vector& detector_windows) + { + // write detector windows grouped by label + // example output: aeroplane:74x30,131x30,70x45,54x70,198x30;bicycle:70x57,32x70,70x32,51x70,128x30,30x121;car:70x36,70x60,99x30,52x70,30x83,30x114,30x200 + + std::map> detector_windows_by_label; + for (const auto& detector_window : detector_windows) + detector_windows_by_label[detector_window.label].push_back(detector_window); + + size_t label_count = 0; + for (const auto& i : detector_windows_by_label) + { + const auto& label = i.first; + const auto& detector_windows = i.second; + + if (label_count++ > 0) + out << ";"; + out << label << ":"; + + for (size_t j = 0; j < detector_windows.size(); ++j) + { + out << detector_windows[j].width << "x" << detector_windows[j].height; + if (j + 1 < detector_windows.size()) + out << ","; + } + } + + return out; + } + // ---------------------------------------------------------------------------------------- class loss_mmod_ @@ -1395,15 +1425,10 @@ namespace dlib { out << "loss_mmod\t ("; - out << "detector_windows:("; auto& opts = item.options; - for (size_t i = 0; i < opts.detector_windows.size(); ++i) - { - out << opts.detector_windows[i].width << "x" << opts.detector_windows[i].height; - if (i+1 < opts.detector_windows.size()) - out << ","; - } - out << ")"; + + out << "detector_windows:(" << opts.detector_windows << ")"; + out << ", loss per FA:" << opts.loss_per_false_alarm; out << ", loss per miss:" << opts.loss_per_missed_target; out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold; diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 0ed18e662..86a20957f 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -460,6 +460,7 @@ namespace dlib test_steps_without_progress = 0; previous_loss_values.clear(); test_previous_loss_values.clear(); + previous_loss_values_to_keep_until_disk_sync.clear(); } learning_rate = lr; lr_schedule.set_size(0); @@ -602,6 +603,11 @@ namespace dlib // discard really old loss values. while (previous_loss_values.size() > iter_without_progress_thresh) previous_loss_values.pop_front(); + + // separately keep another loss history until disk sync + // (but only if disk sync is enabled) + if (!sync_filename.empty()) + previous_loss_values_to_keep_until_disk_sync.push_back(loss); } template @@ -700,10 +706,10 @@ namespace dlib // optimization has flattened out, so drop the learning rate. learning_rate = learning_rate_shrink*learning_rate; test_steps_without_progress = 0; + // Empty out some of the previous loss values so that test_steps_without_progress // will decrease below test_iter_without_progress_thresh. - for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt) - test_previous_loss_values.pop_front(); + drop_some_test_previous_loss_values(); } } } @@ -820,10 +826,10 @@ namespace dlib // optimization has flattened out, so drop the learning rate. learning_rate = learning_rate_shrink*learning_rate; steps_without_progress = 0; + // Empty out some of the previous loss values so that steps_without_progress // will decrease below iter_without_progress_thresh. - for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt) - previous_loss_values.pop_front(); + drop_some_previous_loss_values(); } } } @@ -895,7 +901,7 @@ namespace dlib friend void serialize(const dnn_trainer& item, std::ostream& out) { item.wait_for_thread_to_pause(); - int version = 12; + int version = 13; serialize(version, out); size_t nl = dnn_trainer::num_layers; @@ -924,14 +930,14 @@ namespace dlib serialize(item.test_previous_loss_values, out); serialize(item.previous_loss_values_dump_amount, out); serialize(item.test_previous_loss_values_dump_amount, out); - + serialize(item.previous_loss_values_to_keep_until_disk_sync, out); } friend void deserialize(dnn_trainer& item, std::istream& in) { item.wait_for_thread_to_pause(); int version = 0; deserialize(version, in); - if (version != 12) + if (version != 13) throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); size_t num_layers = 0; @@ -970,6 +976,7 @@ namespace dlib deserialize(item.test_previous_loss_values, in); deserialize(item.previous_loss_values_dump_amount, in); deserialize(item.test_previous_loss_values_dump_amount, in); + deserialize(item.previous_loss_values_to_keep_until_disk_sync, in); if (item.devices.size() > 1) { @@ -987,6 +994,20 @@ namespace dlib } } + // Empty out some of the previous loss values so that steps_without_progress will decrease below iter_without_progress_thresh. + void drop_some_previous_loss_values() + { + for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount + iter_without_progress_thresh / 10 && previous_loss_values.size() > 0; ++cnt) + previous_loss_values.pop_front(); + } + + // Empty out some of the previous test loss values so that test_steps_without_progress will decrease below test_iter_without_progress_thresh. + void drop_some_test_previous_loss_values() + { + for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10 && test_previous_loss_values.size() > 0; ++cnt) + test_previous_loss_values.pop_front(); + } + void sync_to_disk ( bool do_it_now = false ) @@ -1020,6 +1041,20 @@ namespace dlib sync_file_reloaded = true; if (verbose) std::cout << "Loss has been increasing, reloading saved state from " << newest_syncfile() << std::endl; + + // Are we repeatedly hitting our head against the wall? If so, then we + // might be better off giving up at this learning rate, and trying a + // lower one instead. + if (prob_loss_increasing_thresh >= prob_loss_increasing_thresh_max_value) + { + std::cout << "(and while at it, also shrinking the learning rate)" << std::endl; + learning_rate = learning_rate_shrink * learning_rate; + steps_without_progress = 0; + test_steps_without_progress = 0; + + drop_some_previous_loss_values(); + drop_some_test_previous_loss_values(); + } } else { @@ -1057,34 +1092,39 @@ namespace dlib if (!std::ifstream(newest_syncfile(), std::ios::binary)) return false; - for (auto x : previous_loss_values) + // Now look at the data since a little before the last disk sync. We will + // check if the loss is getting better or worse. + while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync) + previous_loss_values_to_keep_until_disk_sync.pop_front(); + + running_gradient g; + + for (auto x : previous_loss_values_to_keep_until_disk_sync) { // If we get a NaN value of loss assume things have gone horribly wrong and // we should reload the state of the trainer. if (std::isnan(x)) return true; + + g.add(x); } - // if we haven't seen much data yet then just say false. Or, alternatively, if - // it's been too long since the last sync then don't reload either. - if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync) + // if we haven't seen much data yet then just say false. + if (gradient_updates_since_last_sync < 30) return false; - // Now look at the data since a little before the last disk sync. We will - // check if the loss is getting bettor or worse. - running_gradient g; - for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i) - g.add(previous_loss_values[i]); - // if the loss is very likely to be increasing then return true const double prob = g.probability_gradient_greater_than(0); - if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value) + if (prob > prob_loss_increasing_thresh) { // Exponentially decay the threshold towards 1 so that if we keep finding // the loss to be increasing over and over we will make the test // progressively harder and harder until it fails, therefore ensuring we // can't get stuck reloading from a previous state over and over. - prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1; + prob_loss_increasing_thresh = std::min( + 0.1*prob_loss_increasing_thresh + 0.9*1, + prob_loss_increasing_thresh_max_value + ); return true; } else @@ -1247,6 +1287,8 @@ namespace dlib std::atomic test_steps_without_progress; std::deque test_previous_loss_values; + std::deque previous_loss_values_to_keep_until_disk_sync; + std::atomic learning_rate_shrink; std::chrono::time_point last_sync_time; std::string sync_filename; diff --git a/dlib/image_transforms/interpolation.h b/dlib/image_transforms/interpolation.h index 2e19c0f6b..d7a26026c 100644 --- a/dlib/image_transforms/interpolation.h +++ b/dlib/image_transforms/interpolation.h @@ -865,10 +865,18 @@ namespace dlib float fout[4]; out.store(fout); - out_img[r][c] = static_cast(fout[0]); - out_img[r][c+1] = static_cast(fout[1]); - out_img[r][c+2] = static_cast(fout[2]); - out_img[r][c+3] = static_cast(fout[3]); + const auto convert_to_output_type = [](float value) + { + if (std::is_integral::value) + return static_cast(value + 0.5); + else + return static_cast(value); + }; + + out_img[r][c] = convert_to_output_type(fout[0]); + out_img[r][c+1] = convert_to_output_type(fout[1]); + out_img[r][c+2] = convert_to_output_type(fout[2]); + out_img[r][c+3] = convert_to_output_type(fout[3]); } x = -x_scale + c*x_scale; for (; c < out_img.nc(); ++c) diff --git a/dlib/pixel.h b/dlib/pixel.h index 50ead2c34..046131f13 100644 --- a/dlib/pixel.h +++ b/dlib/pixel.h @@ -125,6 +125,13 @@ namespace dlib unsigned char red; unsigned char green; unsigned char blue; + + bool operator == (const rgb_pixel& that) const + { + return this->red == that.red + && this->green == that.green + && this->blue == that.blue; + } }; // ---------------------------------------------------------------------------------------- @@ -151,6 +158,13 @@ namespace dlib unsigned char blue; unsigned char green; unsigned char red; + + bool operator == (const bgr_pixel& that) const + { + return this->blue == that.blue + && this->green == that.green + && this->red == that.red; + } }; // ---------------------------------------------------------------------------------------- @@ -177,6 +191,14 @@ namespace dlib unsigned char green; unsigned char blue; unsigned char alpha; + + bool operator == (const rgb_alpha_pixel& that) const + { + return this->red == that.red + && this->green == that.green + && this->blue == that.blue + && this->alpha == that.alpha; + } }; // ---------------------------------------------------------------------------------------- @@ -200,6 +222,13 @@ namespace dlib unsigned char h; unsigned char s; unsigned char i; + + bool operator == (const hsi_pixel& that) const + { + return this->h == that.h + && this->s == that.s + && this->i == that.i; + } }; // ---------------------------------------------------------------------------------------- @@ -222,6 +251,13 @@ namespace dlib unsigned char l; unsigned char a; unsigned char b; + + bool operator == (const lab_pixel& that) const + { + return this->l == that.l + && this->a == that.a + && this->b == that.b; + } }; // ---------------------------------------------------------------------------------------- diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 1f29f043d..a6446789a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -148,8 +148,10 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER) add_gui_example(dnn_mmod_find_cars2_ex) add_example(dnn_mmod_train_find_cars_ex) add_gui_example(dnn_semantic_segmentation_ex) + add_gui_example(dnn_instance_segmentation_ex) add_example(dnn_imagenet_train_ex) add_example(dnn_semantic_segmentation_train_ex) + add_example(dnn_instance_segmentation_train_ex) add_example(dnn_metric_learning_on_images_ex) endif() diff --git a/examples/dnn_instance_segmentation_ex.cpp b/examples/dnn_instance_segmentation_ex.cpp new file mode 100644 index 000000000..6f3834470 --- /dev/null +++ b/examples/dnn_instance_segmentation_ex.cpp @@ -0,0 +1,178 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This example shows how to do instance segmentation on an image using net pretrained + on the PASCAL VOC2012 dataset. For an introduction to what instance segmentation is, + see the accompanying header file dnn_instance_segmentation_ex.h. + + Instructions how to run the example: + 1. Download the PASCAL VOC2012 data, and untar it somewhere. + http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar + 2. Build the dnn_instance_segmentation_train_ex example program. + 3. Run: + ./dnn_instance_segmentation_train_ex /path/to/VOC2012 + 4. Wait while the network is being trained. + 5. Build the dnn_instance_segmentation_ex example program. + 6. Run: + ./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images + + An alternative to steps 2-4 above is to download a pre-trained network + from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn + + It would be a good idea to become familiar with dlib's DNN tooling before reading this + example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp + before reading this example program. +*/ + +#include "dnn_instance_segmentation_ex.h" +#include "pascal_voc_2012.h" + +#include +#include +#include + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) try +{ + if (argc != 2) + { + cout << "You call this program like this: " << endl; + cout << "./dnn_instance_segmentation_train_ex /path/to/images" << endl; + cout << endl; + cout << "You will also need a trained '" << instance_segmentation_net_filename << "' file." << endl; + cout << "You can either train it yourself (see example program" << endl; + cout << "dnn_instance_segmentation_train_ex), or download a" << endl; + cout << "copy from here: http://dlib.net/files/" << instance_segmentation_net_filename << endl; + return 1; + } + + // Read the file containing the trained networks from the working directory. + det_anet_type det_net; + std::map seg_nets_by_class; + deserialize(instance_segmentation_net_filename) >> det_net >> seg_nets_by_class; + + // Show inference results in a window. + image_window win; + + matrix input_image; + + // Find supported image files. + const std::vector files = dlib::get_files_in_directory_tree(argv[1], + dlib::match_endings(".jpeg .jpg .png")); + + dlib::rand rnd; + + cout << "Found " << files.size() << " images, processing..." << endl; + + for (const file& file : files) + { + // Load the input image. + load_image(input_image, file.full_name()); + + // Draw largest objects last + const auto sort_instances = [](const std::vector& input) { + auto output = input; + const auto compare_area = [](const mmod_rect& lhs, const mmod_rect& rhs) { + return lhs.rect.area() < rhs.rect.area(); + }; + std::sort(output.begin(), output.end(), compare_area); + return output; + }; + + // Find instances in the input image + const auto instances = sort_instances(det_net(input_image)); + + matrix rgb_label_image; + matrix input_chip; + + rgb_label_image.set_size(input_image.nr(), input_image.nc()); + rgb_label_image = rgb_pixel(0, 0, 0); + + bool found_something = false; + + for (const auto& instance : instances) + { + if (!found_something) + { + cout << "Found "; + found_something = true; + } + else + { + cout << ", "; + } + cout << instance.label; + + const auto cropping_rect = get_cropping_rect(instance.rect); + const chip_details chip_details(cropping_rect, chip_dims(seg_dim, seg_dim)); + extract_image_chip(input_image, chip_details, input_chip, interpolate_bilinear()); + + const auto i = seg_nets_by_class.find(instance.label); + if (i == seg_nets_by_class.end()) + { + // per-class segmentation net not found, so we must be using the same net for all classes + // (see bool separate_seg_net_for_each_class in dnn_instance_segmentation_train_ex.cpp) + DLIB_CASSERT(seg_nets_by_class.size() == 1); + DLIB_CASSERT(seg_nets_by_class.begin()->first == ""); + } + + auto& seg_net = i != seg_nets_by_class.end() + ? i->second // use the segmentation net trained for this class + : seg_nets_by_class.begin()->second; // use the same segmentation net for all classes + + const auto mask = seg_net(input_chip); + + const rgb_pixel random_color( + rnd.get_random_8bit_number(), + rnd.get_random_8bit_number(), + rnd.get_random_8bit_number() + ); + + dlib::matrix resized_mask( + static_cast(chip_details.rect.height()), + static_cast(chip_details.rect.width()) + ); + + dlib::resize_image(mask, resized_mask); + + for (int r = 0; r < resized_mask.nr(); ++r) + { + for (int c = 0; c < resized_mask.nc(); ++c) + { + if (resized_mask(r, c)) + { + const auto y = chip_details.rect.top() + r; + const auto x = chip_details.rect.left() + c; + if (y >= 0 && y < rgb_label_image.nr() && x >= 0 && x < rgb_label_image.nc()) + rgb_label_image(y, x) = random_color; + } + } + } + + const Voc2012class& voc2012_class = find_voc2012_class( + [&instance](const Voc2012class& candidate) { + return candidate.classlabel == instance.label; + } + ); + + dlib::draw_rectangle(rgb_label_image, instance.rect, voc2012_class.rgb_label, 1); + } + + // Show the input image on the left, and the predicted RGB labels on the right. + win.set_image(join_rows(input_image, rgb_label_image)); + + if (!instances.empty()) + { + cout << " in " << file.name() << " - hit enter to process the next image"; + cin.get(); + } + } +} +catch(std::exception& e) +{ + cout << e.what() << endl; +} + diff --git a/examples/dnn_instance_segmentation_ex.h b/examples/dnn_instance_segmentation_ex.h new file mode 100644 index 000000000..1693c304e --- /dev/null +++ b/examples/dnn_instance_segmentation_ex.h @@ -0,0 +1,200 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + Instance segmentation using the PASCAL VOC2012 dataset. + + Instance segmentation sort-of combines object detection with semantic + segmentation. While each dog, for example, is detected separately, + the output is not only a bounding-box but a more accurate, per-pixel + mask. + + For introductions to object detection and semantic segmentation, you + can have a look at dnn_mmod_ex.cpp and dnn_semantic_segmentation.h, + respectively. + + Instructions how to run the example: + 1. Download the PASCAL VOC2012 data, and untar it somewhere. + http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar + 2. Build the dnn_instance_segmentation_train_ex example program. + 3. Run: + ./dnn_instance_segmentation_train_ex /path/to/VOC2012 + 4. Wait while the network is being trained. + 5. Build the dnn_instance_segmentation_ex example program. + 6. Run: + ./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images + + An alternative to steps 2-4 above is to download a pre-trained network + from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn + + It would be a good idea to become familiar with dlib's DNN tooling before reading this + example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp + before reading this example program. +*/ + +#ifndef DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_ +#define DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_ + +#include + +// ---------------------------------------------------------------------------------------- + +namespace { + // Segmentation will be performed using patches having this size. + constexpr int seg_dim = 227; +} + +dlib::rectangle get_cropping_rect(const dlib::rectangle& rectangle) +{ + DLIB_ASSERT(!rectangle.is_empty()); + + const auto center_point = dlib::center(rectangle); + const auto max_dim = std::max(rectangle.width(), rectangle.height()); + const auto d = static_cast(std::round(max_dim / 2.0 * 1.5)); // add +50% + + return dlib::rectangle( + center_point.x() - d, + center_point.y() - d, + center_point.x() + d, + center_point.y() + d + ); +} + +// ---------------------------------------------------------------------------------------- + +// The object detection network. +// Adapted from dnn_mmod_train_find_cars_ex.cpp and friends. + +template using con5d = dlib::con; +template using con5 = dlib::con; + +template using bdownsampler = dlib::relu>>>>>>>>; +template using adownsampler = dlib::relu>>>>>>>>; + +template using brcon5 = dlib::relu>>; +template using arcon5 = dlib::relu>>; + +using det_bnet_type = dlib::loss_mmod>>>>>>>; +using det_anet_type = dlib::loss_mmod>>>>>>>; + +// The segmentation network. +// For the time being, this is very much copy-paste from dnn_semantic_segmentation.h, although the network is made narrower (smaller feature maps). + +template class BN, int stride, typename SUBNET> +using block = BN>>>>; + +template class BN, int stride, typename SUBNET> +using blockt = BN>>>>; + +template