Added more unit tests for the forces_last_weight_to_1 stuff.

This commit is contained in:
Davis King 2013-01-03 22:17:10 -05:00
parent 9ab59297b2
commit 3b0f4ff135
2 changed files with 42 additions and 10 deletions

View File

@ -223,7 +223,7 @@ namespace
// ----------------------------------------------------------------------------------------
template <typename K>
template <typename K, bool use_dcd_trainer>
class simple_rank_trainer
{
public:
@ -250,18 +250,35 @@ namespace
}
}
svm_c_linear_dcd_trainer<K> trainer;
trainer.set_c(1.0/samples.size());
trainer.set_epsilon(1e-10);
trainer.force_last_weight_to_1(true);
//trainer.be_verbose();
return trainer.train(samples, labels);
if (use_dcd_trainer)
{
svm_c_linear_dcd_trainer<K> trainer;
trainer.set_c(1.0/samples.size());
trainer.set_epsilon(1e-10);
trainer.force_last_weight_to_1(true);
//trainer.be_verbose();
return trainer.train(samples, labels);
}
else
{
svm_c_linear_trainer<K> trainer;
trainer.set_c(1.0);
trainer.set_epsilon(1e-13);
trainer.force_last_weight_to_1(true);
//trainer.be_verbose();
decision_function<K> df = trainer.train(samples, labels);
DLIB_TEST_MSG(df.b == 0, df.b);
return df;
}
}
};
template <bool use_dcd_trainer>
void test_svmrank_weight_force_dense()
{
print_spinner();
dlog << LINFO << "use_dcd_trainer: "<< use_dcd_trainer;
typedef matrix<double,10,1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
@ -291,7 +308,7 @@ namespace
dlog << LINFO << "ranking accuracy: " << acc1;
DLIB_TEST(std::abs(acc1 - 1) == 0);
simple_rank_trainer<kernel_type> strainer;
simple_rank_trainer<kernel_type,use_dcd_trainer> strainer;
decision_function<kernel_type> df2;
df2 = strainer.train(pair);
dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0));
@ -325,7 +342,8 @@ namespace
test_count_ranking_inversions();
dotest1();
dotest_sparse_vectors();
test_svmrank_weight_force_dense();
test_svmrank_weight_force_dense<true>();
test_svmrank_weight_force_dense<false>();
}
} a;

View File

@ -250,6 +250,7 @@ namespace
typedef linear_kernel<sample_type> kernel_type;
svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
@ -257,7 +258,9 @@ namespace
const double C = 1;
linear_trainer.set_epsilon(1e-10);
linear_trainer_cpa.set_epsilon(1e-11);
linear_trainer_cpa.force_last_weight_to_1(force_weight);
linear_trainer.force_last_weight_to_1(force_weight);
linear_trainer.include_bias(have_bias);
@ -268,7 +271,7 @@ namespace
// make an instance of a sample vector so we can use it below
sample_type sample;
decision_function<kernel_type> df;
decision_function<kernel_type> df, df2;
running_stats<double> rs;
@ -299,11 +302,22 @@ namespace
labels.push_back(label);
linear_trainer.set_c(C);
linear_trainer_cpa.set_c(C*samples.size());
df = linear_trainer.train(samples, labels, state);
if (force_weight)
{
DLIB_TEST(std::abs(df.basis_vectors(0)(9) - 1) < 1e-8);
DLIB_TEST(std::abs(df.b) < 1e-8);
if (samples.size() > 1)
{
df2 = linear_trainer_cpa.train(samples, labels);
DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0))));
DLIB_TEST( std::abs(df.b - df2.b) < 1e-7);
}
}
if (!have_bias)
DLIB_TEST(std::abs(df.b) < 1e-8);