diff --git a/dlib/dnn.h b/dlib/dnn.h index 7d8678cdc..36f701542 100644 --- a/dlib/dnn.h +++ b/dlib/dnn.h @@ -19,6 +19,7 @@ #include "dnn/cpu_dlib.h" #include "dnn/tensor_tools.h" #include "dnn/utilities.h" +#include "dnn/validation.h" #endif // DLIB_DNn_ diff --git a/dlib/dnn/validation.h b/dlib/dnn/validation.h new file mode 100644 index 000000000..b347e005c --- /dev/null +++ b/dlib/dnn/validation.h @@ -0,0 +1,95 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_VALIDATION_H_ +#define DLIB_DNn_VALIDATION_H_ + +#include "validation_abstract.h" +#include "../svm/cross_validate_object_detection_trainer.h" +#include "layers.h" + +namespace dlib +{ + + template < + typename SUBNET, + typename image_array_type + > + const matrix test_object_detection_function ( + loss_binary_mmod& detector, + const image_array_type& images, + const std::vector>& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( is_learning_problem(images,truth_dets) == true , + "\t matrix test_object_detection_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets) + << "\n\t images.size(): " << images.size() + ); + + + + double correct_hits = 0; + double total_true_targets = 0; + + std::vector > all_dets; + unsigned long missing_detections = 0; + + resizable_tensor temp; + + for (unsigned long i = 0; i < images.size(); ++i) + { + std::vector hits; + detector.to_tensor(&images[i], &images[i]+1, temp); + detector.subnet().forward(temp); + detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold); + + + std::vector truth_boxes; + std::vector ignore; + std::vector> boxes; + // copy hits and truth_dets into the above three objects + for (auto&& b : truth_dets[i]) + { + if (b.ignore) + ignore.push_back(b); + else + truth_boxes.push_back(full_object_detection(b.rect)); + } + for (auto&& b : hits) + boxes.push_back(std::make_pair(b.detection_confidence, b.rect)); + + correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections); + total_true_targets += truth_boxes.size(); + } + + std::sort(all_dets.rbegin(), all_dets.rend()); + + double precision, recall; + + double total_hits = all_dets.size(); + + if (total_hits == 0) + precision = 1; + else + precision = correct_hits / total_hits; + + if (total_true_targets == 0) + recall = 1; + else + recall = correct_hits / total_true_targets; + + matrix res; + res = precision, recall, average_precision(all_dets, missing_detections); + return res; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_VALIDATION_H_ + diff --git a/dlib/dnn/validation_abstract.h b/dlib/dnn/validation_abstract.h new file mode 100644 index 000000000..5ff77eb60 --- /dev/null +++ b/dlib/dnn/validation_abstract.h @@ -0,0 +1,72 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_VALIDATION_ABSTRACT_H_ +#ifdef DLIB_DNn_VALIDATION_ABSTRACT_H_ + +#include "../svm/cross_validate_object_detection_trainer_abstract.h" +#include "layers_abstract.h" + +namespace dlib +{ + + template < + typename SUBNET, + typename image_array_type + > + const matrix test_object_detection_function ( + loss_binary_mmod& detector, + const image_array_type& images, + const std::vector>& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - is_learning_problem(images,truth_dets) + - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h + and it must contain objects which can be accepted by detector(). + ensures + - This function is just like the test_object_detection_function() for + object_detector's except it runs on CNNs that use loss_binary_mmod. + - Tests the given detector against the supplied object detection problem and + returns the precision, recall, and average precision. Note that the task is + to predict, for each images[i], the set of object locations given by + truth_dets[i]. Additionally, any detections on image[i] that match a box in + truth_dets[i] that are marked ignore are ignored. That is, detections + matching an ignore box do not count as a false alarm and similarly if any + ignored box in truth_dets goes undetected it does not count as a missed + detection. + - In particular, returns a matrix M such that: + - M(0) == the precision of the detector object. This is a number + in the range [0,1] which measures the fraction of detector outputs + which correspond to a real target. A value of 1 means the detector + never produces any false alarms while a value of 0 means it only + produces false alarms. + - M(1) == the recall of the detector object. This is a number in the + range [0,1] which measures the fraction of targets found by the detector. + A value of 1 means the detector found all the non-ignore targets in + truth_dets while a value of 0 means the detector didn't locate any of the + targets. + - M(2) == the average precision of the detector object. This is a number + in the range [0,1] which measures the overall quality of the detector. + We compute this by taking all the detections output by the detector and + ordering them in descending order of their detection scores. Then we use + the average_precision() routine to score the ranked listing and store the + output into M(2). + - This function considers a detector output D to match a rectangle T if and + only if overlap_tester(T,D) returns true. + - Note that you can use the adjust_threshold argument to raise or lower the + detection threshold. This value is passed into the identically named + argument to the detector object and therefore influences the number of + output detections. It can be useful, for example, to lower the detection + threshold because it results in more detections being output by the + detector, and therefore provides more information in the ranking, + possibly raising the average precision. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_VALIDATION_ABSTRACT_H_ +