mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Replace sgd-based fc classifier with svm_multiclass_linear_trainer (#2452)
* Replace fc classifier with svm_multiclass_linear_trainer * Mention about find_max_global() Co-authored-by: Davis E. King <davis@dlib.net> * Use double instead of float for extracted features Co-authored-by: Davis E. King <davis@dlib.net> * fix compilation with double features * Revert "fix compilation with double features" This reverts commit76ebab4b91
. * Revert "Use double instead of float for extracted features" This reverts commit9a50809ebf
. * Find best C using global optimization Co-authored-by: Davis E. King <davis@dlib.net>
This commit is contained in:
parent
f77189db03
commit
5091e9c880
@ -36,10 +36,12 @@
|
||||
a max pooling layer afterwards, like the paper does.
|
||||
*/
|
||||
|
||||
#include <dlib/dnn.h>
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/cmd_line_parser.h>
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/dnn.h>
|
||||
#include <dlib/global_optimization.h>
|
||||
#include <dlib/gui_widgets.h>
|
||||
#include <dlib/svm_threaded.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
@ -82,14 +84,12 @@ namespace resnet50
|
||||
|
||||
// This model namespace contains the definitions for:
|
||||
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
|
||||
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
|
||||
// - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image.
|
||||
namespace model
|
||||
{
|
||||
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
|
||||
template <typename SUBNET> using classifier = fc<10, SUBNET>;
|
||||
|
||||
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
|
||||
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
|
||||
using feats = loss_metric<resnet50::def<affine>::backbone<input_rgb_image>>;
|
||||
}
|
||||
|
||||
rectangle make_random_cropping_rect(
|
||||
@ -288,58 +288,49 @@ try
|
||||
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
|
||||
}
|
||||
|
||||
// To check the quality of the learned feature representations, we will train a linear
|
||||
// classififer on top of the frozen backbone.
|
||||
model::infer inet;
|
||||
// Assign the network, without the projector, which is only used for the self-supervised
|
||||
// training.
|
||||
layer<2>(inet) = layer<5>(net);
|
||||
// Freeze the backbone
|
||||
set_all_learning_rate_multipliers(layer<2>(inet), 0);
|
||||
// Train the network
|
||||
{
|
||||
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
|
||||
// Since this model doesn't train with pairs, just single images, we can increase
|
||||
// the batch size.
|
||||
trainer.set_mini_batch_size(2 * batch_size);
|
||||
trainer.set_learning_rate(learning_rate);
|
||||
trainer.set_min_learning_rate(min_learning_rate);
|
||||
trainer.set_iterations_without_progress_threshold(5000);
|
||||
trainer.set_synchronization_file("cifar_10_sync");
|
||||
trainer.be_verbose();
|
||||
cout << trainer << endl;
|
||||
// Now, we initialize the feature extractor model with the backbone we have just learned.
|
||||
model::feats fnet(layer<5>(net));
|
||||
// And we will generate all the features for the training set to train a multiclass SVM
|
||||
// classifier.
|
||||
std::vector<matrix<float, 0, 1>> features;
|
||||
cout << "Extracting features for linear classifier..." << endl;
|
||||
features = fnet(training_images, 4 * batch_size);
|
||||
|
||||
std::vector<matrix<rgb_pixel>> images;
|
||||
std::vector<unsigned long> labels;
|
||||
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
|
||||
// Find the most appropriate C setting using find_max_global.
|
||||
auto cross_validation_score = [&](const double c)
|
||||
{
|
||||
images.clear();
|
||||
labels.clear();
|
||||
while (images.size() < trainer.get_mini_batch_size())
|
||||
{
|
||||
const auto idx = rnd.get_random_32bit_number() % training_images.size();
|
||||
images.push_back(augment(training_images[idx], false, rnd));
|
||||
labels.push_back(training_labels[idx]);
|
||||
}
|
||||
trainer.train_one_step(images, labels);
|
||||
}
|
||||
trainer.get_net();
|
||||
inet.clean();
|
||||
serialize("resnet50_cifar_10.dnn") << inet;
|
||||
}
|
||||
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
|
||||
trainer.set_num_threads(std::thread::hardware_concurrency());
|
||||
trainer.set_c(c);
|
||||
cout << "C: " << c << endl;
|
||||
const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3);
|
||||
const double accuracy = sum(diag(cm)) / sum(cm);
|
||||
cout << "cross validation accuracy: " << accuracy << endl;;
|
||||
cout << "confusion matrix:\n " << cm << endl;
|
||||
return accuracy;
|
||||
};
|
||||
const auto result = find_max_global(cross_validation_score, 1e-4, 10000, max_function_calls(50));
|
||||
cout << "Best C: " << result.x(0) << endl;
|
||||
|
||||
// Proceed to train the SVM classifier with the best C.
|
||||
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
|
||||
trainer.set_num_threads(std::thread::hardware_concurrency());
|
||||
trainer.set_c(result.x(0));
|
||||
cout << "Training Multiclass SVM..." << endl;
|
||||
const auto df = trainer.train(features, training_labels);
|
||||
serialize("multiclass_svm_cifar_10.dat") << df;
|
||||
|
||||
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
|
||||
auto compute_accuracy = [&inet, batch_size](
|
||||
const std::vector<matrix<rgb_pixel>>& images,
|
||||
auto compute_accuracy = [&fnet, &df, batch_size](
|
||||
const std::vector<matrix<float, 0, 1>>& samples,
|
||||
const std::vector<unsigned long>& labels
|
||||
)
|
||||
{
|
||||
size_t num_right = 0;
|
||||
size_t num_wrong = 0;
|
||||
const auto preds = inet(images, batch_size * 2);
|
||||
for (size_t i = 0; i < labels.size(); ++i)
|
||||
{
|
||||
if (labels[i] == preds[i])
|
||||
if (labels[i] == df(samples[i]))
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
@ -350,11 +341,12 @@ try
|
||||
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
||||
};
|
||||
|
||||
// If everything works as expected, we should get accuracies that are between 87% and 90%.
|
||||
cout << "training accuracy" << endl;
|
||||
compute_accuracy(training_images, training_labels);
|
||||
// We should get a training accuracy of around 93% and a testing accuracy of around 88%.
|
||||
cout << "\ntraining accuracy" << endl;
|
||||
compute_accuracy(features, training_labels);
|
||||
cout << "\ntesting accuracy" << endl;
|
||||
compute_accuracy(testing_images, testing_labels);
|
||||
features = fnet(testing_images, 4 * batch_size);
|
||||
compute_accuracy(features, testing_labels);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
catch (const exception& e)
|
||||
|
Loading…
Reference in New Issue
Block a user