diff --git a/dlib/svm.h b/dlib/svm.h index 1a18726a0..17ca243de 100644 --- a/dlib/svm.h +++ b/dlib/svm.h @@ -8,6 +8,7 @@ #include "svm/kcentroid.h" #include "svm/kkmeans.h" #include "svm/feature_ranking.h" +#include "svm/rbf_network.h" #endif // DLIB_SVm_HEADER diff --git a/dlib/svm/rbf_network.h b/dlib/svm/rbf_network.h new file mode 100644 index 000000000..1323decec --- /dev/null +++ b/dlib/svm/rbf_network.h @@ -0,0 +1,176 @@ +// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RBf_NETWORK_ +#define DLIB_RBf_NETWORK_ + +#include "../matrix.h" +#include "rbf_network_abstract.h" +#include "kernel.h" +#include "kcentroid.h" +#include "function.h" +#include "../algs.h" + +namespace dlib +{ + +// ------------------------------------------------------------------------------ + + template < + typename sample_type_ + > + class rbf_network_trainer + { + /*! + This is an implemenation of an RBF network trainer that follows + the directions right off Wikipedia basically. So nothing + particularly fancy. + !*/ + + public: + typedef radial_basis_kernel kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rbf_network_trainer ( + ) : + gamma(0.1), + tolerance(0.01) + { + } + + void set_gamma ( + scalar_type gamma_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(gamma_ > 0, + "\tvoid rbf_network_trainer::set_gamma(gamma_)" + << "\n\t invalid inputs were given to this function" + << "\n\t gamma: " << gamma_ + ); + gamma = gamma_; + } + + const scalar_type get_gamma ( + ) const + { + return gamma; + } + + void set_tolerance ( + const scalar_type& tol + ) + { + tolerance = tol; + } + + const scalar_type& get_tolerance ( + ) const + { + return tolerance; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(vector_to_matrix(x), vector_to_matrix(y)); + } + + void swap ( + rbf_network_trainer& item + ) + { + exchange(gamma, item.gamma); + exchange(tolerance, item.tolerance); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\tdecision_function rbf_network_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false") + ); + + // first run all the sampes through a kcentroid object to find the rbf centers + const kernel_type kernel(gamma); + kcentroid kc(kernel,tolerance); + for (long i = 0; i < x.size(); ++i) + { + kc.train(x(i)); + } + + // now we have a trained kcentroid so lets just extract its results. Note that + // all we want out of the kcentroid is really just the set of support vectors + // it contains so that we can use them as the RBF centers. + distance_function df(kc.get_distance_function()); + const long num_centers = df.support_vectors.nr(); + + // fill the K matrix with the output of the kernel for all the center and sample point pairs + matrix K(x.nr(), num_centers+1); + for (long r = 0; r < x.nr(); ++r) + { + for (long c = 0; c < num_centers; ++c) + { + K(r,c) = kernel(x(r), df.support_vectors(c)); + } + // This last column of the K matrix takes care of the bias term + K(r,num_centers) = 1; + } + + // compute the best weights by using the pseudo-inverse + scalar_vector_type weights(pinv(K)*y); + + // now put everything into a decision_function object and return it + return decision_function (remove_row(weights,num_centers), + -weights(num_centers), + kernel, + df.support_vectors); + + } + + scalar_type gamma; + scalar_type tolerance; + + }; // end of class rbf_network_trainer + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rbf_network_trainer& a, + rbf_network_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RBf_NETWORK_ + diff --git a/dlib/svm/rbf_network_abstract.h b/dlib/svm/rbf_network_abstract.h new file mode 100644 index 000000000..7dd6487cb --- /dev/null +++ b/dlib/svm/rbf_network_abstract.h @@ -0,0 +1,138 @@ +// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RBf_NETWORK_ABSTRACT_ +#ifdef DLIB_RBf_NETWORK_ABSTRACT_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_ + > + class rbf_network_trainer + { + /*! + REQUIREMENTS ON sample_type_ + is a dlib::matrix type + + INITIAL VALUE + - get_gamma() == 0.1 + - get_tolerance() == 0.01 + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for radial basis function network for + solving binary classification problems. + + The implementation of this algorithm follows the normal RBF training + process. For more details see the code or the Wikipedia article + about RBF networks. + !*/ + public: + typedef radial_basis_kernel kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rbf_network_trainer ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_gamma ( + scalar_type gamma + ); + /*! + requires + - gamma > 0 + ensures + - #get_gamma() == gamma + !*/ + + const scalar_type get_gamma ( + ) const + /*! + ensures + - returns the gamma argument used in the radial_basis_kernel used + to represent each node in an RBF network. + !*/ + + void set_tolerance ( + const scalar_type& tol + ); + /*! + ensures + - #get_tolerance() == tol + !*/ + + const scalar_type& get_tolerance ( + ) const; + /*! + ensures + - returns the tolerance parameter. This parameter controls how many + RBF centers (a.k.a. support_vectors in the trained decision_function) + you get when you call the train function. A smaller tolerance + results in more centers while a bigger number results in fewer. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + /*! + requires + - is_binary_classification_problem(x,y) == true + ensures + - trains a RBF network given the training samples in x and + labels in y. + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + throws + - std::bad_alloc + !*/ + + void swap ( + rbf_network_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rbf_network_trainer& a, + rbf_network_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RBf_NETWORK_ABSTRACT_ + + +