From bd6994cc66915655774155c078aa8ca6496eddd1 Mon Sep 17 00:00:00 2001 From: Juha Reunanen Date: Mon, 20 Jan 2020 14:47:47 +0200 Subject: [PATCH] Add new loss layer for binary loss per pixel (#1976) * Add new loss layer for binary loss per pixel --- dlib/cuda/cuda_dlib.cu | 77 ++++++- dlib/cuda/cuda_dlib.h | 74 ++++++- dlib/dnn/loss.h | 158 ++++++++++++++ dlib/dnn/loss_abstract.h | 62 ++++++ dlib/test/dnn.cpp | 199 ++++++++++++++++++ examples/dnn_instance_segmentation_ex.cpp | 44 ++-- examples/dnn_instance_segmentation_ex.h | 12 +- .../dnn_instance_segmentation_train_ex.cpp | 14 +- 8 files changed, 599 insertions(+), 41 deletions(-) diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 43a0425b6..973f48d8d 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -1681,6 +1681,48 @@ namespace dlib } } + // ---------------------------------------------------------------------------------------- + + __device__ float cuda_log1pexp(float x) + { + if (x <= -18) + return std::exp(x); + else if (-18 < x && x <= 9) + return std::log1p(std::exp(x)); + else if (9 < x && x <= 16) + return x + std::exp(-x); + else + return x; + } + + __global__ void _cuda_compute_loss_binary_log_per_pixel(float* loss_out, float* g, const float* truth, const float* out_data, size_t n, const float scale) + { + float loss = 0; + for(auto i : grid_stride_range(0, n)) + { + const float y = truth[i]; + + if (y > 0.f) + { + const float temp = cuda_log1pexp(-out_data[i]); + loss += y*temp; + g[i] = y*scale*(g[i]-1); + } + else if (y < 0.f) + { + const float temp = -(-out_data[i]-cuda_log1pexp(-out_data[i])); + loss += -y*temp; + g[i] = -y*scale*g[i]; + } + else + { + g[i] = 0.f; + } + } + + warp_reduce_atomic_add(*loss_out, loss); + } + // ---------------------------------------------------------------------------------------- __device__ float cuda_safe_log(float x, float epsilon = 1e-10) @@ -1720,29 +1762,52 @@ namespace dlib warp_reduce_atomic_add(*loss_out, loss); } + // ---------------------------------------------------------------------------------------- + + void compute_loss_binary_log_per_pixel:: + do_work( + cuda_data_ptr loss_work_buffer, + cuda_data_ptr truth_buffer, + const tensor& subnetwork_output, + tensor& gradient, + double& loss + ) + { + CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); + sigmoid(gradient, subnetwork_output); + + // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. + const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.nr() * subnetwork_output.nc()); + + launch_kernel(_cuda_compute_loss_binary_log_per_pixel, max_jobs(gradient.size()), + loss_work_buffer.data(), gradient.device(), truth_buffer.data(), subnetwork_output.device(), gradient.size(), scale); + + float floss; + dlib::cuda::memcpy(&floss, loss_work_buffer); + loss = scale*floss; + } void compute_loss_multiclass_log_per_pixel:: do_work( - float* loss_cuda_work_buffer, - const uint16_t* truth_buffer, + cuda_data_ptr loss_work_buffer, + cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ) { - CHECK_CUDA(cudaMemset(loss_cuda_work_buffer, 0, sizeof(float))); + CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); softmax(gradient, subnetwork_output); static const uint16_t label_to_ignore = std::numeric_limits::max(); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.nr() * subnetwork_output.nc()); - launch_kernel(_cuda_compute_loss_multiclass_log_per_pixel, max_jobs(gradient.size()), - loss_cuda_work_buffer, gradient.device(), truth_buffer, gradient.size(), gradient.nr()*gradient.nc(), gradient.nr()*gradient.nc()*gradient.k(), gradient.k(), label_to_ignore, scale); + loss_work_buffer.data(), gradient.device(), truth_buffer.data(), gradient.size(), gradient.nr()*gradient.nc(), gradient.nr()*gradient.nc()*gradient.k(), gradient.k(), label_to_ignore, scale); float floss; - CHECK_CUDA(cudaMemcpy(&floss, loss_cuda_work_buffer, sizeof(float), cudaMemcpyDefault)); + dlib::cuda::memcpy(&floss, loss_work_buffer); loss = scale*floss; } diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index 62fb2f256..d0e83a47e 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -424,11 +424,70 @@ namespace dlib // ---------------------------------------------------------------------------------------- + class compute_loss_binary_log_per_pixel + { + /*! + The point of this class is to compute the loss computed by + loss_binary_log_per_pixel_, but to do so with CUDA. + !*/ + public: + + compute_loss_binary_log_per_pixel( + ) + { + } + + template < + typename const_label_iterator + > + void operator() ( + const_label_iterator truth, + const tensor& subnetwork_output, + tensor& gradient, + double& loss + ) const + { + const auto image_size = subnetwork_output.nr()*subnetwork_output.nc(); + const size_t bytes_per_plane = image_size*sizeof(float); + // Allocate a cuda buffer to store all the truth images and also one float + // for the scalar loss output. + buf = device_global_buffer(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float)); + + cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); + buf = buf+sizeof(float); + + // copy the truth data into a cuda buffer. + for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth) + { + const matrix& t = *truth; + DLIB_ASSERT(t.nr() == subnetwork_output.nr()); + DLIB_ASSERT(t.nc() == subnetwork_output.nc()); + memcpy(buf + i*bytes_per_plane, &t(0,0), bytes_per_plane); + } + + auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); + + do_work(loss_buf, truth_buf, subnetwork_output, gradient, loss); + } + + private: + + static void do_work( + cuda_data_ptr loss_work_buffer, + cuda_data_ptr truth_buffer, + const tensor& subnetwork_output, + tensor& gradient, + double& loss + ); + + mutable cuda_data_void_ptr buf; + }; + class compute_loss_multiclass_log_per_pixel { /*! The point of this class is to compute the loss computed by - loss_multiclass_log_per_pixel, but to do so with CUDA. + loss_multiclass_log_per_pixel_, but to do so with CUDA. !*/ public: @@ -447,12 +506,13 @@ namespace dlib double& loss ) const { - const size_t bytes_per_plane = subnetwork_output.nr()*subnetwork_output.nc()*sizeof(uint16_t); + const auto image_size = subnetwork_output.nr()*subnetwork_output.nc(); + const size_t bytes_per_plane = image_size*sizeof(uint16_t); // Allocate a cuda buffer to store all the truth images and also one float // for the scalar loss output. buf = device_global_buffer(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float)); - cuda_data_void_ptr loss_buf = buf; + cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); buf = buf+sizeof(float); // copy the truth data into a cuda buffer. @@ -464,14 +524,16 @@ namespace dlib memcpy(buf + i*bytes_per_plane, &t(0,0), bytes_per_plane); } - do_work(static_cast(loss_buf.data()), static_cast(buf.data()), subnetwork_output, gradient, loss); + auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); + + do_work(loss_buf, truth_buf, subnetwork_output, gradient, loss); } private: static void do_work( - float* loss_cuda_work_buffer, - const uint16_t* truth_buffer, + cuda_data_ptr loss_work_buffer, + cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index ca4c226e5..0b5d974da 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -2481,6 +2481,164 @@ namespace dlib template using loss_mean_squared_multioutput = add_loss_layer; +// ---------------------------------------------------------------------------------------- + + class loss_binary_log_per_pixel_ + { + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + static void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* const out_data = output_tensor.host(); + + for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) + { + iter->set_size(output_tensor.nr(), output_tensor.nc()); + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + iter->operator()(r, c) = out_data[tensor_index(output_tensor, i, r, c)]; + } + } + } + } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.k() == 1); + DLIB_CASSERT(output_tensor.nr() == grad.nr() && + output_tensor.nc() == grad.nc() && + output_tensor.k() == grad.k()); + for (long idx = 0; idx < output_tensor.num_samples(); ++idx) + { + const_label_iterator truth_matrix_ptr = (truth + idx); + DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && + truth_matrix_ptr->nc() == output_tensor.nc(), + "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " + "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); + } + +#ifdef DLIB_USE_CUDA + double loss; + cuda_compute(truth, output_tensor, grad, loss); + return loss; +#else + + tt::sigmoid(grad, output_tensor); + + // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. + const double scale = 1.0/(output_tensor.num_samples()*output_tensor.nr()*output_tensor.nc()); + double loss = 0; + float* const g = grad.host(); + const float* const out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) + { + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + const float y = truth->operator()(r, c); + const size_t idx = tensor_index(output_tensor, i, r, c); + + if (y > 0.f) + { + const float temp = log1pexp(-out_data[idx]); + loss += y*scale*temp; + g[idx] = y*scale*(g[idx]-1); + } + else if (y < 0.f) + { + const float temp = -(-out_data[idx]-log1pexp(-out_data[idx])); + loss += -y*scale*temp; + g[idx] = -y*scale*g[idx]; + } + else + { + g[idx] = 0.f; + } + } + } + } + return loss; +#endif + } + + friend void serialize(const loss_binary_log_per_pixel_& , std::ostream& out) + { + serialize("loss_binary_log_per_pixel_", out); + } + + friend void deserialize(loss_binary_log_per_pixel_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_binary_log_per_pixel_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_log_per_pixel_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_binary_log_per_pixel_& ) + { + out << "loss_binary_log_per_pixel"; + return out; + } + + friend void to_xml(const loss_binary_log_per_pixel_& /*item*/, std::ostream& out) + { + out << ""; + } + + private: + static size_t tensor_index(const tensor& t, long sample, long row, long column) + { + DLIB_ASSERT(t.k() == 1); + + // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 + return (sample * t.nr() + row) * t.nc() + column; + } + +#ifdef DLIB_USE_CUDA + cuda::compute_loss_binary_log_per_pixel cuda_compute; +#endif + }; + + template + using loss_binary_log_per_pixel = add_loss_layer; + // ---------------------------------------------------------------------------------------- class loss_multiclass_log_per_pixel_ diff --git a/dlib/dnn/loss_abstract.h b/dlib/dnn/loss_abstract.h index bbaff6f6d..f859f368e 100644 --- a/dlib/dnn/loss_abstract.h +++ b/dlib/dnn/loss_abstract.h @@ -1283,6 +1283,68 @@ namespace dlib template using loss_mean_squared_multioutput = add_loss_layer; +// ---------------------------------------------------------------------------------------- + + class loss_binary_log_per_pixel_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the log loss, which is + appropriate for binary classification problems. It is basically just like + loss_binary_log_ except that it lets you define matrix outputs instead + of scalar outputs. It should be useful, for example, in segmentation + where we want to classify each pixel of an image, and also get at least + some sort of confidence estimate for each pixel. + !*/ + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the raw score for each classified object. If the score + is > 0 then the classifier is predicting the +1 class, otherwise it is + predicting the -1 class. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all pixel values pointed to by truth correspond to the desired target values. + Nominally they should be +1 or -1, each indicating the desired class label, + or 0 to indicate that the corresponding pixel is to be ignored. + !*/ + + }; + + template + using loss_binary_log_per_pixel = add_loss_layer; + // ---------------------------------------------------------------------------------------- class loss_multiclass_log_per_pixel_ diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index b04102a99..a1820c825 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -2587,6 +2587,202 @@ namespace DLIB_TEST_MSG(error_after < error_before, "multi channel error increased after training"); } +// ---------------------------------------------------------------------------------------- + + void test_loss_binary_log_per_pixel_learned_params_on_trivial_two_pixel_task() + { + print_spinner(); + + ::std::vector> x({ matrix({ -1, 1 }) }); + ::std::vector> y({ matrix({ -1, 1 }) }); + + using net_type = loss_binary_log_per_pixel>>>; + net_type net; + + dnn_trainer trainer(net, sgd(0,0)); + trainer.set_learning_rate(1e7); + trainer.set_max_num_epochs(1); + trainer.train(x, y); + + const tensor& learned_params = layer<1>(net).layer_details().get_layer_params(); + const float* learned_params_data = learned_params.host(); + + DLIB_TEST(learned_params_data[0] > 1e5); + DLIB_TEST(abs(learned_params_data[1]) < 1); + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_binary_log_per_pixel_outputs_on_trivial_task() + { + print_spinner(); + + constexpr int input_height = 7; + constexpr int input_width = 5; + constexpr int output_height = input_height; + constexpr int output_width = input_width; + constexpr int num_samples = 7; + constexpr int filter_height = 3; + constexpr int filter_width = 3; + + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + + matrix xtmp(input_height, input_width); + matrix ytmp(output_height, output_width); + + ::std::default_random_engine generator(16); + ::std::normal_distribution n01(0); + + const auto z = 0.674490; // This should give us a 50/50 split between the classes + + // Generate training data: random inputs x, and the corresponding target outputs y + for (int ii = 0; ii < num_samples; ++ii) { + for (int jj = 0; jj < input_height; ++jj) { + for (int kk = 0; kk < input_width; ++kk) { + xtmp(jj, kk) = n01(generator); + ytmp(jj, kk) = std::abs(xtmp(jj, kk)) > z ? 1.f : -1.f; + } + } + x[ii] = xtmp; + y[ii] = ytmp; + } + + using net_type = loss_binary_log_per_pixel>>>>>; + net_type net; + + dnn_trainer trainer(net, sgd(0, 0.9)); + trainer.set_learning_rate(1); + trainer.set_max_num_epochs(800); + trainer.train(x, y); + + // The learning task is easy, so the net should have no problem + // getting all the outputs right. + const auto response = net(x); + for (int ii = 0; ii < num_samples; ++ii) + for (int jj = 0; jj < output_height; ++jj) + for (int kk = 0; kk < output_width; ++kk) + DLIB_TEST((response[ii](jj,kk) > 0) == (y[ii](jj,kk) > 0)); + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_binary_log_per_pixel_with_noise_and_pixels_to_ignore() + { + // Test learning when some pixels are to be ignored, etc. + + print_spinner(); + + constexpr int input_height = 5; + constexpr int input_width = 7; + constexpr int output_height = input_height; + constexpr int output_width = input_width; + const int num_samples = 1000; + const double ignore_probability = 0.5; + const double noise_probability = 0.05; + + ::std::default_random_engine generator(16); + ::std::bernoulli_distribution ignore(ignore_probability); + ::std::bernoulli_distribution noise_occurrence(noise_probability); + ::std::bernoulli_distribution noisy_label(0.5); + + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + + ::std::vector truth_histogram(2); + + matrix xtmp(input_height, input_width); + matrix ytmp(output_height, output_width); + + // The function to be learned. + const auto ground_truth = [](const matrix& x, int row, int column) { + double sum = 0.0; + const int first_column = std::max(0, column - 1); + const int last_column = std::min(static_cast(x.nc() - 1), column + 1); + for (int c = first_column; c <= last_column; ++c) { + sum += x(row, c); + } + DLIB_TEST(sum < 2.0 * (last_column - first_column + 1)); + return sum > (last_column - first_column + 1); + }; + + for ( int ii = 0; ii < num_samples; ++ii ) { + for ( int jj = 0; jj < input_height; ++jj ) { + for ( int kk = 0; kk < input_width; ++kk ) { + // Generate numbers between 0 and 2. + double value = static_cast(ii + jj + kk) / 10.0; + value -= (static_cast(value) / 2) * 2; + DLIB_TEST(value >= 0.0 && value < 2.0); + xtmp(jj, kk) = value; + } + } + x[ii] = xtmp; + + for ( int jj = 0; jj < output_height; ++jj ) { + for ( int kk = 0; kk < output_width; ++kk ) { + const bool truth = ground_truth(x[ii], jj, kk); + ++truth_histogram[truth]; + if (ignore(generator)) { + ytmp(jj, kk) = 0.f; + } + else if (noise_occurrence(generator)) { + ytmp(jj, kk) = noisy_label(generator) ? 1.f : -1.f; + } + else { + ytmp(jj, kk) = truth ? 1.f : -1.f; + } + } + } + + y[ii] = ytmp; + } + + const int num_total_elements = num_samples * output_height * output_width; + + { // Require a reasonably balanced truth histogram in order to make sure that a trivial classifier is not enough + const int required_min_histogram_value = static_cast(::std::ceil(num_total_elements / 2.0 * 0.375)); + for (auto histogram_value : truth_histogram) { + DLIB_TEST_MSG(histogram_value >= required_min_histogram_value, + "Histogram value = " << histogram_value << ", required = " << required_min_histogram_value); + } + } + + using net_type = loss_binary_log_per_pixel>>>; + net_type net; + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(0.1); + trainer.set_min_learning_rate(0.01); + trainer.set_mini_batch_size(50); + trainer.set_max_num_epochs(170); + trainer.train(x, y); + + const ::std::vector> predictions = net(x); + + int num_correct = 0; + + for ( int ii = 0; ii < num_samples; ++ii ) { + const matrix& prediction = predictions[ii]; + DLIB_TEST(prediction.nr() == output_height); + DLIB_TEST(prediction.nc() == output_width); + for ( int jj = 0; jj < output_height; ++jj ) + for ( int kk = 0; kk < output_width; ++kk ) + if ( (prediction(jj, kk) > 0.f) == ground_truth(x[ii], jj, kk) ) + ++num_correct; + } + + // First some sanity checks. + const int num_correct_max = num_total_elements; + DLIB_TEST(num_correct_max == ::std::accumulate(truth_histogram.begin(), truth_histogram.end(), 0)); + DLIB_TEST_MSG(num_correct <= num_correct_max, + "Number of correctly classified elements = " << num_correct << ", max = " << num_correct_max); + + // This is the real test, verifying that we have actually learned something. + const int num_correct_required = static_cast(::std::ceil(0.9 * num_correct_max)); + DLIB_TEST_MSG(num_correct >= num_correct_required, + "Number of correctly classified elements = " << num_correct << ", required = " << num_correct_required); + } + // ---------------------------------------------------------------------------------------- void test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task() @@ -3429,6 +3625,9 @@ namespace test_multioutput_linear_regression(); test_simple_autoencoder(); test_loss_mean_squared_per_channel_and_pixel(); + test_loss_binary_log_per_pixel_learned_params_on_trivial_two_pixel_task(); + test_loss_binary_log_per_pixel_outputs_on_trivial_task(); + test_loss_binary_log_per_pixel_with_noise_and_pixels_to_ignore(); test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task(); test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task(); test_loss_multiclass_per_pixel_outputs_on_trivial_task(); diff --git a/examples/dnn_instance_segmentation_ex.cpp b/examples/dnn_instance_segmentation_ex.cpp index 6f3834470..b864015b9 100644 --- a/examples/dnn_instance_segmentation_ex.cpp +++ b/examples/dnn_instance_segmentation_ex.cpp @@ -16,7 +16,7 @@ ./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images An alternative to steps 2-4 above is to download a pre-trained network - from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn + from here: http://dlib.net/files/instance_segmentation_voc2012net_v2.dnn It would be a good idea to become familiar with dlib's DNN tooling before reading this example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp @@ -71,26 +71,21 @@ int main(int argc, char** argv) try { // Load the input image. load_image(input_image, file.full_name()); - - // Draw largest objects last - const auto sort_instances = [](const std::vector& input) { - auto output = input; - const auto compare_area = [](const mmod_rect& lhs, const mmod_rect& rhs) { - return lhs.rect.area() < rhs.rect.area(); - }; - std::sort(output.begin(), output.end(), compare_area); - return output; - }; - + // Find instances in the input image - const auto instances = sort_instances(det_net(input_image)); + const auto instances = det_net(input_image); matrix rgb_label_image; + matrix label_image_confidence; + matrix input_chip; rgb_label_image.set_size(input_image.nr(), input_image.nc()); rgb_label_image = rgb_pixel(0, 0, 0); + label_image_confidence.set_size(input_image.nr(), input_image.nc()); + label_image_confidence = 0.0; + bool found_something = false; for (const auto& instance : instances) @@ -131,7 +126,7 @@ int main(int argc, char** argv) try rnd.get_random_8bit_number() ); - dlib::matrix resized_mask( + dlib::matrix resized_mask( static_cast(chip_details.rect.height()), static_cast(chip_details.rect.width()) ); @@ -142,12 +137,29 @@ int main(int argc, char** argv) try { for (int c = 0; c < resized_mask.nc(); ++c) { - if (resized_mask(r, c)) + const auto new_confidence = resized_mask(r, c); + if (new_confidence > 0) { const auto y = chip_details.rect.top() + r; const auto x = chip_details.rect.left() + c; if (y >= 0 && y < rgb_label_image.nr() && x >= 0 && x < rgb_label_image.nc()) - rgb_label_image(y, x) = random_color; + { + auto& current_confidence = label_image_confidence(y, x); + if (new_confidence > current_confidence) + { + auto rgb_label = random_color; + const auto baseline_confidence = 5; + if (new_confidence < baseline_confidence) + { + // Scale label intensity if confidence isn't high + rgb_label.red *= new_confidence / baseline_confidence; + rgb_label.green *= new_confidence / baseline_confidence; + rgb_label.blue *= new_confidence / baseline_confidence; + } + rgb_label_image(y, x) = rgb_label; + current_confidence = new_confidence; + } + } } } } diff --git a/examples/dnn_instance_segmentation_ex.h b/examples/dnn_instance_segmentation_ex.h index 1693c304e..e26ab2973 100644 --- a/examples/dnn_instance_segmentation_ex.h +++ b/examples/dnn_instance_segmentation_ex.h @@ -23,7 +23,7 @@ ./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images An alternative to steps 2-4 above is to download a pre-trained network - from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn + from here: http://dlib.net/files/instance_segmentation_voc2012net_v2.dnn It would be a good idea to become familiar with dlib's DNN tooling before reading this example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp @@ -159,13 +159,13 @@ template using concat_utag4 = resize_and_concat>>>>>>>>>>>>>>>>>>>>>>>>; // testing network type (replaced batch normalization with fixed affine transforms) -using seg_anet_type = dlib::loss_multiclass_log_per_pixel< - dlib::cont<2,1,1,1,1, +using seg_anet_type = dlib::loss_binary_log_per_pixel< + dlib::cont<1,1,1,1,1, dlib::relu input_image; - matrix label_image; // The ground-truth label of each pixel. + matrix label_image; // The ground-truth label of each pixel. (+1 or -1) }; // ---------------------------------------------------------------------------------------- @@ -321,12 +321,12 @@ det_bnet_type train_detection_network( // ---------------------------------------------------------------------------------------- -matrix keep_only_current_instance(const matrix& rgb_label_image, const rgb_pixel rgb_label) +matrix keep_only_current_instance(const matrix& rgb_label_image, const rgb_pixel rgb_label) { const auto nr = rgb_label_image.nr(); const auto nc = rgb_label_image.nc(); - matrix result(nr, nc); + matrix result(nr, nc); for (long r = 0; r < nr; ++r) { @@ -334,11 +334,11 @@ matrix keep_only_current_instance(const matrix& rgb_label_i { const auto& index = rgb_label_image(r, c); if (index == rgb_label) - result(r, c) = 1; + result(r, c) = +1; else if (index == dlib::rgb_pixel(224, 224, 192)) - result(r, c) = dlib::loss_multiclass_log_per_pixel_::label_to_ignore; - else result(r, c) = 0; + else + result(r, c) = -1; } } @@ -373,7 +373,7 @@ seg_bnet_type train_segmentation_network( cout << seg_trainer << endl; std::vector> samples; - std::vector> labels; + std::vector> labels; // Start a bunch of threads that read images from disk and pull out random crops. It's // important to be sure to feed the GPU fast enough to keep it busy. Using multiple