Fix a warning and add some more error handling.

This commit is contained in:
Davis King 2020-08-23 22:22:40 -04:00
parent dd06c1169b
commit b401185aa5
3 changed files with 12 additions and 4 deletions

View File

@ -852,6 +852,11 @@ namespace dlib
const float* out_data = output_tensor.host();
for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
{
const long long num_label_categories = truth->size();
DLIB_CASSERT(output_tensor.k() == num_label_categories,
"Number of label types should match the number of output channels. "
"output_tensor.k(): " << output_tensor.k()
<< ", num_label_categories: "<< num_label_categories);
for (long k = 0; k < output_tensor.k(); ++k)
{
const float y = (*truth)[k];

View File

@ -767,8 +767,9 @@ namespace dlib
- sub.get_output().nc() == 1
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
- all values pointed to by truth are std::vectors of non-zero elements.
Nominally they should be +1 or -1, each indicating the desired class label.
- truth points to training_label_type elements, each of size sub.get_output.k().
The elements of each truth training_label_type instance are nominally +1 or -1,
each representing a binary class label.
!*/
};

View File

@ -3391,11 +3391,13 @@ namespace
{
for (size_t j = 0; j < labels[i].size(); ++j)
{
if (labels[i][j] == 1 && preds[i][j] < 0 ||
labels[i][j] == 0 && preds[i][j] > 0)
if ((labels[i][j] == 1 && preds[i][j] < 0) ||
(labels[i][j] == 0 && preds[i][j] > 0))
{
++num_wrong;
}
}
}
return num_wrong / labels.size() / dims;
};