@ -36,10 +36,12 @@
a max pooling layer afterwards , like the paper does .
*/
# include <dlib/dnn.h>
# include <dlib/data_io.h>
# include <dlib/cmd_line_parser.h>
# include <dlib/data_io.h>
# include <dlib/dnn.h>
# include <dlib/global_optimization.h>
# include <dlib/gui_widgets.h>
# include <dlib/svm_threaded.h>
using namespace std ;
using namespace dlib ;
@ -82,14 +84,12 @@ namespace resnet50
// This model namespace contains the definitions for:
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image.
// - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image.
namespace model
{
template < typename SUBNET > using projector = fc < 128 , relu < bn_fc < fc < 512 , SUBNET > > > > ;
template < typename SUBNET > using classifier = fc < 10 , SUBNET > ;
using train = loss_barlow_twins < projector < resnet50 : : def < bn_con > : : backbone < input_rgb_image_pair > > > ;
using infer = loss_multiclass_log < classifier < resnet50 : : def < affine > : : backbone < input_rgb_image > > > ;
using feats = loss_metric < resnet50 : : def < affine > : : backbone < input_rgb_image > > ;
}
rectangle make_random_cropping_rect (
@ -288,73 +288,65 @@ try
serialize ( " resnet50_self_supervised_cifar_10.net " ) < < layer < 5 > ( net ) ;
}
// To check the quality of the learned feature representations, we will train a linear
// classififer on top of the frozen backbone.
model : : infer inet ;
// Assign the network, without the projector, which is only used for the self-supervised
// training.
layer < 2 > ( inet ) = layer < 5 > ( net ) ;
// Freeze the backbone
set_all_learning_rate_multipliers ( layer < 2 > ( inet ) , 0 ) ;
// Train the network
// Now, we initialize the feature extractor model with the backbone we have just learned.
model : : feats fnet ( layer < 5 > ( net ) ) ;
// And we will generate all the features for the training set to train a multiclass SVM
// classifier.
std : : vector < matrix < float , 0 , 1 > > features ;
cout < < " Extracting features for linear classifier... " < < endl ;
features = fnet ( training_images , 4 * batch_size ) ;
// Find the most appropriate C setting using find_max_global.
auto cross_validation_score = [ & ] ( const double c )
{
dnn_trainer < model : : infer , adam > trainer ( inet , adam ( 1e-6 , 0.9 , 0.999 ) , gpus ) ;
// Since this model doesn't train with pairs, just single images, we can increase
// the batch size.
trainer . set_mini_batch_size ( 2 * batch_size ) ;
trainer . set_learning_rate ( learning_rate ) ;
trainer . set_min_learning_rate ( min_learning_rate ) ;
trainer . set_iterations_without_progress_threshold ( 5000 ) ;
trainer . set_synchronization_file ( " cifar_10_sync " ) ;
trainer . be_verbose ( ) ;
cout < < trainer < < endl ;
svm_multiclass_linear_trainer < linear_kernel < matrix < float , 0 , 1 > > , unsigned long > trainer ;
trainer . set_num_threads ( std : : thread : : hardware_concurrency ( ) ) ;
trainer . set_c ( c ) ;
cout < < " C: " < < c < < endl ;
const auto cm = cross_validate_multiclass_trainer ( trainer , features , training_labels , 3 ) ;
const double accuracy = sum ( diag ( cm ) ) / sum ( cm ) ;
cout < < " cross validation accuracy: " < < accuracy < < endl ; ;
cout < < " confusion matrix: \n " < < cm < < endl ;
return accuracy ;
} ;
const auto result = find_max_global ( cross_validation_score , 1e-4 , 10000 , max_function_calls ( 50 ) ) ;
cout < < " Best C: " < < result . x ( 0 ) < < endl ;
std : : vector < matrix < rgb_pixel > > images ;
std : : vector < unsigned long > labels ;
while ( trainer . get_learning_rate ( ) > = trainer . get_min_learning_rate ( ) )
{
images . clear ( ) ;
labels . clear ( ) ;
while ( images . size ( ) < trainer . get_mini_batch_size ( ) )
{
const auto idx = rnd . get_random_32bit_number ( ) % training_images . size ( ) ;
images . push_back ( augment ( training_images [ idx ] , false , rnd ) ) ;
labels . push_back ( training_labels [ idx ] ) ;
}
trainer . train_one_step ( images , labels ) ;
}
trainer . get_net ( ) ;
inet . clean ( ) ;
serialize ( " resnet50_cifar_10.dnn " ) < < inet ;
}
// Proceed to train the SVM classifier with the best C.
svm_multiclass_linear_trainer < linear_kernel < matrix < float , 0 , 1 > > , unsigned long > trainer ;
trainer . set_num_threads ( std : : thread : : hardware_concurrency ( ) ) ;
trainer . set_c ( result . x ( 0 ) ) ;
cout < < " Training Multiclass SVM... " < < endl ;
const auto df = trainer . train ( features , training_labels ) ;
serialize ( " multiclass_svm_cifar_10.dat " ) < < df ;
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
auto compute_accuracy = [ & inet , batch_size ] (
const std : : vector < matrix < rgb_pixel > > & imag es,
auto compute_accuracy = [ & fnet , & df , batch_size ] (
const std : : vector < matrix < float , 0 , 1 > > & samples ,
const std : : vector < unsigned long > & labels
)
{
size_t num_right = 0 ;
size_t num_wrong = 0 ;
const auto preds = inet ( images , batch_size * 2 ) ;
for ( size_t i = 0 ; i < labels . size ( ) ; + + i )
{
if ( labels [ i ] = = preds[ i ] )
if ( labels [ i ] = = df( samples [ i ] ) )
+ + num_right ;
else
+ + num_wrong ;
}
cout < < " num right: " < < num_right < < endl ;
cout < < " num wrong: " < < num_wrong < < endl ;
cout < < " accuracy: " < < num_right / static_cast < double > ( num_right + num_wrong ) < < endl ;
cout < < " error rate: " < < num_wrong / static_cast < double > ( num_right + num_wrong ) < < endl ;
cout < < " num right: " < < num_right < < endl ;
cout < < " num wrong: " < < num_wrong < < endl ;
cout < < " accuracy: " < < num_right / static_cast < double > ( num_right + num_wrong ) < < endl ;
cout < < " error rate: " < < num_wrong / static_cast < double > ( num_right + num_wrong ) < < endl ;
} ;
// If everything works as expected, we should get accuracies that are between 87% and 90 %.
cout < < " training accuracy " < < endl ;
compute_accuracy ( training_imag es, training_labels ) ;
// We should get a training accuracy of around 93% and a testing accuracy of around 88 %.
cout < < " \n training accuracy " < < endl ;
compute_accuracy ( fea tu res, training_labels ) ;
cout < < " \n testing accuracy " < < endl ;
compute_accuracy ( testing_images , testing_labels ) ;
features = fnet ( testing_images , 4 * batch_size ) ;
compute_accuracy ( features , testing_labels ) ;
return EXIT_SUCCESS ;
}
catch ( const exception & e )