Improve the data augmentation in the SSL example (#2684)

I was using the data augmentation recommended for the ImageNet dataset, which is not well suited
for CIFAR-10.
After doing so, the test accuracy increased by 1 point.
This commit is contained in:
Adria Arrufat 2022-11-10 12:07:00 +09:00 committed by GitHub
parent 3d5fb6fc7f
commit e5b2cedff8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -104,41 +104,25 @@ rectangle make_random_cropping_rect(
dlib::rand& rnd
)
{
const double mins = 7. / 15.;
const double maxs = 7. / 8.;
const auto scale = rnd.get_double_in_range(mins, maxs);
const auto scale = rnd.get_double_in_range(0.5, 1.0);
const auto ratio = exp(rnd.get_double_in_range(log(3.0 / 4.0), log(4.0 / 3.0)));
const auto size = scale * min(image.nr(), image.nc());
const rectangle rect(size, size);
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
const auto rect = move_rect(set_aspect_ratio(rectangle(size, size), ratio), 0, 0);
const point offset(rnd.get_integer(max(0, static_cast<int>(image.nc() - rect.width()))),
rnd.get_integer(max(0, static_cast<int>(image.nr() - rect.height()))));
return move_rect(rect, offset);
}
// A helper function to generate different kinds of augmentations depending on prime.
matrix<rgb_pixel> augment(
const matrix<rgb_pixel>& image,
const bool prime,
dlib::rand& rnd
)
{
matrix<rgb_pixel> crop;
// blur
matrix<rgb_pixel> blurred;
const double sigma = rnd.get_double_in_range(0.1, 1.1);
if (!prime || (prime && rnd.get_random_double() < 0.1))
{
const auto rect = gaussian_blur(image, blurred, sigma);
extract_image_chip(blurred, rect, crop);
blurred = crop;
}
else
{
blurred = image;
}
// randomly crop
matrix<rgb_pixel> crop;
const auto rect = make_random_cropping_rect(image, rnd);
extract_image_chip(blurred, chip_details(rect, chip_dims(32, 32)), crop);
extract_image_chip(image, chip_details(rect, chip_dims(32, 32)), crop);
// image left-right flip
if (rnd.get_random_double() < 0.5)
@ -146,7 +130,7 @@ matrix<rgb_pixel> augment(
// color augmentation
if (rnd.get_random_double() < 0.8)
disturb_colors(crop, rnd, 0.5, 0.5);
disturb_colors(crop, rnd, 1.0, 0.5);
// grayscale
if (rnd.get_random_double() < 0.2)
@ -155,20 +139,6 @@ matrix<rgb_pixel> augment(
assign_image(gray, crop);
assign_image(crop, gray);
}
// solarize
if (prime && rnd.get_random_double() < 0.2)
{
for (auto& p : crop)
{
if (p.red > 128)
p.red = 255 - p.red;
if (p.green > 128)
p.green = 255 - p.green;
if (p.blue > 128)
p.blue = 255 - p.blue;
}
}
return crop;
}
@ -259,8 +229,8 @@ try
while (batch.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
auto image = training_images[idx];
batch.emplace_back(augment(image, false, rnd), augment(image, true, rnd));
const auto& image = training_images[idx];
batch.emplace_back(augment(image, rnd), augment(image, rnd));
}
trainer.train_one_step(batch);
@ -351,10 +321,9 @@ try
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
};
// Using 10% of the training labels should result in training and testing accuracies
// of around 92% and 87%, respectively.
// Had we used all labels to train the multiclass SVM classifier, we would have got
// a training accuracy of around 93% and a testing accuracy of around 89%, instead.
// Using 10% of the training labels should result in a testing accuracy of
// around 88%. Had we used all labels to train the multiclass SVM classifier,
// we would have got a testing accuracy of around 90%, instead.
cout << "\ntraining accuracy" << endl;
compute_accuracy(features, training_labels);
cout << "\ntesting accuracy" << endl;