@ -277,7 +277,7 @@ try
// visualize it.
tt : : gemm ( 0 , eccm , 1 , za_norm , true , zb_norm , false ) ;
eccm / = batch_size ;
win . set_image ( abs( mat ( eccm ) ) * 255 ) ;
win . set_image ( round( abs( mat ( eccm ) ) * 255 ) ) ;
win . set_title ( " Barlow Twins step#: " + to_string ( trainer . get_train_one_step_calls ( ) ) ) ;
}
}
@ -304,12 +304,14 @@ try
auto cross_validation_score = [ & ] ( const double c )
{
svm_multiclass_linear_trainer < linear_kernel < matrix < float , 0 , 1 > > , unsigned long > trainer ;
trainer . set_num_threads ( std : : thread : : hardware_concurrency ( ) ) ;
trainer . set_c ( c ) ;
trainer . set_epsilon ( 0.01 ) ;
trainer . set_max_iterations ( 100 ) ;
trainer . set_num_threads ( std : : thread : : hardware_concurrency ( ) ) ;
cout < < " C: " < < c < < endl ;
const auto cm = cross_validate_multiclass_trainer ( trainer , features , training_labels , 3 ) ;
const double accuracy = sum ( diag ( cm ) ) / sum ( cm ) ;
cout < < " cross validation accuracy: " < < accuracy < < endl ; ;
cout < < " cross validation accuracy: " < < accuracy < < endl ;
cout < < " confusion matrix: \n " < < cm < < endl ;
return accuracy ;
} ;
@ -345,7 +347,7 @@ try
cout < < " error rate: " < < num_wrong / static_cast < double > ( num_right + num_wrong ) < < endl ;
} ;
// We should get a training accuracy of around 93% and a testing accuracy of around 8 8 %.
// We should get a training accuracy of around 93% and a testing accuracy of around 8 9 %.
cout < < " \n training accuracy " < < endl ;
compute_accuracy ( features , training_labels ) ;
cout < < " \n testing accuracy " < < endl ;