Replace fc classifier with svm_multiclass_linear_trainer

This commit is contained in:
Adrià Arrufat 2021-11-02 02:48:18 +09:00
parent a41b3d7ce8
commit c8c810f221

View File

@ -40,6 +40,7 @@
#include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h>
#include <dlib/gui_widgets.h>
#include <dlib/svm_threaded.h>
using namespace std;
using namespace dlib;
@ -82,14 +83,12 @@ namespace resnet50
// This model namespace contains the definitions for:
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
// - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image.
namespace model
{
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
template <typename SUBNET> using classifier = fc<10, SUBNET>;
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
using feats = loss_metric<resnet50::def<affine>::backbone<input_rgb_image>>;
}
rectangle make_random_cropping_rect(
@ -170,6 +169,7 @@ try
// The default settings are fine for the example already.
command_line_parser parser;
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
parser.add_option("c", "SVM C parameter for the linear classifier (default: 5000)", 1);
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
@ -196,6 +196,7 @@ try
const double lambda = get_option(parser, "lambda", 1.0 / dims);
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
const double svm_c = get_option(parser, "c", 5000);
// Load the CIFAR-10 dataset into memory.
std::vector<matrix<rgb_pixel>> training_images, testing_images;
@ -288,73 +289,47 @@ try
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
}
// To check the quality of the learned feature representations, we will train a linear
// classififer on top of the frozen backbone.
model::infer inet;
// Assign the network, without the projector, which is only used for the self-supervised
// training.
layer<2>(inet) = layer<5>(net);
// Freeze the backbone
set_all_learning_rate_multipliers(layer<2>(inet), 0);
// Train the network
{
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
// Since this model doesn't train with pairs, just single images, we can increase
// the batch size.
trainer.set_mini_batch_size(2 * batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(5000);
trainer.set_synchronization_file("cifar_10_sync");
trainer.be_verbose();
cout << trainer << endl;
std::vector<matrix<rgb_pixel>> images;
std::vector<unsigned long> labels;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
{
images.clear();
labels.clear();
while (images.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
images.push_back(augment(training_images[idx], false, rnd));
labels.push_back(training_labels[idx]);
}
trainer.train_one_step(images, labels);
}
trainer.get_net();
inet.clean();
serialize("resnet50_cifar_10.dnn") << inet;
}
// Now, we initialize the feature extractor model with the backbone we have just learned.
model::feats fnet(layer<5>(net));
// And we will generate all the features for the training set to train a multiclass SVM
// classifier.
std::vector<matrix<float, 0, 1>> features;
cout << "Extracting features for linear classifier..." << endl;
features = fnet(training_images, 4 * batch_size);
svm_multiclass_linear_trainer<linear_kernel<matrix<float,0,1>>, unsigned long> trainer;
trainer.set_num_threads(std::thread::hardware_concurrency());
trainer.set_c(svm_c);
cout << "Training Multiclass SVM..." << endl;
const auto df = trainer.train(features, training_labels);
serialize("multiclass_svm_cifar_10.dat") << df;
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
auto compute_accuracy = [&inet, batch_size](
const std::vector<matrix<rgb_pixel>>& images,
auto compute_accuracy = [&fnet, &df, batch_size](
const std::vector<matrix<float, 0, 1>>& samples,
const std::vector<unsigned long>& labels
)
{
size_t num_right = 0;
size_t num_wrong = 0;
const auto preds = inet(images, batch_size * 2);
for (size_t i = 0; i < labels.size(); ++i)
{
if (labels[i] == preds[i])
if (labels[i] == df(samples[i]))
++num_right;
else
++num_wrong;
}
cout << "num right: " << num_right << endl;
cout << "num wrong: " << num_wrong << endl;
cout << "accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
cout << "error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
cout << " num right: " << num_right << endl;
cout << " num wrong: " << num_wrong << endl;
cout << " accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
};
// If everything works as expected, we should get accuracies that are between 87% and 90%.
cout << "training accuracy" << endl;
compute_accuracy(training_images, training_labels);
// We should get a training accuracy of around 93% and a testing accuracy of around 88%.
cout << "\ntraining accuracy" << endl;
compute_accuracy(features, training_labels);
cout << "\ntesting accuracy" << endl;
compute_accuracy(testing_images, testing_labels);
features = fnet(testing_images, 4 * batch_size);
compute_accuracy(features, testing_labels);
return EXIT_SUCCESS;
}
catch (const exception& e)