improvements to cnn face detection python interface (#780)

* improvements to cnn face detection interface

* mmod rectangle object renaming. possibility to set batch size in multi image detection. Added check to make sure images are all the same size.
pull/2019/head
Guillaume Ramé 7 years ago committed by Davis E. King
parent 306fd3af10
commit bbf3d987b2

@ -140,6 +140,12 @@ namespace dlib
bool ignore = false;
operator rectangle() const { return rect; }
bool operator == (const mmod_rect& rhs) const
{
return rect == rhs.rect
&& detection_confidence == rhs.detection_confidence
&& ignore == rhs.ignore;
}
};
inline mmod_rect ignored_mmod_rect(const rectangle& r)

@ -39,7 +39,6 @@
# Or downloaded from http://scikit-image.org/download.html.
import sys
import dlib
from skimage import io
@ -51,7 +50,7 @@ if len(sys.argv) < 3:
" http://dlib.net/files/mmod_human_face_detector.dat.bz2")
exit()
cnn_face_detection_model = dlib.cnn_face_detection_model_v1(sys.argv[1])
cnn_face_detector = dlib.cnn_face_detection_model_v1(sys.argv[1])
win = dlib.image_window()
for f in sys.argv[2:]:
@ -60,13 +59,27 @@ for f in sys.argv[2:]:
# The 1 in the second argument indicates that we should upsample the image
# 1 time. This will make everything bigger and allow us to detect more
# faces.
dets = cnn_face_detection_model.cnn_face_detector(img, 1)
dets = cnn_face_detector(img, 1)
'''
This detector returns a mmod_rectangles object. This object contains a list of mmod_rectangle objects.
These objects can be accessed by simply iterating over the mmod_rectangles object
The mmod_rectangle object has two member variables, a dlib.rectangle object, and a confidence score.
It is also possible to pass a list of images to the detector.
- like this: dets = cnn_face_detector([image list], upsample_num, batch_size = 128)
In this case it will return a mmod_rectangless object.
This object behaves just like a list of lists and can be iterated over.
'''
print("Number of faces detected: {}".format(len(dets)))
for i, d in enumerate(dets):
print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
i, d.left(), d.top(), d.right(), d.bottom()))
print("Detection {}: Left: {} Top: {} Right: {} Bottom: {} Confidence: {}".format(
i, d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom(), d.confidence))
rects = dlib.rectangles()
rects.extend([d.rect for d in dets])
win.clear_overlay()
win.set_image(img)
win.add_overlay(dets)
dlib.hit_enter_to_continue()
win.add_overlay(rects)
dlib.hit_enter_to_continue()

@ -2,10 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license.
#include <dlib/python.h>
#include <boost/shared_ptr.hpp>
#include <dlib/matrix.h>
#include <boost/python/slice.hpp>
#include <dlib/geometry/vector.h>
#include <dlib/dnn.h>
#include <dlib/image_transforms.h>
#include "indexing.h"
@ -14,9 +11,6 @@ using namespace dlib;
using namespace std;
using namespace boost::python;
typedef matrix<double,0,1> cv;
class cnn_face_detection_model_v1
{
@ -27,13 +21,13 @@ public:
deserialize(model_filename) >> net;
}
std::vector<rectangle> cnn_face_detector (
std::vector<mmod_rect> detect (
object pyimage,
const int upsample_num_times
)
{
pyramid_down<2> pyr;
std::vector<rectangle> rects;
std::vector<mmod_rect> rects;
// Copy the data into dlib based objects
matrix<rgb_pixel> image;
@ -59,12 +53,69 @@ public:
// if the image was upscaled.
for (auto&& d : dets) {
d.rect = pyr.rect_down(d.rect, upsample_num_times);
rects.push_back(d.rect);
rects.push_back(d);
}
return rects;
}
std::vector<std::vector<mmod_rect> > detect_mult (
boost::python::list& imgs,
const int upsample_num_times,
const int batch_size = 128
)
{
pyramid_down<2> pyr;
std::vector<matrix<rgb_pixel> > dimgs;
dimgs.reserve(len(imgs));
for(int i = 0; i < len(imgs); i++)
{
// Copy the data into dlib based objects
matrix<rgb_pixel> image;
object tmp = boost::python::extract<object>(imgs[i]);
if (is_gray_python_image(tmp))
assign_image(image, numpy_gray_image(tmp));
else if (is_rgb_python_image(tmp))
assign_image(image, numpy_rgb_image(tmp));
else
throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");
for(int i = 0; i < upsample_num_times; i++)
{
pyramid_up(image);
}
dimgs.push_back(image);
}
for(int i = 1; i < dimgs.size(); i++)
{
if
(
dimgs[i - 1].nc() != dimgs[i].nc() ||
dimgs[i - 1].nr() != dimgs[i].nr()
)
throw dlib::error("Images in list must all have the same dimensions.");
}
auto dets = net(dimgs, batch_size);
std::vector<std::vector<mmod_rect> > all_rects;
for(auto&& im_dets : dets)
{
std::vector<mmod_rect> rects;
rects.reserve(im_dets.size());
for (auto&& d : im_dets) {
d.rect = pyr.rect_down(d.rect, upsample_num_times);
rects.push_back(d);
}
all_rects.push_back(rects);
}
return all_rects;
}
private:
template <long num_filters, typename SUBNET> using con5d = con<num_filters,5,5,2,2,SUBNET>;
@ -78,7 +129,6 @@ private:
net_type net;
};
// ----------------------------------------------------------------------------------------
void bind_cnn_face_detection()
@ -86,10 +136,35 @@ void bind_cnn_face_detection()
using boost::python::arg;
{
class_<cnn_face_detection_model_v1>("cnn_face_detection_model_v1", "This object detects human faces in an image. The constructor loads the face detection model from a file. You can download a pre-trained model from http://dlib.net/files/mmod_human_face_detector.dat.bz2.", init<std::string>())
.def("cnn_face_detector", &cnn_face_detection_model_v1::cnn_face_detector, (arg("img"), arg("upsample_num_times")=0),
.def(
"__call__",
&cnn_face_detection_model_v1::detect,
(arg("img"), arg("upsample_num_times")=0),
"Find faces in an image using a deep learning model.\n\
- Upsamples the image upsample_num_times before running the face \n\
detector."
)
.def(
"__call__",
&cnn_face_detection_model_v1::detect_mult,
(arg("imgs"), arg("upsample_num_times")=0, arg("batch_size")=128),
"takes a list of images as input returning a 2d list of mmod rectangles"
);
}
{
typedef mmod_rect type;
class_<type>("mmod_rectangle", "Wrapper around a rectangle object and a detection confidence score.")
.def_readwrite("rect", &type::rect)
.def_readwrite("confidence", &type::detection_confidence);
}
{
typedef std::vector<mmod_rect> type;
class_<type>("mmod_rectangles", "An array of mmod rectangle objects.")
.def(vector_indexing_suite<type>());
}
{
typedef std::vector<std::vector<mmod_rect> > type;
class_<type>("mmod_rectangless", "A 2D array of mmod rectangle objects.")
.def(vector_indexing_suite<type>());
}
}

Loading…
Cancel
Save