mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added a version of test_object_detection_function() for the DNN based MMOD detector.
This commit is contained in:
parent
9e290dce34
commit
d54597230b
@ -19,6 +19,7 @@
|
||||
#include "dnn/cpu_dlib.h"
|
||||
#include "dnn/tensor_tools.h"
|
||||
#include "dnn/utilities.h"
|
||||
#include "dnn/validation.h"
|
||||
|
||||
#endif // DLIB_DNn_
|
||||
|
||||
|
95
dlib/dnn/validation.h
Normal file
95
dlib/dnn/validation.h
Normal file
@ -0,0 +1,95 @@
|
||||
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_DNn_VALIDATION_H_
|
||||
#define DLIB_DNn_VALIDATION_H_
|
||||
|
||||
#include "validation_abstract.h"
|
||||
#include "../svm/cross_validate_object_detection_trainer.h"
|
||||
#include "layers.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
template <
|
||||
typename SUBNET,
|
||||
typename image_array_type
|
||||
>
|
||||
const matrix<double,1,3> test_object_detection_function (
|
||||
loss_binary_mmod<SUBNET>& detector,
|
||||
const image_array_type& images,
|
||||
const std::vector<std::vector<mmod_rect>>& truth_dets,
|
||||
const test_box_overlap& overlap_tester = test_box_overlap(),
|
||||
const double adjust_threshold = 0
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true ,
|
||||
"\t matrix test_object_detection_function()"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
|
||||
<< "\n\t images.size(): " << images.size()
|
||||
);
|
||||
|
||||
|
||||
|
||||
double correct_hits = 0;
|
||||
double total_true_targets = 0;
|
||||
|
||||
std::vector<std::pair<double,bool> > all_dets;
|
||||
unsigned long missing_detections = 0;
|
||||
|
||||
resizable_tensor temp;
|
||||
|
||||
for (unsigned long i = 0; i < images.size(); ++i)
|
||||
{
|
||||
std::vector<mmod_rect> hits;
|
||||
detector.to_tensor(&images[i], &images[i]+1, temp);
|
||||
detector.subnet().forward(temp);
|
||||
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
|
||||
|
||||
|
||||
std::vector<full_object_detection> truth_boxes;
|
||||
std::vector<rectangle> ignore;
|
||||
std::vector<std::pair<double,rectangle>> boxes;
|
||||
// copy hits and truth_dets into the above three objects
|
||||
for (auto&& b : truth_dets[i])
|
||||
{
|
||||
if (b.ignore)
|
||||
ignore.push_back(b);
|
||||
else
|
||||
truth_boxes.push_back(full_object_detection(b.rect));
|
||||
}
|
||||
for (auto&& b : hits)
|
||||
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
|
||||
|
||||
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections);
|
||||
total_true_targets += truth_boxes.size();
|
||||
}
|
||||
|
||||
std::sort(all_dets.rbegin(), all_dets.rend());
|
||||
|
||||
double precision, recall;
|
||||
|
||||
double total_hits = all_dets.size();
|
||||
|
||||
if (total_hits == 0)
|
||||
precision = 1;
|
||||
else
|
||||
precision = correct_hits / total_hits;
|
||||
|
||||
if (total_true_targets == 0)
|
||||
recall = 1;
|
||||
else
|
||||
recall = correct_hits / total_true_targets;
|
||||
|
||||
matrix<double, 1, 3> res;
|
||||
res = precision, recall, average_precision(all_dets, missing_detections);
|
||||
return res;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_DNn_VALIDATION_H_
|
||||
|
72
dlib/dnn/validation_abstract.h
Normal file
72
dlib/dnn/validation_abstract.h
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_DNn_VALIDATION_ABSTRACT_H_
|
||||
#ifdef DLIB_DNn_VALIDATION_ABSTRACT_H_
|
||||
|
||||
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
|
||||
#include "layers_abstract.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
template <
|
||||
typename SUBNET,
|
||||
typename image_array_type
|
||||
>
|
||||
const matrix<double,1,3> test_object_detection_function (
|
||||
loss_binary_mmod<SUBNET>& detector,
|
||||
const image_array_type& images,
|
||||
const std::vector<std::vector<mmod_rect>>& truth_dets,
|
||||
const test_box_overlap& overlap_tester = test_box_overlap(),
|
||||
const double adjust_threshold = 0
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_learning_problem(images,truth_dets)
|
||||
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
|
||||
and it must contain objects which can be accepted by detector().
|
||||
ensures
|
||||
- This function is just like the test_object_detection_function() for
|
||||
object_detector's except it runs on CNNs that use loss_binary_mmod.
|
||||
- Tests the given detector against the supplied object detection problem and
|
||||
returns the precision, recall, and average precision. Note that the task is
|
||||
to predict, for each images[i], the set of object locations given by
|
||||
truth_dets[i]. Additionally, any detections on image[i] that match a box in
|
||||
truth_dets[i] that are marked ignore are ignored. That is, detections
|
||||
matching an ignore box do not count as a false alarm and similarly if any
|
||||
ignored box in truth_dets goes undetected it does not count as a missed
|
||||
detection.
|
||||
- In particular, returns a matrix M such that:
|
||||
- M(0) == the precision of the detector object. This is a number
|
||||
in the range [0,1] which measures the fraction of detector outputs
|
||||
which correspond to a real target. A value of 1 means the detector
|
||||
never produces any false alarms while a value of 0 means it only
|
||||
produces false alarms.
|
||||
- M(1) == the recall of the detector object. This is a number in the
|
||||
range [0,1] which measures the fraction of targets found by the detector.
|
||||
A value of 1 means the detector found all the non-ignore targets in
|
||||
truth_dets while a value of 0 means the detector didn't locate any of the
|
||||
targets.
|
||||
- M(2) == the average precision of the detector object. This is a number
|
||||
in the range [0,1] which measures the overall quality of the detector.
|
||||
We compute this by taking all the detections output by the detector and
|
||||
ordering them in descending order of their detection scores. Then we use
|
||||
the average_precision() routine to score the ranked listing and store the
|
||||
output into M(2).
|
||||
- This function considers a detector output D to match a rectangle T if and
|
||||
only if overlap_tester(T,D) returns true.
|
||||
- Note that you can use the adjust_threshold argument to raise or lower the
|
||||
detection threshold. This value is passed into the identically named
|
||||
argument to the detector object and therefore influences the number of
|
||||
output detections. It can be useful, for example, to lower the detection
|
||||
threshold because it results in more detections being output by the
|
||||
detector, and therefore provides more information in the ranking,
|
||||
possibly raising the average precision.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_DNn_VALIDATION_ABSTRACT_H_
|
||||
|
Loading…
Reference in New Issue
Block a user