Add dnn self supervised learning example (#2434)

* wip: loss goes down when training without a dnn_trainer

if I use a dnn_trainer, it segfaults (also with bigger batch sizes...)

* remove commented code

* fix gradient computation (hopefully)

* fix loss computation

* fix crash in input_rgb_image_pair::to_tensor

* fix alias tensor offset

* refactor loss and input layers and complete the example

* add more data augmentation

* add documentation

* add documentation

* small fix in the gradient computation and reuse terms

* fix warning in comment

* use tensor_tools instead of matrix to compute the gradients

* complete the example program

* add support for mult-gpu

* Update dlib/dnn/input_abstract.h

* Update dlib/dnn/input_abstract.h

* Update dlib/dnn/loss_abstract.h

* Update examples/dnn_self_supervised_learning_ex.cpp

* Update examples/dnn_self_supervised_learning_ex.cpp

* Update examples/dnn_self_supervised_learning_ex.cpp

* Update examples/dnn_self_supervised_learning_ex.cpp

* [TYPE_SAFE_UNION] upgrade (#2443)

* [TYPE_SAFE_UNION] upgrade

* MSVC doesn't like keyword not

* MSVC doesn't like keyword and

* added tests for emplate(), copy semantics, move semantics, swap, overloaded and apply_to_contents with non void return types

* - didn't need is_void anymore
- added result_of_t
- didn't really need ostream_helper or istream_helper
- split apply_to_contents into apply_to_contents (return void) and visit (return anything so long as visitor is publicly accessible)

* - updated abstract file

* - added get_type_t
- removed deserialize_helper dupplicate
- don't use std::decay_t, that's c++14

* - removed white spaces
- don't need a return-statement when calling apply_to_contents_impl()
- use unchecked_get() whenever possible to minimise explicit use of pointer casting. lets keep that to a minimum

* - added type_safe_union_size
- added type_safe_union_size_v if C++14 is available
- added tests for above

* - test type_safe_union_size_v

* testing nested unions with visitors.

* re-added comment

* added index() in abstract file

* - refactored reset() to clear()
- added comment about clear() in abstract file
- in deserialize(), only reset the object if necessary

* - removed unecessary comment about exceptions
- removed unecessary // -------------
- struct is_valid is not mentioned in abstract. Instead rather requiring T to be a valid type, it is ensured!
- get_type and get_type_t are private. Client code shouldn't need this.
- shuffled some functions around
- type_safe_union_size and type_safe_union_size_v are removed. not needed
- reset() -> clear()
- bug fix in deserialize() index counts from 1, not 0
- improved the abstract file

* refactored index() to get_current_type_id() as per suggestion

* maybe slightly improved docs

* - HURRAY, don't need std::result_of or std::invoke_result for visit() to work. Just privately define your own type trait, in this case called return_type and return_type_t. it works!
- apply_to_contents() now always calls visit()

* example with private visitor using friendship with non-void return types.

* Fix up contracts

It can't be a post condition that T is a valid type, since the choice of T is up to the caller, it's not something these functions decide.  Making it a precondition.

* Update dlib/type_safe_union/type_safe_union_kernel_abstract.h

* Update dlib/type_safe_union/type_safe_union_kernel_abstract.h

* Update dlib/type_safe_union/type_safe_union_kernel_abstract.h

* - added more tests for copy constructors/assignments, move constructors/assignments, and converting constructors/assignments
- helper_copy -> helper_forward
- added validate_type<T> in a couple of places

* - helper_move only takes non-const lvalue references. So we are not using std::move with universal references !
- use enable_if<is_valid<T>> in favor of validate_type<T>()

* - use enable_if<is_valid<T>> in favor of validate_type<T>()

* - added is_valid_check<>. This wraps enable_if<is_valid<T>,bool> and makes use of SFINAE more robust

Co-authored-by: pfeatherstone <peter@me>
Co-authored-by: pf <pf@me>
Co-authored-by: Davis E. King <davis685@gmail.com>

* Just minor cleanup of docs and renamed some stuff, tweaked formatting.

* fix spelling error

* fix most vexing parse error

Co-authored-by: Davis E. King <davis@dlib.net>
Co-authored-by: pfeatherstone <45853521+pfeatherstone@users.noreply.github.com>
Co-authored-by: pfeatherstone <peter@me>
Co-authored-by: pf <pf@me>
Co-authored-by: Davis E. King <davis685@gmail.com>
pull/2453/head
Adrià Arrufat 3 years ago committed by GitHub
parent f323d1824c
commit 2e8bac1915
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,6 +31,8 @@ namespace dlib
template <size_t NR, size_t NC=NR>
class input_rgb_image_sized;
class input_rgb_image_pair;
class input_rgb_image
{
public:
@ -54,7 +56,11 @@ namespace dlib
template <size_t NR, size_t NC>
inline input_rgb_image (
const input_rgb_image_sized<NR,NC>& item
);
);
inline input_rgb_image (
const input_rgb_image_pair& item
);
float get_avg_red() const { return avg_red; }
float get_avg_green() const { return avg_green; }
@ -87,7 +93,7 @@ namespace dlib
);
}
// initialize data to the right size to contain the stuff in the iterator range.
data.set_size(std::distance(ibegin,iend), 3, nr, nc);
@ -127,7 +133,7 @@ namespace dlib
{
std::string version;
deserialize(version, in);
if (version != "input_rgb_image" && version != "input_rgb_image_sized")
if (version != "input_rgb_image" && version != "input_rgb_image_sized" && version != "input_rgb_image_pair")
throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image.");
deserialize(item.avg_red, in);
deserialize(item.avg_green, in);
@ -216,7 +222,7 @@ namespace dlib
);
}
// initialize data to the right size to contain the stuff in the iterator range.
data.set_size(std::distance(ibegin,iend), 3, NR, NC);
@ -303,6 +309,164 @@ namespace dlib
avg_blue(item.get_avg_blue())
{}
// ----------------------------------------------------------------------------------------
class input_rgb_image_pair
{
public:
typedef std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>> input_type;
input_rgb_image_pair (
) :
avg_red(122.782),
avg_green(117.001),
avg_blue(104.298)
{
}
input_rgb_image_pair (
float avg_red,
float avg_green,
float avg_blue
) : avg_red(avg_red), avg_green(avg_green), avg_blue(avg_blue)
{}
inline input_rgb_image_pair (
const input_rgb_image& item
) :
avg_red(item.get_avg_red()),
avg_green(item.get_avg_green()),
avg_blue(item.get_avg_blue())
{}
template <size_t NR, size_t NC>
inline input_rgb_image_pair (
const input_rgb_image_sized<NR, NC>& item
) :
avg_red(item.get_avg_red()),
avg_green(item.get_avg_green()),
avg_blue(item.get_avg_blue())
{}
float get_avg_red() const { return avg_red; }
float get_avg_green() const { return avg_green; }
float get_avg_blue() const { return avg_blue; }
bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
template <typename forward_iterator>
void to_tensor (
forward_iterator ibegin,
forward_iterator iend,
resizable_tensor& data
) const
{
DLIB_CASSERT(std::distance(ibegin, iend) > 0);
const auto nr = ibegin->first.nr();
const auto nc = ibegin->first.nc();
// make sure all the input matrices have the same dimensions
for (auto i = ibegin; i != iend; ++i)
{
DLIB_CASSERT(i->first.nr() == nr && i->first.nc()==nc &&
i->second.nr() == nr && i->second.nc() == nc,
"\t input_rgb_image_pair::to_tensor()"
<< "\n\t All matrices given to to_tensor() must have the same dimensions."
<< "\n\t nr: " << nr
<< "\n\t nc: " << nc
<< "\n\t i->first.nr(): " << i->first.nr()
<< "\n\t i->first.nc(): " << i->first.nc()
<< "\n\t i->second.nr(): " << i->second.nr()
<< "\n\t i->second.nc(): " << i->second.nc()
);
}
// initialize data to the right size to contain the stuff in the iterator range.
data.set_size(2 * std::distance(ibegin, iend), 3, nr, nc);
const size_t offset = nr * nc;
const size_t offset2 = data.size() / 2;
auto ptr = data.host();
for (auto i = ibegin; i != iend; ++i)
{
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
rgb_pixel temp_first = i->first(r, c);
rgb_pixel temp_second = i->second(r, c);
auto p = ptr++;
*p = (temp_first.red - avg_red) / 256.0;
*(p + offset2) = (temp_second.red - avg_red) / 256.0;
p += offset;
*p = (temp_first.green - avg_green) / 256.0;
*(p + offset2) = (temp_second.green - avg_green) / 256.0;
p += offset;
*p = (temp_first.blue - avg_blue) / 256.0;
*(p + offset2) = (temp_second.blue - avg_blue) / 256.0;
p += offset;
}
}
ptr += offset * (data.k() - 1);
}
}
friend void serialize(const input_rgb_image_pair& item, std::ostream& out)
{
serialize("input_rgb_image_pair", out);
serialize(item.avg_red, out);
serialize(item.avg_green, out);
serialize(item.avg_blue, out);
}
friend void deserialize(input_rgb_image_pair& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "input_rgb_image_pair" && version != "input_rgb_image" && version != "input_rgb_image_sized")
throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_pair.");
deserialize(item.avg_red, in);
deserialize(item.avg_green, in);
deserialize(item.avg_blue, in);
// read and discard the sizes if this was really a sized input layer.
if (version == "input_rgb_image_sized")
{
size_t nr, nc;
deserialize(nr, in);
deserialize(nc, in);
}
}
friend std::ostream& operator<<(std::ostream& out, const input_rgb_image_pair& item)
{
out << "input_rgb_image_pair("<< item.avg_red<<","<<item.avg_green<<","<<item.avg_blue << ")";
return out;
}
friend void to_xml(const input_rgb_image_pair& item, std::ostream& out)
{
out << "<input_rgb_image_pair r='"<<item.avg_red<<"' g='"<<item.avg_green<<"' b='"<<item.avg_blue<<"'/>";
}
private:
float avg_red;
float avg_green;
float avg_blue;
};
// ----------------------------------------------------------------------------------------
input_rgb_image::
input_rgb_image (
const input_rgb_image_pair& item
) : avg_red(item.get_avg_red()),
avg_green(item.get_avg_green()),
avg_blue(item.get_avg_blue())
{}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L>

