mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added unit tests for svm_rank_trainer::force_last_weight_to_1()
This commit is contained in:
parent
824eb4558d
commit
19c02d3862
@ -222,6 +222,89 @@ namespace
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template <typename K>
|
||||||
|
class simple_rank_trainer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
template <typename T>
|
||||||
|
decision_function<K> train (
|
||||||
|
const ranking_pair<T>& pair
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
typedef matrix<double,10,1> sample_type;
|
||||||
|
|
||||||
|
std::vector<sample_type> relevant = pair.relevant;
|
||||||
|
std::vector<sample_type> nonrelevant = pair.nonrelevant;
|
||||||
|
|
||||||
|
std::vector<sample_type> samples;
|
||||||
|
std::vector<double> labels;
|
||||||
|
for (unsigned long i = 0; i < relevant.size(); ++i)
|
||||||
|
{
|
||||||
|
for (unsigned long j = 0; j < nonrelevant.size(); ++j)
|
||||||
|
{
|
||||||
|
samples.push_back(relevant[i] - nonrelevant[j]);
|
||||||
|
labels.push_back(+1);
|
||||||
|
samples.push_back(nonrelevant[i] - relevant[j]);
|
||||||
|
labels.push_back(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
svm_c_linear_dcd_trainer<K> trainer;
|
||||||
|
trainer.set_c(1.0/samples.size());
|
||||||
|
trainer.set_epsilon(1e-10);
|
||||||
|
trainer.force_last_weight_to_1(true);
|
||||||
|
//trainer.be_verbose();
|
||||||
|
return trainer.train(samples, labels);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void test_svmrank_weight_force_dense()
|
||||||
|
{
|
||||||
|
print_spinner();
|
||||||
|
typedef matrix<double,10,1> sample_type;
|
||||||
|
typedef linear_kernel<sample_type> kernel_type;
|
||||||
|
|
||||||
|
ranking_pair<sample_type> pair;
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; ++i)
|
||||||
|
{
|
||||||
|
pair.relevant.push_back(abs(gaussian_randm(10,1,i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; ++i)
|
||||||
|
{
|
||||||
|
pair.nonrelevant.push_back(-abs(gaussian_randm(10,1,i+10000)));
|
||||||
|
pair.nonrelevant.back()(9) += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
svm_rank_trainer<kernel_type> trainer;
|
||||||
|
trainer.force_last_weight_to_1(true);
|
||||||
|
trainer.set_epsilon(1e-13);
|
||||||
|
//trainer.be_verbose();
|
||||||
|
decision_function<kernel_type> df;
|
||||||
|
df = trainer.train(pair);
|
||||||
|
|
||||||
|
dlog << LINFO << "weights: "<< trans(df.basis_vectors(0));
|
||||||
|
const double acc1 = test_ranking_function(df, pair);
|
||||||
|
dlog << LINFO << "ranking accuracy: " << acc1;
|
||||||
|
DLIB_TEST(std::abs(acc1 - 1) == 0);
|
||||||
|
|
||||||
|
simple_rank_trainer<kernel_type> strainer;
|
||||||
|
decision_function<kernel_type> df2;
|
||||||
|
df2 = strainer.train(pair);
|
||||||
|
dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0));
|
||||||
|
const double acc2 = test_ranking_function(df2, pair);
|
||||||
|
dlog << LINFO << "ranking accuracy: " << acc2;
|
||||||
|
DLIB_TEST(std::abs(acc2 - 1) == 0);
|
||||||
|
|
||||||
|
dlog << LINFO << "w error: " << max(abs(df.basis_vectors(0) - df2.basis_vectors(0)));
|
||||||
|
dlog << LINFO << "b error: " << abs(df.b - df2.b);
|
||||||
|
DLIB_TEST(std::abs(max(abs(df.basis_vectors(0) - df2.basis_vectors(0)))) < 1e-8);
|
||||||
|
DLIB_TEST(std::abs(abs(df.b - df2.b)) < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
@ -242,6 +325,8 @@ namespace
|
|||||||
test_count_ranking_inversions();
|
test_count_ranking_inversions();
|
||||||
dotest1();
|
dotest1();
|
||||||
dotest_sparse_vectors();
|
dotest_sparse_vectors();
|
||||||
|
test_svmrank_weight_force_dense();
|
||||||
|
|
||||||
}
|
}
|
||||||
} a;
|
} a;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user