Fix computation of the Barlow Twins loss gradient (#2680)

This commit is contained in:
Adria Arrufat 2022-11-02 20:55:58 +09:00 committed by GitHub
parent 7f06f6e185
commit bdb1089ae6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 17 deletions

View File

@ -4128,16 +4128,16 @@ namespace dlib
// --------------------------------------------
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_a, 2.0 / batch_size, zb_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_b, 2.0 / batch_size, za_norm, false, cdiag_1, false);
// off-diag term: sum(((A'* B) .* D).^2)
// --------------------------------
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C' .* (D .* D)) = 2 * B * (C' .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
tt::multiply(false, off_diag, eccm, off_mask);
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
tt::gemm(1, grad_input_a, lambda * 2.0 / batch_size, zb_norm, false, off_diag, true);
tt::gemm(1, grad_input_b, lambda * 2.0 / batch_size, za_norm, false, off_diag, false);
// Compute the batch norm gradients, g and b grads are not used
auto gza = split(grad);

View File

@ -107,7 +107,7 @@ rectangle make_random_cropping_rect(
const double mins = 7. / 15.;
const double maxs = 7. / 8.;
const auto scale = rnd.get_double_in_range(mins, maxs);
const auto size = scale * std::min(image.nr(), image.nc());
const auto size = scale * min(image.nr(), image.nc());
const rectangle rect(size, size);
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
@ -179,11 +179,12 @@ try
command_line_parser parser;
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
parser.add_option("lambda", "off-diagonal terms penalty (default: 1/dims)", 1);
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
parser.add_option("min-learning-rate", "set the min learning rate (default: 1e-5)", 1);
parser.add_option("num-gpus", "number of GPUs (default: 1)", 1);
parser.add_option("fraction", "fraction of labels to use (default: 0.1)", 1);
parser.add_option("patience", "steps without progress threshold (default: 10000)", 1);
parser.set_group_name("Help Options");
parser.add_option("h", "alias for --help");
parser.add_option("help", "display this message and exit");
@ -200,13 +201,14 @@ try
}
parser.check_option_arg_range("fraction", 0.0, 1.0);
const double labels_fraction = get_option(parser, "fraction", 0.1);
const size_t num_gpus = get_option(parser, "num-gpus", 1);
const size_t batch_size = get_option(parser, "batch", 64) * num_gpus;
const long dims = get_option(parser, "dims", 128);
const double lambda = get_option(parser, "lambda", 1.0 / dims);
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
const double labels_fraction = get_option(parser, "fraction", 0.1);
const size_t patience = get_option(parser, "patience", 10000);
// Load the CIFAR-10 dataset into memory.
std::vector<matrix<rgb_pixel>> training_images, testing_images;
@ -220,7 +222,7 @@ try
disable_duplicative_biases(net);
dlib::rand rnd;
std::vector<int> gpus(num_gpus);
std::iota(gpus.begin(), gpus.end(), 0);
iota(gpus.begin(), gpus.end(), 0);
// Train the feature extractor using the Barlow Twins method on all the training
// data.
@ -229,7 +231,7 @@ try
trainer.set_mini_batch_size(batch_size);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(min_learning_rate);
trainer.set_iterations_without_progress_threshold(10000);
trainer.set_iterations_without_progress_threshold(patience);
trainer.set_synchronization_file("barlow_twins_sync");
trainer.be_verbose();
cout << trainer << endl;
@ -250,7 +252,7 @@ try
beta.set_size(1, dims);
image_window win;
std::vector<std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
std::vector<pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
{
batch.clear();
@ -308,14 +310,14 @@ try
randomize_samples(training_images, training_labels);
std::vector<matrix<rgb_pixel>> sub_images(
training_images.begin(),
training_images.begin() + std::lround(training_images.size() * labels_fraction));
training_images.begin() + lround(training_images.size() * labels_fraction));
std::vector<unsigned long> sub_labels(
training_labels.begin(),
training_labels.begin() + std::lround(training_labels.size() * labels_fraction));
training_labels.begin() + lround(training_labels.size() * labels_fraction));
std::swap(sub_images, training_images);
std::swap(sub_labels, training_labels);
swap(sub_images, training_images);
swap(sub_labels, training_labels);
}
// Let's generate the features for those samples that have labels to train a
@ -324,12 +326,12 @@ try
cout << "Extracting features for linear classifier from " << training_images.size() << " samples..." << endl;
features = fnet(training_images, 4 * batch_size);
const auto df = auto_train_multiclass_svm_linear_classifier(features, training_labels, std::chrono::minutes(1));
const auto df = auto_train_multiclass_svm_linear_classifier(features, training_labels, chrono::minutes(1));
serialize("multiclass_svm_cifar_10.dat") << df;
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and
// test images.
auto compute_accuracy = [&df](
const auto compute_accuracy = [&df](
const std::vector<matrix<float, 0, 1>>& samples,
const std::vector<unsigned long>& labels
)