loss_barlow_twins: add get_eccm member function (#2906)

This allows us to greatly simplify the self supervised learning example:
- the computation in user code was a bit too distracting
- avoids duplicated computation/allocation of this matrix
- avoids edge case where net outputs are zero due to trainer synchronization
This commit is contained in:
Adrià Arrufat 2024-01-09 11:59:09 +09:00 committed by GitHub
parent 46e59a2174
commit b0f6be8058
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 33 deletions

View File

@ -4154,6 +4154,9 @@ namespace dlib
float get_lambda() const { return lambda; } float get_lambda() const { return lambda; }
tensor& get_eccm() { return eccm; }
const tensor& get_eccm() const { return eccm; }
friend void serialize(const loss_barlow_twins_& item, std::ostream& out) friend void serialize(const loss_barlow_twins_& item, std::ostream& out)
{ {
serialize("loss_barlow_twins_", out); serialize("loss_barlow_twins_", out);

View File

@ -2144,6 +2144,20 @@ namespace dlib
in WHAT THIS OBJECT REPRESENTS for details. in WHAT THIS OBJECT REPRESENTS for details.
!*/ !*/
tensor& get_eccm();
/*!
ensures
- returns the empirical cross-correlation matrix computed by the loss.
- this is only meant to be used for visualization/debugging purposes.
!*/
const tensor& get_eccm() const;
/*!
ensures
- returns the empirical cross-correlation matrix computed by the loss.
- this is only meant to be used for visualization/debugging purposes.
!*/
template < template <
typename SUBNET typename SUBNET
> >

View File

@ -206,20 +206,10 @@ try
trainer.be_verbose(); trainer.be_verbose();
cout << trainer << endl; cout << trainer << endl;
// During the training, we will compute the empirical cross-correlation // During the training, we will visualize the empirical cross-correlation
// matrix between the features of both versions of the augmented images. // matrix between the features of both versions of the augmented images.
// This matrix should be getting close to the identity matrix as the training // This matrix should be getting close to the identity matrix as the training
// progresses. Note that this step is already done in the loss layer, and it's // progresses. Note that this is done here for visualization purposes only.
// not necessary to do it here for the example to work. However, it provides
// a nice visualization of the training progress: the closer to the identity
// matrix, the better.
resizable_tensor eccm;
eccm.set_size(dims, dims);
// Some tensors needed to perform batch normalization
resizable_tensor za_norm, zb_norm, means, invstds, rms, rvs, gamma, beta;
const double eps = DEFAULT_BATCH_NORM_EPS;
gamma.set_size(1, dims);
beta.set_size(1, dims);
image_window win; image_window win;
std::vector<pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch; std::vector<pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
@ -234,32 +224,13 @@ try
} }
trainer.train_one_step(batch); trainer.train_one_step(batch);
// Compute the empirical cross-correlation matrix every 100 steps. Again, // Get the empirical cross-correlation matrix every 100 steps. Again,
// this is not needed for the training to work, but it's nice to visualize. // this is not needed for the training to work, but it's nice to visualize.
if (trainer.get_train_one_step_calls() % 100 == 0) if (trainer.get_train_one_step_calls() % 100 == 0)
{ {
// Wait for threaded processing to stop in the trainer. // Wait for threaded processing to stop in the trainer.
trainer.get_net(force_flush_to_disk::no); trainer.get_net(force_flush_to_disk::no);
// Get the output from the last fc layer const matrix<float> eccm = mat(net.loss_details().get_eccm());
const auto& out = net.subnet().get_output();
// The trainer might have synchronized its state to the disk and cleaned
// the network state. If that happens, the output will be empty, in which
// case, we just skip the empirical cross-correlation matrix computation.
if (out.size() == 0)
continue;
// Separate both augmented versions of the images
alias_tensor split(out.num_samples() / 2, dims);
auto za = split(out);
auto zb = split(out, split.size());
gamma = 1;
beta = 0;
// Perform batch normalization on each feature representation, independently.
tt::batch_normalize(eps, za_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
tt::batch_normalize(eps, zb_norm, means, invstds, 1, rms, rvs, zb, gamma, beta);
// Compute the empirical cross-correlation matrix between the features and
// visualize it.
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
win.set_image(round(abs(mat(eccm)) * 255)); win.set_image(round(abs(mat(eccm)) * 255));
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls())); win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
} }