Add basic image io and remove python C-API refs from numpy_returns.cpp (#1258)

* Fixed reference count issue

* Fixed refcount issue in Python dlib.jitter_image and dlib.get_face_chips

* Consolidation of https://github.com/davisking/dlib/pull/1249

* Fixed build issue

* Fixed: Paths in a pytest file should be relative to dlib root

* Skip numpy return tests for Python 2.7 or if Numpy is not installed

* Enabled numpy returns tests on Python 2.7 using cPickle.dumps
pull/1265/head
visionworkz 7 years ago committed by Davis E. King
parent d7dfd8ad26
commit e8faced822

@ -33,14 +33,12 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import sys
import dlib
from skimage import io
if len(sys.argv) < 3:
print(
@ -55,7 +53,7 @@ win = dlib.image_window()
for f in sys.argv[2:]:
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
# The 1 in the second argument indicates that we should upsample the image
# 1 time. This will make everything bigger and allow us to detect more
# faces.

@ -32,16 +32,14 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import os
import glob
import dlib
from skimage import io
# Path to the video frames
video_folder = os.path.join("..", "examples", "video_frames")
@ -54,7 +52,7 @@ win = dlib.image_window()
# We will track the frames as we load them off of disk
for k, f in enumerate(sorted(glob.glob(os.path.join(video_folder, "*.jpg")))):
print("Processing Frame {}".format(k))
img = io.imread(f)
img = dlib.load_rgb_image(f)
# We need to initialize the tracker on the first frame
if k == 0:

@ -21,16 +21,13 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires OpenCV and Numpy which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install opencv-python numpy
# Or downloaded from http://opencv.org/releases.html
# pip install numpy
import sys
import dlib
import cv2
import numpy as np
if len(sys.argv) != 3:
print(
@ -48,14 +45,8 @@ face_file_path = sys.argv[2]
detector = dlib.get_frontal_face_detector()
sp = dlib.shape_predictor(predictor_path)
# Load the image using OpenCV
bgr_img = cv2.imread(face_file_path)
if bgr_img is None:
print("Sorry, we could not load '{}' as an image".format(face_file_path))
exit()
# Convert to RGB since dlib uses RGB images
img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
# Load the image using Dlib
img = dlib.load_rgb_image(face_file_path)
# Ask the detector to find the bounding boxes of each face. The 1 in the
# second argument indicates that we should upsample the image 1 time. This
@ -72,20 +63,17 @@ faces = dlib.full_object_detections()
for detection in dets:
faces.append(sp(img, detection))
window = dlib.image_window()
# Get the aligned face images
# Optionally:
# images = dlib.get_face_chips(img, faces, size=160, padding=0.25)
images = dlib.get_face_chips(img, faces, size=320)
for image in images:
cv_bgr_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imshow('image',cv_bgr_img)
cv2.waitKey(0)
window.set_image(image)
dlib.hit_enter_to_continue()
# It is also possible to get a single chip
image = dlib.get_face_chip(img, faces[0])
cv_bgr_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imshow('image',cv_bgr_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
window.set_image(image)
dlib.hit_enter_to_continue()

@ -28,16 +28,14 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import sys
import os
import dlib
import glob
from skimage import io
if len(sys.argv) != 5:
print(
@ -66,7 +64,7 @@ images = []
# Now find all the faces and compute 128D face descriptors for each face.
for f in glob.glob(os.path.join(faces_folder_path, "*.jpg")):
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
# Ask the detector to find the bounding boxes of each face. The 1 in the
# second argument indicates that we should upsample the image 1 time. This

@ -37,23 +37,20 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import sys
import dlib
from skimage import io
detector = dlib.get_frontal_face_detector()
win = dlib.image_window()
for f in sys.argv[1:]:
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
# The 1 in the second argument indicates that we should upsample the image
# 1 time. This will make everything bigger and allow us to detect more
# faces.
@ -76,7 +73,7 @@ for f in sys.argv[1:]:
# Also, the idx tells you which of the face sub-detectors matched. This can be
# used to broadly identify faces in different orientations.
if (len(sys.argv[1:]) > 0):
img = io.imread(sys.argv[1])
img = dlib.load_rgb_image(sys.argv[1])
dets, scores, idx = detector.run(img, 1, -1)
for i, d in enumerate(dets):
print("Detection {}, score: {}, face_type:{}".format(

@ -25,26 +25,23 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires OpenCV and Numpy which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install opencv-python numpy
# pip install numpy
#
# The image file used in this example is in the public domain:
# https://commons.wikimedia.org/wiki/File:Tom_Cruise_avp_2014_4.jpg
import sys
import dlib
import cv2
import numpy as np
def show_jittered_images(jittered_images):
def show_jittered_images(window, jittered_images):
'''
Shows the specified jittered images one by one
'''
for img in jittered_images:
cv_bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imshow('image',cv_bgr_img)
cv2.waitKey(0)
window.set_image(img)
dlib.hit_enter_to_continue()
if len(sys.argv) != 2:
print(
@ -62,14 +59,8 @@ face_file_path = "../examples/faces/Tom_Cruise_avp_2014_4.jpg"
detector = dlib.get_frontal_face_detector()
sp = dlib.shape_predictor(predictor_path)
# Load the image using OpenCV
bgr_img = cv2.imread(face_file_path)
if bgr_img is None:
print("Sorry, we could not load '{}' as an image".format(face_file_path))
exit()
# Convert to RGB since dlib uses RGB images
img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
# Load the image using dlib
img = dlib.load_rgb_image(face_file_path)
# Ask the detector to find the bounding boxes of each face.
dets = detector(img)
@ -83,15 +74,14 @@ for detection in dets:
# Get the aligned face image and show it
image = dlib.get_face_chip(img, faces[0], size=320)
cv_bgr_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imshow('image',cv_bgr_img)
cv2.waitKey(0)
window = dlib.image_window()
window.set_image(image)
dlib.hit_enter_to_continue()
# Show 5 jittered images without data augmentation
jittered_images = dlib.jitter_image(image, num_jitters=5)
show_jittered_images(jittered_images)
show_jittered_images(window, jittered_images)
# Show 5 jittered images with data augmentation
jittered_images = dlib.jitter_image(image, num_jitters=5, disturb_colors=True)
show_jittered_images(jittered_images)
cv2.destroyAllWindows()
show_jittered_images(window, jittered_images)

@ -45,16 +45,14 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import sys
import os
import dlib
import glob
from skimage import io
if len(sys.argv) != 3:
print(
@ -76,7 +74,7 @@ win = dlib.image_window()
for f in glob.glob(os.path.join(faces_folder_path, "*.jpg")):
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
win.clear_overlay()
win.set_image(img)

@ -40,16 +40,14 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import sys
import os
import dlib
import glob
from skimage import io
if len(sys.argv) != 4:
print(
@ -76,7 +74,7 @@ win = dlib.image_window()
# Now process all the images
for f in glob.glob(os.path.join(faces_folder_path, "*.jpg")):
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
win.clear_overlay()
win.set_image(img)

@ -31,18 +31,14 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import dlib
from skimage import io
image_file = '../examples/faces/2009_004587.jpg'
img = io.imread(image_file)
img = dlib.load_rgb_image(image_file)
# Locations of candidate objects will be saved into rects
rects = []

@ -0,0 +1,59 @@
#!/usr/bin/python
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
#
# This example program shows how to find frontal human faces in a webcam stream using OpenCV.
# It is also meant to demonstrate that rgb images from Dlib can be used with opencv by just
# swapping the Red and Blue channels.
#
# You can run this program and see the detections from your webcam by executing the
# following command:
# ./opencv_face_detection.py
#
# This face detector is made using the now classic Histogram of Oriented
# Gradients (HOG) feature combined with a linear classifier, an image
# pyramid, and sliding window detection scheme. This type of object detector
# is fairly general and capable of detecting many types of semi-rigid objects
# in addition to human faces. Therefore, if you are interested in making
# your own object detectors then read the train_object_detector.py example
# program.
#
#
# COMPILING/INSTALLING THE DLIB PYTHON INTERFACE
# You can install dlib using the command:
# pip install dlib
#
# Alternatively, if you want to compile dlib yourself then go into the dlib
# root folder and run:
# python setup.py install
# or
# python setup.py install --yes USE_AVX_INSTRUCTIONS
# if you have a CPU that supports AVX instructions, since this makes some
# things run faster.
#
# Compiling dlib should work on any operating system so long as you have
# CMake installed. On Ubuntu, this can be done easily by running the
# command:
# sudo apt-get install cmake
#
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install numpy
import sys
import dlib
import cv2
detector = dlib.get_frontal_face_detector()
cam = cv2.VideoCapture(0)
color_green = (0,255,0)
line_width = 3
while True:
ret_val, img = cam.read()
rgb_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
dets = detector(rgb_image)
for det in dets:
cv2.rectangle(img,(det.left(), det.top()), (det.right(), det.bottom()), color_green, line_width)
cv2.imshow('my webcam', img)
if cv2.waitKey(1) == 27:
break # esc to quit
cv2.destroyAllWindows()

@ -1,3 +1,2 @@
scikit-image>=0.9.3
opencv-python
numpy

@ -25,18 +25,15 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import os
import sys
import glob
import dlib
from skimage import io
# In this example we are going to train a face detector based on the small
# faces dataset in the examples/faces directory. This means you need to supply
@ -116,7 +113,7 @@ print("Showing detections on the images in the faces folder...")
win = dlib.image_window()
for f in glob.glob(os.path.join(faces_folder, "*.jpg")):
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
dets = detector(img)
print("Number of faces detected: {}".format(len(dets)))
for k, d in enumerate(dets):
@ -128,9 +125,6 @@ for f in glob.glob(os.path.join(faces_folder, "*.jpg")):
win.add_overlay(dets)
dlib.hit_enter_to_continue()
# Next, suppose you have trained multiple detectors and you want to run them
# efficiently as a group. You can do this as follows:
detector1 = dlib.fhog_object_detector("detector.svm")
@ -140,22 +134,19 @@ detector2 = dlib.fhog_object_detector("detector.svm")
# make a list of all the detectors you wan to run. Here we have 2, but you
# could have any number.
detectors = [detector1, detector2]
image = io.imread(faces_folder + '/2008_002506.jpg')
image = dlib.load_rgb_image(faces_folder + '/2008_002506.jpg')
[boxes, confidences, detector_idxs] = dlib.fhog_object_detector.run_multiple(detectors, image, upsample_num_times=1, adjust_threshold=0.0)
for i in range(len(boxes)):
print("detector {} found box {} with confidence {}.".format(detector_idxs[i], boxes[i], confidences[i]))
# Finally, note that you don't have to use the XML based input to
# train_simple_object_detector(). If you have already loaded your training
# images and bounding boxes for the objects then you can call it as shown
# below.
# You just need to put your images into a list.
images = [io.imread(faces_folder + '/2008_002506.jpg'),
io.imread(faces_folder + '/2009_004587.jpg')]
images = [dlib.load_rgb_image(faces_folder + '/2008_002506.jpg'),
dlib.load_rgb_image(faces_folder + '/2009_004587.jpg')]
# Then for each image you make a list of rectangles which give the pixel
# locations of the edges of the boxes.
boxes_img1 = ([dlib.rectangle(left=329, top=78, right=437, bottom=186),

@ -33,18 +33,15 @@
# command:
# sudo apt-get install cmake
#
# Also note that this example requires scikit-image which can be installed
# Also note that this example requires Numpy which can be installed
# via the command:
# pip install scikit-image
# Or downloaded from http://scikit-image.org/download.html.
# pip install numpy
import os
import sys
import glob
import dlib
from skimage import io
# In this example we are going to train a face detector based on the small
# faces dataset in the examples/faces directory. This means you need to supply
@ -110,7 +107,7 @@ print("Showing detections and predictions on the images in the faces folder...")
win = dlib.image_window()
for f in glob.glob(os.path.join(faces_folder, "*.jpg")):
print("Processing file: {}".format(f))
img = io.imread(f)
img = dlib.load_rgb_image(f)
win.clear_overlay()
win.set_image(img)

@ -37,16 +37,12 @@ find_package(PythonInterp)
if(PYTHONINTERP_FOUND)
execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import numpy" OUTPUT_QUIET ERROR_QUIET RESULT_VARIABLE NUMPYRC)
if(NUMPYRC EQUAL 1)
message(WARNING "Numpy not found. Functions that return numpy arrays will throw exceptions!")
message(WARNING "Numpy not found. Functions that return numpy arrays will not work without Numpy installed!")
else()
message(STATUS "Found Python with installed numpy package")
execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import sys; from numpy import get_include; sys.stdout.write(get_include())" OUTPUT_VARIABLE NUMPY_INCLUDE_PATH)
message(STATUS "Numpy include path '${NUMPY_INCLUDE_PATH}'")
include_directories(${NUMPY_INCLUDE_PATH})
endif()
else()
message(WARNING "Numpy not found. Functions that return numpy arrays will throw exceptions!")
set(NUMPYRC 1)
message(WARNING "Numpy not found. Functions that return numpy arrays will not work without Numpy installed!")
endif()
add_definitions(-DDLIB_VERSION=${DLIB_VERSION})
@ -73,15 +69,9 @@ set(python_srcs
src/cnn_face_detector.cpp
src/global_optimization.cpp
src/image_dataset_metadata.cpp
src/numpy_returns.cpp
)
# Only add the Numpy returning functions if Numpy is present
if(NUMPYRC EQUAL 1)
list(APPEND python_srcs src/numpy_returns_stub.cpp)
else()
list(APPEND python_srcs src/numpy_returns.cpp)
endif()
# Only add the GUI module if requested
if(NOT ${DLIB_NO_GUI_SUPPORT})
list(APPEND python_srcs src/gui.cpp)

@ -2,16 +2,76 @@
#include <dlib/python.h>
#include "dlib/pixel.h"
#include <dlib/image_transforms.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>
#include <dlib/image_io.h>
#include <pybind11/numpy.h>
using namespace dlib;
using namespace std;
namespace py = pybind11;
py::array_t<uint8_t> convert_to_numpy(matrix<rgb_pixel> &rgb_image)
{
const size_t dtype_size = sizeof(uint8_t);
const auto rows = static_cast<const size_t>(num_rows(rgb_image));
const auto cols = static_cast<const size_t>(num_columns(rgb_image));
const size_t channels = 3;
const size_t image_size = dtype_size * rows * cols * channels;
unique_ptr<rgb_pixel[]> arr_ptr = rgb_image.steal_memory();
uint8_t* arr = (uint8_t *) arr_ptr.release();
return pybind11::array_t<uint8_t>(
{rows, cols, channels}, // shape
{dtype_size * cols * channels, dtype_size * channels, dtype_size}, // strides
arr, // pointer
pybind11::capsule{
arr, [](void *arr_p) {
delete[] reinterpret_cast<uint8_t *>(arr_p);
}
}
);
}
// -------------------------------- Basic Image IO ----------------------------------------
py::array_t<uint8_t> load_rgb_image (const std::string &path)
{
matrix<rgb_pixel> img;
load_image(img, path);
return convert_to_numpy(img);
}
bool has_ending (std::string const full_string, std::string const &ending) {
if(full_string.length() >= ending.length()) {
return (0 == full_string.compare(full_string.length() - ending.length(), ending.length(), ending));
} else {
return false;
}
}
void save_rgb_image(py::object img, const std::string &path)
{
if (!is_rgb_python_image(img))
throw dlib::error("Unsupported image type, must be RGB image.");
std::string lowered_path = path;
std::transform(lowered_path.begin(), lowered_path.end(), lowered_path.begin(), ::tolower);
if(has_ending(lowered_path, ".bmp")) {
save_bmp(numpy_rgb_image(img), path);
} else if(has_ending(lowered_path, ".dng")) {
save_dng(numpy_rgb_image(img), path);
} else if(has_ending(lowered_path, ".png")) {
save_png(numpy_rgb_image(img), path);
} else if(has_ending(lowered_path, ".jpg") || has_ending(lowered_path, ".jpeg")) {
save_jpeg(numpy_rgb_image(img), path);
} else {
throw dlib::error("Unsupported image type, image path must end with one of [.bmp, .png, .dng, .jpg, .jpeg]");
}
return;
}
// ----------------------------------------------------------------------------------------
py::list get_jitter_images(py::object img, size_t num_jitters = 1, bool disturb_colors = false)
@ -27,12 +87,6 @@ py::list get_jitter_images(py::object img, size_t num_jitters = 1, bool disturb_
// The top level list (containing 1 or more images) to return to python
py::list jitter_list;
size_t rows = num_rows(img_mat);
size_t cols = num_columns(img_mat);
// Size of the numpy array
npy_intp dims[3] = { num_rows(img_mat), num_columns(img_mat), 3};
for (int i = 0; i < num_jitters; ++i) {
// Get a jittered crop
matrix<rgb_pixel> crop = dlib::jitter_image(img_mat, rnd_jitter);
@ -40,14 +94,11 @@ py::list get_jitter_images(py::object img, size_t num_jitters = 1, bool disturb_
if(disturb_colors)
dlib::disturb_colors(crop, rnd_jitter);
PyObject *arr = PyArray_SimpleNew(3, dims, NPY_UINT8);
npy_uint8 *outdata = (npy_uint8 *) PyArray_DATA((PyArrayObject*) arr);
memcpy(outdata, image_data(crop), rows * width_step(crop));
// Convert image to Numpy array
py::array_t<uint8_t> arr = convert_to_numpy(crop);
py::handle handle = arr;
// Append image to jittered image list
jitter_list.append(handle);
Py_DECREF(arr);
jitter_list.append(arr);
}
return jitter_list;
@ -77,27 +128,18 @@ py::list get_face_chips (
dlib::array<matrix<rgb_pixel>> face_chips;
extract_image_chips(numpy_rgb_image(img), dets, face_chips);
npy_intp rows = size;
npy_intp cols = size;
// Size of the numpy array
npy_intp dims[3] = { rows, cols, 3};
for (auto& chip : face_chips)
{
PyObject *arr = PyArray_SimpleNew(3, dims, NPY_UINT8);
npy_uint8 *outdata = (npy_uint8 *) PyArray_DATA((PyArrayObject*) arr);
memcpy(outdata, image_data(chip), rows * width_step(chip));
py::handle handle = arr;
// Convert image to Numpy array
py::array_t<uint8_t> arr = convert_to_numpy(chip);
// Append image to chips list
chips_list.append(handle);
Py_DECREF(arr);
chips_list.append(arr);
}
return chips_list;
}
py::object get_face_chip (
py::array_t<uint8_t> get_face_chip (
py::object img,
const full_object_detection& face,
size_t size = 150,
@ -109,36 +151,22 @@ py::object get_face_chip (
matrix<rgb_pixel> chip;
extract_image_chip(numpy_rgb_image(img), get_face_chip_details(face, size, padding), chip);
// Size of the numpy array
npy_intp dims[3] = { num_rows(chip), num_columns(chip), 3};
PyObject *arr = PyArray_SimpleNew(3, dims, NPY_UINT8);
npy_uint8 *outdata = (npy_uint8 *) PyArray_DATA((PyArrayObject *) arr);
memcpy(outdata, image_data(chip), num_rows(chip) * width_step(chip));
return py::reinterpret_steal<py::object>(arr);
return convert_to_numpy(chip);
}
// ----------------------------------------------------------------------------------------
// we need this wonky stuff because different versions of numpy's import_array macro
// contain differently typed return statements inside import_array().
#if PY_VERSION_HEX >= 0x03000000
#define DLIB_NUMPY_IMPORT_ARRAY_RETURN_TYPE void*
#define DLIB_NUMPY_IMPORT_RETURN return 0
#else
#define DLIB_NUMPY_IMPORT_ARRAY_RETURN_TYPE void
#define DLIB_NUMPY_IMPORT_RETURN return
#endif
DLIB_NUMPY_IMPORT_ARRAY_RETURN_TYPE import_numpy_stuff()
{
import_array();
DLIB_NUMPY_IMPORT_RETURN;
}
void bind_numpy_returns(py::module &m)
{
import_numpy_stuff();
m.def("load_rgb_image", &load_rgb_image,
"Takes a path and returns a numpy array (RGB) containing the image",
py::arg("path")
);
m.def("save_rgb_image", &save_rgb_image,
"Saves the given (RGB) image to the specified path. Determines the file type from the file extension specified in the path",
py::arg("img"), py::arg("path")
);
m.def("jitter_image", &get_jitter_images,
"Takes an image and returns a list of jittered images."

@ -1,59 +0,0 @@
#include "opaque_types.h"
#include <dlib/python.h>
#include "dlib/pixel.h"
#include <dlib/image_transforms.h>
using namespace dlib;
using namespace std;
namespace py = pybind11;
// ----------------------------------------------------------------------------------------
py::list get_jitter_images(py::object img, size_t num_jitters = 1, bool disturb_colors = false)
{
throw dlib::error("jitter_image is only supported if you compiled dlib with numpy installed!");
}
// ----------------------------------------------------------------------------------------
py::list get_face_chips (
py::object img,
const std::vector<full_object_detection>& faces,
size_t size = 150,
float padding = 0.25
)
{
throw dlib::error("get_face_chips is only supported if you compiled dlib with numpy installed!");
}
py::object get_face_chip (
py::object img,
const full_object_detection& face,
size_t size = 150,
float padding = 0.25
)
{
throw dlib::error("get_face_chip is only supported if you compiled dlib with numpy installed!");
}
// ----------------------------------------------------------------------------------------
void bind_numpy_returns(py::module &m)
{
m.def("jitter_image", &get_jitter_images,
"Takes an image and returns a list of jittered images."
"The returned list contains num_jitters images (default is 1)."
"If disturb_colors is set to True, the colors of the image are disturbed (default is False)",
py::arg("img"), py::arg("num_jitters")=1, py::arg("disturb_colors")=false
);
m.def("get_face_chip", &get_face_chip,
"Takes an image and a full_object_detection that references a face in that image and returns the face as a Numpy array representing the image. The face will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.",
py::arg("img"), py::arg("face"), py::arg("size")=150, py::arg("padding")=0.25
);
m.def("get_face_chips", &get_face_chips,
"Takes an image and a full_object_detections object that reference faces in that image and returns the faces as a list of Numpy arrays representing the image. The faces will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.",
py::arg("img"), py::arg("faces"), py::arg("size")=150, py::arg("padding")=0.25
);
}

@ -0,0 +1,33 @@
#!/usr/bin/python
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
#
# This utility generates the test data required for the tests contained in test_numpy_returns.py
#
# Also note that this utility requires Numpy which can be installed
# via the command:
# pip install numpy
import sys
import dlib
import numpy as np
import utils
if len(sys.argv) != 2:
print(
"Call this program like this:\n"
" ./generate_numpy_returns_test_data.py shape_predictor_5_face_landmarks.dat\n"
"You can download a trained facial shape predictor from:\n"
" http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2\n")
exit()
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(sys.argv[1])
img = dlib.load_rgb_image("../../../examples/faces/Tom_Cruise_avp_2014_4.jpg")
dets = detector(img)
shape = predictor(img, dets[0])
utils.save_pickled_compatible(shape, "shape.pkl")
face_chip = dlib.get_face_chip(img, shape)
np.save("test_face_chip", face_chip)

@ -0,0 +1,3 @@
cdlib
full_object_detection
q)<29>qU+'åÜŽOŒ¡wä q…b.

@ -0,0 +1,66 @@
import sys
import pickle
import dlib
import pytest
import utils
# Paths are relative to dlib root
image_path = "examples/faces/Tom_Cruise_avp_2014_4.jpg"
shape_path = "tools/python/test/shape.pkl"
face_chip_path = "tools/python/test/test_face_chip.npy"
def get_test_image_and_shape():
img = dlib.load_rgb_image(image_path)
shape = utils.load_pickled_compatible(shape_path)
return img, shape
def get_test_face_chips():
rgb_img, shape = get_test_image_and_shape()
shapes = dlib.full_object_detections()
shapes.append(shape)
return dlib.get_face_chips(rgb_img, shapes)
def get_test_face_chip():
rgb_img, shape = get_test_image_and_shape()
return dlib.get_face_chip(rgb_img, shape)
# The tests below will be skipped if Numpy is not installed
@pytest.mark.skipif(not utils.is_numpy_installed(), reason="requires numpy")
def test_get_face_chip():
import numpy
face_chip = get_test_face_chip()
expected = numpy.load(face_chip_path)
assert numpy.array_equal(face_chip, expected)
@pytest.mark.skipif(not utils.is_numpy_installed(), reason="requires numpy")
def test_get_face_chips():
import numpy
face_chips = get_test_face_chips()
expected = numpy.load(face_chip_path)
assert numpy.array_equal(face_chips[0], expected)
@pytest.mark.skipif(not utils.is_numpy_installed(), reason="requires numpy")
def test_regression_issue_1220_get_face_chip():
"""
Memory leak in Python get_face_chip
https://github.com/davisking/dlib/issues/1220
"""
face_chip = get_test_face_chip()
# we expect two references:
# 1.) the local variable
# 2.) the temporary passed to getrefcount
assert sys.getrefcount(face_chip) == 2
@pytest.mark.skipif(not utils.is_numpy_installed(), reason="requires numpy")
def test_regression_issue_1220_get_face_chips():
"""
Memory leak in Python get_face_chip
https://github.com/davisking/dlib/issues/1220
"""
face_chips = get_test_face_chips()
count = sys.getrefcount(face_chips)
assert count == 2
count = sys.getrefcount(face_chips[0])
assert count == 2

@ -0,0 +1,48 @@
import pkgutil
import sys
def save_pickled_compatible(obj_to_pickle, file_name):
'''
Save an object to the specified file in a backward compatible
way for Pybind objects. See:
http://pybind11.readthedocs.io/en/stable/advanced/classes.html#pickling-support
and https://github.com/pybind/pybind11/issues/271
'''
try:
import cPickle as pickle # Use cPickle on Python 2.7
except ImportError:
import pickle
data = pickle.dumps(obj_to_pickle, 2)
with open(file_name, "wb") as handle:
handle.write(data)
def load_pickled_compatible(file_name):
'''
Loads a pickled object from the specified file
'''
try:
import cPickle as pickle # Use cPickle on Python 2.7
except ImportError:
import pickle
with open(file_name, "rb") as handle:
data = handle.read()
if not is_python3():
return pickle.loads(data)
else:
return pickle.loads(data, encoding="bytes")
def is_numpy_installed():
'''
Returns True if Numpy is installed otherwise False
'''
if pkgutil.find_loader("numpy"):
return True
else:
return False
def is_python3():
'''
Returns True if using Python 3 or above, otherwise False
'''
return sys.version_info >= (3, 0)
Loading…
Cancel
Save