From 9d055a4e875ee8cca621fc4f6deceee6ff125287 Mon Sep 17 00:00:00 2001 From: Davis King Date: Thu, 14 Mar 2013 20:36:48 -0400 Subject: [PATCH] Added find_k_nearest_neighbors_lsh() and hash_samples() --- .../find_k_nearest_neighbors_lsh.h | 269 ++++++++++++++++++ .../find_k_nearest_neighbors_lsh_abstract.h | 102 +++++++ dlib/graph_utils_threaded.h | 11 + 3 files changed, 382 insertions(+) create mode 100644 dlib/graph_utils/find_k_nearest_neighbors_lsh.h create mode 100644 dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h create mode 100644 dlib/graph_utils_threaded.h diff --git a/dlib/graph_utils/find_k_nearest_neighbors_lsh.h b/dlib/graph_utils/find_k_nearest_neighbors_lsh.h new file mode 100644 index 000000000..4b510593a --- /dev/null +++ b/dlib/graph_utils/find_k_nearest_neighbors_lsh.h @@ -0,0 +1,269 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__ +#define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__ + +#include "find_k_nearest_neighbors_lsh_abstract.h" +#include "../threads.h" +#include "../lsh/hashes.h" +#include +#include +#include "sample_pair.h" +#include "edge_list_graphs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct compare_sample_pair_with_distance + { + inline bool operator() (const sample_pair& a, const sample_pair& b) const + { + return a.distance() < b.distance(); + } + }; + + template < + typename vector_type, + typename hash_function_type + > + class hash_block + { + public: + hash_block( + const vector_type& samples_, + const hash_function_type& hash_funct_, + std::vector& hashes_ + ) : + samples(samples_), + hash_funct(hash_funct_), + hashes(hashes_) + {} + + void operator() (long i) const + { + hashes[i] = hash_funct(samples[i]); + } + + const vector_type& samples; + const hash_function_type& hash_funct; + std::vector& hashes; + }; + + template < + typename vector_type, + typename distance_function_type, + typename hash_function_type, + typename alloc + > + class scan_find_k_nearest_neighbors_lsh + { + public: + scan_find_k_nearest_neighbors_lsh ( + const vector_type& samples_, + const distance_function_type& dist_funct_, + const hash_function_type& hash_funct_, + const unsigned long k_, + std::vector& edges_, + const unsigned long k_oversample_, + const std::vector& hashes_ + ) : + samples(samples_), + dist_funct(dist_funct_), + hash_funct(hash_funct_), + k(k_), + edges(edges_), + k_oversample(k_oversample_), + hashes(hashes_) + { + edges.clear(); + edges.reserve(samples.size()*k/2); + } + + mutex m; + const vector_type& samples; + const distance_function_type& dist_funct; + const hash_function_type& hash_funct; + const unsigned long k; + std::vector& edges; + const unsigned long k_oversample; + const std::vector& hashes; + + void operator() (unsigned long i) const + { + const unsigned long k_hash = k*k_oversample; + + std::priority_queue > best_hashes; + std::priority_queue, dlib::impl::compare_sample_pair_with_distance> best_samples; + unsigned long worst_distance = std::numeric_limits::max(); + // scan over the hashes and find the best matches for hashes[i] + for (unsigned long j = 0; j < hashes.size(); ++j) + { + if (i == j) + continue; + + const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]); + if (dist < worst_distance || best_hashes.size() < k_hash) + { + if (best_hashes.size() >= k_hash) + best_hashes.pop(); + best_hashes.push(std::make_pair(dist, j)); + worst_distance = best_hashes.top().first; + } + } + + // Now figure out which of the best_hashes are actually the k best matches + // according to dist_funct() + while (best_hashes.size() != 0) + { + const unsigned long j = best_hashes.top().second; + best_hashes.pop(); + + const double dist = dist_funct(samples[i], samples[j]); + if (dist < std::numeric_limits::infinity()) + { + if (best_samples.size() >= k) + best_samples.pop(); + best_samples.push(sample_pair(i,j,dist)); + } + } + + // Finally, now put the k best matches according to dist_funct() into edges + auto_mutex lock(m); + while (best_samples.size() != 0) + { + edges.push_back(best_samples.top()); + best_samples.pop(); + } + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename hash_function_type + > + void hash_samples ( + const vector_type& samples, + const hash_function_type& hash_funct, + const unsigned long num_threads, + std::vector& hashes + ) + /*! + requires + - hash_funct() is threadsafe. This means that it must be safe for multiple + threads to invoke the member functions of hash_funct() at the same time. + - vector_type is any container that looks like a std::vector or dlib::array. + - hash_funct must be a function object with an interface compatible with the + objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct + must be capable of hashing the elements in the samples vector. + ensures + - This function hashes all the elements in samples and stores the results in + hashes. It will also use num_threads concurrent threads to do this. You + should set this value equal to the number of processing cores on your + computer for maximum speed. + - #hashes.size() == 0 + - for all valid i: + - #hashes[i] = hash_funct(samples[i]) + (i.e. #hashes[i] will contain the hash of samples[i]) + !*/ + { + hashes.resize(samples.size()); + + typedef impl::hash_block block_type; + parallel_for(num_threads, 0, samples.size(), block_type(samples, hash_funct, hashes)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename hash_function_type, + typename alloc + > + void find_k_nearest_neighbors_lsh ( + const vector_type& samples, + const distance_function_type& dist_funct, + const hash_function_type& hash_funct, + const unsigned long k, + const unsigned long num_threads, + std::vector& edges, + const unsigned long k_oversample = 20 + ) + /*! + requires + - hash_funct and dist_funct are threadsafe. This means that it must be safe + for multiple threads to invoke the member functions of these objects at the + same time. + - k > 0 + - k_oversample > 0 + - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates + to a floating point number + - vector_type is any container that looks like a std::vector or dlib::array. + - hash_funct must be a function object with an interface compatible with the + objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct + must be capable of hashing the elements in the samples vector. + ensures + - This function computes an approximate form of a k nearest neighbors graph of + the elements in samples. In particular, the way it works is that it first + hashes all elements in samples using the provided locality sensitive hash + function hash_funct(). Then it performs an exact k nearest neighbors on the + hashes which can be done very quickly. For each of these neighbors we + compute the true distance using dist_funct() and the k nearest neighbors for + each sample are stored into #edges. + - Note that samples with an infinite distance between them are considered to be + not connected at all. Therefore, we exclude edges with such distances from + being output. + - for all valid i: + - #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()]) + - #edges[i].distance() < std::numeric_limits::infinity() + - contains_duplicate_pairs(#edges) == false + - This function will use num_threads concurrent threads of processing. You + should set this value equal to the number of processing cores on your + computer for maximum speed. + - The hash based k nearest neighbor step is approximate, however, you can + improve the output accuracy by using a larger k value for this first step. + Therefore, this function finds k*k_oversample nearest neighbors during the + first hashing based step. + !*/ + { + // make sure requires clause is not broken + DLIB_ASSERT(k > 0 && k_oversample > 0, + "\t void find_k_nearest_neighbors_lsh()" + << "\n\t Invalid inputs were given to this function." + << "\n\t samples.size(): " << samples.size() + << "\n\t k: " << k + << "\n\t k_oversample: " << k_oversample + ); + + edges.clear(); + + if (samples.size() <= 1) + { + return; + } + + typedef typename hash_function_type::result_type hash_type; + std::vector hashes; + hash_samples(samples, hash_funct, num_threads, hashes); + + typedef impl::scan_find_k_nearest_neighbors_lsh scan_type; + parallel_for(num_threads, 0, hashes.size(), scan_type(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes)); + + remove_duplicate_edges(edges); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__ + + diff --git a/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h b/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h new file mode 100644 index 000000000..df0216a67 --- /dev/null +++ b/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h @@ -0,0 +1,102 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__ +#ifdef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__ + +#include "../lsh/hashes_abstract.h" +#include "sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename hash_function_type + > + void hash_samples ( + const vector_type& samples, + const hash_function_type& hash_funct, + const unsigned long num_threads, + std::vector& hashes + ); + /*! + requires + - hash_funct() is threadsafe. This means that it must be safe for multiple + threads to invoke the member functions of hash_funct() at the same time. + - vector_type is any container that looks like a std::vector or dlib::array. + - hash_funct must be a function object with an interface compatible with the + objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct + must be capable of hashing the elements in the samples vector. + ensures + - This function hashes all the elements in samples and stores the results in + hashes. It will also use num_threads concurrent threads to do this. You + should set this value equal to the number of processing cores on your + computer for maximum speed. + - #hashes.size() == 0 + - for all valid i: + - #hashes[i] = hash_funct(samples[i]) + (i.e. #hashes[i] will contain the hash of samples[i]) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename hash_function_type, + typename alloc + > + void find_k_nearest_neighbors_lsh ( + const vector_type& samples, + const distance_function_type& dist_funct, + const hash_function_type& hash_funct, + const unsigned long k, + const unsigned long num_threads, + std::vector& edges, + const unsigned long k_oversample = 20 + ); + /*! + requires + - hash_funct and dist_funct are threadsafe. This means that it must be safe + for multiple threads to invoke the member functions of these objects at the + same time. + - k > 0 + - k_oversample > 0 + - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates + to a floating point number + - vector_type is any container that looks like a std::vector or dlib::array. + - hash_funct must be a function object with an interface compatible with the + objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct + must be capable of hashing the elements in the samples vector. + ensures + - This function computes an approximate form of a k nearest neighbors graph of + the elements in samples. In particular, the way it works is that it first + hashes all elements in samples using the provided locality sensitive hash + function hash_funct(). Then it performs an exact k nearest neighbors on the + hashes which can be done very quickly. For each of these neighbors we + compute the true distance using dist_funct() and the k nearest neighbors for + each sample are stored into #edges. + - Note that samples with an infinite distance between them are considered to be + not connected at all. Therefore, we exclude edges with such distances from + being output. + - for all valid i: + - #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()]) + - #edges[i].distance() < std::numeric_limits::infinity() + - contains_duplicate_pairs(#edges) == false + - This function will use num_threads concurrent threads of processing. You + should set this value equal to the number of processing cores on your + computer for maximum speed. + - The hash based k nearest neighbor step is approximate, however, you can + improve the output accuracy by using a larger k value for this first step. + Therefore, this function finds k*k_oversample nearest neighbors during the + first hashing based step. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__ + diff --git a/dlib/graph_utils_threaded.h b/dlib/graph_utils_threaded.h new file mode 100644 index 000000000..9381b9d69 --- /dev/null +++ b/dlib/graph_utils_threaded.h @@ -0,0 +1,11 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_UTILs_THREADED_H_ +#define DLIB_GRAPH_UTILs_THREADED_H_ + +#include "graph_utils/find_k_nearest_neighbors_lsh.h" + +#endif // DLIB_GRAPH_UTILs_THREADED_H_ + + +