mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Fix a warning and add some more error handling.
This commit is contained in:
parent
dd06c1169b
commit
b401185aa5
@ -852,6 +852,11 @@ namespace dlib
|
||||
const float* out_data = output_tensor.host();
|
||||
for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
|
||||
{
|
||||
const long long num_label_categories = truth->size();
|
||||
DLIB_CASSERT(output_tensor.k() == num_label_categories,
|
||||
"Number of label types should match the number of output channels. "
|
||||
"output_tensor.k(): " << output_tensor.k()
|
||||
<< ", num_label_categories: "<< num_label_categories);
|
||||
for (long k = 0; k < output_tensor.k(); ++k)
|
||||
{
|
||||
const float y = (*truth)[k];
|
||||
|
@ -767,8 +767,9 @@ namespace dlib
|
||||
- sub.get_output().nc() == 1
|
||||
- sub.get_output().num_samples() == input_tensor.num_samples()
|
||||
- sub.sample_expansion_factor() == 1
|
||||
- all values pointed to by truth are std::vectors of non-zero elements.
|
||||
Nominally they should be +1 or -1, each indicating the desired class label.
|
||||
- truth points to training_label_type elements, each of size sub.get_output.k().
|
||||
The elements of each truth training_label_type instance are nominally +1 or -1,
|
||||
each representing a binary class label.
|
||||
!*/
|
||||
|
||||
};
|
||||
|
@ -3391,11 +3391,13 @@ namespace
|
||||
{
|
||||
for (size_t j = 0; j < labels[i].size(); ++j)
|
||||
{
|
||||
if (labels[i][j] == 1 && preds[i][j] < 0 ||
|
||||
labels[i][j] == 0 && preds[i][j] > 0)
|
||||
if ((labels[i][j] == 1 && preds[i][j] < 0) ||
|
||||
(labels[i][j] == 0 && preds[i][j] > 0))
|
||||
{
|
||||
++num_wrong;
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_wrong / labels.size() / dims;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user