From 1380e6b95ff2c307cd1d11203eeb8ab482e57c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Wed, 18 Mar 2020 21:33:54 +0900 Subject: [PATCH] add loss multiclass log weighted (#2022) * add loss_multiclass_log_weighted * fix class name in loss_abstract * add loss_multiclass_log_weighted test * rename test function to match class name * fix typo * reuse the weighted label struct across weighted losses * do not break compatibility with loss_multiclass_log_per_pixel_weighted * actually test the loss and fix docs * fix build with gcc 9 --- dlib/dnn/loss.h | 149 +++++++++++++++++++++++++++++++++++---- dlib/dnn/loss_abstract.h | 123 ++++++++++++++++++++++++++------ dlib/test/dnn.cpp | 81 +++++++++++++++++++++ 3 files changed, 318 insertions(+), 35 deletions(-) diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 0b5d974da..c3a8b701c 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -367,6 +367,140 @@ namespace dlib template using loss_multiclass_log = add_loss_layer; +// ---------------------------------------------------------------------------------------- + + template + struct weighted_label + { + weighted_label() + {} + + weighted_label(label_type label, float weight = 1.f) + : label(label), weight(weight) + {} + + label_type label{}; + float weight = 1.f; + }; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_weighted_ + { + public: + + typedef weighted_label training_label_type; + typedef unsigned long output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 ); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + + // Note that output_tensor.k() should match the number of labels. + + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + // The index of the largest output for this sample is the label. + *iter++ = index_of_max(rowm(mat(output_tensor),i)); + } + } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1); + + tt::softmax(grad, output_tensor); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const auto wl = *truth++; + const long y = wl.label; + const float weight = wl.weight; + // The network must produce a number of outputs that is equal to the number + // of labels when using this type of loss. + DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k()); + for (long k = 0; k < output_tensor.k(); ++k) + { + const unsigned long idx = i*output_tensor.k()+k; + if (k == y) + { + loss += weight*scale*-safe_log(g[idx]); + g[idx] =weight*scale*(g[idx]-1); + } + else + { + g[idx] = weight*scale*g[idx]; + } + } + } + return loss; + } + + friend void serialize(const loss_multiclass_log_weighted_& , std::ostream& out) + { + serialize("loss_multiclass_log_weighted_", out); + } + + friend void deserialize(loss_multiclass_log_weighted_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_multiclass_log_weighted_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_weighted_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_weighted_& ) + { + out << "loss_multiclass_log_weighted"; + return out; + } + + friend void to_xml(const loss_multiclass_log_weighted_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_multiclass_log_weighted = add_loss_layer; + // ---------------------------------------------------------------------------------------- class loss_multimulticlass_log_ @@ -2832,20 +2966,7 @@ namespace dlib { public: - struct weighted_label - { - weighted_label() - {} - - weighted_label(uint16_t label, float weight = 1.f) - : label(label), weight(weight) - {} - - // In semantic segmentation, 65536 classes ought to be enough for anybody. - uint16_t label = 0; - float weight = 1.f; - }; - + typedef dlib::weighted_label weighted_label; typedef matrix training_label_type; typedef matrix output_label_type; diff --git a/dlib/dnn/loss_abstract.h b/dlib/dnn/loss_abstract.h index f859f368e..c2b539594 100644 --- a/dlib/dnn/loss_abstract.h +++ b/dlib/dnn/loss_abstract.h @@ -369,6 +369,107 @@ namespace dlib template using loss_multiclass_log = add_loss_layer; +// ---------------------------------------------------------------------------------------- + + template + struct weighted_label + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents the truth label of a single sample, together with + an associated weight (the higher the weight, the more emphasis the + corresponding sample is given during the training). + This object is used in the following loss layers: + - loss_multiclass_log_weighted_ with unsigned long as label_type + - loss_multiclass_log_per_pixel_weighted_ with uint16_t as label_type, + since, in semantic segmentation, 65536 classes ought to be enough for + anybody. + !*/ + weighted_label() + {} + + weighted_label(label_type label, float weight = 1.f) + : label(label), weight(weight) + {} + + // The ground truth label + label_type label{}; + + // The weight of the corresponding sample + float weight = 1.f; + }; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_weighted_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic + regression loss (e.g. negative log-likelihood loss), which is appropriate + for multiclass classification problems. It is basically just like the + loss_multiclass_log except that it lets you define per-sample weights, + which might be useful e.g. if you want to emphasize rare classes while + training. If the classification problem is difficult, a flat weight + structure may lead the network to always predict the most common label, + in particular if the degree of imbalance is high. To emphasize a certain + class or classes, simply increase the weights of the corresponding samples, + relative to the weights of other pixels. + + Note that if you set all the weights equals to 1, then you get + loss_multiclass_log_ as a special case. + !*/ + + public: + + typedef weighted_label training_label_type; + typedef unsigned long output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted class for each classified object. The number + of possible output classes is sub.get_output().k(). + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all values pointed to by truth are < sub.get_output().k() + !*/ + + }; + + template + using loss_multiclass_log_weighted = add_loss_layer;// ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- class loss_multimulticlass_log_ @@ -1434,27 +1535,7 @@ namespace dlib !*/ public: - struct weighted_label - { - /*! - WHAT THIS OBJECT REPRESENTS - This object represents the truth label of a single pixel, together with - an associated weight (the higher the weight, the more emphasis the - corresponding pixel is given during the training). - !*/ - - weighted_label(); - weighted_label(uint16_t label, float weight = 1.f); - - // The ground-truth label. In semantic segmentation, 65536 classes ought to be - // enough for anybody. - uint16_t label = 0; - - // The weight of the corresponding pixel. - float weight = 1.f; - }; - - typedef matrix training_label_type; + typedef matrix training_label_type; typedef matrix output_label_type; template < diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 9fe97ec00..8b3e0e85d 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -3224,6 +3224,86 @@ namespace } } +// ---------------------------------------------------------------------------------------- + + void test_loss_multiclass_log_weighted() + { + + print_spinner(); + + constexpr int input_height = 5; + constexpr int input_width = 7; + const size_t num_samples = 1000; + const size_t num_classes = 4; + + ::std::vector> x(num_samples); + ::std::vector y(num_samples); + + matrix xtmp(input_height, input_width); + + dlib::rand rnd; + // Generate input data + for (size_t ii = 0; ii < num_samples; ++ii) + { + for (int jj = 0; jj < input_height; ++jj) + { + for (int kk = 0; kk < input_width; ++kk) + { + xtmp(jj, kk) = rnd.get_random_float(); + } + } + x[ii] = xtmp; + y[ii] = rnd.get_integer_in_range(0, num_classes); + } + + using net_type = loss_multiclass_log_weighted>>>; + + ::std::vector> y_weighted(num_samples); + + for (size_t weighted_class = 0; weighted_class < num_classes; ++weighted_class) + { + + print_spinner(); + + // Assign weights + for (size_t ii = 0; ii < num_samples; ++ii) + { + const unsigned long label = y[ii]; + const float weight + = label == weighted_class + ? 1.4f + : 0.6f; + y_weighted[ii] = weighted_label(label, weight); + } + + net_type net; + sgd defsolver(0, 0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(0.1); + trainer.set_min_learning_rate(0.01); + trainer.set_mini_batch_size(10); + trainer.set_max_num_epochs(10); + trainer.train(x, y_weighted); + + const ::std::vector predictions = net(x); + + int num_weighted_class = 0; + int num_not_weighted_class = 0; + + for (size_t ii = 0; ii < num_samples; ++ii) + { + if (predictions[ii] == weighted_class) + ++num_weighted_class; + else + ++num_not_weighted_class; + } + + DLIB_TEST_MSG(num_weighted_class > num_not_weighted_class, + "The weighted class (" << weighted_class << ") does not dominate: " + << num_weighted_class << " <= " << num_not_weighted_class); + } + } + // ---------------------------------------------------------------------------------------- void test_tensor_resize_bilinear(long samps, long k, long nr, long nc, long onr, long onc) @@ -3645,6 +3725,7 @@ namespace test_loss_multiclass_per_pixel_outputs_on_trivial_task(); test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore(); test_loss_multiclass_per_pixel_weighted(); + test_loss_multiclass_log_weighted(); test_serialization(); test_loss_dot(); test_loss_multimulticlass_log();