From b0f6be80587b94585949f6943a7ee263394c2d15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:59:09 +0900 Subject: [PATCH] loss_barlow_twins: add get_eccm member function (#2906) This allows us to greatly simplify the self supervised learning example: - the computation in user code was a bit too distracting - avoids duplicated computation/allocation of this matrix - avoids edge case where net outputs are zero due to trainer synchronization --- dlib/dnn/loss.h | 3 ++ dlib/dnn/loss_abstract.h | 14 ++++++++ examples/dnn_self_supervised_learning_ex.cpp | 37 +++----------------- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 91e0b7886..70c316b0f 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -4154,6 +4154,9 @@ namespace dlib float get_lambda() const { return lambda; } + tensor& get_eccm() { return eccm; } + const tensor& get_eccm() const { return eccm; } + friend void serialize(const loss_barlow_twins_& item, std::ostream& out) { serialize("loss_barlow_twins_", out); diff --git a/dlib/dnn/loss_abstract.h b/dlib/dnn/loss_abstract.h index cb3a2f145..9ddfb6a4a 100644 --- a/dlib/dnn/loss_abstract.h +++ b/dlib/dnn/loss_abstract.h @@ -2144,6 +2144,20 @@ namespace dlib in WHAT THIS OBJECT REPRESENTS for details. !*/ + tensor& get_eccm(); + /*! + ensures + - returns the empirical cross-correlation matrix computed by the loss. + - this is only meant to be used for visualization/debugging purposes. + !*/ + + const tensor& get_eccm() const; + /*! + ensures + - returns the empirical cross-correlation matrix computed by the loss. + - this is only meant to be used for visualization/debugging purposes. + !*/ + template < typename SUBNET > diff --git a/examples/dnn_self_supervised_learning_ex.cpp b/examples/dnn_self_supervised_learning_ex.cpp index 3e0079414..c8d472dbd 100644 --- a/examples/dnn_self_supervised_learning_ex.cpp +++ b/examples/dnn_self_supervised_learning_ex.cpp @@ -206,20 +206,10 @@ try trainer.be_verbose(); cout << trainer << endl; - // During the training, we will compute the empirical cross-correlation + // During the training, we will visualize the empirical cross-correlation // matrix between the features of both versions of the augmented images. // This matrix should be getting close to the identity matrix as the training - // progresses. Note that this step is already done in the loss layer, and it's - // not necessary to do it here for the example to work. However, it provides - // a nice visualization of the training progress: the closer to the identity - // matrix, the better. - resizable_tensor eccm; - eccm.set_size(dims, dims); - // Some tensors needed to perform batch normalization - resizable_tensor za_norm, zb_norm, means, invstds, rms, rvs, gamma, beta; - const double eps = DEFAULT_BATCH_NORM_EPS; - gamma.set_size(1, dims); - beta.set_size(1, dims); + // progresses. Note that this is done here for visualization purposes only. image_window win; std::vector, matrix>> batch; @@ -234,32 +224,13 @@ try } trainer.train_one_step(batch); - // Compute the empirical cross-correlation matrix every 100 steps. Again, + // Get the empirical cross-correlation matrix every 100 steps. Again, // this is not needed for the training to work, but it's nice to visualize. if (trainer.get_train_one_step_calls() % 100 == 0) { // Wait for threaded processing to stop in the trainer. trainer.get_net(force_flush_to_disk::no); - // Get the output from the last fc layer - const auto& out = net.subnet().get_output(); - // The trainer might have synchronized its state to the disk and cleaned - // the network state. If that happens, the output will be empty, in which - // case, we just skip the empirical cross-correlation matrix computation. - if (out.size() == 0) - continue; - // Separate both augmented versions of the images - alias_tensor split(out.num_samples() / 2, dims); - auto za = split(out); - auto zb = split(out, split.size()); - gamma = 1; - beta = 0; - // Perform batch normalization on each feature representation, independently. - tt::batch_normalize(eps, za_norm, means, invstds, 1, rms, rvs, za, gamma, beta); - tt::batch_normalize(eps, zb_norm, means, invstds, 1, rms, rvs, zb, gamma, beta); - // Compute the empirical cross-correlation matrix between the features and - // visualize it. - tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false); - eccm /= batch_size; + const matrix eccm = mat(net.loss_details().get_eccm()); win.set_image(round(abs(mat(eccm)) * 255)); win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls())); }