diff --git a/examples/dnn_self_supervised_learning_ex.cpp b/examples/dnn_self_supervised_learning_ex.cpp index 6d87bca29..0b7c51e62 100644 --- a/examples/dnn_self_supervised_learning_ex.cpp +++ b/examples/dnn_self_supervised_learning_ex.cpp @@ -34,6 +34,15 @@ the CIFAR-10 contains relatively small images, we will define a ResNet50 architecture that doesn't downsample the input in the first convolutional layer, and doesn't have a max pooling layer afterwards, like the paper does. + + This example shows how to use the Barlow Twins loss for the this common scenario: + Let's imagine that we have collected some images but we don't have enough resources + to label it all, just a small fraction of it. + We can train a feature extractor using the Barlow Twins loss on all the available + training data (both labeled and unlabeled images) to learn meaningful representations + for the dataset. + Once the feature extractor is trained, we can train a multiclass SVM classifier on + top it using only the fraction of labeled data. */ #include @@ -83,7 +92,7 @@ namespace resnet50 } // This model namespace contains the definitions for: -// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair. +// - A SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair. // - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image. namespace model { @@ -107,6 +116,7 @@ rectangle make_random_cropping_rect( return move_rect(rect, offset); } +// A helper function to generate different kinds of augmentations depending on prime. matrix augment( const matrix& image, const bool prime, @@ -175,6 +185,7 @@ try parser.add_option("learning-rate", "set the initial learning rate (default: 1e-3)", 1); parser.add_option("min-learning-rate", "set the min learning rate (default: 1e-5)", 1); parser.add_option("num-gpus", "number of GPUs (default: 1)", 1); + parser.add_option("fraction", "fraction of labels to use (default: 0.1)", 1); parser.set_group_name("Help Options"); parser.add_option("h", "alias for --help"); parser.add_option("help", "display this message and exit"); @@ -190,12 +201,14 @@ try return EXIT_SUCCESS; } + parser.check_option_arg_range("fraction", 0.0, 1.0); const size_t num_gpus = get_option(parser, "num-gpus", 1); const size_t batch_size = get_option(parser, "batch", 64) * num_gpus; const long dims = get_option(parser, "dims", 128); const double lambda = get_option(parser, "lambda", 1.0 / dims); const double learning_rate = get_option(parser, "learning-rate", 1e-3); const double min_learning_rate = get_option(parser, "min-learning-rate", 1e-5); + const double labels_fraction = get_option(parser, "fraction", 0.1); // Load the CIFAR-10 dataset into memory. std::vector> training_images, testing_images; @@ -211,7 +224,7 @@ try std::vector gpus(num_gpus); std::iota(gpus.begin(), gpus.end(), 0); - // Train the feature extractor using the Barlow Twins method + // Train the feature extractor using the Barlow Twins method on all the training data. { dnn_trainer trainer(net, adam(1e-6, 0.9, 0.999), gpus); trainer.set_mini_batch_size(batch_size); @@ -290,10 +303,27 @@ try // Now, we initialize the feature extractor model with the backbone we have just learned. model::feats fnet(layer<5>(net)); - // And we will generate all the features for the training set to train a multiclass SVM - // classifier. + + // Use only the specified fraction of training labels + if (labels_fraction < 1.0) + { + randomize_samples(training_images, training_labels); + std::vector> sub_images( + training_images.begin(), + training_images.begin() + std::lround(training_images.size() * labels_fraction)); + + std::vector sub_labels( + training_labels.begin(), + training_labels.begin() + std::lround(training_labels.size() * labels_fraction)); + + std::swap(sub_images, training_images); + std::swap(sub_labels, training_labels); + } + + // Let's generate the features for those samples that have labels to train a multiclass + // SVM classifier. std::vector> features; - cout << "Extracting features for linear classifier..." << endl; + cout << "Extracting features for linear classifier from " << training_images.size() << " samples..." << endl; features = fnet(training_images, 4 * batch_size); vector_normalizer> normalizer; normalizer.train(features); @@ -347,7 +377,10 @@ try cout << " error rate: " << num_wrong / static_cast(num_right + num_wrong) << endl; }; - // We should get a training accuracy of around 93% and a testing accuracy of around 89%. + // Using 10% of the training labels should result in training and testing accuracies of + // around 92% and 87%, respectively. + // Had we used all labels to train the multiclass SVM classifier, we would have gotten a + // training accuracy of around 93% and a testing accuracy of around 89%, instead. cout << "\ntraining accuracy" << endl; compute_accuracy(features, training_labels); cout << "\ntesting accuracy" << endl;