diff --git a/dlib/test/svm.cpp b/dlib/test/svm.cpp index bf47e99ce..90f0c1389 100644 --- a/dlib/test/svm.cpp +++ b/dlib/test/svm.cpp @@ -341,6 +341,11 @@ namespace rvm_trainer rvm_trainer; rvm_trainer.set_kernel(kernel_type(gamma)); + svm_pegasos pegasos_trainer; + pegasos_trainer.set_kernel(kernel_type(gamma)); + pegasos_trainer.set_lambda(0.00001); + + svm_nu_trainer trainer; trainer.set_kernel(kernel_type(gamma)); trainer.set_nu(0.05); @@ -352,14 +357,18 @@ namespace print_spinner(); matrix rbf_cv = cross_validate_trainer_threaded(rbf_trainer, x,y, 4, 2); print_spinner(); + matrix peg_cv = cross_validate_trainer_threaded(batch(pegasos_trainer,1.0), x,y, 4, 2); + print_spinner(); dlog << LDEBUG << "rvm cv: " << rvm_cv; dlog << LDEBUG << "svm cv: " << svm_cv; dlog << LDEBUG << "rbf cv: " << rbf_cv; + dlog << LDEBUG << "peg cv: " << peg_cv; DLIB_CASSERT(mean(rvm_cv) > 0.9, rvm_cv); DLIB_CASSERT(mean(svm_cv) > 0.9, svm_cv); DLIB_CASSERT(mean(rbf_cv) > 0.9, rbf_cv); + DLIB_CASSERT(mean(peg_cv) > 0.9, rbf_cv); const long num_sv = trainer.train(x,y).support_vectors.size(); print_spinner();