Made test a little more numerically robust.

This commit is contained in:
Davis King 2014-02-22 13:03:14 -05:00
parent 5bcfa6853e
commit 6d32e8c804

View File

@ -116,18 +116,17 @@ namespace
ovo_trainer trainer; ovo_trainer trainer;
typedef polynomial_kernel<sample_type> poly_kernel; typedef histogram_intersection_kernel<sample_type> hist_kernel;
typedef radial_basis_kernel<sample_type> rbf_kernel; typedef radial_basis_kernel<sample_type> rbf_kernel;
// make the binary trainers and set some parameters // make the binary trainers and set some parameters
krr_trainer<rbf_kernel> rbf_trainer; krr_trainer<rbf_kernel> rbf_trainer;
svm_nu_trainer<poly_kernel> poly_trainer; svm_nu_trainer<hist_kernel> hist_trainer;
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
rbf_trainer.set_kernel(rbf_kernel(0.1)); rbf_trainer.set_kernel(rbf_kernel(0.1));
trainer.set_trainer(rbf_trainer); trainer.set_trainer(rbf_trainer);
trainer.set_trainer(poly_trainer, 1, 2); trainer.set_trainer(hist_trainer, 1, 2);
randomize_samples(samples, labels); randomize_samples(samples, labels);
matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
@ -143,8 +142,7 @@ namespace
// test using a normalized_function with a one_vs_one_decision_function // test using a normalized_function with a one_vs_one_decision_function
{ {
poly_trainer.set_kernel(poly_kernel(1.1, 1, 2)); trainer.set_trainer(hist_trainer, 1, 2);
trainer.set_trainer(poly_trainer, 1, 2);
vector_normalizer<sample_type> normalizer; vector_normalizer<sample_type> normalizer;
normalizer.train(samples); normalizer.train(samples);
for (unsigned long i = 0; i < samples.size(); ++i) for (unsigned long i = 0; i < samples.size(); ++i)
@ -156,8 +154,7 @@ namespace
DLIB_TEST(ndf(samples[40]) == labels[40]); DLIB_TEST(ndf(samples[40]) == labels[40]);
DLIB_TEST(ndf(samples[90]) == labels[90]); DLIB_TEST(ndf(samples[90]) == labels[90]);
DLIB_TEST(ndf(samples[120]) == labels[120]); DLIB_TEST(ndf(samples[120]) == labels[120]);
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2)); trainer.set_trainer(hist_trainer, 1, 2);
trainer.set_trainer(poly_trainer, 1, 2);
print_spinner(); print_spinner();
} }
@ -173,7 +170,7 @@ namespace
one_vs_one_decision_function<ovo_trainer, one_vs_one_decision_function<ovo_trainer,
decision_function<poly_kernel>, // This is the output of the poly_trainer decision_function<hist_kernel>, // This is the output of the hist_trainer
decision_function<rbf_kernel> // This is the output of the rbf_trainer decision_function<rbf_kernel> // This is the output of the rbf_trainer
> df2, df3; > df2, df3;