Added overloads of the kernel_derivative object for all the kernels in dlib.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403067
This commit is contained in:
Davis King 2009-05-27 02:21:00 +00:00
parent 4ed6922b37
commit 925a9be91c
3 changed files with 116 additions and 1 deletions

View File

@ -318,6 +318,28 @@ namespace dlib
}
}
template <
typename T
>
struct kernel_derivative<sigmoid_kernel<T> >
{
typedef typename T::type scalar_type;
typedef T sample_type;
typedef typename T::mem_manager_type mem_manager_type;
kernel_derivative(const sigmoid_kernel<T>& k_) : k(k_){}
const sample_type& operator() (const sample_type& x, const sample_type& y) const
{
// return the derivative of the rbf kernel
temp = k.gamma*x*(1-std::pow(k(x,y),2));
return temp;
}
const sigmoid_kernel<T>& k;
mutable sample_type temp;
};
// ----------------------------------------------------------------------------------------
template <typename T>
@ -359,6 +381,25 @@ namespace dlib
std::istream& in
){}
template <
typename T
>
struct kernel_derivative<linear_kernel<T> >
{
typedef typename T::type scalar_type;
typedef T sample_type;
typedef typename T::mem_manager_type mem_manager_type;
kernel_derivative(const linear_kernel<T>& k_) : k(k_){}
const sample_type& operator() (const sample_type& x, const sample_type& y) const
{
return x;
}
const linear_kernel<T>& k;
};
// ----------------------------------------------------------------------------------------
template <typename T>
@ -442,6 +483,25 @@ namespace dlib
}
}
template <
typename T
>
struct kernel_derivative<offset_kernel<T> >
{
typedef typename T::scalar_type scalar_type;
typedef typename T::sample_type sample_type;
typedef typename T::mem_manager_type mem_manager_type;
kernel_derivative(const offset_kernel<T>& k) : der(k.kernel){}
const sample_type operator() (const sample_type& x, const sample_type& y) const
{
return der(x,y);
}
kernel_derivative<T> der;
};
// ----------------------------------------------------------------------------------------
}

View File

@ -537,6 +537,9 @@ namespace dlib
kernel_type must be one of the following kernel types:
- radial_basis_kernel
- polynomial_kernel
- sigmoid_kernel
- linear_kernel
- offset_kernel
WHAT THIS OBJECT REPRESENTS
This is a function object that computes the derivative of a kernel
@ -562,7 +565,7 @@ namespace dlib
) const;
/*!
ensures
- returns the derivative of k with respect to y. Or in other words, k(x, y+dy)/dy
- returns the derivative of k with respect to y.
!*/
const kernel_type& k;

View File

@ -399,9 +399,60 @@ namespace
// ----------------------------------------------------------------------------------------
template <typename kernel_type>
struct kernel_der_obj
{
typename kernel_type::sample_type x;
kernel_type k;
double operator()(const typename kernel_type::sample_type& y) const { return k(x,y); }
};
template <typename kernel_type>
void test_kernel_derivative (
const kernel_type& k,
const typename kernel_type::sample_type& x,
const typename kernel_type::sample_type& y
)
{
kernel_der_obj<kernel_type> obj;
obj.x = x;
obj.k = k;
kernel_derivative<kernel_type> der(obj.k);
DLIB_CASSERT(dlib::equal(derivative(obj)(y) , der(obj.x,y), 1e-5), "");
}
void test_kernel_derivative (
)
{
typedef matrix<double, 2, 1> sample_type;
sigmoid_kernel<sample_type> k1;
radial_basis_kernel<sample_type> k2;
linear_kernel<sample_type> k3;
polynomial_kernel<sample_type> k4(2,3,4);
offset_kernel<sigmoid_kernel<sample_type> > k5;
offset_kernel<radial_basis_kernel<sample_type> > k6;
dlib::rand::float_1a rnd;
sample_type x, y;
for (int i = 0; i < 10; ++i)
{
x = randm(2,1,rnd);
y = randm(2,1,rnd);
test_kernel_derivative(k1, x, y);
test_kernel_derivative(k2, x, y);
test_kernel_derivative(k3, x, y);
test_kernel_derivative(k4, x, y);
test_kernel_derivative(k5, x, y);
test_kernel_derivative(k6, x, y);
}
}
// ----------------------------------------------------------------------------------------
class svm_tester : public tester
{
@ -415,6 +466,7 @@ namespace
void perform_test (
)
{
test_kernel_derivative();
test_binary_classification();
test_clutering();
test_regression();