diff --git a/dlib/svm/kernel.h b/dlib/svm/kernel.h index 9a13b003b..66f0af5cc 100644 --- a/dlib/svm/kernel.h +++ b/dlib/svm/kernel.h @@ -400,6 +400,52 @@ namespace dlib const linear_kernel& k; }; +// ---------------------------------------------------------------------------------------- + + template + struct histogram_intersection_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + scalar_type temp = 0; + for (long i = 0; i < a.size(); ++i) + { + temp += std::min(a(i), b(i)); + } + return temp; + } + + bool operator== ( + const histogram_intersection_kernel& + ) const + { + return true; + } + }; + + template < + typename T + > + void serialize ( + const histogram_intersection_kernel& , + std::ostream& + ){} + + template < + typename T + > + void deserialize ( + histogram_intersection_kernel& , + std::istream& + ){} + // ---------------------------------------------------------------------------------------- template diff --git a/dlib/svm/kernel_abstract.h b/dlib/svm/kernel_abstract.h index 45866de19..96a3ea351 100644 --- a/dlib/svm/kernel_abstract.h +++ b/dlib/svm/kernel_abstract.h @@ -422,6 +422,71 @@ namespace dlib provides deserialization support for linear_kernel !*/ +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct histogram_intersection_kernel + { + /*! + REQUIREMENTS ON T + T must be a dlib::matrix object + + WHAT THIS OBJECT REPRESENTS + This object represents a histogram intersection kernel kernel + !*/ + + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - is_vector(a) + - is_vector(b) + - a.size() == b.size() + - min(a) >= 0 + - min(b) >= 0 + ensures + - returns sum over all i: std::min(a(i), b(i)) + !*/ + + bool operator== ( + const histogram_intersection_kernel& k + ) const; + /*! + ensures + - returns true + !*/ + }; + + template < + typename T + > + void serialize ( + const histogram_intersection_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for histogram_intersection_kernel + !*/ + + template < + typename T + > + void deserialize ( + histogram_intersection_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for histogram_intersection_kernel + !*/ + // ---------------------------------------------------------------------------------------- template < diff --git a/dlib/svm/sparse_kernel.h b/dlib/svm/sparse_kernel.h index d3b4166d7..0df4f5f0d 100644 --- a/dlib/svm/sparse_kernel.h +++ b/dlib/svm/sparse_kernel.h @@ -311,6 +311,69 @@ namespace dlib std::istream& in ){} +// ---------------------------------------------------------------------------------------- + + template + struct sparse_histogram_intersection_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + typename sample_type::const_iterator ai = a.begin(); + typename sample_type::const_iterator bi = b.begin(); + + scalar_type sum = 0; + while (ai != a.end() && bi != b.end()) + { + if (ai->first == bi->first) + { + sum += std::min(ai->second , bi->second); + ++ai; + ++bi; + } + else if (ai->first < bi->first) + { + ++ai; + } + else + { + ++bi; + } + } + + return sum; + } + + bool operator== ( + const sparse_histogram_intersection_kernel& + ) const + { + return true; + } + }; + + template < + typename T + > + void serialize ( + const sparse_histogram_intersection_kernel& item, + std::ostream& out + ){} + + template < + typename T + > + void deserialize ( + sparse_histogram_intersection_kernel& item, + std::istream& in + ){} + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/svm/sparse_kernel_abstract.h b/dlib/svm/sparse_kernel_abstract.h index b1cdfbf00..454cd81ad 100644 --- a/dlib/svm/sparse_kernel_abstract.h +++ b/dlib/svm/sparse_kernel_abstract.h @@ -64,8 +64,8 @@ namespace dlib ) const; /*! requires - - a contains a sorted range - - b contains a sorted range + - a is a sparse vector + - b is a sparse vector ensures - returns exp(-gamma * sparse_vector::distance_squared(a,b)) !*/ @@ -170,8 +170,8 @@ namespace dlib ) const; /*! requires - - a contains a sorted range - - b contains a sorted range + - a is a sparse vector + - b is a sparse vector ensures - returns tanh(gamma * sparse_vector::dot(a,b) + coef) !*/ @@ -282,8 +282,8 @@ namespace dlib ) const; /*! requires - - a contains a sorted range - - b contains a sorted range + - a is a sparse vector + - b is a sparse vector ensures - returns pow(gamma * sparse_vector::dot(a,b) + coef, degree) !*/ @@ -359,8 +359,8 @@ namespace dlib ) const; /*! requires - - a contains a sorted range - - b contains a sorted range + - a is a sparse vector + - b is a sparse vector ensures - returns sparse_vector::dot(a,b) !*/ @@ -396,6 +396,72 @@ namespace dlib provides deserialization support for sparse_linear_kernel !*/ +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_histogram_intersection_kernel + { + /*! + REQUIREMENTS ON T + Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a histogram intersection kernel + that works with sparse vectors. + !*/ + + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a is a sparse vector + - b is a sparse vector + - all the values in a and b are >= 0 + ensures + - Let A(i) denote the value of the ith dimension of the a vector. + - Let B(i) denote the value of the ith dimension of the b vector. + - returns sum over all i: std::min(A(i), B(i)) + !*/ + + bool operator== ( + const sparse_histogram_intersection_kernel& k + ) const; + /*! + ensures + - returns true + !*/ + }; + + template < + typename T + > + void serialize ( + const sparse_histogram_intersection_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sparse_histogram_intersection_kernel + !*/ + + template < + typename T + > + void deserialize ( + sparse_histogram_intersection_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sparse_histogram_intersection_kernel + !*/ + // ---------------------------------------------------------------------------------------- }