mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
- Fixed a minor numerical error in the krls code so now it gets slightly better
results. - Added the ability to cap the number of dictionary vectors used by the krls object at a user specified number. This changes the serialization format of the object. I also removed the function to set the threshold after the object has been constructed. --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402311
This commit is contained in:
parent
f8fe979b7a
commit
307740b85d
114
dlib/svm/krls.h
114
dlib/svm/krls.h
@ -31,37 +31,33 @@ namespace dlib
|
||||
|
||||
explicit krls (
|
||||
const kernel_type& kernel_,
|
||||
scalar_type tolerance_ = 0.001
|
||||
scalar_type tolerance_ = 0.001,
|
||||
unsigned long max_dictionary_size_ = 1000000
|
||||
) :
|
||||
kernel(kernel_),
|
||||
tolerance(tolerance_)
|
||||
tolerance(tolerance_),
|
||||
max_dictionary_size(max_dictionary_size_)
|
||||
{
|
||||
clear_dictionary();
|
||||
}
|
||||
|
||||
void set_tolerance (scalar_type tolerance_)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(tolerance_ >= 0,
|
||||
"\tvoid krls::set_tolerance"
|
||||
<< "\n\tinvalid tolerance value"
|
||||
<< "\n\ttolerance: " << tolerance_
|
||||
<< "\n\tthis: " << this
|
||||
);
|
||||
tolerance = tolerance_;
|
||||
}
|
||||
|
||||
scalar_type get_tolerance() const
|
||||
{
|
||||
return tolerance;
|
||||
}
|
||||
|
||||
unsigned long get_max_dictionary_size() const
|
||||
{
|
||||
return max_dictionary_size;
|
||||
}
|
||||
|
||||
void clear_dictionary ()
|
||||
{
|
||||
dictionary.clear();
|
||||
alpha.clear();
|
||||
|
||||
K_inv.set_size(0,0);
|
||||
K.set_size(0,0);
|
||||
P.set_size(0,0);
|
||||
}
|
||||
|
||||
@ -71,7 +67,7 @@ namespace dlib
|
||||
{
|
||||
scalar_type temp = 0;
|
||||
for (unsigned long i = 0; i < alpha.size(); ++i)
|
||||
temp += alpha[i]*kernel(dictionary[i], x);
|
||||
temp += alpha[i]*kern(dictionary[i], x);
|
||||
|
||||
return temp;
|
||||
}
|
||||
@ -88,6 +84,8 @@ namespace dlib
|
||||
|
||||
K_inv.set_size(1,1);
|
||||
K_inv(0,0) = 1/kx;
|
||||
K.set_size(1,1);
|
||||
K(0,0) = kx;
|
||||
|
||||
alpha.push_back(y/kx);
|
||||
dictionary.push_back(x);
|
||||
@ -104,12 +102,25 @@ namespace dlib
|
||||
// compute the error we would have if we approximated the new x sample
|
||||
// with the dictionary. That is, do the ALD test from the KRLS paper.
|
||||
a = K_inv*k;
|
||||
const scalar_type delta = kx - trans(k)*a;
|
||||
scalar_type delta = kx - trans(k)*a;
|
||||
|
||||
// if this new vector isn't approximately linearly dependent on the vectors
|
||||
// in our dictionary.
|
||||
if (std::abs(delta) > tolerance)
|
||||
{
|
||||
if (dictionary.size() >= max_dictionary_size)
|
||||
{
|
||||
// We need to remove one of the old members of the dictionary before
|
||||
// we proceed with adding a new one. So remove the oldest one.
|
||||
remove_dictionary_vector(0);
|
||||
|
||||
// recompute these guys since they were computed with the old
|
||||
// kernel matrix
|
||||
k = remove_row(k,0);
|
||||
a = K_inv*k;
|
||||
delta = kx - trans(k)*a;
|
||||
}
|
||||
|
||||
// add x to the dictionary
|
||||
dictionary.push_back(x);
|
||||
|
||||
@ -127,6 +138,22 @@ namespace dlib
|
||||
temp.swap(K_inv);
|
||||
|
||||
|
||||
|
||||
|
||||
// update K (the kernel matrix)
|
||||
temp.set_size(K.nr()+1, K.nc()+1);
|
||||
set_subm(temp, get_rect(K)) = K;
|
||||
// update the right column of the matrix
|
||||
set_subm(temp, 0, K.nr(),K.nr(),1) = k;
|
||||
// update the bottom row of the matrix
|
||||
set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k);
|
||||
temp(K.nr(), K.nc()) = kx;
|
||||
// put temp into K
|
||||
temp.swap(K);
|
||||
|
||||
|
||||
|
||||
|
||||
// Now update the P matrix (equation 3.15)
|
||||
temp.set_size(P.nr()+1, P.nc()+1);
|
||||
set_subm(temp, get_rect(P)) = P;
|
||||
@ -170,12 +197,14 @@ namespace dlib
|
||||
dictionary.swap(item.dictionary);
|
||||
alpha.swap(item.alpha);
|
||||
K_inv.swap(item.K_inv);
|
||||
K.swap(item.K);
|
||||
P.swap(item.P);
|
||||
exchange(tolerance, item.tolerance);
|
||||
q.swap(item.q);
|
||||
a.swap(item.a);
|
||||
k.swap(item.k);
|
||||
temp_matrix.swap(item.temp_matrix);
|
||||
exchange(max_dictionary_size, item.max_dictionary_size);
|
||||
}
|
||||
|
||||
unsigned long dictionary_size (
|
||||
@ -186,7 +215,7 @@ namespace dlib
|
||||
{
|
||||
return decision_function<kernel_type>(
|
||||
vector_to_matrix(alpha),
|
||||
0, // the KRLS algorithm doesn't have a bias term
|
||||
-sum(vector_to_matrix(alpha))*tau,
|
||||
kernel,
|
||||
vector_to_matrix(dictionary)
|
||||
);
|
||||
@ -198,8 +227,10 @@ namespace dlib
|
||||
serialize(item.dictionary, out);
|
||||
serialize(item.alpha, out);
|
||||
serialize(item.K_inv, out);
|
||||
serialize(item.K, out);
|
||||
serialize(item.P, out);
|
||||
serialize(item.tolerance, out);
|
||||
serialize(item.max_dictionary_size, out);
|
||||
}
|
||||
|
||||
friend void deserialize(krls& item, std::istream& in)
|
||||
@ -208,15 +239,58 @@ namespace dlib
|
||||
deserialize(item.dictionary, in);
|
||||
deserialize(item.alpha, in);
|
||||
deserialize(item.K_inv, in);
|
||||
deserialize(item.K, in);
|
||||
deserialize(item.P, in);
|
||||
deserialize(item.tolerance, in);
|
||||
deserialize(item.max_dictionary_size, in);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
inline scalar_type kern (const sample_type& m1, const sample_type& m2) const
|
||||
{
|
||||
return kernel(m1,m2) + 0.001;
|
||||
return kernel(m1,m2) + tau;
|
||||
}
|
||||
|
||||
void remove_dictionary_vector (
|
||||
long i
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
- 0 <= i < dictionary.size()
|
||||
ensures
|
||||
- #dictionary.size() == dictionary.size() - 1
|
||||
- #alpha.size() == alpha.size() - 1
|
||||
- updates the K_inv matrix so that it is still a proper inverse of the
|
||||
kernel matrix
|
||||
- also removes the necessary row and column from the K matrix
|
||||
- uses the this->a variable so after this function runs that variable
|
||||
will contain a different value.
|
||||
!*/
|
||||
{
|
||||
// remove the dictionary vector
|
||||
dictionary.erase(dictionary.begin()+i);
|
||||
|
||||
// remove the i'th vector from the inverse kernel matrix. This formula is basically
|
||||
// just the reverse of the way K_inv is updated by equation 3.14 during normal training.
|
||||
K_inv = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i);
|
||||
|
||||
// now compute the updated alpha values to take account that we just removed one of
|
||||
// our dictionary vectors
|
||||
a = (K_inv*remove_row(K,i)*vector_to_matrix(alpha));
|
||||
|
||||
// now copy over the new alpha values
|
||||
alpha.resize(alpha.size()-1);
|
||||
for (unsigned long k = 0; k < alpha.size(); ++k)
|
||||
{
|
||||
alpha[k] = a(k);
|
||||
}
|
||||
|
||||
// update the P matrix as well
|
||||
P = removerc(P,i,i);
|
||||
|
||||
// update the K matrix as well
|
||||
K = removerc(K,i,i);
|
||||
}
|
||||
|
||||
|
||||
@ -231,9 +305,11 @@ namespace dlib
|
||||
alpha_vector_type alpha;
|
||||
|
||||
matrix<scalar_type,0,0,mem_manager_type> K_inv;
|
||||
matrix<scalar_type,0,0,mem_manager_type> K;
|
||||
matrix<scalar_type,0,0,mem_manager_type> P;
|
||||
|
||||
scalar_type tolerance;
|
||||
unsigned long max_dictionary_size;
|
||||
|
||||
|
||||
// temp variables here just so we don't have to reconstruct them over and over. Thus,
|
||||
@ -243,6 +319,8 @@ namespace dlib
|
||||
matrix<scalar_type,0,1,mem_manager_type> k;
|
||||
matrix<scalar_type,1,0,mem_manager_type> temp_matrix;
|
||||
|
||||
const static double tau = 0.01;
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -32,6 +32,13 @@ namespace dlib
|
||||
The long and short of this algorithm is that it is an online kernel based
|
||||
regression algorithm. You give it samples (x,y) and it learns the function
|
||||
f(x) == y. For a detailed description of the algorithm read the above paper.
|
||||
|
||||
Also note that the algorithm internally keeps a set of "dictionary vectors"
|
||||
that are used to represent the regression function. You can force the
|
||||
algorithm to use no more than a set number of vectors by setting
|
||||
the 3rd constructor argument to whatever you want. However, note that
|
||||
doing this causes the algorithm to bias it's results towards more
|
||||
recent training examples.
|
||||
!*/
|
||||
|
||||
public:
|
||||
@ -42,7 +49,8 @@ namespace dlib
|
||||
|
||||
explicit krls (
|
||||
const kernel_type& kernel_,
|
||||
scalar_type tolerance_ = 0.001
|
||||
scalar_type tolerance_ = 0.001,
|
||||
unsigned long max_dictionary_size_ = 1000000
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
@ -50,16 +58,7 @@ namespace dlib
|
||||
- #get_tolerance() == tolerance_
|
||||
- #get_decision_function().kernel_function == kernel_
|
||||
(i.e. this object will use the given kernel function)
|
||||
!*/
|
||||
|
||||
void set_tolerance (
|
||||
scalar_type tolerance_
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- tolerance_ >= 0
|
||||
ensures
|
||||
- #get_tolerance() == tolerance_
|
||||
- #get_max_dictionary_size() == max_dictionary_size_
|
||||
!*/
|
||||
|
||||
scalar_type get_tolerance(
|
||||
@ -75,6 +74,15 @@ namespace dlib
|
||||
less accurate decision function but also in less support vectors.
|
||||
!*/
|
||||
|
||||
unsigned long get_max_dictionary_size(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the maximum number of dictionary vectors this object
|
||||
will use at a time. That is, dictionary_size() will never be
|
||||
greater than get_max_dictionary_size().
|
||||
!*/
|
||||
|
||||
void clear_dictionary (
|
||||
);
|
||||
/*!
|
||||
@ -98,6 +106,11 @@ namespace dlib
|
||||
/*!
|
||||
ensures
|
||||
- trains this object that the given x should be mapped to the given y
|
||||
- if (dictionary_size() == get_max_dictionary_size() and training
|
||||
would add another dictionary vector to this object) then
|
||||
- discards the oldest dictionary vector so that we can still
|
||||
add a new one and remain below the max number of dictionary
|
||||
vectors.
|
||||
!*/
|
||||
|
||||
void swap (
|
||||
|
Loading…
Reference in New Issue
Block a user