mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Improve the data augmentation in the SSL example (#2684)
I was using the data augmentation recommended for the ImageNet dataset, which is not well suited for CIFAR-10. After doing so, the test accuracy increased by 1 point.
This commit is contained in:
parent
3d5fb6fc7f
commit
e5b2cedff8
@ -104,41 +104,25 @@ rectangle make_random_cropping_rect(
|
||||
dlib::rand& rnd
|
||||
)
|
||||
{
|
||||
const double mins = 7. / 15.;
|
||||
const double maxs = 7. / 8.;
|
||||
const auto scale = rnd.get_double_in_range(mins, maxs);
|
||||
const auto scale = rnd.get_double_in_range(0.5, 1.0);
|
||||
const auto ratio = exp(rnd.get_double_in_range(log(3.0 / 4.0), log(4.0 / 3.0)));
|
||||
const auto size = scale * min(image.nr(), image.nc());
|
||||
const rectangle rect(size, size);
|
||||
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
|
||||
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
|
||||
const auto rect = move_rect(set_aspect_ratio(rectangle(size, size), ratio), 0, 0);
|
||||
const point offset(rnd.get_integer(max(0, static_cast<int>(image.nc() - rect.width()))),
|
||||
rnd.get_integer(max(0, static_cast<int>(image.nr() - rect.height()))));
|
||||
return move_rect(rect, offset);
|
||||
}
|
||||
|
||||
// A helper function to generate different kinds of augmentations depending on prime.
|
||||
matrix<rgb_pixel> augment(
|
||||
const matrix<rgb_pixel>& image,
|
||||
const bool prime,
|
||||
dlib::rand& rnd
|
||||
)
|
||||
{
|
||||
matrix<rgb_pixel> crop;
|
||||
// blur
|
||||
matrix<rgb_pixel> blurred;
|
||||
const double sigma = rnd.get_double_in_range(0.1, 1.1);
|
||||
if (!prime || (prime && rnd.get_random_double() < 0.1))
|
||||
{
|
||||
const auto rect = gaussian_blur(image, blurred, sigma);
|
||||
extract_image_chip(blurred, rect, crop);
|
||||
blurred = crop;
|
||||
}
|
||||
else
|
||||
{
|
||||
blurred = image;
|
||||
}
|
||||
|
||||
// randomly crop
|
||||
matrix<rgb_pixel> crop;
|
||||
const auto rect = make_random_cropping_rect(image, rnd);
|
||||
extract_image_chip(blurred, chip_details(rect, chip_dims(32, 32)), crop);
|
||||
extract_image_chip(image, chip_details(rect, chip_dims(32, 32)), crop);
|
||||
|
||||
// image left-right flip
|
||||
if (rnd.get_random_double() < 0.5)
|
||||
@ -146,7 +130,7 @@ matrix<rgb_pixel> augment(
|
||||
|
||||
// color augmentation
|
||||
if (rnd.get_random_double() < 0.8)
|
||||
disturb_colors(crop, rnd, 0.5, 0.5);
|
||||
disturb_colors(crop, rnd, 1.0, 0.5);
|
||||
|
||||
// grayscale
|
||||
if (rnd.get_random_double() < 0.2)
|
||||
@ -155,20 +139,6 @@ matrix<rgb_pixel> augment(
|
||||
assign_image(gray, crop);
|
||||
assign_image(crop, gray);
|
||||
}
|
||||
|
||||
// solarize
|
||||
if (prime && rnd.get_random_double() < 0.2)
|
||||
{
|
||||
for (auto& p : crop)
|
||||
{
|
||||
if (p.red > 128)
|
||||
p.red = 255 - p.red;
|
||||
if (p.green > 128)
|
||||
p.green = 255 - p.green;
|
||||
if (p.blue > 128)
|
||||
p.blue = 255 - p.blue;
|
||||
}
|
||||
}
|
||||
return crop;
|
||||
}
|
||||
|
||||
@ -259,8 +229,8 @@ try
|
||||
while (batch.size() < trainer.get_mini_batch_size())
|
||||
{
|
||||
const auto idx = rnd.get_random_32bit_number() % training_images.size();
|
||||
auto image = training_images[idx];
|
||||
batch.emplace_back(augment(image, false, rnd), augment(image, true, rnd));
|
||||
const auto& image = training_images[idx];
|
||||
batch.emplace_back(augment(image, rnd), augment(image, rnd));
|
||||
}
|
||||
trainer.train_one_step(batch);
|
||||
|
||||
@ -351,10 +321,9 @@ try
|
||||
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
||||
};
|
||||
|
||||
// Using 10% of the training labels should result in training and testing accuracies
|
||||
// of around 92% and 87%, respectively.
|
||||
// Had we used all labels to train the multiclass SVM classifier, we would have got
|
||||
// a training accuracy of around 93% and a testing accuracy of around 89%, instead.
|
||||
// Using 10% of the training labels should result in a testing accuracy of
|
||||
// around 88%. Had we used all labels to train the multiclass SVM classifier,
|
||||
// we would have got a testing accuracy of around 90%, instead.
|
||||
cout << "\ntraining accuracy" << endl;
|
||||
compute_accuracy(features, training_labels);
|
||||
cout << "\ntesting accuracy" << endl;
|
||||
|
Loading…
Reference in New Issue
Block a user