mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Replace fc classifier with svm_multiclass_linear_trainer
This commit is contained in:
parent
a41b3d7ce8
commit
c8c810f221
@ -40,6 +40,7 @@
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/cmd_line_parser.h>
|
||||
#include <dlib/gui_widgets.h>
|
||||
#include <dlib/svm_threaded.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
@ -82,14 +83,12 @@ namespace resnet50
|
||||
|
||||
// This model namespace contains the definitions for:
|
||||
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
|
||||
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
|
||||
// - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image.
|
||||
namespace model
|
||||
{
|
||||
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
|
||||
template <typename SUBNET> using classifier = fc<10, SUBNET>;
|
||||
|
||||
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
|
||||
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
|
||||
using feats = loss_metric<resnet50::def<affine>::backbone<input_rgb_image>>;
|
||||
}
|
||||
|
||||
rectangle make_random_cropping_rect(
|
||||
@ -170,6 +169,7 @@ try
|
||||
// The default settings are fine for the example already.
|
||||
command_line_parser parser;
|
||||
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
|
||||
parser.add_option("c", "SVM C parameter for the linear classifier (default: 5000)", 1);
|
||||
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
|
||||
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
|
||||
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
|
||||
@ -196,6 +196,7 @@ try
|
||||
const double lambda = get_option(parser, "lambda", 1.0 / dims);
|
||||
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
|
||||
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
|
||||
const double svm_c = get_option(parser, "c", 5000);
|
||||
|
||||
// Load the CIFAR-10 dataset into memory.
|
||||
std::vector<matrix<rgb_pixel>> training_images, testing_images;
|
||||
@ -288,73 +289,47 @@ try
|
||||
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
|
||||
}
|
||||
|
||||
// To check the quality of the learned feature representations, we will train a linear
|
||||
// classififer on top of the frozen backbone.
|
||||
model::infer inet;
|
||||
// Assign the network, without the projector, which is only used for the self-supervised
|
||||
// training.
|
||||
layer<2>(inet) = layer<5>(net);
|
||||
// Freeze the backbone
|
||||
set_all_learning_rate_multipliers(layer<2>(inet), 0);
|
||||
// Train the network
|
||||
{
|
||||
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
|
||||
// Since this model doesn't train with pairs, just single images, we can increase
|
||||
// the batch size.
|
||||
trainer.set_mini_batch_size(2 * batch_size);
|
||||
trainer.set_learning_rate(learning_rate);
|
||||
trainer.set_min_learning_rate(min_learning_rate);
|
||||
trainer.set_iterations_without_progress_threshold(5000);
|
||||
trainer.set_synchronization_file("cifar_10_sync");
|
||||
trainer.be_verbose();
|
||||
cout << trainer << endl;
|
||||
|
||||
std::vector<matrix<rgb_pixel>> images;
|
||||
std::vector<unsigned long> labels;
|
||||
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
|
||||
{
|
||||
images.clear();
|
||||
labels.clear();
|
||||
while (images.size() < trainer.get_mini_batch_size())
|
||||
{
|
||||
const auto idx = rnd.get_random_32bit_number() % training_images.size();
|
||||
images.push_back(augment(training_images[idx], false, rnd));
|
||||
labels.push_back(training_labels[idx]);
|
||||
}
|
||||
trainer.train_one_step(images, labels);
|
||||
}
|
||||
trainer.get_net();
|
||||
inet.clean();
|
||||
serialize("resnet50_cifar_10.dnn") << inet;
|
||||
}
|
||||
// Now, we initialize the feature extractor model with the backbone we have just learned.
|
||||
model::feats fnet(layer<5>(net));
|
||||
// And we will generate all the features for the training set to train a multiclass SVM
|
||||
// classifier.
|
||||
std::vector<matrix<float, 0, 1>> features;
|
||||
cout << "Extracting features for linear classifier..." << endl;
|
||||
features = fnet(training_images, 4 * batch_size);
|
||||
svm_multiclass_linear_trainer<linear_kernel<matrix<float,0,1>>, unsigned long> trainer;
|
||||
trainer.set_num_threads(std::thread::hardware_concurrency());
|
||||
trainer.set_c(svm_c);
|
||||
cout << "Training Multiclass SVM..." << endl;
|
||||
const auto df = trainer.train(features, training_labels);
|
||||
serialize("multiclass_svm_cifar_10.dat") << df;
|
||||
|
||||
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
|
||||
auto compute_accuracy = [&inet, batch_size](
|
||||
const std::vector<matrix<rgb_pixel>>& images,
|
||||
auto compute_accuracy = [&fnet, &df, batch_size](
|
||||
const std::vector<matrix<float, 0, 1>>& samples,
|
||||
const std::vector<unsigned long>& labels
|
||||
)
|
||||
{
|
||||
size_t num_right = 0;
|
||||
size_t num_wrong = 0;
|
||||
const auto preds = inet(images, batch_size * 2);
|
||||
for (size_t i = 0; i < labels.size(); ++i)
|
||||
{
|
||||
if (labels[i] == preds[i])
|
||||
if (labels[i] == df(samples[i]))
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
}
|
||||
cout << "num right: " << num_right << endl;
|
||||
cout << "num wrong: " << num_wrong << endl;
|
||||
cout << "accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
|
||||
cout << "error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
||||
cout << " num right: " << num_right << endl;
|
||||
cout << " num wrong: " << num_wrong << endl;
|
||||
cout << " accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
|
||||
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
||||
};
|
||||
|
||||
// If everything works as expected, we should get accuracies that are between 87% and 90%.
|
||||
cout << "training accuracy" << endl;
|
||||
compute_accuracy(training_images, training_labels);
|
||||
// We should get a training accuracy of around 93% and a testing accuracy of around 88%.
|
||||
cout << "\ntraining accuracy" << endl;
|
||||
compute_accuracy(features, training_labels);
|
||||
cout << "\ntesting accuracy" << endl;
|
||||
compute_accuracy(testing_images, testing_labels);
|
||||
features = fnet(testing_images, 4 * batch_size);
|
||||
compute_accuracy(features, testing_labels);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
catch (const exception& e)
|
||||
|
Loading…
Reference in New Issue
Block a user