mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added find_k_nearest_neighbors_lsh() and hash_samples()
This commit is contained in:
parent
4e96485601
commit
9d055a4e87
269
dlib/graph_utils/find_k_nearest_neighbors_lsh.h
Normal file
269
dlib/graph_utils/find_k_nearest_neighbors_lsh.h
Normal file
@ -0,0 +1,269 @@
|
||||
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__
|
||||
#define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__
|
||||
|
||||
#include "find_k_nearest_neighbors_lsh_abstract.h"
|
||||
#include "../threads.h"
|
||||
#include "../lsh/hashes.h"
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include "sample_pair.h"
|
||||
#include "edge_list_graphs.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace impl
|
||||
{
|
||||
struct compare_sample_pair_with_distance
|
||||
{
|
||||
inline bool operator() (const sample_pair& a, const sample_pair& b) const
|
||||
{
|
||||
return a.distance() < b.distance();
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename vector_type,
|
||||
typename hash_function_type
|
||||
>
|
||||
class hash_block
|
||||
{
|
||||
public:
|
||||
hash_block(
|
||||
const vector_type& samples_,
|
||||
const hash_function_type& hash_funct_,
|
||||
std::vector<typename hash_function_type::result_type>& hashes_
|
||||
) :
|
||||
samples(samples_),
|
||||
hash_funct(hash_funct_),
|
||||
hashes(hashes_)
|
||||
{}
|
||||
|
||||
void operator() (long i) const
|
||||
{
|
||||
hashes[i] = hash_funct(samples[i]);
|
||||
}
|
||||
|
||||
const vector_type& samples;
|
||||
const hash_function_type& hash_funct;
|
||||
std::vector<typename hash_function_type::result_type>& hashes;
|
||||
};
|
||||
|
||||
template <
|
||||
typename vector_type,
|
||||
typename distance_function_type,
|
||||
typename hash_function_type,
|
||||
typename alloc
|
||||
>
|
||||
class scan_find_k_nearest_neighbors_lsh
|
||||
{
|
||||
public:
|
||||
scan_find_k_nearest_neighbors_lsh (
|
||||
const vector_type& samples_,
|
||||
const distance_function_type& dist_funct_,
|
||||
const hash_function_type& hash_funct_,
|
||||
const unsigned long k_,
|
||||
std::vector<sample_pair, alloc>& edges_,
|
||||
const unsigned long k_oversample_,
|
||||
const std::vector<typename hash_function_type::result_type>& hashes_
|
||||
) :
|
||||
samples(samples_),
|
||||
dist_funct(dist_funct_),
|
||||
hash_funct(hash_funct_),
|
||||
k(k_),
|
||||
edges(edges_),
|
||||
k_oversample(k_oversample_),
|
||||
hashes(hashes_)
|
||||
{
|
||||
edges.clear();
|
||||
edges.reserve(samples.size()*k/2);
|
||||
}
|
||||
|
||||
mutex m;
|
||||
const vector_type& samples;
|
||||
const distance_function_type& dist_funct;
|
||||
const hash_function_type& hash_funct;
|
||||
const unsigned long k;
|
||||
std::vector<sample_pair, alloc>& edges;
|
||||
const unsigned long k_oversample;
|
||||
const std::vector<typename hash_function_type::result_type>& hashes;
|
||||
|
||||
void operator() (unsigned long i) const
|
||||
{
|
||||
const unsigned long k_hash = k*k_oversample;
|
||||
|
||||
std::priority_queue<std::pair<unsigned long, unsigned long> > best_hashes;
|
||||
std::priority_queue<sample_pair, std::vector<sample_pair>, dlib::impl::compare_sample_pair_with_distance> best_samples;
|
||||
unsigned long worst_distance = std::numeric_limits<unsigned long>::max();
|
||||
// scan over the hashes and find the best matches for hashes[i]
|
||||
for (unsigned long j = 0; j < hashes.size(); ++j)
|
||||
{
|
||||
if (i == j)
|
||||
continue;
|
||||
|
||||
const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]);
|
||||
if (dist < worst_distance || best_hashes.size() < k_hash)
|
||||
{
|
||||
if (best_hashes.size() >= k_hash)
|
||||
best_hashes.pop();
|
||||
best_hashes.push(std::make_pair(dist, j));
|
||||
worst_distance = best_hashes.top().first;
|
||||
}
|
||||
}
|
||||
|
||||
// Now figure out which of the best_hashes are actually the k best matches
|
||||
// according to dist_funct()
|
||||
while (best_hashes.size() != 0)
|
||||
{
|
||||
const unsigned long j = best_hashes.top().second;
|
||||
best_hashes.pop();
|
||||
|
||||
const double dist = dist_funct(samples[i], samples[j]);
|
||||
if (dist < std::numeric_limits<double>::infinity())
|
||||
{
|
||||
if (best_samples.size() >= k)
|
||||
best_samples.pop();
|
||||
best_samples.push(sample_pair(i,j,dist));
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, now put the k best matches according to dist_funct() into edges
|
||||
auto_mutex lock(m);
|
||||
while (best_samples.size() != 0)
|
||||
{
|
||||
edges.push_back(best_samples.top());
|
||||
best_samples.pop();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename vector_type,
|
||||
typename hash_function_type
|
||||
>
|
||||
void hash_samples (
|
||||
const vector_type& samples,
|
||||
const hash_function_type& hash_funct,
|
||||
const unsigned long num_threads,
|
||||
std::vector<typename hash_function_type::result_type>& hashes
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
- hash_funct() is threadsafe. This means that it must be safe for multiple
|
||||
threads to invoke the member functions of hash_funct() at the same time.
|
||||
- vector_type is any container that looks like a std::vector or dlib::array.
|
||||
- hash_funct must be a function object with an interface compatible with the
|
||||
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
|
||||
must be capable of hashing the elements in the samples vector.
|
||||
ensures
|
||||
- This function hashes all the elements in samples and stores the results in
|
||||
hashes. It will also use num_threads concurrent threads to do this. You
|
||||
should set this value equal to the number of processing cores on your
|
||||
computer for maximum speed.
|
||||
- #hashes.size() == 0
|
||||
- for all valid i:
|
||||
- #hashes[i] = hash_funct(samples[i])
|
||||
(i.e. #hashes[i] will contain the hash of samples[i])
|
||||
!*/
|
||||
{
|
||||
hashes.resize(samples.size());
|
||||
|
||||
typedef impl::hash_block<vector_type,hash_function_type> block_type;
|
||||
parallel_for(num_threads, 0, samples.size(), block_type(samples, hash_funct, hashes));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename vector_type,
|
||||
typename distance_function_type,
|
||||
typename hash_function_type,
|
||||
typename alloc
|
||||
>
|
||||
void find_k_nearest_neighbors_lsh (
|
||||
const vector_type& samples,
|
||||
const distance_function_type& dist_funct,
|
||||
const hash_function_type& hash_funct,
|
||||
const unsigned long k,
|
||||
const unsigned long num_threads,
|
||||
std::vector<sample_pair, alloc>& edges,
|
||||
const unsigned long k_oversample = 20
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
- hash_funct and dist_funct are threadsafe. This means that it must be safe
|
||||
for multiple threads to invoke the member functions of these objects at the
|
||||
same time.
|
||||
- k > 0
|
||||
- k_oversample > 0
|
||||
- dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
|
||||
to a floating point number
|
||||
- vector_type is any container that looks like a std::vector or dlib::array.
|
||||
- hash_funct must be a function object with an interface compatible with the
|
||||
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
|
||||
must be capable of hashing the elements in the samples vector.
|
||||
ensures
|
||||
- This function computes an approximate form of a k nearest neighbors graph of
|
||||
the elements in samples. In particular, the way it works is that it first
|
||||
hashes all elements in samples using the provided locality sensitive hash
|
||||
function hash_funct(). Then it performs an exact k nearest neighbors on the
|
||||
hashes which can be done very quickly. For each of these neighbors we
|
||||
compute the true distance using dist_funct() and the k nearest neighbors for
|
||||
each sample are stored into #edges.
|
||||
- Note that samples with an infinite distance between them are considered to be
|
||||
not connected at all. Therefore, we exclude edges with such distances from
|
||||
being output.
|
||||
- for all valid i:
|
||||
- #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()])
|
||||
- #edges[i].distance() < std::numeric_limits<double>::infinity()
|
||||
- contains_duplicate_pairs(#edges) == false
|
||||
- This function will use num_threads concurrent threads of processing. You
|
||||
should set this value equal to the number of processing cores on your
|
||||
computer for maximum speed.
|
||||
- The hash based k nearest neighbor step is approximate, however, you can
|
||||
improve the output accuracy by using a larger k value for this first step.
|
||||
Therefore, this function finds k*k_oversample nearest neighbors during the
|
||||
first hashing based step.
|
||||
!*/
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(k > 0 && k_oversample > 0,
|
||||
"\t void find_k_nearest_neighbors_lsh()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t samples.size(): " << samples.size()
|
||||
<< "\n\t k: " << k
|
||||
<< "\n\t k_oversample: " << k_oversample
|
||||
);
|
||||
|
||||
edges.clear();
|
||||
|
||||
if (samples.size() <= 1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
typedef typename hash_function_type::result_type hash_type;
|
||||
std::vector<hash_type> hashes;
|
||||
hash_samples(samples, hash_funct, num_threads, hashes);
|
||||
|
||||
typedef impl::scan_find_k_nearest_neighbors_lsh<vector_type, distance_function_type,hash_function_type,alloc> scan_type;
|
||||
parallel_for(num_threads, 0, hashes.size(), scan_type(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes));
|
||||
|
||||
remove_duplicate_edges(edges);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_H__
|
||||
|
||||
|
102
dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h
Normal file
102
dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h
Normal file
@ -0,0 +1,102 @@
|
||||
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__
|
||||
#ifdef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__
|
||||
|
||||
#include "../lsh/hashes_abstract.h"
|
||||
#include "sample_pair_abstract.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename vector_type,
|
||||
typename hash_function_type
|
||||
>
|
||||
void hash_samples (
|
||||
const vector_type& samples,
|
||||
const hash_function_type& hash_funct,
|
||||
const unsigned long num_threads,
|
||||
std::vector<typename hash_function_type::result_type>& hashes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- hash_funct() is threadsafe. This means that it must be safe for multiple
|
||||
threads to invoke the member functions of hash_funct() at the same time.
|
||||
- vector_type is any container that looks like a std::vector or dlib::array.
|
||||
- hash_funct must be a function object with an interface compatible with the
|
||||
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
|
||||
must be capable of hashing the elements in the samples vector.
|
||||
ensures
|
||||
- This function hashes all the elements in samples and stores the results in
|
||||
hashes. It will also use num_threads concurrent threads to do this. You
|
||||
should set this value equal to the number of processing cores on your
|
||||
computer for maximum speed.
|
||||
- #hashes.size() == 0
|
||||
- for all valid i:
|
||||
- #hashes[i] = hash_funct(samples[i])
|
||||
(i.e. #hashes[i] will contain the hash of samples[i])
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename vector_type,
|
||||
typename distance_function_type,
|
||||
typename hash_function_type,
|
||||
typename alloc
|
||||
>
|
||||
void find_k_nearest_neighbors_lsh (
|
||||
const vector_type& samples,
|
||||
const distance_function_type& dist_funct,
|
||||
const hash_function_type& hash_funct,
|
||||
const unsigned long k,
|
||||
const unsigned long num_threads,
|
||||
std::vector<sample_pair, alloc>& edges,
|
||||
const unsigned long k_oversample = 20
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- hash_funct and dist_funct are threadsafe. This means that it must be safe
|
||||
for multiple threads to invoke the member functions of these objects at the
|
||||
same time.
|
||||
- k > 0
|
||||
- k_oversample > 0
|
||||
- dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
|
||||
to a floating point number
|
||||
- vector_type is any container that looks like a std::vector or dlib::array.
|
||||
- hash_funct must be a function object with an interface compatible with the
|
||||
objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
|
||||
must be capable of hashing the elements in the samples vector.
|
||||
ensures
|
||||
- This function computes an approximate form of a k nearest neighbors graph of
|
||||
the elements in samples. In particular, the way it works is that it first
|
||||
hashes all elements in samples using the provided locality sensitive hash
|
||||
function hash_funct(). Then it performs an exact k nearest neighbors on the
|
||||
hashes which can be done very quickly. For each of these neighbors we
|
||||
compute the true distance using dist_funct() and the k nearest neighbors for
|
||||
each sample are stored into #edges.
|
||||
- Note that samples with an infinite distance between them are considered to be
|
||||
not connected at all. Therefore, we exclude edges with such distances from
|
||||
being output.
|
||||
- for all valid i:
|
||||
- #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()])
|
||||
- #edges[i].distance() < std::numeric_limits<double>::infinity()
|
||||
- contains_duplicate_pairs(#edges) == false
|
||||
- This function will use num_threads concurrent threads of processing. You
|
||||
should set this value equal to the number of processing cores on your
|
||||
computer for maximum speed.
|
||||
- The hash based k nearest neighbor step is approximate, however, you can
|
||||
improve the output accuracy by using a larger k value for this first step.
|
||||
Therefore, this function finds k*k_oversample nearest neighbors during the
|
||||
first hashing based step.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_H__
|
||||
|
11
dlib/graph_utils_threaded.h
Normal file
11
dlib/graph_utils_threaded.h
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_GRAPH_UTILs_THREADED_H_
|
||||
#define DLIB_GRAPH_UTILs_THREADED_H_
|
||||
|
||||
#include "graph_utils/find_k_nearest_neighbors_lsh.h"
|
||||
|
||||
#endif // DLIB_GRAPH_UTILs_THREADED_H_
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user