@ -271,6 +271,115 @@ namespace dlib
};
// ----------------------------------------------------------------------------------------
class input_rgb_image_pair
{
/*!
WHAT THIS OBJECT REPRESENTS
This input layer works with std::pair of RGB images of type matrix<rgb_pixel>.
It is useful when you want to input image pairs that are related to each other,
for instance, they are different distorted views of the same original image.
It is mainly supposed to be used with unsupervised loss functions such as
loss_barlow_twins_. You can also convert between input_rgb_image and
input_rgb_image_pair by copy construction or assignment.
!*/
public:
typedef std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>> input_type;
input_rgb_image_pair (
);
/*!
ensures
- #get_avg_red() == 122.782
- #get_avg_green() == 117.001
- #get_avg_blue() == 104.298
!*/
input_rgb_image_pair (
float avg_red,
float avg_green,
float avg_blue
);
/*!
ensures
- #get_avg_red() == avg_red
- #get_avg_green() == avg_green
- #get_avg_blue() == avg_blue
!*/
inline input_rgb_image_pair (
const input_rgb_image& item
);
/*!
ensures
- #get_avg_red() == item.get_avg_red()
- #get_avg_green() == item.get_avg_green()
- #get_avg_blue() == item.get_avg_blue()
!*/
template <size_t NR, size_t NC>
inline input_rgb_image_pair (
const input_rgb_image_sized<NR, NC>& item
);
/*!
ensures
- #get_avg_red() == item.get_avg_red()
- #get_avg_green() == item.get_avg_green()
- #get_avg_blue() == item.get_avg_blue()
!*/
float get_avg_red(
) const;
/*!
ensures
- returns the value subtracted from the red color channel.
!*/
float get_avg_green(
) const;
/*!
ensures
- returns the value subtracted from the green color channel.
!*/
float get_avg_blue(
) const;
/*!
ensures
- returns the value subtracted from the blue color channel.
!*/
void to_tensor (
forward_iterator ibegin,
forward_iterator iend,
resizable_tensor& data
) const;
/*!
requires
- [ibegin, iend) is an iterator range over input_type objects.
- std::distance(ibegin,iend) > 0
- The input range should contain images that all have the same
dimensions.
ensures
- Converts the iterator range into a tensor and stores it into #data. In
particular, if the input images have R rows, C columns then we will have:
- #data.num_samples() == 2 * std::distance(ibegin,iend)
- #data.nr() == R
- #data.nc() == C
- #data.k() == 3
Moreover, each color channel is normalized by having its average value
subtracted (according to get_avg_red(), get_avg_green(), or
get_avg_blue()) and then is divided by 256.0.
Additionally, the first elements in each pair are placed in the first half
of the batch, and the second elements in the second half.
!*/
// Provided for compatibility with input_rgb_image_pyramid's interface
bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
// ----------------------------------------------------------------------------------------
template <

@ -3970,6 +3970,155 @@ namespace dlib
template <template <typename> class TAG_1, template <typename> class TAG_2, template <typename> class TAG_3, typename SUBNET>
using loss_yolo = add_loss_layer<loss_yolo_<TAG_1, TAG_2, TAG_3>, SUBNET>;
// ----------------------------------------------------------------------------------------
class loss_barlow_twins_
{
public:
loss_barlow_twins_() = default;
loss_barlow_twins_(float lambda) : lambda(lambda)
{
DLIB_CASSERT(lambda > 0);
}
template <
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
SUBNET& sub
) const
{
const tensor& output_tensor = sub.get_output();
tensor& grad = sub.get_gradient_input();
DLIB_CASSERT(sub.sample_expansion_factor() == 2);
DLIB_CASSERT(input_tensor.num_samples() != 0);
DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0);
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1);
DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1);
const auto batch_size = output_tensor.num_samples() / 2;
const auto sample_size = output_tensor.k();
const auto offset = batch_size * sample_size;
// Alias helpers to access the samples in the batch
alias_tensor split(batch_size, sample_size);
auto za = split(output_tensor);
auto zb = split(output_tensor, offset);
// Normalize both batches independently across the batch dimension
const double eps = 1e-4;
resizable_tensor za_norm, means_a, invstds_a;
resizable_tensor zb_norm, means_b, invstds_b;
resizable_tensor rms, rvs, g, b;
g.set_size(1, sample_size);
g = 1;
b.set_size(1, sample_size);
b = 0;
tt::batch_normalize(eps, za_norm, means_a, invstds_a, 1, rms, rvs, za, g, b);
tt::batch_normalize(eps, zb_norm, means_b, invstds_b, 1, rms, rvs, zb, g, b);
// Compute the empirical cross-correlation matrix
resizable_tensor eccm;
eccm.set_size(sample_size, sample_size);
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const matrix<float> C = mat(eccm);
const double diagonal_loss = sum(squared(diag(C) - 1));
const double off_diag_loss = sum(squared(C - diagm(diag(C))));
double loss = diagonal_loss + lambda * off_diag_loss;
// Loss gradient, which will be used as the input of the batch normalization gradient
resizable_tensor grad_input;
grad_input.copy_size(grad);
auto grad_input_a = split(grad_input);
auto grad_input_b = split(grad_input, offset);
// Compute the loss: notation from http://www.matrixcalculus.org/
// A = za_norm
// B = zb_norm
// C = eccm
// D = off_mask: a mask that keeps only the elements outside the diagonal
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// --------------------------------------------
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
resizable_tensor cdiag_1(diagm(diag(mat(eccm) - 1)));
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
// off-diag term: sum(((A'* B) .* D).^2)
// --------------------------------
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size));
resizable_tensor off_diag(sample_size, sample_size);
tt::multiply(false, off_diag, eccm, off_mask);
tt::gemm(1, grad_input_a, lambda, zb_norm, false, off_diag, false);
tt::gemm(1, grad_input_b, lambda, za_norm, false, off_diag, false);
// Compute the batch norm gradients, g and b grads are not used
resizable_tensor g_grad, b_grad;
g_grad.copy_size(g);
b_grad.copy_size(b);
auto gza = split(grad);
auto gzb = split(grad, offset);
tt::batch_normalize_gradient(eps, grad_input_a, means_a, invstds_a, za, g, gza, g_grad, b_grad);
tt::batch_normalize_gradient(eps, grad_input_b, means_b, invstds_b, zb, g, gzb, g_grad, b_grad);
return loss;
}
float get_lambda() const { return lambda; }
friend void serialize(const loss_barlow_twins_& item, std::ostream& out)
{
serialize("loss_barlow_twins_", out);
serialize(item.lambda, out);
}
friend void deserialize(loss_barlow_twins_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version == "loss_barlow_twins_")
{
deserialize(item.lambda, in);
}
else
{
throw serialization_error("Unexpected version found while deserializing dlib::loss_barlow_twins_. Instead found " + version);
}
}
friend std::ostream& operator<<(std::ostream& out, const loss_barlow_twins_& item)
{
out << "loss_barlow_twins (lambda=" << item.lambda << ")";
return out;
}
friend void to_xml(const loss_barlow_twins_& item, std::ostream& out)
{
out << "<loss_barlow_twins lambda='" << item.lambda << "'/>";
}
private:
float lambda = 0.0051;
};
template <typename SUBNET>
using loss_barlow_twins = add_loss_layer<loss_barlow_twins_, SUBNET>;
}
#endif // DLIB_DNn_LOSS_H_

