mirror of https://github.com/davisking/dlib.git
Add dnn self supervised learning example (#2434)
* wip: loss goes down when training without a dnn_trainer if I use a dnn_trainer, it segfaults (also with bigger batch sizes...) * remove commented code * fix gradient computation (hopefully) * fix loss computation * fix crash in input_rgb_image_pair::to_tensor * fix alias tensor offset * refactor loss and input layers and complete the example * add more data augmentation * add documentation * add documentation * small fix in the gradient computation and reuse terms * fix warning in comment * use tensor_tools instead of matrix to compute the gradients * complete the example program * add support for mult-gpu * Update dlib/dnn/input_abstract.h * Update dlib/dnn/input_abstract.h * Update dlib/dnn/loss_abstract.h * Update examples/dnn_self_supervised_learning_ex.cpp * Update examples/dnn_self_supervised_learning_ex.cpp * Update examples/dnn_self_supervised_learning_ex.cpp * Update examples/dnn_self_supervised_learning_ex.cpp * [TYPE_SAFE_UNION] upgrade (#2443) * [TYPE_SAFE_UNION] upgrade * MSVC doesn't like keyword not * MSVC doesn't like keyword and * added tests for emplate(), copy semantics, move semantics, swap, overloaded and apply_to_contents with non void return types * - didn't need is_void anymore - added result_of_t - didn't really need ostream_helper or istream_helper - split apply_to_contents into apply_to_contents (return void) and visit (return anything so long as visitor is publicly accessible) * - updated abstract file * - added get_type_t - removed deserialize_helper dupplicate - don't use std::decay_t, that's c++14 * - removed white spaces - don't need a return-statement when calling apply_to_contents_impl() - use unchecked_get() whenever possible to minimise explicit use of pointer casting. lets keep that to a minimum * - added type_safe_union_size - added type_safe_union_size_v if C++14 is available - added tests for above * - test type_safe_union_size_v * testing nested unions with visitors. * re-added comment * added index() in abstract file * - refactored reset() to clear() - added comment about clear() in abstract file - in deserialize(), only reset the object if necessary * - removed unecessary comment about exceptions - removed unecessary // ------------- - struct is_valid is not mentioned in abstract. Instead rather requiring T to be a valid type, it is ensured! - get_type and get_type_t are private. Client code shouldn't need this. - shuffled some functions around - type_safe_union_size and type_safe_union_size_v are removed. not needed - reset() -> clear() - bug fix in deserialize() index counts from 1, not 0 - improved the abstract file * refactored index() to get_current_type_id() as per suggestion * maybe slightly improved docs * - HURRAY, don't need std::result_of or std::invoke_result for visit() to work. Just privately define your own type trait, in this case called return_type and return_type_t. it works! - apply_to_contents() now always calls visit() * example with private visitor using friendship with non-void return types. * Fix up contracts It can't be a post condition that T is a valid type, since the choice of T is up to the caller, it's not something these functions decide. Making it a precondition. * Update dlib/type_safe_union/type_safe_union_kernel_abstract.h * Update dlib/type_safe_union/type_safe_union_kernel_abstract.h * Update dlib/type_safe_union/type_safe_union_kernel_abstract.h * - added more tests for copy constructors/assignments, move constructors/assignments, and converting constructors/assignments - helper_copy -> helper_forward - added validate_type<T> in a couple of places * - helper_move only takes non-const lvalue references. So we are not using std::move with universal references ! - use enable_if<is_valid<T>> in favor of validate_type<T>() * - use enable_if<is_valid<T>> in favor of validate_type<T>() * - added is_valid_check<>. This wraps enable_if<is_valid<T>,bool> and makes use of SFINAE more robust Co-authored-by: pfeatherstone <peter@me> Co-authored-by: pf <pf@me> Co-authored-by: Davis E. King <davis685@gmail.com> * Just minor cleanup of docs and renamed some stuff, tweaked formatting. * fix spelling error * fix most vexing parse error Co-authored-by: Davis E. King <davis@dlib.net> Co-authored-by: pfeatherstone <45853521+pfeatherstone@users.noreply.github.com> Co-authored-by: pfeatherstone <peter@me> Co-authored-by: pf <pf@me> Co-authored-by: Davis E. King <davis685@gmail.com>pull/2453/head
parent
f323d1824c
commit
2e8bac1915
@ -0,0 +1,364 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
This is an example illustrating the use of the deep learning tools from the dlib C++
|
||||
Library. I'm assuming you have already read the dnn_introduction_ex.cpp, the
|
||||
dnn_introduction2_ex.cpp and the dnn_introduction3_ex.cpp examples. In this example
|
||||
program we are going to show how one can train a neural network using an unsupervised
|
||||
loss function. In particular, we will train the ResNet50 model from the paper
|
||||
"Deep Residual Learning for Image Recognition" by Kaiming He, Xiangyu Zhang, Shaoqing
|
||||
Ren, Jian Sun.
|
||||
|
||||
To train the unsupervised loss, we will use the self-supervised learning (SSL) method
|
||||
called Barlow Twins, introduced in this paper:
|
||||
"Barlow Twins: Self-Supervised Learning via Redundancy Reduction" by Jure Zbontar,
|
||||
Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny.
|
||||
|
||||
The paper contains a good explanation on how and why this works, but the main idea
|
||||
behind the Barlow Twins method is:
|
||||
- generate two distorted views of a batch of images: YA, YB
|
||||
- feed them to a deep neural network and obtain their representations and
|
||||
and batch normalize them: ZA, ZB
|
||||
- compute the empirical cross-correlation matrix between both feature
|
||||
representations as: C = trans(ZA) * ZB.
|
||||
- make C as close as possible to the identity matrix.
|
||||
|
||||
This removes the redundancy of the feature representations, by maximizing the
|
||||
encoded information about the images themselves, while minimizing the information
|
||||
about the transforms and data augmentations used to obtain the representations.
|
||||
|
||||
The original Barlow Twins paper uses the ImageNet dataset, but in this example we
|
||||
are using CIFAR-10, so we will follow the recommendations of this paper, instead:
|
||||
"A Note on Connecting Barlow Twins with Negative-Sample-Free Contrastive Learning"
|
||||
by Yao-Hung Hubert Tsai, Shaojie Bai, Louis-Philippe Morency, Ruslan Salakhutdinov,
|
||||
in which they experiment with Barlow Twins on CIFAR-10 and Tiny ImageNet. Since
|
||||
the CIFAR-10 contains relatively small images, we will define a ResNet50 architecture
|
||||
that doesn't downsample the input in the first convolutional layer, and doesn't have
|
||||
a max pooling layer afterwards, like the paper does.
|
||||
*/
|
||||
|
||||
#include <dlib/dnn.h>
|
||||
#include <dlib/data_io.h>
|
||||
#include <dlib/cmd_line_parser.h>
|
||||
#include <dlib/gui_widgets.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
// A custom definition of ResNet50 with a downsampling factor of 8 instead of 32.
|
||||
// It is essentially the original ResNet50, but without the max pooling and a
|
||||
// convolutional layer with a stride of 1 instead of 2 at the input.
|
||||
namespace resnet50
|
||||
{
|
||||
using namespace dlib;
|
||||
template <template <typename> class BN>
|
||||
struct def
|
||||
{
|
||||
template <long N, int K, int S, typename SUBNET>
|
||||
using conv = add_layer<con_<N, K, K, S, S, K / 2, K / 2>, SUBNET>;
|
||||
|
||||
template<long N, int S, typename SUBNET>
|
||||
using bottleneck = BN<conv<4 * N, 1, 1, relu<BN<conv<N, 3, S, relu<BN<conv<N, 1, 1, SUBNET>>>>>>>>;
|
||||
|
||||
template <long N, typename SUBNET>
|
||||
using residual = add_prev1<bottleneck<N, 1, tag1<SUBNET>>>;
|
||||
|
||||
template <typename SUBNET> using res_512 = relu<residual<512, SUBNET>>;
|
||||
template <typename SUBNET> using res_256 = relu<residual<256, SUBNET>>;
|
||||
template <typename SUBNET> using res_128 = relu<residual<128, SUBNET>>;
|
||||
template <typename SUBNET> using res_64 = relu<residual<64, SUBNET>>;
|
||||
|
||||
template <long N, int S, typename SUBNET>
|
||||
using transition = add_prev2<BN<conv<4 * N, 1, S, skip1<tag2<bottleneck<N, S, tag1<SUBNET>>>>>>>;
|
||||
|
||||
template <typename INPUT>
|
||||
using backbone = avg_pool_everything<
|
||||
repeat<2, res_512, transition<512, 2,
|
||||
repeat<5, res_256, transition<256, 2,
|
||||
repeat<3, res_128, transition<128, 2,
|
||||
repeat<2, res_64, transition<64, 1,
|
||||
relu<BN<conv<64, 3, 1,INPUT>>>>>>>>>>>>;
|
||||
};
|
||||
};
|
||||
|
||||
// This model namespace contains the definitions for:
|
||||
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
|
||||
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
|
||||
namespace model
|
||||
{
|
||||
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
|
||||
template <typename SUBNET> using classifier = fc<10, SUBNET>;
|
||||
|
||||
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
|
||||
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
|
||||
}
|
||||
|
||||
rectangle make_random_cropping_rect(
|
||||
const matrix<rgb_pixel>& image,
|
||||
dlib::rand& rnd
|
||||
)
|
||||
{
|
||||
const double mins = 7. / 15.;
|
||||
const double maxs = 7. / 8.;
|
||||
const auto scale = rnd.get_double_in_range(mins, maxs);
|
||||
const auto size = scale * std::min(image.nr(), image.nc());
|
||||
const rectangle rect(size, size);
|
||||
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
|
||||
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
|
||||
return move_rect(rect, offset);
|
||||
}
|
||||
|
||||
matrix<rgb_pixel> augment(
|
||||
const matrix<rgb_pixel>& image,
|
||||
const bool prime,
|
||||
dlib::rand& rnd
|
||||
)
|
||||
{
|
||||
matrix<rgb_pixel> crop;
|
||||
// blur
|
||||
matrix<rgb_pixel> blurred;
|
||||
const double sigma = rnd.get_double_in_range(0.1, 1.1);
|
||||
if (!prime || (prime && rnd.get_random_double() < 0.1))
|
||||
{
|
||||
const auto rect = gaussian_blur(image, blurred, sigma);
|
||||
extract_image_chip(blurred, rect, crop);
|
||||
blurred = crop;
|
||||
}
|
||||
else
|
||||
{
|
||||
blurred = image;
|
||||
}
|
||||
|
||||
// randomly crop
|
||||
const auto rect = make_random_cropping_rect(image, rnd);
|
||||
extract_image_chip(blurred, chip_details(rect, chip_dims(32, 32)), crop);
|
||||
|
||||
// image left-right flip
|
||||
if (rnd.get_random_double() < 0.5)
|
||||
flip_image_left_right(crop);
|
||||
|
||||
// color augmentation
|
||||
if (rnd.get_random_double() < 0.8)
|
||||
disturb_colors(crop, rnd, 0.5, 0.5);
|
||||
|
||||
// grayscale
|
||||
if (rnd.get_random_double() < 0.2)
|
||||
{
|
||||
matrix<unsigned char> gray;
|
||||
assign_image(gray, crop);
|
||||
assign_image(crop, gray);
|
||||
}
|
||||
|
||||
// solarize
|
||||
if (prime && rnd.get_random_double() < 0.2)
|
||||
{
|
||||
for (auto& p : crop)
|
||||
{
|
||||
if (p.red > 128)
|
||||
p.red = 255 - p.red;
|
||||
if (p.green > 128)
|
||||
p.green = 255 - p.green;
|
||||
if (p.blue > 128)
|
||||
p.blue = 255 - p.blue;
|
||||
}
|
||||
}
|
||||
return crop;
|
||||
}
|
||||
|
||||
int main(const int argc, const char** argv)
|
||||
try
|
||||
{
|
||||
// The default settings are fine for the example already.
|
||||
command_line_parser parser;
|
||||
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
|
||||
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
|
||||
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
|
||||
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
|
||||
parser.add_option("min-learning-rate", "set the min learning rate (default: 1e-5)", 1);
|
||||
parser.add_option("num-gpus", "number of GPUs (default: 1)", 1);
|
||||
parser.set_group_name("Help Options");
|
||||
parser.add_option("h", "alias for --help");
|
||||
parser.add_option("help", "display this message and exit");
|
||||
parser.parse(argc, argv);
|
||||
|
||||
if (parser.number_of_arguments() < 1 || parser.option("h") || parser.option("help"))
|
||||
{
|
||||
cout << "This example needs the CIFAR-10 dataset to run." << endl;
|
||||
cout << "You can get CIFAR-10 from https://www.cs.toronto.edu/~kriz/cifar.html" << endl;
|
||||
cout << "Download the binary version the dataset, decompress it, and put the 6" << endl;
|
||||
cout << "bin files in a folder. Then give that folder as input to this program." << endl;
|
||||
parser.print_options();
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
const size_t num_gpus = get_option(parser, "num-gpus", 1);
|
||||
const size_t batch_size = get_option(parser, "batch", 64) * num_gpus;
|
||||
const long dims = get_option(parser, "dims", 128);
|
||||
const double lambda = get_option(parser, "lambda", 1.0 / dims);
|
||||
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
|
||||
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
|
||||
|
||||
// Load the CIFAR-10 dataset into memory.
|
||||
std::vector<matrix<rgb_pixel>> training_images, testing_images;
|
||||
std::vector<unsigned long> training_labels, testing_labels;
|
||||
load_cifar_10_dataset(parser[0], training_images, training_labels, testing_images, testing_labels);
|
||||
|
||||
// Initialize the model with the specified projector dimensions and lambda. According to the
|
||||
// second paper, lambda = 1/dims works well on CIFAR-10.
|
||||
model::train net((loss_barlow_twins_(lambda)));
|
||||
layer<1>(net).layer_details().set_num_outputs(dims);
|
||||
disable_duplicative_biases(net);
|
||||
dlib::rand rnd;
|
||||
std::vector<int> gpus(num_gpus);
|
||||
std::iota(gpus.begin(), gpus.end(), 0);
|
||||
|
||||
// Train the feature extractor using the Barlow Twins method
|
||||
{
|
||||
dnn_trainer<model::train, adam> trainer(net, adam(1e-6, 0.9, 0.999), gpus);
|
||||
trainer.set_mini_batch_size(batch_size);
|
||||
trainer.set_learning_rate(learning_rate);
|
||||
trainer.set_min_learning_rate(min_learning_rate);
|
||||
trainer.set_iterations_without_progress_threshold(10000);
|
||||
trainer.set_synchronization_file("barlow_twins_sync");
|
||||
trainer.be_verbose();
|
||||
cout << trainer << endl;
|
||||
|
||||
// During the training, we will compute the empirical cross-correlation matrix
|
||||
// between the features of both versions of the augmented images. This matrix
|
||||
// should be getting close to the identity matrix as the training progresses.
|
||||
// Note that this step is already done in the loss layer, and it's not necessary
|
||||
// to do it here for the example to work. However, it provides a nice
|
||||
// visualization of the training progress: the closer to the identity matrix,
|
||||
// the better.
|
||||
resizable_tensor eccm;
|
||||
eccm.set_size(dims, dims);
|
||||
// Some tensors needed to perform batch normalization
|
||||
resizable_tensor za_norm, zb_norm, means, invstds, rms, rvs, gamma, beta;
|
||||
const double eps = DEFAULT_BATCH_NORM_EPS;
|
||||
gamma.set_size(1, dims);
|
||||
beta.set_size(1, dims);
|
||||
image_window win;
|
||||
|
||||
std::vector<std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
|
||||
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
|
||||
{
|
||||
batch.clear();
|
||||
while (batch.size() < trainer.get_mini_batch_size())
|
||||
{
|
||||
const auto idx = rnd.get_random_32bit_number() % training_images.size();
|
||||
auto image = training_images[idx];
|
||||
batch.emplace_back(augment(image, false, rnd), augment(image, true, rnd));
|
||||
}
|
||||
trainer.train_one_step(batch);
|
||||
|
||||
// Compute the empirical cross-correlation matrix every 100 steps. Again,
|
||||
// this is not needed for the training to work, but it's nice to visualize.
|
||||
if (trainer.get_train_one_step_calls() % 100 == 0)
|
||||
{
|
||||
// Wait for threaded processing to stop in the trainer.
|
||||
trainer.get_net(force_flush_to_disk::no);
|
||||
// Get the output from the last fc layer
|
||||
const auto& out = net.subnet().get_output();
|
||||
// The trainer might have synchronized its state to the disk and cleaned
|
||||
// the network state. If that happens, the output will be empty, in
|
||||
// which case, we just skip the empirical cross-correlation matrix
|
||||
// computation.
|
||||
if (out.size() == 0)
|
||||
continue;
|
||||
// Separate both augmented versions of the images
|
||||
alias_tensor split(out.num_samples() / 2, dims);
|
||||
auto za = split(out);
|
||||
auto zb = split(out, split.size());
|
||||
gamma = 1;
|
||||
beta = 0;
|
||||
// Perform batch normalization on each feature representation, independently.
|
||||
tt::batch_normalize(eps, za_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
|
||||
tt::batch_normalize(eps, zb_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
|
||||
// Compute the empirical cross-correlation matrix between the features and
|
||||
// visualize it.
|
||||
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
|
||||
eccm /= batch_size;
|
||||
win.set_image(abs(mat(eccm)) * 255);
|
||||
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
|
||||
}
|
||||
}
|
||||
trainer.get_net();
|
||||
net.clean();
|
||||
// After training, we can discard the projector head and just keep the backone
|
||||
// to train it or finetune it on other downstream tasks.
|
||||
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
|
||||
}
|
||||
|
||||
// To check the quality of the learned feature representations, we will train a linear
|
||||
// classififer on top of the frozen backbone.
|
||||
model::infer inet;
|
||||
// Assign the network, without the projector, which is only used for the self-supervised
|
||||
// training.
|
||||
layer<2>(inet) = layer<5>(net);
|
||||
// Freeze the backbone
|
||||
set_all_learning_rate_multipliers(layer<2>(inet), 0);
|
||||
// Train the network
|
||||
{
|
||||
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
|
||||
// Since this model doesn't train with pairs, just single images, we can increase
|
||||
// the batch size.
|
||||
trainer.set_mini_batch_size(2 * batch_size);
|
||||
trainer.set_learning_rate(learning_rate);
|
||||
trainer.set_min_learning_rate(min_learning_rate);
|
||||
trainer.set_iterations_without_progress_threshold(5000);
|
||||
trainer.set_synchronization_file("cifar_10_sync");
|
||||
trainer.be_verbose();
|
||||
cout << trainer << endl;
|
||||
|
||||
std::vector<matrix<rgb_pixel>> images;
|
||||
std::vector<unsigned long> labels;
|
||||
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
|
||||
{
|
||||
images.clear();
|
||||
labels.clear();
|
||||
while (images.size() < trainer.get_mini_batch_size())
|
||||
{
|
||||
const auto idx = rnd.get_random_32bit_number() % training_images.size();
|
||||
images.push_back(augment(training_images[idx], false, rnd));
|
||||
labels.push_back(training_labels[idx]);
|
||||
}
|
||||
trainer.train_one_step(images, labels);
|
||||
}
|
||||
trainer.get_net();
|
||||
inet.clean();
|
||||
serialize("resnet50_cifar_10.dnn") << inet;
|
||||
}
|
||||
|
||||
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
|
||||
auto compute_accuracy = [&inet, batch_size](
|
||||
const std::vector<matrix<rgb_pixel>>& images,
|
||||
const std::vector<unsigned long>& labels
|
||||
)
|
||||
{
|
||||
size_t num_right = 0;
|
||||
size_t num_wrong = 0;
|
||||
const auto preds = inet(images, batch_size * 2);
|
||||
for (size_t i = 0; i < labels.size(); ++i)
|
||||
{
|
||||
if (labels[i] == preds[i])
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
}
|
||||
cout << "num right: " << num_right << endl;
|
||||
cout << "num wrong: " << num_wrong << endl;
|
||||
cout << "accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
|
||||
cout << "error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
||||
};
|
||||
|
||||
// If everything works as expected, we should get accuracies that are between 87% and 90%.
|
||||
cout << "training accuracy" << endl;
|
||||
compute_accuracy(training_images, training_labels);
|
||||
cout << "\ntesting accuracy" << endl;
|
||||
compute_accuracy(testing_images, testing_labels);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
catch (const exception& e)
|
||||
{
|
||||
cout << e.what() << endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
Loading…
Reference in new issue