Replace sgd-based fc classifier with svm_multiclass_linear_trainer (#2452)

* Replace fc classifier with svm_multiclass_linear_trainer

* Mention about find_max_global()

Co-authored-by: Davis E. King <davis@dlib.net>

* Use double instead of float for extracted features

Co-authored-by: Davis E. King <davis@dlib.net>

* fix compilation with double features

* Revert "fix compilation with double features"

This reverts commit 76ebab4b91.

* Revert "Use double instead of float for extracted features"

This reverts commit 9a50809ebf.

* Find best C using global optimization

Co-authored-by: Davis E. King <davis@dlib.net>
This commit is contained in:
Adrià Arrufat 2021-11-06 23:33:31 +01:00 committed by GitHub
parent f77189db03
commit 5091e9c880
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,10 +36,12 @@
a max pooling layer afterwards, like the paper does.
*/
#include <dlib/dnn.h>
#include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h>
#include <dlib/data_io.h>
#include <dlib/dnn.h>
#include <dlib/global_optimization.h>
#include <dlib/gui_widgets.h>
#include <dlib/svm_threaded.h>
using namespace std;
using namespace dlib;
@ -82,14 +84,12 @@ namespace resnet50
// This model namespace contains the definitions for:
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
// - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image.
namespace model
{
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
template <typename SUBNET> using classifier = fc<10, SUBNET>;
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
using feats = loss_metric<resnet50::def<affine>::backbone<input_rgb_image>>;
}
rectangle make_random_cropping_rect(
@ -288,58 +288,49 @@ try
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
}
// To check the quality of the learned feature representations, we will train a linear
// classififer on top of the frozen backbone.
model::infer inet;
// Assign the network, without the projector, which is only used for the self-supervised
// training.
layer<2>(inet) = layer<5>(net);
// Freeze the backbone
set_all_learning_rate_multipliers(layer<2>(inet), 0);
// Train the network
{
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
// Since this model doesn't train with pairs, just single images, we can increase
// the batch size.
trainer.set_mini_batch_size(2 * batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(5000);
trainer.set_synchronization_file("cifar_10_sync");
trainer.be_verbose();
cout << trainer << endl;
// Now, we initialize the feature extractor model with the backbone we have just learned.
model::feats fnet(layer<5>(net));
// And we will generate all the features for the training set to train a multiclass SVM
// classifier.
std::vector<matrix<float, 0, 1>> features;
cout << "Extracting features for linear classifier..." << endl;
features = fnet(training_images, 4 * batch_size);
std::vector<matrix<rgb_pixel>> images;
std::vector<unsigned long> labels;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
// Find the most appropriate C setting using find_max_global.
auto cross_validation_score = [&](const double c)
{
images.clear();
labels.clear();
while (images.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
images.push_back(augment(training_images[idx], false, rnd));
labels.push_back(training_labels[idx]);
}
trainer.train_one_step(images, labels);
}
trainer.get_net();
inet.clean();
serialize("resnet50_cifar_10.dnn") << inet;
}
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
trainer.set_num_threads(std::thread::hardware_concurrency());
trainer.set_c(c);
cout << "C: " << c << endl;
const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3);
const double accuracy = sum(diag(cm)) / sum(cm);
cout << "cross validation accuracy: " << accuracy << endl;;
cout << "confusion matrix:\n " << cm << endl;
return accuracy;
};
const auto result = find_max_global(cross_validation_score, 1e-4, 10000, max_function_calls(50));
cout << "Best C: " << result.x(0) << endl;
// Proceed to train the SVM classifier with the best C.
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
trainer.set_num_threads(std::thread::hardware_concurrency());
trainer.set_c(result.x(0));
cout << "Training Multiclass SVM..." << endl;
const auto df = trainer.train(features, training_labels);
serialize("multiclass_svm_cifar_10.dat") << df;
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
auto compute_accuracy = [&inet, batch_size](
const std::vector<matrix<rgb_pixel>>& images,
auto compute_accuracy = [&fnet, &df, batch_size](
const std::vector<matrix<float, 0, 1>>& samples,
const std::vector<unsigned long>& labels
)
{
size_t num_right = 0;
size_t num_wrong = 0;
const auto preds = inet(images, batch_size * 2);
for (size_t i = 0; i < labels.size(); ++i)
{
if (labels[i] == preds[i])
if (labels[i] == df(samples[i]))
++num_right;
else
++num_wrong;
@ -350,11 +341,12 @@ try
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
};
// If everything works as expected, we should get accuracies that are between 87% and 90%.
cout << "training accuracy" << endl;
compute_accuracy(training_images, training_labels);
// We should get a training accuracy of around 93% and a testing accuracy of around 88%.
cout << "\ntraining accuracy" << endl;
compute_accuracy(features, training_labels);
cout << "\ntesting accuracy" << endl;
compute_accuracy(testing_images, testing_labels);
features = fnet(testing_images, 4 * batch_size);
compute_accuracy(features, testing_labels);
return EXIT_SUCCESS;
}
catch (const exception& e)