mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added svm_rank_trainer. Need to flesh out abstracts and unit tests next.
This commit is contained in:
parent
c1a9572cbf
commit
6ffe8d799b
@ -3,6 +3,7 @@
|
||||
#ifndef DLIB_SVm_HEADER
|
||||
#define DLIB_SVM_HEADER
|
||||
|
||||
#include "svm/svm_rank_trainer.h"
|
||||
#include "svm/svm.h"
|
||||
#include "svm/krls.h"
|
||||
#include "svm/rls.h"
|
||||
|
344
dlib/svm/ranking_tools.h
Normal file
344
dlib/svm/ranking_tools.h
Normal file
@ -0,0 +1,344 @@
|
||||
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_RANKING_ToOLS_H__
|
||||
#define DLIB_RANKING_ToOLS_H__
|
||||
|
||||
#include "ranking_tools_abstract.h"
|
||||
|
||||
#include "../algs.h"
|
||||
#include "../matrix.h"
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
struct ranking_pair
|
||||
{
|
||||
ranking_pair() {}
|
||||
|
||||
ranking_pair(
|
||||
const std::vector<T>& r,
|
||||
const std::vector<T>& nr
|
||||
) :
|
||||
relevant(r), nonrelevant(nr)
|
||||
{}
|
||||
|
||||
std::vector<T> relevant;
|
||||
std::vector<T> nonrelevant;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
void serialize (
|
||||
const ranking_pair<T>& item,
|
||||
std::ostream& out
|
||||
)
|
||||
{
|
||||
int version = 1;
|
||||
serialize(version, out);
|
||||
serialize(item.relevant, out);
|
||||
serialize(item.nonrelevant, out);
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
void deserialize (
|
||||
ranking_pair<T>& item,
|
||||
std::istream& in
|
||||
)
|
||||
{
|
||||
int version = 0;
|
||||
deserialize(version, in);
|
||||
if (version != 1)
|
||||
throw dlib::serialization_error("Wrong version found while deserializing dlib::ranking_pair");
|
||||
|
||||
deserialize(item.relevant, in);
|
||||
deserialize(item.nonrelevant, in);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
bool is_ranking_problem (
|
||||
const std::vector<ranking_pair<T> >& samples
|
||||
)
|
||||
{
|
||||
if (samples.size() == 0)
|
||||
return false;
|
||||
|
||||
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
if (samples[i].relevant.size() == 0)
|
||||
return false;
|
||||
if (samples[i].nonrelevant.size() == 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
// If these are dense vectors then they must all have the same dimensionality.
|
||||
if (is_matrix<T>::value)
|
||||
{
|
||||
const long dims = max_index_plus_one(samples[0].relevant);
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
for (unsigned long j = 0; j < samples[i].relevant.size(); ++j)
|
||||
{
|
||||
if (samples[i].relevant[j].size() != dims)
|
||||
return false;
|
||||
}
|
||||
for (unsigned long j = 0; j < samples[i].nonrelevant.size(); ++j)
|
||||
{
|
||||
if (samples[i].nonrelevant[j].size() != dims)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
unsigned long max_index_plus_one (
|
||||
const ranking_pair<T>& item
|
||||
)
|
||||
{
|
||||
return std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
unsigned long max_index_plus_one (
|
||||
const std::vector<ranking_pair<T> >& samples
|
||||
)
|
||||
{
|
||||
unsigned long dims = 0;
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
dims = std::max(dims, max_index_plus_one(samples[i]));
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
void count_ranking_inversions (
|
||||
const std::vector<T>& x,
|
||||
const std::vector<T>& y,
|
||||
std::vector<unsigned long>& x_count,
|
||||
std::vector<unsigned long>& y_count
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- This function counts how many times we see a y value greater than or equal to
|
||||
x value. This is done efficiently in O(n*log(n)) time via the use of quick
|
||||
sort.
|
||||
- #x_count.size() == x.size()
|
||||
- #y_count.size() == y.size()
|
||||
- for all valid i:
|
||||
- #x_count[i] == how many times a value in y was >= x[i].
|
||||
- for all valid j:
|
||||
- #y_count[j] == how many times a value in x was <= y[j].
|
||||
!*/
|
||||
{
|
||||
x_count.assign(x.size(),0);
|
||||
y_count.assign(y.size(),0);
|
||||
|
||||
if (x.size() == 0 || y.size() == 0)
|
||||
return;
|
||||
|
||||
std::vector<std::pair<T,unsigned long> > xsort(x.size());
|
||||
std::vector<std::pair<T,unsigned long> > ysort(y.size());
|
||||
for (unsigned long i = 0; i < x.size(); ++i)
|
||||
xsort[i] = std::make_pair(x[i], i);
|
||||
for (unsigned long j = 0; j < y.size(); ++j)
|
||||
ysort[j] = std::make_pair(y[j], j);
|
||||
|
||||
std::sort(xsort.begin(), xsort.end());
|
||||
std::sort(ysort.begin(), ysort.end());
|
||||
|
||||
|
||||
unsigned long i, j;
|
||||
|
||||
// Do the counting for the x values.
|
||||
for (i = 0, j = 0; i < x_count.size(); ++i)
|
||||
{
|
||||
// Skip past y values that are in the correct order with respect to xsort[i].
|
||||
while (j < ysort.size() && xsort[i].first > ysort[j].first)
|
||||
++j;
|
||||
|
||||
x_count[xsort[i].second] = ysort.size() - j;
|
||||
}
|
||||
|
||||
|
||||
// Now do the counting for the y values.
|
||||
for (i = 0, j = 0; j < y_count.size(); ++j)
|
||||
{
|
||||
// Skip past x values that are in the incorrect order with respect to ysort[j].
|
||||
while (i < xsort.size() && xsort[i].first <= ysort[j].first)
|
||||
++i;
|
||||
|
||||
y_count[ysort[j].second] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename ranking_function,
|
||||
typename T
|
||||
>
|
||||
double test_ranking_function (
|
||||
const ranking_function& funct,
|
||||
const std::vector<ranking_pair<T> >& samples
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- returns the fraction of ranking pairs predicted correctly.
|
||||
!*/
|
||||
{
|
||||
unsigned long total_pairs = 0;
|
||||
unsigned long total_wrong = 0;
|
||||
|
||||
std::vector<double> rel_scores;
|
||||
std::vector<double> nonrel_scores;
|
||||
std::vector<unsigned long> rel_counts;
|
||||
std::vector<unsigned long> nonrel_counts;
|
||||
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
rel_scores.resize(samples[i].relevant.size());
|
||||
nonrel_scores.resize(samples[i].nonrelevant.size());
|
||||
|
||||
for (unsigned long k = 0; k < rel_scores.size(); ++k)
|
||||
rel_scores[k] = funct(samples[i].relevant[k]);
|
||||
|
||||
for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
|
||||
nonrel_scores[k] = funct(samples[i].nonrelevant[k]);
|
||||
|
||||
count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
|
||||
|
||||
total_pairs += rel_scores.size()*nonrel_scores.size();
|
||||
|
||||
// Note that we don't need to look at nonrel_counts since it is redundant with
|
||||
// the information in rel_counts in this case.
|
||||
total_wrong += sum(vector_to_matrix(rel_counts));
|
||||
|
||||
// TODO, remove
|
||||
DLIB_CASSERT(sum(vector_to_matrix(rel_counts)) == sum(vector_to_matrix(nonrel_counts)), "");
|
||||
}
|
||||
|
||||
return static_cast<double>(total_pairs - total_wrong) / total_pairs;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename trainer_type,
|
||||
typename T
|
||||
>
|
||||
double cross_validate_ranking_trainer (
|
||||
const trainer_type& trainer,
|
||||
const std::vector<ranking_pair<T> >& samples,
|
||||
const long folds
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_CASSERT(is_ranking_problem(samples) &&
|
||||
1 < folds && folds <= static_cast<long>(samples.size()),
|
||||
"\t double cross_validate_ranking_trainer()"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t samples.size(): " << samples.size()
|
||||
<< "\n\t folds: " << folds
|
||||
<< "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
|
||||
);
|
||||
|
||||
|
||||
const long num_in_test = samples.size()/folds;
|
||||
const long num_in_train = samples.size() - num_in_test;
|
||||
|
||||
|
||||
std::vector<ranking_pair<T> > samples_test, samples_train;
|
||||
|
||||
|
||||
long next_test_idx = 0;
|
||||
|
||||
unsigned long total_pairs = 0;
|
||||
unsigned long total_wrong = 0;
|
||||
|
||||
std::vector<double> rel_scores;
|
||||
std::vector<double> nonrel_scores;
|
||||
std::vector<unsigned long> rel_counts;
|
||||
std::vector<unsigned long> nonrel_counts;
|
||||
|
||||
|
||||
for (long i = 0; i < folds; ++i)
|
||||
{
|
||||
samples_test.clear();
|
||||
samples_train.clear();
|
||||
|
||||
// load up the test samples
|
||||
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
||||
{
|
||||
samples_test.push_back(samples[next_test_idx]);
|
||||
next_test_idx = (next_test_idx + 1)%samples.size();
|
||||
}
|
||||
|
||||
// load up the training samples
|
||||
long next = next_test_idx;
|
||||
for (long cnt = 0; cnt < num_in_train; ++cnt)
|
||||
{
|
||||
samples_train.push_back(samples[next]);
|
||||
next = (next + 1)%samples.size();
|
||||
}
|
||||
|
||||
|
||||
const typename trainer_type::trained_function_type& df = trainer.train(samples_train);
|
||||
|
||||
// check how good df is on the test data
|
||||
for (unsigned long i = 0; i < samples_test.size(); ++i)
|
||||
{
|
||||
rel_scores.resize(samples_test[i].relevant.size());
|
||||
nonrel_scores.resize(samples_test[i].nonrelevant.size());
|
||||
|
||||
for (unsigned long k = 0; k < rel_scores.size(); ++k)
|
||||
rel_scores[k] = df(samples_test[i].relevant[k]);
|
||||
|
||||
for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
|
||||
nonrel_scores[k] = df(samples_test[i].nonrelevant[k]);
|
||||
|
||||
count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
|
||||
|
||||
total_pairs += rel_scores.size()*nonrel_scores.size();
|
||||
|
||||
// Note that we don't need to look at nonrel_counts since it is redundant with
|
||||
// the information in rel_counts in this case.
|
||||
total_wrong += sum(vector_to_matrix(rel_counts));
|
||||
}
|
||||
|
||||
} // for (long i = 0; i < folds; ++i)
|
||||
|
||||
return static_cast<double>(total_pairs - total_wrong) / total_pairs;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_RANKING_ToOLS_H__
|
||||
|
191
dlib/svm/ranking_tools_abstract.h
Normal file
191
dlib/svm/ranking_tools_abstract.h
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_RANKING_ToOLS_ABSTRACT_H__
|
||||
#ifdef DLIB_RANKING_ToOLS_ABSTRACT_H__
|
||||
|
||||
|
||||
#include "../algs.h"
|
||||
#include "../matrix.h"
|
||||
#include <vector>
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
struct ranking_pair
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
!*/
|
||||
|
||||
ranking_pair() {}
|
||||
|
||||
ranking_pair(
|
||||
const std::vector<T>& r,
|
||||
const std::vector<T>& nr
|
||||
) :
|
||||
relevant(r), nonrelevant(nr)
|
||||
{}
|
||||
|
||||
std::vector<T> relevant;
|
||||
std::vector<T> nonrelevant;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
void serialize (
|
||||
const ranking_pair<T>& item,
|
||||
std::ostream& out
|
||||
);
|
||||
/*!
|
||||
provides serialization support
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
void deserialize (
|
||||
ranking_pair<T>& item,
|
||||
std::istream& in
|
||||
);
|
||||
/*!
|
||||
provides deserialization support
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
bool is_ranking_problem (
|
||||
const std::vector<ranking_pair<T> >& samples
|
||||
)
|
||||
{
|
||||
if (samples.size() == 0)
|
||||
return false;
|
||||
|
||||
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
if (samples[i].relevant.size() == 0)
|
||||
return false;
|
||||
if (samples[i].nonrelevant.size() == 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
// If these are dense vectors then they must all have the same dimensionality.
|
||||
if (is_matrix<T>::value)
|
||||
{
|
||||
const long dims = max_index_plus_one(samples[0].relevant);
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
for (unsigned long j = 0; j < samples[i].relevant.size(); ++j)
|
||||
{
|
||||
if (samples[i].relevant[j].size() != dims)
|
||||
return false;
|
||||
}
|
||||
for (unsigned long j = 0; j < samples[i].nonrelevant.size(); ++j)
|
||||
{
|
||||
if (samples[i].nonrelevant[j].size() != dims)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
unsigned long max_index_plus_one (
|
||||
const ranking_pair<T>& item
|
||||
)
|
||||
{
|
||||
return std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
unsigned long max_index_plus_one (
|
||||
const std::vector<ranking_pair<T> >& samples
|
||||
)
|
||||
{
|
||||
unsigned long dims = 0;
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
dims = std::max(dims, max_index_plus_one(samples[i]));
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
void count_ranking_inversions (
|
||||
const std::vector<T>& x,
|
||||
const std::vector<T>& y,
|
||||
std::vector<unsigned long>& x_count,
|
||||
std::vector<unsigned long>& y_count
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- This function counts how many times we see a y value greater than or equal to
|
||||
x value. This is done efficiently in O(n*log(n)) time via the use of quick
|
||||
sort.
|
||||
- #x_count.size() == x.size()
|
||||
- #y_count.size() == y.size()
|
||||
- for all valid i:
|
||||
- #x_count[i] == how many times a value in y was >= x[i].
|
||||
- for all valid j:
|
||||
- #y_count[j] == how many times a value in x was <= y[j].
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename ranking_function,
|
||||
typename T
|
||||
>
|
||||
double test_ranking_function (
|
||||
const ranking_function& funct,
|
||||
const std::vector<ranking_pair<T> >& samples
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- returns the fraction of ranking pairs predicted correctly.
|
||||
- TODO
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename trainer_type,
|
||||
typename T
|
||||
>
|
||||
double cross_validate_ranking_trainer (
|
||||
const trainer_type& trainer,
|
||||
const std::vector<ranking_pair<T> >& samples,
|
||||
const long folds
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_ranking_problem(samples) == true
|
||||
- 1 < folds <= samples.size()
|
||||
ensures
|
||||
- TODO
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_RANKING_ToOLS_ABSTRACT_H__
|
||||
|
||||
|
392
dlib/svm/svm_rank_trainer.h
Normal file
392
dlib/svm/svm_rank_trainer.h
Normal file
@ -0,0 +1,392 @@
|
||||
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_SVM_RANK_TrAINER_H__
|
||||
#define DLIB_SVM_RANK_TrAINER_H__
|
||||
|
||||
#include "svm_rank_trainer_abstract.h"
|
||||
|
||||
#include "ranking_tools.h"
|
||||
#include "../algs.h"
|
||||
#include "../optimization.h"
|
||||
#include "function.h"
|
||||
#include "kernel.h"
|
||||
#include "sparse_vector.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename matrix_type,
|
||||
typename sample_type
|
||||
>
|
||||
class oca_problem_ranking_svm : public oca_problem<matrix_type >
|
||||
{
|
||||
public:
|
||||
/*
|
||||
This class is used as part of the implementation of the svm_rank_trainer
|
||||
defined towards the end of this file.
|
||||
*/
|
||||
|
||||
typedef typename matrix_type::type scalar_type;
|
||||
|
||||
oca_problem_ranking_svm(
|
||||
const scalar_type C_,
|
||||
const std::vector<ranking_pair<sample_type> >& samples_,
|
||||
const bool be_verbose_,
|
||||
const scalar_type eps_,
|
||||
const unsigned long max_iter
|
||||
) :
|
||||
samples(samples_),
|
||||
C(C_),
|
||||
be_verbose(be_verbose_),
|
||||
eps(eps_),
|
||||
max_iterations(max_iter)
|
||||
{
|
||||
}
|
||||
|
||||
virtual scalar_type get_c (
|
||||
) const
|
||||
{
|
||||
return C;
|
||||
}
|
||||
|
||||
virtual long get_num_dimensions (
|
||||
) const
|
||||
{
|
||||
return max_index_plus_one(samples);
|
||||
}
|
||||
|
||||
virtual bool optimization_status (
|
||||
scalar_type current_objective_value,
|
||||
scalar_type current_error_gap,
|
||||
scalar_type current_risk_value,
|
||||
scalar_type current_risk_gap,
|
||||
unsigned long num_cutting_planes,
|
||||
unsigned long num_iterations
|
||||
) const
|
||||
{
|
||||
if (be_verbose)
|
||||
{
|
||||
using namespace std;
|
||||
cout << "objective: " << current_objective_value << endl;
|
||||
cout << "objective gap: " << current_error_gap << endl;
|
||||
cout << "risk: " << current_risk_value << endl;
|
||||
cout << "risk gap: " << current_risk_gap << endl;
|
||||
cout << "num planes: " << num_cutting_planes << endl;
|
||||
cout << "iter: " << num_iterations << endl;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
if (num_iterations >= max_iterations)
|
||||
return true;
|
||||
|
||||
if (current_risk_gap < eps)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool risk_has_lower_bound (
|
||||
scalar_type& lower_bound
|
||||
) const
|
||||
{
|
||||
lower_bound = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void get_risk (
|
||||
matrix_type& w,
|
||||
scalar_type& risk,
|
||||
matrix_type& subgradient
|
||||
) const
|
||||
{
|
||||
subgradient.set_size(w.size(),1);
|
||||
subgradient = 0;
|
||||
risk = 0;
|
||||
|
||||
// Note that we want the risk value to be in terms of the fraction of overall
|
||||
// rank flips. So a risk of 0.1 would mean that rank flips happen < 10% of the
|
||||
// time.
|
||||
|
||||
std::vector<double> rel_scores;
|
||||
std::vector<double> nonrel_scores;
|
||||
std::vector<unsigned long> rel_counts;
|
||||
std::vector<unsigned long> nonrel_counts;
|
||||
|
||||
unsigned long total_pairs = 0;
|
||||
|
||||
// loop over all the samples and compute the risk and its subgradient at the current solution point w
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
rel_scores.resize(samples[i].relevant.size());
|
||||
nonrel_scores.resize(samples[i].nonrelevant.size());
|
||||
|
||||
for (unsigned long k = 0; k < rel_scores.size(); ++k)
|
||||
rel_scores[k] = dot(samples[i].relevant[k], w);
|
||||
|
||||
for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
|
||||
nonrel_scores[k] = dot(samples[i].nonrelevant[k], w) + 1;
|
||||
|
||||
count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
|
||||
|
||||
total_pairs += rel_scores.size()*nonrel_scores.size();
|
||||
|
||||
for (unsigned long k = 0; k < rel_counts.size(); ++k)
|
||||
{
|
||||
if (rel_counts[k] != 0)
|
||||
{
|
||||
risk -= rel_counts[k]*rel_scores[k];
|
||||
subtract_from(subgradient, samples[i].relevant[k], rel_counts[k]);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned long k = 0; k < nonrel_counts.size(); ++k)
|
||||
{
|
||||
if (nonrel_counts[k] != 0)
|
||||
{
|
||||
risk += nonrel_counts[k]*nonrel_scores[k];
|
||||
add_to(subgradient, samples[i].nonrelevant[k], nonrel_counts[k]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const scalar_type scale = 1.0/total_pairs;
|
||||
|
||||
risk *= scale;
|
||||
subgradient = scale*subgradient;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// -----------------------------------------------------
|
||||
// -----------------------------------------------------
|
||||
|
||||
|
||||
const std::vector<ranking_pair<sample_type> >& samples;
|
||||
const scalar_type C;
|
||||
|
||||
const bool be_verbose;
|
||||
const scalar_type eps;
|
||||
const unsigned long max_iterations;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename matrix_type,
|
||||
typename sample_type,
|
||||
typename scalar_type
|
||||
>
|
||||
oca_problem_ranking_svm<matrix_type, sample_type> make_oca_problem_ranking_svm (
|
||||
const scalar_type C,
|
||||
const std::vector<ranking_pair<sample_type> >& samples,
|
||||
const bool be_verbose,
|
||||
const scalar_type eps,
|
||||
const unsigned long max_iterations
|
||||
)
|
||||
{
|
||||
return oca_problem_ranking_svm<matrix_type, sample_type>(
|
||||
C, samples, be_verbose, eps, max_iterations);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
class svm_rank_trainer
|
||||
{
|
||||
|
||||
public:
|
||||
typedef K kernel_type;
|
||||
typedef typename kernel_type::scalar_type scalar_type;
|
||||
typedef typename kernel_type::sample_type sample_type;
|
||||
typedef typename kernel_type::mem_manager_type mem_manager_type;
|
||||
typedef decision_function<kernel_type> trained_function_type;
|
||||
|
||||
// You are getting a compiler error on this line because you supplied a non-linear kernel
|
||||
// to the svm_rank_trainer object. You have to use one of the linear kernels with this
|
||||
// trainer.
|
||||
COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
|
||||
is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
|
||||
|
||||
svm_rank_trainer (
|
||||
)
|
||||
{
|
||||
C = 1;
|
||||
verbose = false;
|
||||
eps = 0.001;
|
||||
max_iterations = 10000;
|
||||
learn_nonnegative_weights = false;
|
||||
}
|
||||
|
||||
explicit svm_rank_trainer (
|
||||
const scalar_type& C_
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(C_ > 0,
|
||||
"\t svm_rank_trainer::svm_rank_trainer()"
|
||||
<< "\n\t C_ must be greater than 0"
|
||||
<< "\n\t C_: " << C_
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
C = C_;
|
||||
verbose = false;
|
||||
eps = 0.001;
|
||||
max_iterations = 10000;
|
||||
learn_nonnegative_weights = false;
|
||||
}
|
||||
|
||||
void set_epsilon (
|
||||
scalar_type eps_
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(eps_ > 0,
|
||||
"\t void svm_rank_trainer::set_epsilon()"
|
||||
<< "\n\t eps_ must be greater than 0"
|
||||
<< "\n\t eps_: " << eps_
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
eps = eps_;
|
||||
}
|
||||
|
||||
const scalar_type get_epsilon (
|
||||
) const { return eps; }
|
||||
|
||||
unsigned long get_max_iterations (
|
||||
) const { return max_iterations; }
|
||||
|
||||
void set_max_iterations (
|
||||
unsigned long max_iter
|
||||
)
|
||||
{
|
||||
max_iterations = max_iter;
|
||||
}
|
||||
|
||||
void be_verbose (
|
||||
)
|
||||
{
|
||||
verbose = true;
|
||||
}
|
||||
|
||||
void be_quiet (
|
||||
)
|
||||
{
|
||||
verbose = false;
|
||||
}
|
||||
|
||||
void set_oca (
|
||||
const oca& item
|
||||
)
|
||||
{
|
||||
solver = item;
|
||||
}
|
||||
|
||||
const oca get_oca (
|
||||
) const
|
||||
{
|
||||
return solver;
|
||||
}
|
||||
|
||||
const kernel_type get_kernel (
|
||||
) const
|
||||
{
|
||||
return kernel_type();
|
||||
}
|
||||
|
||||
bool learns_nonnegative_weights (
|
||||
) const { return learn_nonnegative_weights; }
|
||||
|
||||
void set_learns_nonnegative_weights (
|
||||
bool value
|
||||
)
|
||||
{
|
||||
learn_nonnegative_weights = value;
|
||||
}
|
||||
|
||||
void set_c (
|
||||
scalar_type C_
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(C_ > 0,
|
||||
"\t void svm_rank_trainer::set_c()"
|
||||
<< "\n\t C_ must be greater than 0"
|
||||
<< "\n\t C_: " << C_
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
C = C_;
|
||||
}
|
||||
|
||||
const scalar_type get_c (
|
||||
) const
|
||||
{
|
||||
return C;
|
||||
}
|
||||
|
||||
const decision_function<kernel_type> train (
|
||||
const std::vector<ranking_pair<sample_type> >& samples
|
||||
) const
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_CASSERT(is_ranking_problem(samples) == true,
|
||||
"\t decision_function svm_rank_trainer::train(samples)"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t samples.size(): " << samples.size()
|
||||
<< "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
|
||||
);
|
||||
|
||||
|
||||
typedef matrix<scalar_type,0,1> w_type;
|
||||
w_type w;
|
||||
|
||||
const unsigned long num_dims = max_index_plus_one(samples);
|
||||
|
||||
unsigned long num_nonnegative = 0;
|
||||
if (learn_nonnegative_weights)
|
||||
{
|
||||
num_nonnegative = num_dims;
|
||||
}
|
||||
|
||||
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations),
|
||||
w,
|
||||
num_nonnegative);
|
||||
|
||||
// put the solution into a decision function and then return it
|
||||
decision_function<kernel_type> df;
|
||||
df.b = 0;
|
||||
df.basis_vectors.set_size(1);
|
||||
// Copy the results into the output basis vector. The output vector might be a
|
||||
// sparse vector container so we need to use this special kind of copy to
|
||||
// handle that case.
|
||||
assign(df.basis_vectors(0), matrix_cast<scalar_type>(w));
|
||||
df.alpha.set_size(1);
|
||||
df.alpha(0) = 1;
|
||||
|
||||
return df;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
scalar_type C;
|
||||
oca solver;
|
||||
scalar_type eps;
|
||||
bool verbose;
|
||||
unsigned long max_iterations;
|
||||
bool learn_nonnegative_weights;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
#endif // DLIB_SVM_RANK_TrAINER_H__
|
||||
|
2
dlib/svm/svm_rank_trainer_abstract.h
Normal file
2
dlib/svm/svm_rank_trainer_abstract.h
Normal file
@ -0,0 +1,2 @@
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user