mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Updated this example to use the scan_fhog_pyramid version of the object
detector since it is much more user friendly.
This commit is contained in:
parent
e79a764834
commit
8e1e548a70
@ -5,7 +5,7 @@
|
||||
functional command line tool for object detection. This example assumes
|
||||
you are familiar with the contents of at least the following example
|
||||
programs:
|
||||
- object_detector_ex.cpp
|
||||
- fhog_object_detector_ex.cpp
|
||||
- compress_stream_ex.cpp
|
||||
|
||||
|
||||
@ -35,7 +35,11 @@
|
||||
holding the shift key, left clicking, and dragging the mouse will allow you to
|
||||
draw boxes around the objects you wish to detect. So next, label all the objects
|
||||
with boxes. Note that it is important to label all the objects since any object
|
||||
not labeled is implicitly assumed to be not an object we should detect.
|
||||
not labeled is implicitly assumed to be not an object we should detect. If there
|
||||
are objects you are not sure about you should draw a box around them, then double
|
||||
click the box and press i. This will cross out the box and mark it as "ignore".
|
||||
The training code in dlib will then simply ignore detections matching that box.
|
||||
|
||||
|
||||
Once you finish labeling objects go to the file menu, click save, and then close
|
||||
the program. This will save the object boxes back to mydataset.xml. You can verify
|
||||
@ -53,18 +57,20 @@
|
||||
This command will display some_image.png in a window and any detected objects will
|
||||
be indicated by a red box.
|
||||
|
||||
|
||||
There are a number of other useful command line options in the current example
|
||||
program which you can explore below.
|
||||
Finally, to make running this example easy dlib includes some training data in the
|
||||
examples/faces folder. Therefore, you can test this program out by running the
|
||||
following sequence of commands:
|
||||
./train_object_detector -tv examples/faces/training.xml -u1 --flip
|
||||
./train_object_detector --test examples/faces/testing.xml -u1
|
||||
./train_object_detector examples/faces/*.jpg -u1
|
||||
That will make a face detector that performs perfectly on the test images listed in
|
||||
testing.xml and then it will show you the detections on all the images.
|
||||
*/
|
||||
|
||||
|
||||
#include <dlib/svm_threaded.h>
|
||||
#include <dlib/string.h>
|
||||
#include <dlib/gui_widgets.h>
|
||||
#include <dlib/array.h>
|
||||
#include <dlib/array2d.h>
|
||||
#include <dlib/image_keypoint.h>
|
||||
#include <dlib/image_processing.h>
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/cmd_line_parser.h>
|
||||
@ -79,44 +85,131 @@ using namespace dlib;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void pick_best_window_size (
|
||||
const std::vector<std::vector<rectangle> >& boxes,
|
||||
unsigned long& width,
|
||||
unsigned long& height,
|
||||
const unsigned long target_size
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- Finds the average aspect ratio of the elements of boxes and outputs a width
|
||||
and height such that the aspect ratio is equal to the average and also the
|
||||
area is equal to target_size. That is, the following will be approximately true:
|
||||
- #width*#height == target_size
|
||||
- #width/#height == the average aspect ratio of the elements of boxes.
|
||||
!*/
|
||||
{
|
||||
// find the average width and height
|
||||
running_stats<double> avg_width, avg_height;
|
||||
for (unsigned long i = 0; i < boxes.size(); ++i)
|
||||
{
|
||||
for (unsigned long j = 0; j < boxes[i].size(); ++j)
|
||||
{
|
||||
avg_width.add(boxes[i][j].width());
|
||||
avg_height.add(boxes[i][j].height());
|
||||
}
|
||||
}
|
||||
|
||||
// now adjust the box size so that it is about target_pixels pixels in size
|
||||
double size = avg_width.mean()*avg_height.mean();
|
||||
double scale = std::sqrt(target_size/size);
|
||||
|
||||
width = (unsigned long)(avg_width.mean()*scale+0.5);
|
||||
height = (unsigned long)(avg_height.mean()*scale+0.5);
|
||||
// make sure the width and height never round to zero.
|
||||
if (width == 0)
|
||||
width = 1;
|
||||
if (height == 0)
|
||||
height = 1;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
bool contains_any_boxes (
|
||||
const std::vector<std::vector<rectangle> >& boxes
|
||||
)
|
||||
{
|
||||
for (unsigned long i = 0; i < boxes.size(); ++i)
|
||||
{
|
||||
if (boxes[i].size() != 0)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void throw_invalid_box_error_message (
|
||||
const std::string& dataset_filename,
|
||||
const std::vector<std::vector<rectangle> >& removed,
|
||||
const unsigned long target_size
|
||||
)
|
||||
{
|
||||
image_dataset_metadata::dataset data;
|
||||
load_image_dataset_metadata(data, dataset_filename);
|
||||
|
||||
std::ostringstream sout;
|
||||
sout << "Error! An impossible set of object boxes was given for training. ";
|
||||
sout << "All the boxes need to have a similar aspect ratio and also not be ";
|
||||
sout << "smaller than about " << target_size << " pixels in area. ";
|
||||
sout << "The following images contain invalid boxes:\n";
|
||||
std::ostringstream sout2;
|
||||
for (unsigned long i = 0; i < removed.size(); ++i)
|
||||
{
|
||||
if (removed[i].size() != 0)
|
||||
{
|
||||
const std::string imgname = data.images[i].filename;
|
||||
sout2 << " " << imgname << "\n";
|
||||
}
|
||||
}
|
||||
throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
try
|
||||
{
|
||||
command_line_parser parser;
|
||||
parser.add_option("h","Display this help message.");
|
||||
parser.add_option("v","Be verbose.");
|
||||
parser.add_option("t","Train an object detector and save the detector to disk.");
|
||||
parser.add_option("cross-validate",
|
||||
"Perform cross-validation on an image dataset and print the results.");
|
||||
parser.add_option("test", "Test a trained detector on an image dataset and print the results.");
|
||||
parser.add_option("u", "Upsample each input image <arg> times. Each upsampling quadruples the number of pixels in the image (default: 0).", 1);
|
||||
|
||||
parser.set_group_name("training/cross-validation sub-options");
|
||||
parser.add_option("v","Be verbose.");
|
||||
parser.add_option("folds","When doing cross-validation, do <arg> folds (default: 3).",1);
|
||||
parser.add_option("c","Set the SVM C parameter to <arg> (default: 1.0).",1);
|
||||
parser.add_option("threads", "Use <arg> threads for training <arg> (default: 4).",1);
|
||||
parser.add_option("grid-size", "Extract features in a detection window from an <arg> by <arg> grid. (default: 2).",1);
|
||||
parser.add_option("hash-bits", "Use <arg> bits for the feature hashing (default: 10).", 1);
|
||||
parser.add_option("test", "Test a trained detector on an image dataset and print the results.");
|
||||
parser.add_option("eps", "Set training epsilon to <arg> (default: 0.3).", 1);
|
||||
parser.add_option("eps", "Set training epsilon to <arg> (default: 0.01).", 1);
|
||||
parser.add_option("target-size", "Set size of the sliding window to about <arg> pixels in area (default: 80*80).", 1);
|
||||
parser.add_option("flip", "Add left/right flipped copies of the images into the training dataset. Useful when the objects "
|
||||
"you want to detect are left/right symmetric.");
|
||||
|
||||
|
||||
parser.parse(argc, argv);
|
||||
|
||||
// Now we do a little command line validation. Each of the following functions
|
||||
// checks something and throws an exception if the test fails.
|
||||
const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "grid-size", "hash-bits",
|
||||
"folds", "test", "eps"};
|
||||
const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "target-size",
|
||||
"folds", "test", "eps", "u", "flip"};
|
||||
parser.check_one_time_options(one_time_opts); // Can't give an option more than once
|
||||
// Make sure the arguments to these options are within valid ranges if they are supplied by the user.
|
||||
parser.check_option_arg_range("c", 1e-12, 1e12);
|
||||
parser.check_option_arg_range("eps", 1e-5, 1e4);
|
||||
parser.check_option_arg_range("threads", 1, 1000);
|
||||
parser.check_option_arg_range("grid-size", 1, 100);
|
||||
parser.check_option_arg_range("hash-bits", 1, 32);
|
||||
parser.check_option_arg_range("folds", 2, 100);
|
||||
parser.check_option_arg_range("u", 0, 8);
|
||||
parser.check_option_arg_range("target-size", 4*4, 10000*10000);
|
||||
const char* incompatible[] = {"t", "cross-validate", "test"};
|
||||
parser.check_incompatible_options(incompatible);
|
||||
// You are only allowed to give these training_sub_ops if you also give either -t or --cross-validate.
|
||||
const char* training_ops[] = {"t", "cross-validate"};
|
||||
const char* training_sub_ops[] = {"v", "c", "threads", "grid-size", "hash-bits"};
|
||||
const char* training_sub_ops[] = {"v", "c", "threads", "target-size", "eps", "flip"};
|
||||
parser.check_sub_options(training_ops, training_sub_ops);
|
||||
parser.check_sub_option("cross-validate", "folds");
|
||||
|
||||
@ -130,10 +223,9 @@ int main(int argc, char** argv)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
typedef hashed_feature_image<hog_image<4,4,1,9,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
|
||||
typedef scan_image_pyramid<pyramid_down<3>, feature_extractor_type> image_scanner_type;
|
||||
typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type;
|
||||
// Get the upsample option from the user but use 0 if it wasn't given.
|
||||
const unsigned long upsample_amount = get_option(parser, "u", 0);
|
||||
|
||||
if (parser.option("t") || parser.option("cross-validate"))
|
||||
{
|
||||
@ -145,43 +237,58 @@ int main(int argc, char** argv)
|
||||
}
|
||||
|
||||
dlib::array<array2d<unsigned char> > images;
|
||||
std::vector<std::vector<rectangle> > object_locations;
|
||||
std::vector<std::vector<rectangle> > object_locations, ignore;
|
||||
|
||||
cout << "Loading image dataset from metadata file " << parser[0] << endl;
|
||||
load_image_dataset(images, object_locations, parser[0]);
|
||||
|
||||
ignore = load_image_dataset(images, object_locations, parser[0]);
|
||||
cout << "Number of images loaded: " << images.size() << endl;
|
||||
|
||||
// Get the value of the hash-bits option if the user supplied it. Otherwise
|
||||
// use the default value of 10.
|
||||
const int hash_bits = get_option(parser, "hash-bits", 10);
|
||||
const int grid_size = get_option(parser, "grid-size", 2);
|
||||
// Get the options from the user, but use default values if they are not
|
||||
// supplied.
|
||||
const int threads = get_option(parser, "threads", 4);
|
||||
const double C = get_option(parser, "c", 1.0);
|
||||
const double eps = get_option(parser, "eps", 0.3);
|
||||
const double eps = get_option(parser, "eps", 0.01);
|
||||
unsigned int num_folds = get_option(parser, "folds", 3);
|
||||
const unsigned long target_size = get_option(parser, "target-size", 80*80);
|
||||
// You can't do more folds than there are images.
|
||||
if (num_folds > images.size())
|
||||
num_folds = images.size();
|
||||
|
||||
// Upsample images if the user asked us to do that.
|
||||
for (unsigned long i = 0; i < upsample_amount; ++i)
|
||||
upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore);
|
||||
|
||||
|
||||
image_scanner_type scanner;
|
||||
setup_grid_detection_templates_verbose(scanner, object_locations, grid_size, grid_size);
|
||||
setup_hashed_features(scanner, images, hash_bits);
|
||||
unsigned long width, height;
|
||||
pick_best_window_size(object_locations, width, height, target_size);
|
||||
scanner.set_detection_window_size(width, height);
|
||||
|
||||
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
|
||||
trainer.set_num_threads(threads);
|
||||
|
||||
if (parser.option("v"))
|
||||
trainer.be_verbose();
|
||||
|
||||
trainer.set_c(C);
|
||||
trainer.set_epsilon(eps);
|
||||
|
||||
// Now make sure all the boxes are obtainable by the scanner.
|
||||
std::vector<std::vector<rectangle> > removed;
|
||||
removed = remove_unobtainable_rectangles(trainer, images, object_locations);
|
||||
// if we weren't able to get all the boxes to match then throw an error
|
||||
if (contains_any_boxes(removed))
|
||||
{
|
||||
unsigned long scale = upsample_amount+1;
|
||||
scale = scale*scale;
|
||||
throw_invalid_box_error_message(parser[0], removed, target_size/scale);
|
||||
}
|
||||
|
||||
if (parser.option("flip"))
|
||||
add_image_left_right_flips(images, object_locations, ignore);
|
||||
|
||||
if (parser.option("t"))
|
||||
{
|
||||
// Do the actual training and save the results into the detector object.
|
||||
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
|
||||
object_detector<image_scanner_type> detector = trainer.train(images, object_locations, ignore);
|
||||
|
||||
cout << "Saving trained detector to object_detector.svm" << endl;
|
||||
ofstream fout("object_detector.svm", ios::binary);
|
||||
@ -197,15 +304,19 @@ int main(int argc, char** argv)
|
||||
randomize_samples(images, object_locations);
|
||||
|
||||
cout << num_folds << "-fold cross validation (precision,recall,AP): "
|
||||
<< cross_validate_object_detection_trainer(trainer, images, object_locations, num_folds) << endl;
|
||||
<< cross_validate_object_detection_trainer(trainer, images, object_locations, ignore, num_folds) << endl;
|
||||
}
|
||||
|
||||
cout << "Parameters used: " << endl;
|
||||
cout << " hash-bits: "<< hash_bits << endl;
|
||||
cout << " grid-size: "<< grid_size << endl;
|
||||
cout << " threads: "<< threads << endl;
|
||||
cout << " C: "<< C << endl;
|
||||
cout << " eps: "<< eps << endl;
|
||||
cout << " threads: "<< threads << endl;
|
||||
cout << " C: "<< C << endl;
|
||||
cout << " eps: "<< eps << endl;
|
||||
cout << " target-size: "<< target_size << endl;
|
||||
cout << " detection window width: "<< width << endl;
|
||||
cout << " detection window height: "<< height << endl;
|
||||
cout << " upsample this many times : "<< upsample_amount << endl;
|
||||
if (parser.option("flip"))
|
||||
cout << " trained using left/right flips." << endl;
|
||||
if (parser.option("cross-validate"))
|
||||
cout << " num_folds: "<< num_folds << endl;
|
||||
cout << endl;
|
||||
@ -215,10 +326,12 @@ int main(int argc, char** argv)
|
||||
|
||||
|
||||
|
||||
// The rest of the code is devoted to testing out an already trained
|
||||
// object detector.
|
||||
|
||||
|
||||
|
||||
|
||||
// The rest of the code is devoted to testing an already trained object detector.
|
||||
|
||||
if (parser.number_of_arguments() == 0)
|
||||
{
|
||||
cout << "You must give an image or an image dataset metadata XML file produced by the imglab tool." << endl;
|
||||
@ -243,15 +356,19 @@ int main(int argc, char** argv)
|
||||
// Check if the command line argument is an XML file
|
||||
if (tolower(right_substr(parser[0],".")) == "xml")
|
||||
{
|
||||
std::vector<std::vector<rectangle> > object_locations;
|
||||
std::vector<std::vector<rectangle> > object_locations, ignore;
|
||||
cout << "Loading image dataset from metadata file " << parser[0] << endl;
|
||||
load_image_dataset(images, object_locations, parser[0]);
|
||||
ignore = load_image_dataset(images, object_locations, parser[0]);
|
||||
cout << "Number of images loaded: " << images.size() << endl;
|
||||
|
||||
// Upsample images if the user asked us to do that.
|
||||
for (unsigned long i = 0; i < upsample_amount; ++i)
|
||||
upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore);
|
||||
|
||||
if (parser.option("test"))
|
||||
{
|
||||
cout << "Testing detector on data..." << endl;
|
||||
cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations) << endl;
|
||||
cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations, ignore) << endl;
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
}
|
||||
@ -262,6 +379,13 @@ int main(int argc, char** argv)
|
||||
images.resize(parser.number_of_arguments());
|
||||
for (unsigned long i = 0; i < images.size(); ++i)
|
||||
load_image(images[i], parser[i]);
|
||||
|
||||
// Upsample images if the user asked us to do that.
|
||||
for (unsigned long i = 0; i < upsample_amount; ++i)
|
||||
{
|
||||
for (unsigned long j = 0; j < images.size(); ++j)
|
||||
pyramid_up(images[j]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user