@ -2038,6 +2038,88 @@ namespace dlib
// ----------------------------------------------------------------------------------------
class loss_barlow_twins_
{
public:
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the loss layer interface defined above by
EXAMPLE_LOSS_LAYER_. In particular, it implements the Barlow Twins loss
layer presented in the paper:
Barlow Twins: Self-Supervised Learning via Redundancy Reduction
by Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny
(https://arxiv.org/abs/2103.03230)
This means you use this loss to learn useful representations from data that
has no label information. Useful representations mean that can be used to
train another downstream task, such as classification.
In particular, this loss function applies the redundancy reduction principle
to the representations learned by the network it sits on top of.
To be specific, this layer requires the sample_expansion_factor to be 2, and
in each batch, the second half contains distorted versions of the first half.
Let Z_A and Z_B be the first and second half of the batch that goes into this
loss layer, respectively. Z_A and Z_B have dimensions N rows and D columns,
where N is half the batch size and D is the dimensionality of the output tensor.
Each row in Z_B should contain a distorted version of the corresponding row
in Z_A. Then, this loss computes the empirical cross-correlation matrix between
the batch-normalized versions of Z_A and Z_B:
C = trans(bn(Z_A)) * bn(Z_B)
It then applies the redundancy reduction principle by trying to make C as
close to the identity matrix as possible:
L = squared(diag(C) - 1) + lambda * squared(off_diag(C))
where off_diag grabs all the elements that are not on the diagonal of C and
lambda provides a trade-off between both terms in the loss function. The C
matrix has dimensions D x D: there are only D diagonal terms, but D * (D - 1)
off-diagonal elements. A reasonable value for lambda is 1 / D.
!*/
loss_barlow_twins_(
);
/*!
ensures
- #get_lambda() == 0.0051
!*/
loss_barlow_twins_(float lambda);
/*!
ensures
- #get_lambda() == lambda
!*/
float get_lambda() const;
/*!
ensures
- returns the lambda value used by the loss function. See the discussion
in WHAT THIS OBJECT REPRESENTS for details.
!*/
template <
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
SUBNET& sub
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
except it has the additional calling requirements that:
- sub.get_output().nr() == 1
- sub.get_output().nc() == 1
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 2
!*/
};
template <typename SUBNET>
using loss_barlow_twins = add_loss_layer<loss_barlow_twins_, SUBNET>;
}
#endif // DLIB_DNn_LOSS_ABSTRACT_H_

