From b401185aa5a59bfff8eb5f4675a7e4802c37b070 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 23 Aug 2020 22:22:40 -0400 Subject: [PATCH] Fix a warning and add some more error handling. --- dlib/dnn/loss.h | 5 +++++ dlib/dnn/loss_abstract.h | 5 +++-- dlib/test/dnn.cpp | 6 ++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index e52b7d970..e4b913a3c 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -852,6 +852,11 @@ namespace dlib const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { + const long long num_label_categories = truth->size(); + DLIB_CASSERT(output_tensor.k() == num_label_categories, + "Number of label types should match the number of output channels. " + "output_tensor.k(): " << output_tensor.k() + << ", num_label_categories: "<< num_label_categories); for (long k = 0; k < output_tensor.k(); ++k) { const float y = (*truth)[k]; diff --git a/dlib/dnn/loss_abstract.h b/dlib/dnn/loss_abstract.h index 3fdf4c02b..22d37fbfd 100644 --- a/dlib/dnn/loss_abstract.h +++ b/dlib/dnn/loss_abstract.h @@ -767,8 +767,9 @@ namespace dlib - sub.get_output().nc() == 1 - sub.get_output().num_samples() == input_tensor.num_samples() - sub.sample_expansion_factor() == 1 - - all values pointed to by truth are std::vectors of non-zero elements. - Nominally they should be +1 or -1, each indicating the desired class label. + - truth points to training_label_type elements, each of size sub.get_output.k(). + The elements of each truth training_label_type instance are nominally +1 or -1, + each representing a binary class label. !*/ }; diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 1aad97c97..b24e98ff2 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -3391,9 +3391,11 @@ namespace { for (size_t j = 0; j < labels[i].size(); ++j) { - if (labels[i][j] == 1 && preds[i][j] < 0 || - labels[i][j] == 0 && preds[i][j] > 0) + if ((labels[i][j] == 1 && preds[i][j] < 0) || + (labels[i][j] == 0 && preds[i][j] > 0)) + { ++num_wrong; + } } } return num_wrong / labels.size() / dims;