mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Fix computation of the Barlow Twins loss gradient (#2680)
This commit is contained in:
parent
7f06f6e185
commit
bdb1089ae6
@ -4128,16 +4128,16 @@ namespace dlib
|
||||
// --------------------------------------------
|
||||
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
|
||||
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
|
||||
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
|
||||
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
|
||||
tt::gemm(0, grad_input_a, 2.0 / batch_size, zb_norm, false, cdiag_1, false);
|
||||
tt::gemm(0, grad_input_b, 2.0 / batch_size, za_norm, false, cdiag_1, false);
|
||||
|
||||
// off-diag term: sum(((A'* B) .* D).^2)
|
||||
// --------------------------------
|
||||
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
|
||||
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C' .* (D .* D)) = 2 * B * (C' .* D)
|
||||
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
|
||||
tt::multiply(false, off_diag, eccm, off_mask);
|
||||
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
|
||||
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
|
||||
tt::gemm(1, grad_input_a, lambda * 2.0 / batch_size, zb_norm, false, off_diag, true);
|
||||
tt::gemm(1, grad_input_b, lambda * 2.0 / batch_size, za_norm, false, off_diag, false);
|
||||
|
||||
// Compute the batch norm gradients, g and b grads are not used
|
||||
auto gza = split(grad);
|
||||
|
@ -107,7 +107,7 @@ rectangle make_random_cropping_rect(
|
||||
const double mins = 7. / 15.;
|
||||
const double maxs = 7. / 8.;
|
||||
const auto scale = rnd.get_double_in_range(mins, maxs);
|
||||
const auto size = scale * std::min(image.nr(), image.nc());
|
||||
const auto size = scale * min(image.nr(), image.nc());
|
||||
const rectangle rect(size, size);
|
||||
const point offset(rnd.get_random_32bit_number() % (image.nc() - rect.width()),
|
||||
rnd.get_random_32bit_number() % (image.nr() - rect.height()));
|
||||
@ -179,11 +179,12 @@ try
|
||||
command_line_parser parser;
|
||||
parser.add_option("batch", "set the mini batch size per GPU (default: 64)", 1);
|
||||
parser.add_option("dims", "set the projector dimensions (default: 128)", 1);
|
||||
parser.add_option("lambda", "penalize off-diagonal terms (default: 1/dims)", 1);
|
||||
parser.add_option("lambda", "off-diagonal terms penalty (default: 1/dims)", 1);
|
||||
parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1);
|
||||
parser.add_option("min-learning-rate", "set the min learning rate (default: 1e-5)", 1);
|
||||
parser.add_option("num-gpus", "number of GPUs (default: 1)", 1);
|
||||
parser.add_option("fraction", "fraction of labels to use (default: 0.1)", 1);
|
||||
parser.add_option("patience", "steps without progress threshold (default: 10000)", 1);
|
||||
parser.set_group_name("Help Options");
|
||||
parser.add_option("h", "alias for --help");
|
||||
parser.add_option("help", "display this message and exit");
|
||||
@ -200,13 +201,14 @@ try
|
||||
}
|
||||
|
||||
parser.check_option_arg_range("fraction", 0.0, 1.0);
|
||||
const double labels_fraction = get_option(parser, "fraction", 0.1);
|
||||
const size_t num_gpus = get_option(parser, "num-gpus", 1);
|
||||
const size_t batch_size = get_option(parser, "batch", 64) * num_gpus;
|
||||
const long dims = get_option(parser, "dims", 128);
|
||||
const double lambda = get_option(parser, "lambda", 1.0 / dims);
|
||||
const double learning_rate = get_option(parser, "learning-rate", 1e-3);
|
||||
const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5);
|
||||
const double labels_fraction = get_option(parser, "fraction", 0.1);
|
||||
const size_t patience = get_option(parser, "patience", 10000);
|
||||
|
||||
// Load the CIFAR-10 dataset into memory.
|
||||
std::vector<matrix<rgb_pixel>> training_images, testing_images;
|
||||
@ -220,7 +222,7 @@ try
|
||||
disable_duplicative_biases(net);
|
||||
dlib::rand rnd;
|
||||
std::vector<int> gpus(num_gpus);
|
||||
std::iota(gpus.begin(), gpus.end(), 0);
|
||||
iota(gpus.begin(), gpus.end(), 0);
|
||||
|
||||
// Train the feature extractor using the Barlow Twins method on all the training
|
||||
// data.
|
||||
@ -229,7 +231,7 @@ try
|
||||
trainer.set_mini_batch_size(batch_size);
|
||||
trainer.set_learning_rate(learning_rate);
|
||||
trainer.set_min_learning_rate(min_learning_rate);
|
||||
trainer.set_iterations_without_progress_threshold(10000);
|
||||
trainer.set_iterations_without_progress_threshold(patience);
|
||||
trainer.set_synchronization_file("barlow_twins_sync");
|
||||
trainer.be_verbose();
|
||||
cout << trainer << endl;
|
||||
@ -250,7 +252,7 @@ try
|
||||
beta.set_size(1, dims);
|
||||
image_window win;
|
||||
|
||||
std::vector<std::pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
|
||||
std::vector<pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
|
||||
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
|
||||
{
|
||||
batch.clear();
|
||||
@ -308,14 +310,14 @@ try
|
||||
randomize_samples(training_images, training_labels);
|
||||
std::vector<matrix<rgb_pixel>> sub_images(
|
||||
training_images.begin(),
|
||||
training_images.begin() + std::lround(training_images.size() * labels_fraction));
|
||||
training_images.begin() + lround(training_images.size() * labels_fraction));
|
||||
|
||||
std::vector<unsigned long> sub_labels(
|
||||
training_labels.begin(),
|
||||
training_labels.begin() + std::lround(training_labels.size() * labels_fraction));
|
||||
training_labels.begin() + lround(training_labels.size() * labels_fraction));
|
||||
|
||||
std::swap(sub_images, training_images);
|
||||
std::swap(sub_labels, training_labels);
|
||||
swap(sub_images, training_images);
|
||||
swap(sub_labels, training_labels);
|
||||
}
|
||||
|
||||
// Let's generate the features for those samples that have labels to train a
|
||||
@ -324,12 +326,12 @@ try
|
||||
cout << "Extracting features for linear classifier from " << training_images.size() << " samples..." << endl;
|
||||
features = fnet(training_images, 4 * batch_size);
|
||||
|
||||
const auto df = auto_train_multiclass_svm_linear_classifier(features, training_labels, std::chrono::minutes(1));
|
||||
const auto df = auto_train_multiclass_svm_linear_classifier(features, training_labels, chrono::minutes(1));
|
||||
serialize("multiclass_svm_cifar_10.dat") << df;
|
||||
|
||||
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and
|
||||
// test images.
|
||||
auto compute_accuracy = [&df](
|
||||
const auto compute_accuracy = [&df](
|
||||
const std::vector<matrix<float, 0, 1>>& samples,
|
||||
const std::vector<unsigned long>& labels
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user