Added find_k_nearest_neighbors_lsh() and hash_samples()

This commit is contained in:
Davis King 2013-03-14 20:36:48 -04:00
parent 4e96485601
commit 9d055a4e87
3 changed files with 382 additions and 0 deletions

View File

@ -0,0 +1,269 @@
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__
#define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__
#include "find_k_nearest_neighbors_lsh_abstract.h"
#include "../threads.h"
#include "../lsh/hashes.h"
#include <vector>
#include <queue>
#include "sample_pair.h"
#include "edge_list_graphs.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
struct compare_sample_pair_with_distance
{
inline bool operator() (const sample_pair& a, const sample_pair& b) const
{
return a.distance() < b.distance();
}
};
template <
typename vector_type,
typename hash_function_type
>
class hash_block
{
public:
hash_block(
const vector_type& samples_,
const hash_function_type& hash_funct_,
std::vector<typename hash_function_type::result_type>& hashes_
) :
samples(samples_),
hash_funct(hash_funct_),
hashes(hashes_)
{}
void operator() (long i) const
{
hashes[i] = hash_funct(samples[i]);
}
const vector_type& samples;
const hash_function_type& hash_funct;
std::vector<typename hash_function_type::result_type>& hashes;
};
template <
typename vector_type,
typename distance_function_type,
typename hash_function_type,
typename alloc
>
class scan_find_k_nearest_neighbors_lsh
{
public:
scan_find_k_nearest_neighbors_lsh (
const vector_type& samples_,
const distance_function_type& dist_funct_,
const hash_function_type& hash_funct_,
const unsigned long k_,
std::vector<sample_pair, alloc>& edges_,
const unsigned long k_oversample_,
const std::vector<typename hash_function_type::result_type>& hashes_
) :
samples(samples_),
dist_funct(dist_funct_),
hash_funct(hash_funct_),
k(k_),
edges(edges_),
k_oversample(k_oversample_),
hashes(hashes_)
{
edges.clear();
edges.reserve(samples.size()*k/2);
}
mutex m;
const vector_type& samples;
const distance_function_type& dist_funct;
const hash_function_type& hash_funct;
const unsigned long k;
std::vector<sample_pair, alloc>& edges;
const unsigned long k_oversample;
const std::vector<typename hash_function_type::result_type>& hashes;
void operator() (unsigned long i) const
{
const unsigned long k_hash = k*k_oversample;
std::priority_queue<std::pair<unsigned long, unsigned long> > best_hashes;
std::priority_queue<sample_pair, std::vector<sample_pair>, dlib::impl::compare_sample_pair_with_distance> best_samples;
unsigned long worst_distance = std::numeric_limits<unsigned long>::max();
// scan over the hashes and find the best matches for hashes[i]
for (unsigned long j = 0; j < hashes.size(); ++j)
{
if (i == j)
continue;
const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]);
if (dist < worst_distance || best_hashes.size() < k_hash)
{
if (best_hashes.size() >= k_hash)
best_hashes.pop();
best_hashes.push(std::make_pair(dist, j));
worst_distance = best_hashes.top().first;
}
}
// Now figure out which of the best_hashes are actually the k best matches
// according to dist_funct()
while (best_hashes.size() != 0)
{
const unsigned long j = best_hashes.top().second;
best_hashes.pop();
const double dist = dist_funct(samples[i], samples[j]);
if (dist < std::numeric_limits<double>::infinity())
{
if (best_samples.size() >= k)
best_samples.pop();
best_samples.push(sample_pair(i,j,dist));
}
}
// Finally, now put the k best matches according to dist_funct() into edges
auto_mutex lock(m);
while (best_samples.size() != 0)
{
edges.push_back(best_samples.top());
best_samples.pop();
}
}
};
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename hash_function_type
>
void hash_samples (
const vector_type& samples,
const hash_function_type& hash_funct,
const unsigned long num_threads,
std::vector<typename hash_function_type::result_type>& hashes
)
/*!
requires
- hash_funct() is threadsafe. This means that it must be safe for multiple
threads to invoke the member functions of hash_funct() at the same time.
- vector_type is any container that looks like a std::vector or dlib::array.
- hash_funct must be a function object with an interface compatible with the
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
must be capable of hashing the elements in the samples vector.
ensures
- This function hashes all the elements in samples and stores the results in
hashes. It will also use num_threads concurrent threads to do this. You
should set this value equal to the number of processing cores on your
computer for maximum speed.
- #hashes.size() == 0
- for all valid i:
- #hashes[i] = hash_funct(samples[i])
(i.e. #hashes[i] will contain the hash of samples[i])
!*/
{
hashes.resize(samples.size());
typedef impl::hash_block<vector_type,hash_function_type> block_type;
parallel_for(num_threads, 0, samples.size(), block_type(samples, hash_funct, hashes));
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename distance_function_type,
typename hash_function_type,
typename alloc
>
void find_k_nearest_neighbors_lsh (
const vector_type& samples,
const distance_function_type& dist_funct,
const hash_function_type& hash_funct,
const unsigned long k,
const unsigned long num_threads,
std::vector<sample_pair, alloc>& edges,
const unsigned long k_oversample = 20
)
/*!
requires
- hash_funct and dist_funct are threadsafe. This means that it must be safe
for multiple threads to invoke the member functions of these objects at the
same time.
- k > 0
- k_oversample > 0
- dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
to a floating point number
- vector_type is any container that looks like a std::vector or dlib::array.
- hash_funct must be a function object with an interface compatible with the
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
must be capable of hashing the elements in the samples vector.
ensures
- This function computes an approximate form of a k nearest neighbors graph of
the elements in samples. In particular, the way it works is that it first
hashes all elements in samples using the provided locality sensitive hash
function hash_funct(). Then it performs an exact k nearest neighbors on the
hashes which can be done very quickly. For each of these neighbors we
compute the true distance using dist_funct() and the k nearest neighbors for
each sample are stored into #edges.
- Note that samples with an infinite distance between them are considered to be
not connected at all. Therefore, we exclude edges with such distances from
being output.
- for all valid i:
- #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()])
- #edges[i].distance() < std::numeric_limits<double>::infinity()
- contains_duplicate_pairs(#edges) == false
- This function will use num_threads concurrent threads of processing. You
should set this value equal to the number of processing cores on your
computer for maximum speed.
- The hash based k nearest neighbor step is approximate, however, you can
improve the output accuracy by using a larger k value for this first step.
Therefore, this function finds k*k_oversample nearest neighbors during the
first hashing based step.
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT(k > 0 && k_oversample > 0,
"\t void find_k_nearest_neighbors_lsh()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t samples.size(): " << samples.size()
<< "\n\t k: " << k
<< "\n\t k_oversample: " << k_oversample
);
edges.clear();
if (samples.size() <= 1)
{
return;
}
typedef typename hash_function_type::result_type hash_type;
std::vector<hash_type> hashes;
hash_samples(samples, hash_funct, num_threads, hashes);
typedef impl::scan_find_k_nearest_neighbors_lsh<vector_type, distance_function_type,hash_function_type,alloc> scan_type;
parallel_for(num_threads, 0, hashes.size(), scan_type(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes));
remove_duplicate_edges(edges);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__

View File

@ -0,0 +1,102 @@
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__
#ifdef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__
#include "../lsh/hashes_abstract.h"
#include "sample_pair_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename hash_function_type
>
void hash_samples (
const vector_type& samples,
const hash_function_type& hash_funct,
const unsigned long num_threads,
std::vector<typename hash_function_type::result_type>& hashes
);
/*!
requires
- hash_funct() is threadsafe. This means that it must be safe for multiple
threads to invoke the member functions of hash_funct() at the same time.
- vector_type is any container that looks like a std::vector or dlib::array.
- hash_funct must be a function object with an interface compatible with the
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
must be capable of hashing the elements in the samples vector.
ensures
- This function hashes all the elements in samples and stores the results in
hashes. It will also use num_threads concurrent threads to do this. You
should set this value equal to the number of processing cores on your
computer for maximum speed.
- #hashes.size() == 0
- for all valid i:
- #hashes[i] = hash_funct(samples[i])
(i.e. #hashes[i] will contain the hash of samples[i])
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename distance_function_type,
typename hash_function_type,
typename alloc
>
void find_k_nearest_neighbors_lsh (
const vector_type& samples,
const distance_function_type& dist_funct,
const hash_function_type& hash_funct,
const unsigned long k,
const unsigned long num_threads,
std::vector<sample_pair, alloc>& edges,
const unsigned long k_oversample = 20
);
/*!
requires
- hash_funct and dist_funct are threadsafe. This means that it must be safe
for multiple threads to invoke the member functions of these objects at the
same time.
- k > 0
- k_oversample > 0
- dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
to a floating point number
- vector_type is any container that looks like a std::vector or dlib::array.
- hash_funct must be a function object with an interface compatible with the
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
must be capable of hashing the elements in the samples vector.
ensures
- This function computes an approximate form of a k nearest neighbors graph of
the elements in samples. In particular, the way it works is that it first
hashes all elements in samples using the provided locality sensitive hash
function hash_funct(). Then it performs an exact k nearest neighbors on the
hashes which can be done very quickly. For each of these neighbors we
compute the true distance using dist_funct() and the k nearest neighbors for
each sample are stored into #edges.
- Note that samples with an infinite distance between them are considered to be
not connected at all. Therefore, we exclude edges with such distances from
being output.
- for all valid i:
- #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()])
- #edges[i].distance() < std::numeric_limits<double>::infinity()
- contains_duplicate_pairs(#edges) == false
- This function will use num_threads concurrent threads of processing. You
should set this value equal to the number of processing cores on your
computer for maximum speed.
- The hash based k nearest neighbor step is approximate, however, you can
improve the output accuracy by using a larger k value for this first step.
Therefore, this function finds k*k_oversample nearest neighbors during the
first hashing based step.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__

View File

@ -0,0 +1,11 @@
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_GRAPH_UTILs_THREADED_H_
#define DLIB_GRAPH_UTILs_THREADED_H_
#include "graph_utils/find_k_nearest_neighbors_lsh.h"
#endif // DLIB_GRAPH_UTILs_THREADED_H_