Added a version of test_object_detection_function() for the DNN based MMOD detector.

This commit is contained in:
Davis King 2016-09-05 09:37:30 -04:00
parent 9e290dce34
commit d54597230b
3 changed files with 168 additions and 0 deletions

View File

@ -19,6 +19,7 @@
#include "dnn/cpu_dlib.h"
#include "dnn/tensor_tools.h"
#include "dnn/utilities.h"
#include "dnn/validation.h"
#endif // DLIB_DNn_

95
dlib/dnn/validation.h Normal file
View File

@ -0,0 +1,95 @@
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_VALIDATION_H_
#define DLIB_DNn_VALIDATION_H_
#include "validation_abstract.h"
#include "../svm/cross_validate_object_detection_trainer.h"
#include "layers.h"
namespace dlib
{
template <
typename SUBNET,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
loss_binary_mmod<SUBNET>& detector,
const image_array_type& images,
const std::vector<std::vector<mmod_rect>>& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
// make sure requires clause is not broken
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true ,
"\t matrix test_object_detection_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
<< "\n\t images.size(): " << images.size()
);
double correct_hits = 0;
double total_true_targets = 0;
std::vector<std::pair<double,bool> > all_dets;
unsigned long missing_detections = 0;
resizable_tensor temp;
for (unsigned long i = 0; i < images.size(); ++i)
{
std::vector<mmod_rect> hits;
detector.to_tensor(&images[i], &images[i]+1, temp);
detector.subnet().forward(temp);
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
std::vector<full_object_detection> truth_boxes;
std::vector<rectangle> ignore;
std::vector<std::pair<double,rectangle>> boxes;
// copy hits and truth_dets into the above three objects
for (auto&& b : truth_dets[i])
{
if (b.ignore)
ignore.push_back(b);
else
truth_boxes.push_back(full_object_detection(b.rect));
}
for (auto&& b : hits)
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections);
total_true_targets += truth_boxes.size();
}
std::sort(all_dets.rbegin(), all_dets.rend());
double precision, recall;
double total_hits = all_dets.size();
if (total_hits == 0)
precision = 1;
else
precision = correct_hits / total_hits;
if (total_true_targets == 0)
recall = 1;
else
recall = correct_hits / total_true_targets;
matrix<double, 1, 3> res;
res = precision, recall, average_precision(all_dets, missing_detections);
return res;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_VALIDATION_H_

View File

@ -0,0 +1,72 @@
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_DNn_VALIDATION_ABSTRACT_H_
#ifdef DLIB_DNn_VALIDATION_ABSTRACT_H_
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
#include "layers_abstract.h"
namespace dlib
{
template <
typename SUBNET,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
loss_binary_mmod<SUBNET>& detector,
const image_array_type& images,
const std::vector<std::vector<mmod_rect>>& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- is_learning_problem(images,truth_dets)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- This function is just like the test_object_detection_function() for
object_detector's except it runs on CNNs that use loss_binary_mmod.
- Tests the given detector against the supplied object detection problem and
returns the precision, recall, and average precision. Note that the task is
to predict, for each images[i], the set of object locations given by
truth_dets[i]. Additionally, any detections on image[i] that match a box in
truth_dets[i] that are marked ignore are ignored. That is, detections
matching an ignore box do not count as a false alarm and similarly if any
ignored box in truth_dets goes undetected it does not count as a missed
detection.
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
which correspond to a real target. A value of 1 means the detector
never produces any false alarms while a value of 0 means it only
produces false alarms.
- M(1) == the recall of the detector object. This is a number in the
range [0,1] which measures the fraction of targets found by the detector.
A value of 1 means the detector found all the non-ignore targets in
truth_dets while a value of 0 means the detector didn't locate any of the
targets.
- M(2) == the average precision of the detector object. This is a number
in the range [0,1] which measures the overall quality of the detector.
We compute this by taking all the detections output by the detector and
ordering them in descending order of their detection scores. Then we use
the average_precision() routine to score the ranked listing and store the
output into M(2).
- This function considers a detector output D to match a rectangle T if and
only if overlap_tester(T,D) returns true.
- Note that you can use the adjust_threshold argument to raise or lower the
detection threshold. This value is passed into the identically named
argument to the detector object and therefore influences the number of
output detections. It can be useful, for example, to lower the detection
threshold because it results in more detections being output by the
detector, and therefore provides more information in the ranking,
possibly raising the average precision.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_VALIDATION_ABSTRACT_H_