@ -156,6 +156,7 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
add_example(dnn_metric_learning_on_images_ex)
add_gui_example(dnn_dcgan_train_ex)
add_gui_example(dnn_yolo_train_ex)
add_gui_example(dnn_self_supervised_learning_ex)
endif()

@ -0,0 +1,364 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This is an example illustrating the use of the deep learning tools from the dlib C++
Library. I'm assuming you have already read the dnn_introduction_ex.cpp, the
dnn_introduction2_ex.cpp and the dnn_introduction3_ex.cpp examples. In this example
program we are going to show how one can train a neural network using an unsupervised
loss function. In particular, we will train the ResNet50 model from the paper
"Deep Residual Learning for Image Recognition" by Kaiming He, Xiangyu Zhang, Shaoqing
Ren, Jian Sun.
To train the unsupervised loss, we will use the self-supervised learning (SSL) method
called Barlow Twins, introduced in this paper:
"Barlow Twins: Self-Supervised Learning via Redundancy Reduction" by Jure Zbontar,
Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny.
The paper contains a good explanation on how and why this works, but the main idea
behind the Barlow Twins method is:
- generate two distorted views of a batch of images: YA, YB
- feed them to a deep neural network and obtain their representations and
and batch normalize them: ZA, ZB
- compute the empirical cross-correlation matrix between both feature
representations as: C = trans(ZA) * ZB.
- make C as close as possible to the identity matrix.
This removes the redundancy of the feature representations, by maximizing the
encoded information about the images themselves, while minimizing the information
about the transforms and data augmentations used to obtain the representations.
The original Barlow Twins paper uses the ImageNet dataset, but in this example we
are using CIFAR-10, so we will follow the recommendations of this paper, instead:
"A Note on Connecting Barlow Twins with Negative-Sample-Free Contrastive Learning"
by Yao-Hung Hubert Tsai, Shaojie Bai, Louis-Philippe Morency, Ruslan Salakhutdinov,
in which they experiment with Barlow Twins on CIFAR-10 and Tiny ImageNet. Since
the CIFAR-10 contains relatively small images, we will define a ResNet50 architecture
that doesn't downsample the input in the first convolutional layer, and doesn't have
a max pooling layer afterwards, like the paper does.
*/
#include <dlib/dnn.h>
#include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h>
#include <dlib/gui_widgets.h>
using namespace std;
using namespace dlib;
// A custom definition of ResNet50 with a downsampling factor of 8 instead of 32.
// It is essentially the original ResNet50, but without the max pooling and a
// convolutional layer with a stride of 1 instead of 2 at the input.
namespace resnet50
{
using namespace dlib;
template <template <typename> class BN>
struct def
{
template <long N, int K, int S, typename SUBNET>
using conv = add_layer<con_<N, K, K, S, S, K / 2, K / 2>, SUBNET>;
template<long N, int S, typename SUBNET>
using bottleneck = BN<conv<4 * N, 1, 1, relu<BN<conv<N, 3, S, relu<BN<conv<N, 1, 1, SUBNET>>>>>>>>;
template <long N, typename SUBNET>
using residual = add_prev1<bottleneck<N, 1, tag1<SUBNET>>>;
template <typename SUBNET> using res_512 = relu<residual<512, SUBNET>>;
template <typename SUBNET> using res_256 = relu<residual<256, SUBNET>>;
template <typename SUBNET> using res_128 = relu<residual<128, SUBNET>>;
template <typename SUBNET> using res_64 = relu<residual<64, SUBNET>>;
template <long N, int S, typename SUBNET>
using transition = add_prev2<BN<conv<4 * N, 1, S, skip1<tag2<bottleneck<N, S, tag1<SUBNET>>>>>>>;
template <typename INPUT>
using backbone = avg_pool_everything<
repeat<2, res_512, transition<512, 2,
repeat<5, res_256, transition<256, 2,
repeat<3, res_128, transition<128, 2,
repeat<2, res_64, transition<64, 1,
relu<BN<conv<64, 3, 1,INPUT>>>>>>>>>>>>;
};
};
// This model namespace contains the definitions for:
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
namespace model
{
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
template <typename SUBNET> using classifier = fc<10, SUBNET>;
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>;
}
rectangle make_random_cropping_rect(
const matrix<rgb_pixel>& image,
dlib::rand& rnd
)
{
const double mins = 7. / 15.;
const double maxs = 7. / 8.;
const auto scale = rnd.get_double_in_range(mins, maxs);
const auto size = scale * std::min(image.nr(), image.nc());
const rectangle rect(size, size);
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
return move_rect(rect, offset);
}
matrix<rgb_pixel> augment(
const matrix<rgb_pixel>& image,
const bool prime,
dlib::rand& rnd
)
{
matrix<rgb_pixel> crop;
// blur
matrix<rgb_pixel> blurred;
const double sigma = rnd.get_double_in_range(0.1, 1.1);
if (!prime || (prime && rnd.get_random_double() < 0.1))
{
const auto rect = gaussian_blur(image, blurred, sigma);
extract_image_chip(blurred, rect, crop);
blurred = crop;
}
else
{
blurred = image;
}
// randomly crop
const auto rect = make_random_cropping_rect(image, rnd);
extract_image_chip(blurred, chip_details(rect, chip_dims(32, 32)), crop);
// image left-right flip
if (rnd.get_random_double() < 0.5)
flip_image_left_right(crop);
// color augmentation
if (rnd.get_random_double() < 0.8)
disturb_colors(crop, rnd, 0.5, 0.5);
// grayscale
if (rnd.get_random_double() < 0.2)
{
matrix<unsigned char> gray;
assign_image(gray, crop);
assign_image(crop, gray);
}
// solarize
if (prime && rnd.get_random_double() < 0.2)
{
for (auto& p : crop)
{
if (p.red > 128)
p.red = 255 - p.red;
if (p.green > 128)
p.green = 255 - p.green;
if (p.blue > 128)
p.blue = 255 - p.blue;
}
}
return crop;
}
int main(const int argc, const char** argv)
try
{
// The default settings are fine for the example already.
command_line_parser parser;
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
parser.add_option("min-learning-rate", "set the min learning rate (default: 1e-5)", 1);
parser.add_option("num-gpus", "number of GPUs (default: 1)", 1);
parser.set_group_name("Help Options");
parser.add_option("h", "alias for --help");
parser.add_option("help", "display this message and exit");
parser.parse(argc, argv);
if (parser.number_of_arguments() < 1 || parser.option("h") || parser.option("help"))
{
cout << "This example needs the CIFAR-10 dataset to run." << endl;
cout << "You can get CIFAR-10 from https://www.cs.toronto.edu/~kriz/cifar.html" << endl;
cout << "Download the binary version the dataset, decompress it, and put the 6" << endl;
cout << "bin files in a folder. Then give that folder as input to this program." << endl;
parser.print_options();
return EXIT_SUCCESS;
}
const size_t num_gpus = get_option(parser, "num-gpus", 1);
const size_t batch_size = get_option(parser, "batch", 64) * num_gpus;
const long dims = get_option(parser, "dims", 128);
const double lambda = get_option(parser, "lambda", 1.0 / dims);
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
// Load the CIFAR-10 dataset into memory.
std::vector<matrix<rgb_pixel>> training_images, testing_images;
std::vector<unsigned long> training_labels, testing_labels;
load_cifar_10_dataset(parser[0], training_images, training_labels, testing_images, testing_labels);
// Initialize the model with the specified projector dimensions and lambda. According to the
// second paper, lambda = 1/dims works well on CIFAR-10.
model::train net((loss_barlow_twins_(lambda)));
layer<1>(net).layer_details().set_num_outputs(dims);
disable_duplicative_biases(net);
dlib::rand rnd;
std::vector<int> gpus(num_gpus);
std::iota(gpus.begin(), gpus.end(), 0);
// Train the feature extractor using the Barlow Twins method
{
dnn_trainer<model::train, adam> trainer(net, adam(1e-6, 0.9, 0.999), gpus);
trainer.set_mini_batch_size(batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(10000);
trainer.set_synchronization_file("barlow_twins_sync");
trainer.be_verbose();
cout << trainer << endl;
// During the training, we will compute the empirical cross-correlation matrix
// between the features of both versions of the augmented images. This matrix
// should be getting close to the identity matrix as the training progresses.
// Note that this step is already done in the loss layer, and it's not necessary
// to do it here for the example to work. However, it provides a nice
// visualization of the training progress: the closer to the identity matrix,
// the better.
resizable_tensor eccm;
eccm.set_size(dims, dims);
// Some tensors needed to perform batch normalization
resizable_tensor za_norm, zb_norm, means, invstds, rms, rvs, gamma, beta;
const double eps = DEFAULT_BATCH_NORM_EPS;
gamma.set_size(1, dims);
beta.set_size(1, dims);
image_window win;
std::vector<std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
{
batch.clear();
while (batch.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
auto image = training_images[idx];
batch.emplace_back(augment(image, false, rnd), augment(image, true, rnd));
}
trainer.train_one_step(batch);
// Compute the empirical cross-correlation matrix every 100 steps. Again,
// this is not needed for the training to work, but it's nice to visualize.
if (trainer.get_train_one_step_calls() % 100 == 0)
{
// Wait for threaded processing to stop in the trainer.
trainer.get_net(force_flush_to_disk::no);
// Get the output from the last fc layer
const auto& out = net.subnet().get_output();
// The trainer might have synchronized its state to the disk and cleaned
// the network state. If that happens, the output will be empty, in
// which case, we just skip the empirical cross-correlation matrix
// computation.
if (out.size() == 0)
continue;
// Separate both augmented versions of the images
alias_tensor split(out.num_samples() / 2, dims);
auto za = split(out);
auto zb = split(out, split.size());
gamma = 1;
beta = 0;
// Perform batch normalization on each feature representation, independently.
tt::batch_normalize(eps, za_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
tt::batch_normalize(eps, zb_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
// Compute the empirical cross-correlation matrix between the features and
// visualize it.
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
win.set_image(abs(mat(eccm)) * 255);
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
}
}
trainer.get_net();
net.clean();
// After training, we can discard the projector head and just keep the backone
// to train it or finetune it on other downstream tasks.
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
}
// To check the quality of the learned feature representations, we will train a linear
// classififer on top of the frozen backbone.
model::infer inet;
// Assign the network, without the projector, which is only used for the self-supervised
// training.
layer<2>(inet) = layer<5>(net);
// Freeze the backbone
set_all_learning_rate_multipliers(layer<2>(inet), 0);
// Train the network
{
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus);
// Since this model doesn't train with pairs, just single images, we can increase
// the batch size.
trainer.set_mini_batch_size(2 * batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(5000);
trainer.set_synchronization_file("cifar_10_sync");
trainer.be_verbose();
cout << trainer << endl;
std::vector<matrix<rgb_pixel>> images;
std::vector<unsigned long> labels;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
{
images.clear();
labels.clear();
while (images.size() < trainer.get_mini_batch_size())
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
images.push_back(augment(training_images[idx], false, rnd));
labels.push_back(training_labels[idx]);
}
trainer.train_one_step(images, labels);
}
trainer.get_net();
inet.clean();
serialize("resnet50_cifar_10.dnn") << inet;
}
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
auto compute_accuracy = [&inet, batch_size](
const std::vector<matrix<rgb_pixel>>& images,
const std::vector<unsigned long>& labels
)
{
size_t num_right = 0;
size_t num_wrong = 0;
const auto preds = inet(images, batch_size * 2);
for (size_t i = 0; i < labels.size(); ++i)
{
if (labels[i] == preds[i])
++num_right;
else
++num_wrong;
}
cout << "num right: " << num_right << endl;
cout << "num wrong: " << num_wrong << endl;
cout << "accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
cout << "error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
};
// If everything works as expected, we should get accuracies that are between 87% and 90%.
cout << "training accuracy" << endl;
compute_accuracy(training_images, training_labels);
cout << "\ntesting accuracy" << endl;
compute_accuracy(testing_images, testing_labels);
return EXIT_SUCCESS;
}
catch (const exception& e)
{
cout << e.what() << endl;
return EXIT_FAILURE;
}
Loading…
Cancel
Save