mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
e0a6e30581
pyramid_down and you give it the downsampling amount as a template argument.
301 lines
13 KiB
C++
301 lines
13 KiB
C++
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
|
/*
|
|
|
|
This is an example showing how you might use dlib to create a reasonably
|
|
functional command line tool for object detection. This example assumes
|
|
you are familiar with the contents of at least the following example
|
|
programs:
|
|
- object_detector_ex.cpp
|
|
- compress_stream_ex.cpp
|
|
|
|
|
|
|
|
|
|
This program is a command line tool for learning to detect objects in images.
|
|
Therefore, to create an object detector it requires a set of annotated training
|
|
images. To create this annotated data you will need to use the imglab tool
|
|
included with dlib. It is located in the tools/imglab folder and can be compiled
|
|
using the following commands.
|
|
cd tools/imglab
|
|
mkdir build
|
|
cd build
|
|
cmake ..
|
|
cmake --build . --config Release
|
|
Note that you may need to install CMake (www.cmake.org) for this to work.
|
|
|
|
Next, lets assume you have a folder of images called /tmp/images. These images
|
|
should contain examples of the objects you want to learn to detect. You will
|
|
use the imglab tool to label these objects. Do this by typing the following
|
|
./imglab -c mydataset.xml /tmp/images
|
|
This will create a file called mydataset.xml which simply lists the images in
|
|
/tmp/images. To annotate them run
|
|
./imglab mydataset.xml
|
|
A window will appear showing all the images. You can use the up and down arrow
|
|
keys to cycle though the images and the mouse to label objects. In particular,
|
|
holding the shift key, left clicking, and dragging the mouse will allow you to
|
|
draw boxes around the objects you wish to detect. So next, label all the objects
|
|
with boxes. Note that it is important to label all the objects since any object
|
|
not labeled is implicitly assumed to be not an object we should detect.
|
|
|
|
Once you finish labeling objects go to the file menu, click save, and then close
|
|
the program. This will save the object boxes back to mydataset.xml. You can verify
|
|
this by opening the tool again with
|
|
./imglab mydataset.xml
|
|
and observing that the boxes are present.
|
|
|
|
Returning to the present example program, we can compile it using cmake just as we
|
|
did with the imglab tool. Once compiled, we can issue the command
|
|
./train_object_detector -tv mydataset.xml
|
|
which will train an object detection model based on our labeled data. The model
|
|
will be saved to the file object_detector.svm. Once this has finished we can use
|
|
the object detector to locate objects in new images with a command like
|
|
./train_object_detector some_image.png
|
|
This command will display some_image.png in a window and any detected objects will
|
|
be indicated by a red box.
|
|
|
|
|
|
There are a number of other useful command line options in the current example
|
|
program which you can explore below.
|
|
*/
|
|
|
|
|
|
#include <dlib/svm_threaded.h>
|
|
#include <dlib/string.h>
|
|
#include <dlib/gui_widgets.h>
|
|
#include <dlib/array.h>
|
|
#include <dlib/array2d.h>
|
|
#include <dlib/image_keypoint.h>
|
|
#include <dlib/image_processing.h>
|
|
#include <dlib/data_io.h>
|
|
#include <dlib/cmd_line_parser.h>
|
|
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
|
|
|
|
using namespace std;
|
|
using namespace dlib;
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
try
|
|
{
|
|
command_line_parser parser;
|
|
parser.add_option("h","Display this help message.");
|
|
parser.add_option("v","Be verbose.");
|
|
parser.add_option("t","Train an object detector and save the detector to disk.");
|
|
parser.add_option("cross-validate",
|
|
"Perform cross-validation on an image dataset and print the results.");
|
|
parser.add_option("folds","When doing cross-validation, do <arg> folds (default: 3).",1);
|
|
parser.add_option("c","Set the SVM C parameter to <arg> (default: 1.0).",1);
|
|
parser.add_option("threads", "Use <arg> threads for training <arg> (default: 4).",1);
|
|
parser.add_option("grid-size", "Extract features in a detection window from an <arg> by <arg> grid. (default: 2).",1);
|
|
parser.add_option("hash-bits", "Use <arg> bits for the feature hashing (default: 10).", 1);
|
|
parser.add_option("test", "Test a trained detector on an image dataset and print the results.");
|
|
parser.add_option("eps", "Set training epsilon to <arg> (default: 0.3).", 1);
|
|
|
|
|
|
parser.parse(argc, argv);
|
|
|
|
// Now we do a little command line validation. Each of the following functions
|
|
// checks something and throws an exception if the test fails.
|
|
const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "grid-size", "hash-bits",
|
|
"folds", "test", "eps"};
|
|
parser.check_one_time_options(one_time_opts); // Can't give an option more than once
|
|
// Make sure the arguments to these options are within valid ranges if they are supplied by the user.
|
|
parser.check_option_arg_range("c", 1e-12, 1e12);
|
|
parser.check_option_arg_range("eps", 1e-5, 1e4);
|
|
parser.check_option_arg_range("threads", 1, 1000);
|
|
parser.check_option_arg_range("grid-size", 1, 100);
|
|
parser.check_option_arg_range("hash-bits", 1, 32);
|
|
parser.check_option_arg_range("folds", 2, 100);
|
|
const char* incompatible[] = {"t", "cross-validate", "test"};
|
|
parser.check_incompatible_options(incompatible);
|
|
// You are only allowed to give these training_sub_ops if you also give either -t or --cross-validate.
|
|
const char* training_ops[] = {"t", "cross-validate"};
|
|
const char* training_sub_ops[] = {"v", "c", "threads", "grid-size", "hash-bits"};
|
|
parser.check_sub_options(training_ops, training_sub_ops);
|
|
parser.check_sub_option("cross-validate", "folds");
|
|
|
|
|
|
if (parser.option("h"))
|
|
{
|
|
cout << "Usage: train_object_detector [options] <image dataset file|image file>\n";
|
|
parser.print_options();
|
|
|
|
return EXIT_SUCCESS;
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef hashed_feature_image<hog_image<4,4,1,9,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
|
|
typedef scan_image_pyramid<pyramid_down<3>, feature_extractor_type> image_scanner_type;
|
|
|
|
if (parser.option("t") || parser.option("cross-validate"))
|
|
{
|
|
if (parser.number_of_arguments() != 1)
|
|
{
|
|
cout << "You must give an image dataset metadata XML file produced by the imglab tool." << endl;
|
|
cout << "\nTry the -h option for more information." << endl;
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
dlib::array<array2d<unsigned char> > images;
|
|
std::vector<std::vector<rectangle> > object_locations;
|
|
|
|
cout << "Loading image dataset from metadata file " << parser[0] << endl;
|
|
load_image_dataset(images, object_locations, parser[0]);
|
|
|
|
cout << "Number of images loaded: " << images.size() << endl;
|
|
|
|
// Get the value of the hash-bits option if the user supplied it. Otherwise
|
|
// use the default value of 10.
|
|
const int hash_bits = get_option(parser, "hash-bits", 10);
|
|
const int grid_size = get_option(parser, "grid-size", 2);
|
|
const int threads = get_option(parser, "threads", 4);
|
|
const double C = get_option(parser, "c", 1.0);
|
|
const double eps = get_option(parser, "eps", 0.3);
|
|
unsigned int num_folds = get_option(parser, "folds", 3);
|
|
// You can't do more folds than there are images.
|
|
if (num_folds > images.size())
|
|
num_folds = images.size();
|
|
|
|
|
|
image_scanner_type scanner;
|
|
setup_grid_detection_templates_verbose(scanner, object_locations, grid_size, grid_size);
|
|
setup_hashed_features(scanner, images, hash_bits);
|
|
|
|
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
|
|
trainer.set_num_threads(threads);
|
|
|
|
if (parser.option("v"))
|
|
trainer.be_verbose();
|
|
|
|
trainer.set_c(C);
|
|
trainer.set_epsilon(eps);
|
|
|
|
if (parser.option("t"))
|
|
{
|
|
// Do the actual training and save the results into the detector object.
|
|
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
|
|
|
|
cout << "Saving trained detector to object_detector.svm" << endl;
|
|
ofstream fout("object_detector.svm", ios::binary);
|
|
serialize(detector, fout);
|
|
fout.close();
|
|
|
|
cout << "Testing detector on training data..." << endl;
|
|
cout << "Test detector (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations) << endl;
|
|
}
|
|
else
|
|
{
|
|
// shuffle the order of the training images
|
|
randomize_samples(images, object_locations);
|
|
|
|
cout << num_folds << "-fold cross validation (precision,recall,AP): "
|
|
<< cross_validate_object_detection_trainer(trainer, images, object_locations, num_folds) << endl;
|
|
}
|
|
|
|
cout << "Parameters used: " << endl;
|
|
cout << " hash-bits: "<< hash_bits << endl;
|
|
cout << " grid-size: "<< grid_size << endl;
|
|
cout << " threads: "<< threads << endl;
|
|
cout << " C: "<< C << endl;
|
|
cout << " eps: "<< eps << endl;
|
|
if (parser.option("cross-validate"))
|
|
cout << " num_folds: "<< num_folds << endl;
|
|
cout << endl;
|
|
|
|
return EXIT_SUCCESS;
|
|
}
|
|
|
|
|
|
|
|
// The rest of the code is devoted to testing out an already trained
|
|
// object detector.
|
|
|
|
|
|
if (parser.number_of_arguments() == 0)
|
|
{
|
|
cout << "You must give an image or an image dataset metadata XML file produced by the imglab tool." << endl;
|
|
cout << "\nTry the -h option for more information." << endl;
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
// load a previously trained object detector and try it out on some data
|
|
ifstream fin("object_detector.svm", ios::binary);
|
|
if (!fin)
|
|
{
|
|
cout << "Can't find a trained object detector file object_detector.svm. " << endl;
|
|
cout << "You need to train one using the -t option." << endl;
|
|
cout << "\nTry the -h option for more information." << endl;
|
|
return EXIT_FAILURE;
|
|
|
|
}
|
|
object_detector<image_scanner_type> detector;
|
|
deserialize(detector, fin);
|
|
|
|
dlib::array<array2d<unsigned char> > images;
|
|
// Check if the command line argument is an XML file
|
|
if (tolower(right_substr(parser[0],".")) == "xml")
|
|
{
|
|
std::vector<std::vector<rectangle> > object_locations;
|
|
cout << "Loading image dataset from metadata file " << parser[0] << endl;
|
|
load_image_dataset(images, object_locations, parser[0]);
|
|
cout << "Number of images loaded: " << images.size() << endl;
|
|
|
|
if (parser.option("test"))
|
|
{
|
|
cout << "Testing detector on data..." << endl;
|
|
cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations) << endl;
|
|
return EXIT_SUCCESS;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// In this case, the user should have given some image files. So just
|
|
// load them.
|
|
images.resize(parser.number_of_arguments());
|
|
for (unsigned long i = 0; i < images.size(); ++i)
|
|
load_image(images[i], parser[i]);
|
|
}
|
|
|
|
|
|
// Test the detector on the images we loaded and display the results
|
|
// in a window.
|
|
image_window win;
|
|
for (unsigned long i = 0; i < images.size(); ++i)
|
|
{
|
|
// Run the detector on images[i]
|
|
const std::vector<rectangle> rects = detector(images[i]);
|
|
cout << "Number of detections: "<< rects.size() << endl;
|
|
|
|
// Put the image and detections into the window.
|
|
win.clear_overlay();
|
|
win.set_image(images[i]);
|
|
win.add_overlay(rects, rgb_pixel(255,0,0));
|
|
|
|
cout << "Hit enter to see the next image.";
|
|
cin.get();
|
|
}
|
|
|
|
|
|
}
|
|
catch (exception& e)
|
|
{
|
|
cout << "\nexception thrown!" << endl;
|
|
cout << e.what() << endl;
|
|
cout << "\nTry the -h option for more information." << endl;
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
return EXIT_SUCCESS;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|