mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added the distance_function object
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402387
This commit is contained in:
parent
0c1c9f67b9
commit
f7d97090ba
@ -231,6 +231,144 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
struct distance_function
|
||||
{
|
||||
typedef typename K::scalar_type scalar_type;
|
||||
typedef typename K::sample_type sample_type;
|
||||
typedef typename K::mem_manager_type mem_manager_type;
|
||||
|
||||
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
|
||||
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
|
||||
|
||||
const scalar_vector_type alpha;
|
||||
const scalar_type b;
|
||||
const K kernel_function;
|
||||
const sample_vector_type support_vectors;
|
||||
|
||||
distance_function (
|
||||
) : b(0), kernel_function(K()) {}
|
||||
|
||||
distance_function (
|
||||
const distance_function& d
|
||||
) :
|
||||
alpha(d.alpha),
|
||||
b(d.b),
|
||||
kernel_function(d.kernel_function),
|
||||
support_vectors(d.support_vectors)
|
||||
{}
|
||||
|
||||
distance_function (
|
||||
const scalar_vector_type& alpha_,
|
||||
const scalar_type& b_,
|
||||
const K& kernel_function_,
|
||||
const sample_vector_type& support_vectors_
|
||||
) :
|
||||
alpha(alpha_),
|
||||
b(b_),
|
||||
kernel_function(kernel_function_),
|
||||
support_vectors(support_vectors_)
|
||||
{}
|
||||
|
||||
distance_function& operator= (
|
||||
const distance_function& d
|
||||
)
|
||||
{
|
||||
if (this != &d)
|
||||
{
|
||||
const_cast<scalar_vector_type&>(alpha) = d.alpha;
|
||||
const_cast<scalar_type&>(b) = d.b;
|
||||
const_cast<K&>(kernel_function) = d.kernel_function;
|
||||
const_cast<sample_vector_type&>(support_vectors) = d.support_vectors;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
scalar_type operator() (
|
||||
const sample_type& x
|
||||
) const
|
||||
{
|
||||
scalar_type temp = 0;
|
||||
for (long i = 0; i < alpha.nr(); ++i)
|
||||
temp += alpha(i) * kernel_function(x,support_vectors(i));
|
||||
|
||||
temp = b + kernel_function(x,x) - 2*temp;
|
||||
if (temp > 0)
|
||||
return std::sqrt(temp);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
scalar_type operator() (
|
||||
const distance_function& x
|
||||
) const
|
||||
{
|
||||
scalar_type temp = 0;
|
||||
for (long i = 0; i < alpha.nr(); ++i)
|
||||
for (long j = 0; j < x.alpha.nr(); ++j)
|
||||
temp += alpha(i)*x.alpha(j) * kernel_function(support_vectors(i), x.support_vectors(j));
|
||||
|
||||
temp = b + x.b - 2*temp;
|
||||
if (temp > 0)
|
||||
return std::sqrt(temp);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
void serialize (
|
||||
const distance_function<K>& item,
|
||||
std::ostream& out
|
||||
)
|
||||
{
|
||||
try
|
||||
{
|
||||
serialize(item.alpha, out);
|
||||
serialize(item.b, out);
|
||||
serialize(item.kernel_function, out);
|
||||
serialize(item.support_vectors, out);
|
||||
}
|
||||
catch (serialization_error e)
|
||||
{
|
||||
throw serialization_error(e.info + "\n while serializing object of type distance_function");
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
void deserialize (
|
||||
distance_function<K>& item,
|
||||
std::istream& in
|
||||
)
|
||||
{
|
||||
typedef typename K::scalar_type scalar_type;
|
||||
typedef typename K::sample_type sample_type;
|
||||
typedef typename K::mem_manager_type mem_manager_type;
|
||||
|
||||
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
|
||||
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
|
||||
try
|
||||
{
|
||||
deserialize(const_cast<scalar_vector_type&>(item.alpha), in);
|
||||
deserialize(const_cast<scalar_type&>(item.b), in);
|
||||
deserialize(const_cast<K&>(item.kernel_function), in);
|
||||
deserialize(const_cast<sample_vector_type&>(item.support_vectors), in);
|
||||
}
|
||||
catch (serialization_error e)
|
||||
{
|
||||
throw serialization_error(e.info + "\n while deserializing object of type distance_function");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ namespace dlib
|
||||
for (long i = 0; i < alpha.nr(); ++i)
|
||||
temp += alpha(i) * kernel_function(x,support_vectors(i));
|
||||
|
||||
returns temp - b;
|
||||
return temp - b;
|
||||
}
|
||||
};
|
||||
|
||||
@ -225,6 +225,141 @@ namespace dlib
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
struct distance_function
|
||||
{
|
||||
/*!
|
||||
REQUIREMENTS ON K
|
||||
K must be a kernel function object type as defined at the
|
||||
top of dlib/svm/kernel_abstract.h
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object represents a point in kernel induced feature space.
|
||||
You may use this object to find the distance from the point it
|
||||
represents to points in input space.
|
||||
!*/
|
||||
|
||||
typedef typename K::scalar_type scalar_type;
|
||||
typedef typename K::sample_type sample_type;
|
||||
typedef typename K::mem_manager_type mem_manager_type;
|
||||
|
||||
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
|
||||
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
|
||||
|
||||
const scalar_vector_type alpha;
|
||||
const scalar_type b;
|
||||
const K kernel_function;
|
||||
const sample_vector_type support_vectors;
|
||||
|
||||
distance_function (
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #b == 0
|
||||
- #alpha.nr() == 0
|
||||
- #support_vectors.nr() == 0
|
||||
!*/
|
||||
|
||||
distance_function (
|
||||
const distance_function& f
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #*this is a copy of f
|
||||
!*/
|
||||
|
||||
distance_function (
|
||||
const scalar_vector_type& alpha_,
|
||||
const scalar_type& b_,
|
||||
const K& kernel_function_,
|
||||
const sample_vector_type& support_vectors_
|
||||
) : alpha(alpha_), b(b_), kernel_function(kernel_function_), support_vectors(support_vectors_) {}
|
||||
/*!
|
||||
ensures
|
||||
- populates the decision function with the given support vectors, weights(i.e. alphas),
|
||||
b term, and kernel function.
|
||||
!*/
|
||||
|
||||
distance_function& operator= (
|
||||
const distance_function& d
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #*this is identical to d
|
||||
- returns *this
|
||||
!*/
|
||||
|
||||
scalar_type operator() (
|
||||
const sample_type& x
|
||||
) const
|
||||
/*!
|
||||
ensures
|
||||
- Let O(x) represent the point x projected into kernel induced feature space.
|
||||
- let c == sum alpha(i)*O(support_vectors(i)) == the point in kernel space that
|
||||
this object represents.
|
||||
- Then this object returns the distance between the points O(x) and c in kernel
|
||||
space.
|
||||
!*/
|
||||
{
|
||||
scalar_type temp = 0;
|
||||
for (long i = 0; i < alpha.nr(); ++i)
|
||||
temp += alpha(i) * kernel_function(x,support_vectors(i));
|
||||
|
||||
temp = b + kernel_function(x,x) - 2*temp;
|
||||
if (temp > 0)
|
||||
return std::sqrt(temp);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
scalar_type operator() (
|
||||
const distance_function& x
|
||||
) const
|
||||
/*!
|
||||
ensures
|
||||
- returns the distance between the point in kernel space represented by *this and x.
|
||||
!*/
|
||||
{
|
||||
scalar_type temp = 0;
|
||||
for (long i = 0; i < alpha.nr(); ++i)
|
||||
for (long j = 0; j < x.alpha.nr(); ++j)
|
||||
temp += alpha(i)*x.alpha(j) * kernel_function(support_vectors(i), x.support_vectors(j));
|
||||
|
||||
temp = b + x.b - 2*temp;
|
||||
if (temp > 0)
|
||||
return std::sqrt(temp);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
void serialize (
|
||||
const distance_function<K>& item,
|
||||
std::ostream& out
|
||||
);
|
||||
/*!
|
||||
provides serialization support for distance_function
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename K
|
||||
>
|
||||
void deserialize (
|
||||
distance_function<K>& item,
|
||||
std::istream& in
|
||||
);
|
||||
/*!
|
||||
provides serialization support for distance_function
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_SVm_FUNCTION_ABSTRACT_
|
||||
|
@ -215,6 +215,16 @@ namespace dlib
|
||||
item.bias_is_stale = true;
|
||||
}
|
||||
|
||||
distance_function<kernel_type> get_distance_function (
|
||||
) const
|
||||
{
|
||||
refresh_bias();
|
||||
return distance_function<kernel_type>(vector_to_matrix(alpha),
|
||||
bias,
|
||||
kernel,
|
||||
vector_to_matrix(dictionary));
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void refresh_bias (
|
||||
|
@ -193,6 +193,16 @@ namespace dlib
|
||||
- returns the number of "support vectors" in the dictionary.
|
||||
!*/
|
||||
|
||||
distance_function<kernel_type> get_distance_function (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns a distance function F that represents the point learned
|
||||
by this object so far. I.e. it is the case that:
|
||||
- for all x: F(x) == (*this)(x)
|
||||
!*/
|
||||
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue
Block a user