mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added more unit tests for the forces_last_weight_to_1 stuff.
This commit is contained in:
parent
9ab59297b2
commit
3b0f4ff135
@ -223,7 +223,7 @@ namespace
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename K>
|
||||
template <typename K, bool use_dcd_trainer>
|
||||
class simple_rank_trainer
|
||||
{
|
||||
public:
|
||||
@ -250,18 +250,35 @@ namespace
|
||||
}
|
||||
}
|
||||
|
||||
svm_c_linear_dcd_trainer<K> trainer;
|
||||
trainer.set_c(1.0/samples.size());
|
||||
trainer.set_epsilon(1e-10);
|
||||
trainer.force_last_weight_to_1(true);
|
||||
//trainer.be_verbose();
|
||||
return trainer.train(samples, labels);
|
||||
if (use_dcd_trainer)
|
||||
{
|
||||
svm_c_linear_dcd_trainer<K> trainer;
|
||||
trainer.set_c(1.0/samples.size());
|
||||
trainer.set_epsilon(1e-10);
|
||||
trainer.force_last_weight_to_1(true);
|
||||
//trainer.be_verbose();
|
||||
return trainer.train(samples, labels);
|
||||
}
|
||||
else
|
||||
{
|
||||
svm_c_linear_trainer<K> trainer;
|
||||
trainer.set_c(1.0);
|
||||
trainer.set_epsilon(1e-13);
|
||||
trainer.force_last_weight_to_1(true);
|
||||
//trainer.be_verbose();
|
||||
decision_function<K> df = trainer.train(samples, labels);
|
||||
DLIB_TEST_MSG(df.b == 0, df.b);
|
||||
return df;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool use_dcd_trainer>
|
||||
void test_svmrank_weight_force_dense()
|
||||
{
|
||||
print_spinner();
|
||||
dlog << LINFO << "use_dcd_trainer: "<< use_dcd_trainer;
|
||||
|
||||
typedef matrix<double,10,1> sample_type;
|
||||
typedef linear_kernel<sample_type> kernel_type;
|
||||
|
||||
@ -291,7 +308,7 @@ namespace
|
||||
dlog << LINFO << "ranking accuracy: " << acc1;
|
||||
DLIB_TEST(std::abs(acc1 - 1) == 0);
|
||||
|
||||
simple_rank_trainer<kernel_type> strainer;
|
||||
simple_rank_trainer<kernel_type,use_dcd_trainer> strainer;
|
||||
decision_function<kernel_type> df2;
|
||||
df2 = strainer.train(pair);
|
||||
dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0));
|
||||
@ -325,7 +342,8 @@ namespace
|
||||
test_count_ranking_inversions();
|
||||
dotest1();
|
||||
dotest_sparse_vectors();
|
||||
test_svmrank_weight_force_dense();
|
||||
test_svmrank_weight_force_dense<true>();
|
||||
test_svmrank_weight_force_dense<false>();
|
||||
|
||||
}
|
||||
} a;
|
||||
|
@ -250,6 +250,7 @@ namespace
|
||||
typedef linear_kernel<sample_type> kernel_type;
|
||||
|
||||
|
||||
svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
|
||||
|
||||
svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
|
||||
|
||||
@ -257,7 +258,9 @@ namespace
|
||||
|
||||
const double C = 1;
|
||||
linear_trainer.set_epsilon(1e-10);
|
||||
linear_trainer_cpa.set_epsilon(1e-11);
|
||||
|
||||
linear_trainer_cpa.force_last_weight_to_1(force_weight);
|
||||
|
||||
linear_trainer.force_last_weight_to_1(force_weight);
|
||||
linear_trainer.include_bias(have_bias);
|
||||
@ -268,7 +271,7 @@ namespace
|
||||
// make an instance of a sample vector so we can use it below
|
||||
sample_type sample;
|
||||
|
||||
decision_function<kernel_type> df;
|
||||
decision_function<kernel_type> df, df2;
|
||||
|
||||
running_stats<double> rs;
|
||||
|
||||
@ -299,11 +302,22 @@ namespace
|
||||
labels.push_back(label);
|
||||
|
||||
linear_trainer.set_c(C);
|
||||
linear_trainer_cpa.set_c(C*samples.size());
|
||||
|
||||
df = linear_trainer.train(samples, labels, state);
|
||||
|
||||
if (force_weight)
|
||||
{
|
||||
DLIB_TEST(std::abs(df.basis_vectors(0)(9) - 1) < 1e-8);
|
||||
DLIB_TEST(std::abs(df.b) < 1e-8);
|
||||
|
||||
if (samples.size() > 1)
|
||||
{
|
||||
df2 = linear_trainer_cpa.train(samples, labels);
|
||||
DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0))));
|
||||
DLIB_TEST( std::abs(df.b - df2.b) < 1e-7);
|
||||
}
|
||||
}
|
||||
|
||||
if (!have_bias)
|
||||
DLIB_TEST(std::abs(df.b) < 1e-8);
|
||||
|
Loading…
Reference in New Issue
Block a user