mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added an RBF network trainer
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402397
This commit is contained in:
parent
b68878b412
commit
e32aa6cf90
@ -8,6 +8,7 @@
|
||||
#include "svm/kcentroid.h"
|
||||
#include "svm/kkmeans.h"
|
||||
#include "svm/feature_ranking.h"
|
||||
#include "svm/rbf_network.h"
|
||||
|
||||
#endif // DLIB_SVm_HEADER
|
||||
|
||||
|
176
dlib/svm/rbf_network.h
Normal file
176
dlib/svm/rbf_network.h
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_RBf_NETWORK_
|
||||
#define DLIB_RBf_NETWORK_
|
||||
|
||||
#include "../matrix.h"
|
||||
#include "rbf_network_abstract.h"
|
||||
#include "kernel.h"
|
||||
#include "kcentroid.h"
|
||||
#include "function.h"
|
||||
#include "../algs.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename sample_type_
|
||||
>
|
||||
class rbf_network_trainer
|
||||
{
|
||||
/*!
|
||||
This is an implemenation of an RBF network trainer that follows
|
||||
the directions right off Wikipedia basically. So nothing
|
||||
particularly fancy.
|
||||
!*/
|
||||
|
||||
public:
|
||||
typedef radial_basis_kernel<sample_type_> kernel_type;
|
||||
typedef typename kernel_type::scalar_type scalar_type;
|
||||
typedef typename kernel_type::sample_type sample_type;
|
||||
typedef typename kernel_type::mem_manager_type mem_manager_type;
|
||||
typedef decision_function<kernel_type> trained_function_type;
|
||||
|
||||
rbf_network_trainer (
|
||||
) :
|
||||
gamma(0.1),
|
||||
tolerance(0.01)
|
||||
{
|
||||
}
|
||||
|
||||
void set_gamma (
|
||||
scalar_type gamma_
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(gamma_ > 0,
|
||||
"\tvoid rbf_network_trainer::set_gamma(gamma_)"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t gamma: " << gamma_
|
||||
);
|
||||
gamma = gamma_;
|
||||
}
|
||||
|
||||
const scalar_type get_gamma (
|
||||
) const
|
||||
{
|
||||
return gamma;
|
||||
}
|
||||
|
||||
void set_tolerance (
|
||||
const scalar_type& tol
|
||||
)
|
||||
{
|
||||
tolerance = tol;
|
||||
}
|
||||
|
||||
const scalar_type& get_tolerance (
|
||||
) const
|
||||
{
|
||||
return tolerance;
|
||||
}
|
||||
|
||||
template <
|
||||
typename in_sample_vector_type,
|
||||
typename in_scalar_vector_type
|
||||
>
|
||||
const decision_function<kernel_type> train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y
|
||||
) const
|
||||
{
|
||||
return do_train(vector_to_matrix(x), vector_to_matrix(y));
|
||||
}
|
||||
|
||||
void swap (
|
||||
rbf_network_trainer& item
|
||||
)
|
||||
{
|
||||
exchange(gamma, item.gamma);
|
||||
exchange(tolerance, item.tolerance);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename in_sample_vector_type,
|
||||
typename in_scalar_vector_type
|
||||
>
|
||||
const decision_function<kernel_type> do_train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y
|
||||
) const
|
||||
{
|
||||
typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
|
||||
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
|
||||
"\tdecision_function rbf_network_trainer::train(x,y)"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t x.nr(): " << x.nr()
|
||||
<< "\n\t y.nr(): " << y.nr()
|
||||
<< "\n\t x.nc(): " << x.nc()
|
||||
<< "\n\t y.nc(): " << y.nc()
|
||||
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
|
||||
);
|
||||
|
||||
// first run all the sampes through a kcentroid object to find the rbf centers
|
||||
const kernel_type kernel(gamma);
|
||||
kcentroid<kernel_type> kc(kernel,tolerance);
|
||||
for (long i = 0; i < x.size(); ++i)
|
||||
{
|
||||
kc.train(x(i));
|
||||
}
|
||||
|
||||
// now we have a trained kcentroid so lets just extract its results. Note that
|
||||
// all we want out of the kcentroid is really just the set of support vectors
|
||||
// it contains so that we can use them as the RBF centers.
|
||||
distance_function<kernel_type> df(kc.get_distance_function());
|
||||
const long num_centers = df.support_vectors.nr();
|
||||
|
||||
// fill the K matrix with the output of the kernel for all the center and sample point pairs
|
||||
matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1);
|
||||
for (long r = 0; r < x.nr(); ++r)
|
||||
{
|
||||
for (long c = 0; c < num_centers; ++c)
|
||||
{
|
||||
K(r,c) = kernel(x(r), df.support_vectors(c));
|
||||
}
|
||||
// This last column of the K matrix takes care of the bias term
|
||||
K(r,num_centers) = 1;
|
||||
}
|
||||
|
||||
// compute the best weights by using the pseudo-inverse
|
||||
scalar_vector_type weights(pinv(K)*y);
|
||||
|
||||
// now put everything into a decision_function object and return it
|
||||
return decision_function<kernel_type> (remove_row(weights,num_centers),
|
||||
-weights(num_centers),
|
||||
kernel,
|
||||
df.support_vectors);
|
||||
|
||||
}
|
||||
|
||||
scalar_type gamma;
|
||||
scalar_type tolerance;
|
||||
|
||||
}; // end of class rbf_network_trainer
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename sample_type>
|
||||
void swap (
|
||||
rbf_network_trainer<sample_type>& a,
|
||||
rbf_network_trainer<sample_type>& b
|
||||
) { a.swap(b); }
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_RBf_NETWORK_
|
||||
|
138
dlib/svm/rbf_network_abstract.h
Normal file
138
dlib/svm/rbf_network_abstract.h
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_RBf_NETWORK_ABSTRACT_
|
||||
#ifdef DLIB_RBf_NETWORK_ABSTRACT_
|
||||
|
||||
#include "../matrix/matrix_abstract.h"
|
||||
#include "../algs.h"
|
||||
#include "function_abstract.h"
|
||||
#include "kernel_abstract.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename sample_type_
|
||||
>
|
||||
class rbf_network_trainer
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON sample_type_
|
||||
is a dlib::matrix type
|
||||
|
||||
INITIAL VALUE
|
||||
- get_gamma() == 0.1
|
||||
- get_tolerance() == 0.01
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object implements a trainer for radial basis function network for
|
||||
solving binary classification problems.
|
||||
|
||||
The implementation of this algorithm follows the normal RBF training
|
||||
process. For more details see the code or the Wikipedia article
|
||||
about RBF networks.
|
||||
!*/
|
||||
public:
|
||||
typedef radial_basis_kernel<sample_type_> kernel_type;
|
||||
typedef typename kernel_type::scalar_type scalar_type;
|
||||
typedef typename kernel_type::sample_type sample_type;
|
||||
typedef typename kernel_type::mem_manager_type mem_manager_type;
|
||||
typedef decision_function<kernel_type> trained_function_type;
|
||||
|
||||
rbf_network_trainer (
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- this object is properly initialized
|
||||
!*/
|
||||
|
||||
void set_gamma (
|
||||
scalar_type gamma
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- gamma > 0
|
||||
ensures
|
||||
- #get_gamma() == gamma
|
||||
!*/
|
||||
|
||||
const scalar_type get_gamma (
|
||||
) const
|
||||
/*!
|
||||
ensures
|
||||
- returns the gamma argument used in the radial_basis_kernel used
|
||||
to represent each node in an RBF network.
|
||||
!*/
|
||||
|
||||
void set_tolerance (
|
||||
const scalar_type& tol
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_tolerance() == tol
|
||||
!*/
|
||||
|
||||
const scalar_type& get_tolerance (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the tolerance parameter. This parameter controls how many
|
||||
RBF centers (a.k.a. support_vectors in the trained decision_function)
|
||||
you get when you call the train function. A smaller tolerance
|
||||
results in more centers while a bigger number results in fewer.
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename in_sample_vector_type,
|
||||
typename in_scalar_vector_type
|
||||
>
|
||||
const decision_function<kernel_type> train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y
|
||||
) const
|
||||
/*!
|
||||
requires
|
||||
- is_binary_classification_problem(x,y) == true
|
||||
ensures
|
||||
- trains a RBF network given the training samples in x and
|
||||
labels in y.
|
||||
- returns a decision function F with the following properties:
|
||||
- if (new_x is a sample predicted have +1 label) then
|
||||
- F(new_x) >= 0
|
||||
- else
|
||||
- F(new_x) < 0
|
||||
throws
|
||||
- std::bad_alloc
|
||||
!*/
|
||||
|
||||
void swap (
|
||||
rbf_network_trainer& item
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- swaps *this and item
|
||||
!*/
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename sample_type>
|
||||
void swap (
|
||||
rbf_network_trainer<sample_type>& a,
|
||||
rbf_network_trainer<sample_type>& b
|
||||
) { a.swap(b); }
|
||||
/*!
|
||||
provides a global swap
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_RBf_NETWORK_ABSTRACT_
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user