You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
dlib/examples/dnn_self_supervised_learnin...

365 lines
16 KiB

Add dnn self supervised learning example (#2434) * wip: loss goes down when training without a dnn_trainer if I use a dnn_trainer, it segfaults (also with bigger batch sizes...) * remove commented code * fix gradient computation (hopefully) * fix loss computation * fix crash in input_rgb_image_pair::to_tensor * fix alias tensor offset * refactor loss and input layers and complete the example * add more data augmentation * add documentation * add documentation * small fix in the gradient computation and reuse terms * fix warning in comment * use tensor_tools instead of matrix to compute the gradients * complete the example program * add support for mult-gpu * Update dlib/dnn/input_abstract.h * Update dlib/dnn/input_abstract.h * Update dlib/dnn/loss_abstract.h * Update examples/dnn_self_supervised_learning_ex.cpp * Update examples/dnn_self_supervised_learning_ex.cpp * Update examples/dnn_self_supervised_learning_ex.cpp * Update examples/dnn_self_supervised_learning_ex.cpp * [TYPE_SAFE_UNION] upgrade (#2443) * [TYPE_SAFE_UNION] upgrade * MSVC doesn't like keyword not * MSVC doesn't like keyword and * added tests for emplate(), copy semantics, move semantics, swap, overloaded and apply_to_contents with non void return types * - didn't need is_void anymore - added result_of_t - didn't really need ostream_helper or istream_helper - split apply_to_contents into apply_to_contents (return void) and visit (return anything so long as visitor is publicly accessible) * - updated abstract file * - added get_type_t - removed deserialize_helper dupplicate - don't use std::decay_t, that's c++14 * - removed white spaces - don't need a return-statement when calling apply_to_contents_impl() - use unchecked_get() whenever possible to minimise explicit use of pointer casting. lets keep that to a minimum * - added type_safe_union_size - added type_safe_union_size_v if C++14 is available - added tests for above * - test type_safe_union_size_v * testing nested unions with visitors. * re-added comment * added index() in abstract file * - refactored reset() to clear() - added comment about clear() in abstract file - in deserialize(), only reset the object if necessary * - removed unecessary comment about exceptions - removed unecessary // ------------- - struct is_valid is not mentioned in abstract. Instead rather requiring T to be a valid type, it is ensured! - get_type and get_type_t are private. Client code shouldn't need this. - shuffled some functions around - type_safe_union_size and type_safe_union_size_v are removed. not needed - reset() -> clear() - bug fix in deserialize() index counts from 1, not 0 - improved the abstract file * refactored index() to get_current_type_id() as per suggestion * maybe slightly improved docs * - HURRAY, don't need std::result_of or std::invoke_result for visit() to work. Just privately define your own type trait, in this case called return_type and return_type_t. it works! - apply_to_contents() now always calls visit() * example with private visitor using friendship with non-void return types. * Fix up contracts It can't be a post condition that T is a valid type, since the choice of T is up to the caller, it's not something these functions decide. Making it a precondition. * Update dlib/type_safe_union/type_safe_union_kernel_abstract.h * Update dlib/type_safe_union/type_safe_union_kernel_abstract.h * Update dlib/type_safe_union/type_safe_union_kernel_abstract.h * - added more tests for copy constructors/assignments, move constructors/assignments, and converting constructors/assignments - helper_copy -> helper_forward - added validate_type<T> in a couple of places * - helper_move only takes non-const lvalue references. So we are not using std::move with universal references ! - use enable_if<is_valid<T>> in favor of validate_type<T>() * - use enable_if<is_valid<T>> in favor of validate_type<T>() * - added is_valid_check<>. This wraps enable_if<is_valid<T>,bool> and makes use of SFINAE more robust Co-authored-by: pfeatherstone <peter@me> Co-authored-by: pf <pf@me> Co-authored-by: Davis E. King <davis685@gmail.com> * Just minor cleanup of docs and renamed some stuff, tweaked formatting. * fix spelling error * fix most vexing parse error Co-authored-by: Davis E. King <davis@dlib.net> Co-authored-by: pfeatherstone <45853521+pfeatherstone@users.noreply.github.com> Co-authored-by: pfeatherstone <peter@me> Co-authored-by: pf <pf@me> Co-authored-by: Davis E. King <davis685@gmail.com>
3 years ago
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This is an example illustrating the use of the deep learning tools from the dlib C++
Library. I'm assuming you have already read the dnn_introduction_ex.cpp, the
dnn_introduction2_ex.cpp and the dnn_introduction3_ex.cpp examples. In this example
program we are going to show how one can train a neural network using an unsupervised
loss function. In particular, we will train the ResNet50 model from the paper
"Deep Residual Learning for Image Recognition" by Kaiming He, Xiangyu Zhang, Shaoqing
Ren, Jian Sun.
To train the unsupervised loss, we will use the self-supervised learning (SSL) method
called Barlow Twins, introduced in this paper:
"Barlow Twins: Self-Supervised Learning via Redundancy Reduction" by Jure Zbontar,
Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny.
The paper contains a good explanation on how and why this works, but the main idea
behind the Barlow Twins method is:
- generate two distorted views of a batch of images: YA, YB
- feed them to a deep neural network and obtain their representations and
and batch normalize them: ZA, ZB
- compute the empirical cross-correlation matrix between both feature
representations as: C = trans(ZA) * ZB.
- make C as close as possible to the identity matrix.
This removes the redundancy of the feature representations, by maximizing the
encoded information about the images themselves, while minimizing the information
about the transforms and data augmentations used to obtain the representations.
The original Barlow Twins paper uses the ImageNet dataset, but in this example we
are using CIFAR-10, so we will follow the recommendations of this paper, instead:
"A Note on Connecting Barlow Twins with Negative-Sample-Free Contrastive Learning"
by Yao-Hung Hubert Tsai, Shaojie Bai, Louis-Philippe Morency, Ruslan Salakhutdinov,
in which they experiment with Barlow Twins on CIFAR-10 and Tiny ImageNet. Since
the CIFAR-10 contains relatively small images, we will define a ResNet50 architecture
that doesn't downsample the input in the first convolutional layer, and doesn't have
a max pooling layer afterwards, like the paper does.
*/
#include <dlib/dnn.h>
#include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h>
#include <dlib/gui_widgets.h>
using namespace std;
using namespace dlib;
// A custom definition of ResNet50 with a downsampling factor of 8 instead of 32.
// It is essentially the original ResNet50, but without the max pooling and a
// convolutional layer with a stride of 1 instead of 2 at the input.
namespace resnet50
{
using namespace dlib;
template <template <typename> class BN>
struct def
{
template <long N, int K, int S, typename SUBNET>
using conv = add_layer<con_<N, K, K, S, S, K / 2, K / 2>, SUBNET>;
template<long N, int S, typename SUBNET>
using bottleneck = BN<conv<4 * N, 1, 1, relu<BN<conv<N, 3, S, relu<BN<conv<N, 1, 1, SUBNET>>>>>>>>;
template <long N, typename SUBNET>
using residual = add_prev1<bottleneck<N, 1, tag1<SUBNET>>>;
template <typename SUBNET> using res_512 = relu<residual<512, SUBNET>>;
template <typename SUBNET> using res_256 = relu<residual<256, SUBNET>>;
template <typename SUBNET> using res_128 = relu<residual<128, SUBNET>>;
template <typename SUBNET> using res_64 = relu<residual<64, SUBNET>>;
template <long N, int S, typename SUBNET>
using transition = add_prev2<BN<conv<4 * N, 1, S, skip1<tag2<bottleneck<N, S, tag1<SUBNET>>>>>>>;
template <typename INPUT>
using backbone = avg_pool_everything<
repeat<2, res_512, transition<512, 2,
repeat<5, res_256, transition<256, 2,
repeat<3, res_128, transition<128, 2,
repeat<2, res_64, transition<64, 1,
relu<BN<conv<64, 3, 1,INPUT>>>>>>>>>>>>;
};
};
// This model namespace contains the definitions for:
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
namespace model
{
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
template <typename SUBNET> using classifier = fc<10, SUBNET>;
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
}
rectangle make_random_cropping_rect(
const matrix<rgb_pixel>& image,
dlib::rand& rnd
)
{
const double mins = 7. / 15.;
const double maxs = 7. / 8.;
const auto scale = rnd.get_double_in_range(mins, maxs);
const auto size = scale * std::min(image.nr(), image.nc());
const rectangle rect(size, size);
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
return move_rect(rect, offset);
}
matrix<rgb_pixel> augment(
const matrix<rgb_pixel>& image,
const bool prime,
dlib::rand& rnd
)
{
matrix<rgb_pixel> crop;
// blur
matrix<rgb_pixel> blurred;
const double sigma = rnd.get_double_in_range(0.1, 1.1);
if (!prime || (prime && rnd.get_random_double() < 0.1))
{
const auto rect = gaussian_blur(image, blurred, sigma);
extract_image_chip(blurred, rect, crop);
blurred = crop;
}
else
{
blurred = image;
}
// randomly crop
const auto rect = make_random_cropping_rect(image, rnd);
extract_image_chip(blurred, chip_details(rect, chip_dims(32, 32)), crop);
// image left-right flip
if (rnd.get_random_double() < 0.5)
flip_image_left_right(crop);
// color augmentation
if (rnd.get_random_double() < 0.8)
disturb_colors(crop, rnd, 0.5, 0.5);
// grayscale
if (rnd.get_random_double() < 0.2)
{
matrix<unsigned char> gray;
assign_image(gray, crop);
assign_image(crop, gray);
}
// solarize
if (prime && rnd.get_random_double() < 0.2)
{
for (auto& p : crop)
{
if (p.red > 128)
p.red = 255 - p.red;
if (p.green > 128)
p.green = 255 - p.green;
if (p.blue > 128)
p.blue = 255 - p.blue;
}
}
return crop;
}
int main(const int argc, const char** argv)
try
{
// The default settings are fine for the example already.
command_line_parser parser;
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
parser.add_option("min-learning-rate", "set the min learning rate (default: 1e-5)", 1);
parser.add_option("num-gpus", "number of GPUs (default: 1)", 1);
parser.set_group_name("Help Options");
parser.add_option("h", "alias for --help");
parser.add_option("help", "display this message and exit");
parser.parse(argc, argv);
if (parser.number_of_arguments() < 1 || parser.option("h") || parser.option("help"))
{
cout << "This example needs the CIFAR-10 dataset to run." << endl;
cout << "You can get CIFAR-10 from https://www.cs.toronto.edu/~kriz/cifar.html" << endl;
cout << "Download the binary version the dataset, decompress it, and put the 6" << endl;
cout << "bin files in a folder. Then give that folder as input to this program." << endl;
parser.print_options();
return EXIT_SUCCESS;
}
const size_t num_gpus = get_option(parser, "num-gpus", 1);
const size_t batch_size = get_option(parser, "batch", 64) * num_gpus;
const long dims = get_option(parser, "dims", 128);
const double lambda = get_option(parser, "lambda", 1.0 / dims);
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
// Load the CIFAR-10 dataset into memory.
std::vector<matrix<rgb_pixel>> training_images, testing_images;
std::vector<unsigned long> training_labels, testing_labels;
load_cifar_10_dataset(parser[0], training_images, training_labels, testing_images, testing_labels);
// Initialize the model with the specified projector dimensions and lambda. According to the
// second paper, lambda = 1/dims works well on CIFAR-10.
model::train net((loss_barlow_twins_(lambda)));
layer<1>(net).layer_details().set_num_outputs(dims);
disable_duplicative_biases(net);
dlib::rand rnd;
std::vector<int> gpus(num_gpus);
std::iota(gpus.begin(), gpus.end(), 0);
// Train the feature extractor using the Barlow Twins method
{
dnn_trainer<model::train, adam> trainer(net, adam(1e-6, 0.9, 0.999), gpus);
trainer.set_mini_batch_size(batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(10000);
trainer.set_synchronization_file("barlow_twins_sync");
trainer.be_verbose();
cout << trainer << endl;
// During the training, we will compute the empirical cross-correlation matrix
// between the features of both versions of the augmented images. This matrix
// should be getting close to the identity matrix as the training progresses.
// Note that this step is already done in the loss layer, and it's not necessary
// to do it here for the example to work. However, it provides a nice
// visualization of the training progress: the closer to the identity matrix,
// the better.
resizable_tensor eccm;
eccm.set_size(dims, dims);
// Some tensors needed to perform batch normalization
resizable_tensor za_norm, zb_norm, means, invstds, rms, rvs, gamma, beta;
const double eps = DEFAULT_BATCH_NORM_EPS;
gamma.set_size(1, dims);
beta.set_size(1, dims);
image_window win;
std::vector<std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
{
batch.clear();
while (batch.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
auto image = training_images[idx];
batch.emplace_back(augment(image, false, rnd), augment(image, true, rnd));
}
trainer.train_one_step(batch);
// Compute the empirical cross-correlation matrix every 100 steps. Again,
// this is not needed for the training to work, but it's nice to visualize.
if (trainer.get_train_one_step_calls() % 100 == 0)
{
// Wait for threaded processing to stop in the trainer.
trainer.get_net(force_flush_to_disk::no);
// Get the output from the last fc layer
const auto& out = net.subnet().get_output();
// The trainer might have synchronized its state to the disk and cleaned
// the network state. If that happens, the output will be empty, in
// which case, we just skip the empirical cross-correlation matrix
// computation.
if (out.size() == 0)
continue;
// Separate both augmented versions of the images
alias_tensor split(out.num_samples() / 2, dims);
auto za = split(out);
auto zb = split(out, split.size());
gamma = 1;
beta = 0;
// Perform batch normalization on each feature representation, independently.
tt::batch_normalize(eps, za_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
tt::batch_normalize(eps, zb_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
// Compute the empirical cross-correlation matrix between the features and
// visualize it.
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
win.set_image(abs(mat(eccm)) * 255);
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
}
}
trainer.get_net();
net.clean();
// After training, we can discard the projector head and just keep the backone
// to train it or finetune it on other downstream tasks.
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
}
// To check the quality of the learned feature representations, we will train a linear
// classififer on top of the frozen backbone.
model::infer inet;
// Assign the network, without the projector, which is only used for the self-supervised
// training.
layer<2>(inet) = layer<5>(net);
// Freeze the backbone
set_all_learning_rate_multipliers(layer<2>(inet), 0);
// Train the network
{
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
// Since this model doesn't train with pairs, just single images, we can increase
// the batch size.
trainer.set_mini_batch_size(2 * batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(5000);
trainer.set_synchronization_file("cifar_10_sync");
trainer.be_verbose();
cout << trainer << endl;
std::vector<matrix<rgb_pixel>> images;
std::vector<unsigned long> labels;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
{
images.clear();
labels.clear();
while (images.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
images.push_back(augment(training_images[idx], false, rnd));
labels.push_back(training_labels[idx]);
}
trainer.train_one_step(images, labels);
}
trainer.get_net();
inet.clean();
serialize("resnet50_cifar_10.dnn") << inet;
}
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
auto compute_accuracy = [&inet, batch_size](
const std::vector<matrix<rgb_pixel>>& images,
const std::vector<unsigned long>& labels
)
{
size_t num_right = 0;
size_t num_wrong = 0;
const auto preds = inet(images, batch_size * 2);
for (size_t i = 0; i < labels.size(); ++i)
{
if (labels[i] == preds[i])
++num_right;
else
++num_wrong;
}
cout << "num right: " << num_right << endl;
cout << "num wrong: " << num_wrong << endl;
cout << "accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
cout << "error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
};
// If everything works as expected, we should get accuracies that are between 87% and 90%.
cout << "training accuracy" << endl;
compute_accuracy(training_images, training_labels);
cout << "\ntesting accuracy" << endl;
compute_accuracy(testing_images, testing_labels);
return EXIT_SUCCESS;
}
catch (const exception& e)
{
cout << e.what() << endl;
return EXIT_FAILURE;
}