updated object detector tests to do some movable part stuff.

This commit is contained in:
Davis King 2012-08-24 23:09:58 -04:00
parent 1468d153ce
commit 2a5072d116

View File

@ -71,15 +71,6 @@ namespace
double check_score = dot(psi,detector.get_w()) - thresh;
DLIB_TEST(std::abs(check_score - dets2[j].first) < 1e-10);
// Make sure fdet works the way it is supposed to with get_feature_vector().
psi2 = 0;
detector.get_scanner().get_feature_vector(fdet, psi2);
check_score = dot(psi2,detector.get_w()) - thresh;
DLIB_TEST(std::abs(check_score - dets2[j].first) < 1e-10);
DLIB_TEST(max(abs(psi-psi2)) < 1e-10);
}
}
@ -259,6 +250,90 @@ namespace
object_locations.push_back(temp);
}
template <
typename image_array_type
>
void make_simple_test_data (
image_array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations
)
{
images.clear();
object_locations.clear();
images.resize(3);
images[0].set_size(400,400);
images[1].set_size(400,400);
images[2].set_size(400,400);
// set all the pixel values to black
assign_all_pixels(images[0], 0);
assign_all_pixels(images[1], 0);
assign_all_pixels(images[2], 0);
// Now make some squares and draw them onto our black images. All the
// squares will be 70 pixels wide and tall.
const int shrink = 0;
std::vector<full_object_detection> temp;
temp.push_back(full_object_detection(centered_rect(point(100,100), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[0],temp.back().rect,255); // Paint the square white
temp.push_back(full_object_detection(centered_rect(point(200,300), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[0],temp.back().rect,255); // Paint the square white
object_locations.push_back(temp);
temp.clear();
temp.push_back(full_object_detection(centered_rect(point(140,200), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[1],temp.back().rect,255); // Paint the square white
temp.push_back(full_object_detection(centered_rect(point(303,200), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[1],temp.back().rect,255); // Paint the square white
object_locations.push_back(temp);
temp.clear();
temp.push_back(full_object_detection(centered_rect(point(123,121), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[2],temp.back().rect,255); // Paint the square white
object_locations.push_back(temp);
// corrupt each image with random noise just to make this a little more
// challenging
dlib::rand rnd;
for (unsigned long i = 0; i < images.size(); ++i)
{
for (long r = 0; r < images[i].nr(); ++r)
{
for (long c = 0; c < images[i].nc(); ++c)
{
images[i][r][c] = put_in_range(0,255,images[i][r][c] + 40*rnd.get_random_gaussian());
}
}
}
}
// ----------------------------------------------------------------------------------------
void test_1 (
@ -293,7 +368,55 @@ namespace
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
validate_some_object_detector_stuff(images, detector);
}
}
// ----------------------------------------------------------------------------------------
void test_1m (
)
{
print_spinner();
dlog << LINFO << "test_1m()";
typedef array<array2d<unsigned char> > grayscale_image_array_type;
grayscale_image_array_type images;
std::vector<std::vector<full_object_detection> > object_locations;
make_simple_test_data(images, object_locations);
typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
typedef scan_image_pyramid<pyramid_down, feature_extractor_type> image_scanner_type;
image_scanner_type scanner;
const rectangle object_box = compute_box_dimensions(1,35*35);
std::vector<rectangle> mboxes;
const int mbox_size = 20;
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
scanner.add_detection_template(object_box, create_grid_detection_template(object_box,1,1), mboxes);
setup_hashed_features(scanner, images, 9);
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(4);
trainer.set_overlap_tester(test_box_overlap(0,0));
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
{
ostringstream sout;
serialize(detector, sout);
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
@ -335,7 +458,7 @@ namespace
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
@ -377,7 +500,55 @@ namespace
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
validate_some_object_detector_stuff(images, detector);
}
}
// ----------------------------------------------------------------------------------------
void test_1m_poly (
)
{
print_spinner();
dlog << LINFO << "test_1_poly()";
typedef array<array2d<unsigned char> > grayscale_image_array_type;
grayscale_image_array_type images;
std::vector<std::vector<full_object_detection> > object_locations;
make_simple_test_data(images, object_locations);
typedef hashed_feature_image<poly_image<2> > feature_extractor_type;
typedef scan_image_pyramid<pyramid_down_3_2, feature_extractor_type> image_scanner_type;
image_scanner_type scanner;
const rectangle object_box = compute_box_dimensions(1,35*35);
std::vector<rectangle> mboxes;
const int mbox_size = 20;
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2), mboxes);
setup_hashed_features(scanner, images, 9);
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(4);
trainer.set_overlap_tester(test_box_overlap(0,0));
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
{
ostringstream sout;
serialize(detector, sout);
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
@ -423,7 +594,7 @@ namespace
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
@ -467,7 +638,7 @@ namespace
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
validate_some_object_detector_stuff(images, detector);
@ -559,7 +730,7 @@ namespace
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 2);
}
@ -580,8 +751,10 @@ namespace
)
{
test_1();
test_1m();
test_1_fine_hog();
test_1_poly();
test_1m_poly();
test_1_poly_nn();
test_2();
test_3();