Added the reduced_decision_function_trainer object and

reduced() function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402390
pull/2/head
Davis King 16 years ago
parent 24974b75f8
commit 697df854f9

@ -15,6 +15,7 @@
#include "function.h"
#include "kernel.h"
#include "../enable_if.h"
#include "kcentroid.h"
namespace dlib
{
@ -1419,6 +1420,101 @@ namespace dlib
svm_nu_trainer<K>& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
class reduced_decision_function_trainer
{
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef typename trainer_type::trained_function_type trained_function_type;
explicit reduced_decision_function_trainer (
const trainer_type& trainer_,
const scalar_type tolerance_ = 0.001
) :
trainer(trainer_),
tolerance(tolerance_)
{
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const
{
return do_train(vector_to_matrix(x), vector_to_matrix(y));
}
private:
// ------------------------------------------------------------------------------------
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\tdecision_function reduced_decision_function_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
);
kcentroid<kernel_type> kc(trainer.get_kernel(), tolerance);
decision_function<kernel_type> dec_funct = trainer.train(x,y);
// find the point in kernel space that is approximately the same as what is in the decision_function
// already.
for (long i = 0; i < dec_funct.support_vectors.nr(); ++i)
{
kc.train(dec_funct.support_vectors(i), 1, dec_funct.alpha(i));
}
distance_function<kernel_type> dist_funct = kc.get_distance_function();
return decision_function<kernel_type> (dist_funct.alpha,
dec_funct.b,
dist_funct.kernel_function,
dist_funct.support_vectors);
}
const trainer_type& trainer;
const scalar_type tolerance;
}; // end of class reduced_decision_function_trainer
template <typename trainer_type>
const reduced_decision_function_trainer<trainer_type> reduced (
const trainer_type& trainer,
const typename trainer_type::scalar_type& tolerance = 0.001
)
{
return reduced_decision_function_trainer<trainer_type>(trainer, tolerance);
}
// ----------------------------------------------------------------------------------------
}

@ -292,6 +292,86 @@ namespace dlib
- std::bad_alloc
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
class reduced_decision_function_trainer
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents an implementation of a reduced set algorithm
for support vector decision functions. This object acts as a post
processor for anything that creates decision_function objects. It
wraps another trainer object and performs this reduced set post
processing with the goal of representing the original decision
function in a form that involves fewer support vectors.
!*/
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef typename trainer_type::trained_function_type trained_function_type;
explicit reduced_decision_function_trainer (
const trainer_type& trainer_,
const scalar_type tolerance_ = 0.001
);
/*!
requires
- tolerance >= 0
- trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
ensures
- returns a trainer object that applies post processing to the decision_function
objects created by the given trainer object with the goal of creating
decision_function objects with fewer support vectors.
- tolerance == a parameter that controls how accurate the post processing
is in preserving the original decision_function. Larger values
result in a decision_function with fewer support vectors but may
decrease the accuracy of the resulting decision_function.
!*/
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const;
/*!
ensures
- trains a decision_function using the trainer that was supplied to
this object's constructor and then finds a reduced representation
for it and returns the reduced version.
throws
- std::bad_alloc
- any exceptions thrown by the trainer_type object
!*/
};
template <
typename trainer_type
>
const reduced_decision_function_trainer<trainer_type> reduced (
const trainer_type& trainer,
const typename trainer_type::scalar_type& tolerance = 0.001
) { return reduced_decision_function_trainer<trainer_type>(trainer, tolerance); }
/*!
requires
- tolerance >= 0
- trainer_type == some kind of trainer object that creates decision_function
objects (e.g. svm_nu_trainer)
ensures
- returns a reduced_decision_function_trainer object that has been
instantiated with the given arguments.
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// Miscellaneous functions

Loading…
Cancel
Save