mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Fix Barlow Twins loss gradient (#2518)
* Fix Barlow Twins loss gradient * Update reference test accuracy after fix * Round the empirical cross-correlation matrix Just a tiny modification that allows the values to actually reach 255 (perfect white).
This commit is contained in:
parent
39852f092c
commit
50b78da53a
@ -4066,8 +4066,8 @@ namespace dlib
|
|||||||
resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size));
|
resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size));
|
||||||
resizable_tensor off_diag(sample_size, sample_size);
|
resizable_tensor off_diag(sample_size, sample_size);
|
||||||
tt::multiply(false, off_diag, eccm, off_mask);
|
tt::multiply(false, off_diag, eccm, off_mask);
|
||||||
tt::gemm(1, grad_input_a, lambda, zb_norm, false, off_diag, false);
|
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
|
||||||
tt::gemm(1, grad_input_b, lambda, za_norm, false, off_diag, false);
|
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
|
||||||
|
|
||||||
// Compute the batch norm gradients, g and b grads are not used
|
// Compute the batch norm gradients, g and b grads are not used
|
||||||
resizable_tensor g_grad, b_grad;
|
resizable_tensor g_grad, b_grad;
|
||||||
|
@ -277,7 +277,7 @@ try
|
|||||||
// visualize it.
|
// visualize it.
|
||||||
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
|
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
|
||||||
eccm /= batch_size;
|
eccm /= batch_size;
|
||||||
win.set_image(abs(mat(eccm)) * 255);
|
win.set_image(round(abs(mat(eccm)) * 255));
|
||||||
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
|
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -304,12 +304,14 @@ try
|
|||||||
auto cross_validation_score = [&](const double c)
|
auto cross_validation_score = [&](const double c)
|
||||||
{
|
{
|
||||||
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
|
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
|
||||||
trainer.set_num_threads(std::thread::hardware_concurrency());
|
|
||||||
trainer.set_c(c);
|
trainer.set_c(c);
|
||||||
|
trainer.set_epsilon(0.01);
|
||||||
|
trainer.set_max_iterations(100);
|
||||||
|
trainer.set_num_threads(std::thread::hardware_concurrency());
|
||||||
cout << "C: " << c << endl;
|
cout << "C: " << c << endl;
|
||||||
const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3);
|
const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3);
|
||||||
const double accuracy = sum(diag(cm)) / sum(cm);
|
const double accuracy = sum(diag(cm)) / sum(cm);
|
||||||
cout << "cross validation accuracy: " << accuracy << endl;;
|
cout << "cross validation accuracy: " << accuracy << endl;
|
||||||
cout << "confusion matrix:\n " << cm << endl;
|
cout << "confusion matrix:\n " << cm << endl;
|
||||||
return accuracy;
|
return accuracy;
|
||||||
};
|
};
|
||||||
@ -345,7 +347,7 @@ try
|
|||||||
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We should get a training accuracy of around 93% and a testing accuracy of around 88%.
|
// We should get a training accuracy of around 93% and a testing accuracy of around 89%.
|
||||||
cout << "\ntraining accuracy" << endl;
|
cout << "\ntraining accuracy" << endl;
|
||||||
compute_accuracy(features, training_labels);
|
compute_accuracy(features, training_labels);
|
||||||
cout << "\ntesting accuracy" << endl;
|
cout << "\ntesting accuracy" << endl;
|
||||||
|
Loading…
Reference in New Issue
Block a user