Modified structural_svm_problem_threaded to reduce the amount of copying

overhead.
This commit is contained in:
Davis King 2013-01-26 00:44:58 -05:00
parent 51666563e3
commit 633a9cdc9a

View File

@ -11,6 +11,8 @@
#include "sparse_vector.h" #include "sparse_vector.h"
#include <iostream> #include <iostream>
#include "../threads.h" #include "../threads.h"
#include "../misc_api.h"
#include "../statistics.h"
namespace dlib namespace dlib
{ {
@ -32,7 +34,8 @@ namespace dlib
explicit structural_svm_problem_threaded ( explicit structural_svm_problem_threaded (
unsigned long num_threads unsigned long num_threads
) : ) :
tp(num_threads) tp(num_threads),
num_iterations_executed(0)
{} {}
unsigned long get_num_threads ( unsigned long get_num_threads (
@ -46,25 +49,35 @@ namespace dlib
const structural_svm_problem_threaded& self_, const structural_svm_problem_threaded& self_,
matrix_type& w_, matrix_type& w_,
matrix_type& subgradient_, matrix_type& subgradient_,
scalar_type& total_loss_ scalar_type& total_loss_,
) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_) {} bool buffer_subgradients_locally_
) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_),
buffer_subgradients_locally(buffer_subgradients_locally_){}
void call_oracle ( void call_oracle (
long begin, long begin,
long end long end
) )
{ {
// If we are only going to call the separation oracle once then // If we are only going to call the separation oracle once then don't run
// don't run the slightly more complex for loop version of this code. // the slightly more complex for loop version of this code. Or if we just
if (end-begin <= 1) // don't want to run the complex buffering one. The code later on decides
// if we should do the buffering based on how long it takes to execute. We
// do this because, when the subgradient is really high dimensional it can
// take a lot of time to add them together. So we might want to avoid
// doing that.
if (end-begin <= 1 || !buffer_subgradients_locally)
{ {
scalar_type loss; scalar_type loss;
feature_vector_type ftemp; feature_vector_type ftemp;
self.separation_oracle_cached(begin, w, loss, ftemp); for (long i = begin; i < end; ++i)
{
self.separation_oracle_cached(i, w, loss, ftemp);
auto_mutex lock(self.accum_mutex); auto_mutex lock(self.accum_mutex);
total_loss += loss; total_loss += loss;
add_to(subgradient, ftemp); add_to(subgradient, ftemp);
}
} }
else else
{ {
@ -92,6 +105,7 @@ namespace dlib
matrix_type& w; matrix_type& w;
matrix_type& subgradient; matrix_type& subgradient;
scalar_type& total_loss; scalar_type& total_loss;
bool buffer_subgradients_locally;
}; };
@ -101,22 +115,47 @@ namespace dlib
scalar_type& total_loss scalar_type& total_loss
) const ) const
{ {
++num_iterations_executed;
const long num = this->get_num_samples(); const long num = this->get_num_samples();
// how many samples to process in a single task (aim for 4 jobs per worker) // how many samples to process in a single task (aim for 4 jobs per worker)
const long num_workers = std::max(1UL, tp.num_threads_in_pool()); const long num_workers = std::max(1UL, tp.num_threads_in_pool());
const long block_size = std::max(1L, num/(num_workers*4)); const long block_size = std::max(1L, num/(num_workers*4));
binder b(*this, w, subgradient, total_loss); const uint64 start_time = ts.get_timestamp();
bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();
// every 50 iterations we should try to flip the buffering scheme to see if
// doing it the other way might be better.
if ((num_iterations_executed%50) == 0)
{
buffer_subgradients_locally = !buffer_subgradients_locally;
}
binder b(*this, w, subgradient, total_loss, buffer_subgradients_locally);
for (long i = 0; i < num; i+=block_size) for (long i = 0; i < num; i+=block_size)
{ {
tp.add_task(b, &binder::call_oracle, i, std::min(i+block_size, num)); tp.add_task(b, &binder::call_oracle, i, std::min(i+block_size, num));
} }
tp.wait_for_all_tasks(); tp.wait_for_all_tasks();
const uint64 stop_time = ts.get_timestamp();
if (buffer_subgradients_locally)
with_buffer_time.add(stop_time-start_time);
else
without_buffer_time.add(stop_time-start_time);
} }
mutable thread_pool tp; mutable thread_pool tp;
mutable mutex accum_mutex; mutable mutex accum_mutex;
mutable timestamper ts;
mutable running_stats<double> with_buffer_time;
mutable running_stats<double> without_buffer_time;
mutable unsigned long num_iterations_executed;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------