From 697df854f9a4c50a65724f9c7443dc5318e75526 Mon Sep 17 00:00:00 2001 From: Davis King Date: Tue, 8 Jul 2008 03:07:45 +0000 Subject: [PATCH] Added the reduced_decision_function_trainer object and reduced() function. --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402390 --- dlib/svm/svm.h | 96 +++++++++++++++++++++++++++++++++++++++++ dlib/svm/svm_abstract.h | 80 ++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+) diff --git a/dlib/svm/svm.h b/dlib/svm/svm.h index 07046eb4d..32ee58d48 100644 --- a/dlib/svm/svm.h +++ b/dlib/svm/svm.h @@ -15,6 +15,7 @@ #include "function.h" #include "kernel.h" #include "../enable_if.h" +#include "kcentroid.h" namespace dlib { @@ -1419,6 +1420,101 @@ namespace dlib svm_nu_trainer& b ) { a.swap(b); } +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class reduced_decision_function_trainer + { + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + explicit reduced_decision_function_trainer ( + const trainer_type& trainer_, + const scalar_type tolerance_ = 0.001 + ) : + trainer(trainer_), + tolerance(tolerance_) + { + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(vector_to_matrix(x), vector_to_matrix(y)); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\tdecision_function reduced_decision_function_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false") + ); + + + kcentroid kc(trainer.get_kernel(), tolerance); + decision_function dec_funct = trainer.train(x,y); + + // find the point in kernel space that is approximately the same as what is in the decision_function + // already. + for (long i = 0; i < dec_funct.support_vectors.nr(); ++i) + { + kc.train(dec_funct.support_vectors(i), 1, dec_funct.alpha(i)); + } + + distance_function dist_funct = kc.get_distance_function(); + + return decision_function (dist_funct.alpha, + dec_funct.b, + dist_funct.kernel_function, + dist_funct.support_vectors); + } + + const trainer_type& trainer; + const scalar_type tolerance; + + + }; // end of class reduced_decision_function_trainer + + template + const reduced_decision_function_trainer reduced ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& tolerance = 0.001 + ) + { + return reduced_decision_function_trainer(trainer, tolerance); + } + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/svm/svm_abstract.h b/dlib/svm/svm_abstract.h index cf289922b..036ba20c2 100644 --- a/dlib/svm/svm_abstract.h +++ b/dlib/svm/svm_abstract.h @@ -292,6 +292,86 @@ namespace dlib - std::bad_alloc !*/ +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class reduced_decision_function_trainer + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an implementation of a reduced set algorithm + for support vector decision functions. This object acts as a post + processor for anything that creates decision_function objects. It + wraps another trainer object and performs this reduced set post + processing with the goal of representing the original decision + function in a form that involves fewer support vectors. + !*/ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + explicit reduced_decision_function_trainer ( + const trainer_type& trainer_, + const scalar_type tolerance_ = 0.001 + ); + /*! + requires + - tolerance >= 0 + - trainer_type == some kind of trainer object (e.g. svm_nu_trainer) + ensures + - returns a trainer object that applies post processing to the decision_function + objects created by the given trainer object with the goal of creating + decision_function objects with fewer support vectors. + - tolerance == a parameter that controls how accurate the post processing + is in preserving the original decision_function. Larger values + result in a decision_function with fewer support vectors but may + decrease the accuracy of the resulting decision_function. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + ensures + - trains a decision_function using the trainer that was supplied to + this object's constructor and then finds a reduced representation + for it and returns the reduced version. + throws + - std::bad_alloc + - any exceptions thrown by the trainer_type object + !*/ + + }; + + template < + typename trainer_type + > + const reduced_decision_function_trainer reduced ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& tolerance = 0.001 + ) { return reduced_decision_function_trainer(trainer, tolerance); } + /*! + requires + - tolerance >= 0 + - trainer_type == some kind of trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + ensures + - returns a reduced_decision_function_trainer object that has been + instantiated with the given arguments. + !*/ + + // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // Miscellaneous functions