diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index a238f9288..43c218302 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -4066,8 +4066,8 @@ namespace dlib resizable_tensor off_mask(ones_matrix(sample_size, sample_size) - identity_matrix(sample_size)); resizable_tensor off_diag(sample_size, sample_size); tt::multiply(false, off_diag, eccm, off_mask); - tt::gemm(1, grad_input_a, lambda, zb_norm, false, off_diag, false); - tt::gemm(1, grad_input_b, lambda, za_norm, false, off_diag, false); + tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false); + tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false); // Compute the batch norm gradients, g and b grads are not used resizable_tensor g_grad, b_grad; diff --git a/examples/dnn_self_supervised_learning_ex.cpp b/examples/dnn_self_supervised_learning_ex.cpp index 6ad166117..cb73d1d67 100644 --- a/examples/dnn_self_supervised_learning_ex.cpp +++ b/examples/dnn_self_supervised_learning_ex.cpp @@ -277,7 +277,7 @@ try // visualize it. tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false); eccm /= batch_size; - win.set_image(abs(mat(eccm)) * 255); + win.set_image(round(abs(mat(eccm)) * 255)); win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls())); } } @@ -304,12 +304,14 @@ try auto cross_validation_score = [&](const double c) { svm_multiclass_linear_trainer>, unsigned long> trainer; - trainer.set_num_threads(std::thread::hardware_concurrency()); trainer.set_c(c); + trainer.set_epsilon(0.01); + trainer.set_max_iterations(100); + trainer.set_num_threads(std::thread::hardware_concurrency()); cout << "C: " << c << endl; const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3); const double accuracy = sum(diag(cm)) / sum(cm); - cout << "cross validation accuracy: " << accuracy << endl;; + cout << "cross validation accuracy: " << accuracy << endl; cout << "confusion matrix:\n " << cm << endl; return accuracy; }; @@ -345,7 +347,7 @@ try cout << " error rate: " << num_wrong / static_cast(num_right + num_wrong) << endl; }; - // We should get a training accuracy of around 93% and a testing accuracy of around 88%. + // We should get a training accuracy of around 93% and a testing accuracy of around 89%. cout << "\ntraining accuracy" << endl; compute_accuracy(features, training_labels); cout << "\ntesting accuracy" << endl;