From 633a9cdc9a7b35407428a78c173355b8627fe45d Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 26 Jan 2013 00:44:58 -0500 Subject: [PATCH] Modified structural_svm_problem_threaded to reduce the amount of copying overhead. --- dlib/svm/structural_svm_problem_threaded.h | 61 ++++++++++++++++++---- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/dlib/svm/structural_svm_problem_threaded.h b/dlib/svm/structural_svm_problem_threaded.h index cd3ce7808..cc2ecc1d9 100644 --- a/dlib/svm/structural_svm_problem_threaded.h +++ b/dlib/svm/structural_svm_problem_threaded.h @@ -11,6 +11,8 @@ #include "sparse_vector.h" #include #include "../threads.h" +#include "../misc_api.h" +#include "../statistics.h" namespace dlib { @@ -32,7 +34,8 @@ namespace dlib explicit structural_svm_problem_threaded ( unsigned long num_threads ) : - tp(num_threads) + tp(num_threads), + num_iterations_executed(0) {} unsigned long get_num_threads ( @@ -46,25 +49,35 @@ namespace dlib const structural_svm_problem_threaded& self_, matrix_type& w_, matrix_type& subgradient_, - scalar_type& total_loss_ - ) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_) {} + scalar_type& total_loss_, + bool buffer_subgradients_locally_ + ) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_), + buffer_subgradients_locally(buffer_subgradients_locally_){} void call_oracle ( long begin, long end ) { - // If we are only going to call the separation oracle once then - // don't run the slightly more complex for loop version of this code. - if (end-begin <= 1) + // If we are only going to call the separation oracle once then don't run + // the slightly more complex for loop version of this code. Or if we just + // don't want to run the complex buffering one. The code later on decides + // if we should do the buffering based on how long it takes to execute. We + // do this because, when the subgradient is really high dimensional it can + // take a lot of time to add them together. So we might want to avoid + // doing that. + if (end-begin <= 1 || !buffer_subgradients_locally) { scalar_type loss; feature_vector_type ftemp; - self.separation_oracle_cached(begin, w, loss, ftemp); + for (long i = begin; i < end; ++i) + { + self.separation_oracle_cached(i, w, loss, ftemp); - auto_mutex lock(self.accum_mutex); - total_loss += loss; - add_to(subgradient, ftemp); + auto_mutex lock(self.accum_mutex); + total_loss += loss; + add_to(subgradient, ftemp); + } } else { @@ -92,6 +105,7 @@ namespace dlib matrix_type& w; matrix_type& subgradient; scalar_type& total_loss; + bool buffer_subgradients_locally; }; @@ -101,22 +115,47 @@ namespace dlib scalar_type& total_loss ) const { + ++num_iterations_executed; const long num = this->get_num_samples(); // how many samples to process in a single task (aim for 4 jobs per worker) const long num_workers = std::max(1UL, tp.num_threads_in_pool()); const long block_size = std::max(1L, num/(num_workers*4)); - binder b(*this, w, subgradient, total_loss); + const uint64 start_time = ts.get_timestamp(); + + bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean(); + + // every 50 iterations we should try to flip the buffering scheme to see if + // doing it the other way might be better. + if ((num_iterations_executed%50) == 0) + { + buffer_subgradients_locally = !buffer_subgradients_locally; + } + + + binder b(*this, w, subgradient, total_loss, buffer_subgradients_locally); for (long i = 0; i < num; i+=block_size) { tp.add_task(b, &binder::call_oracle, i, std::min(i+block_size, num)); } tp.wait_for_all_tasks(); + + const uint64 stop_time = ts.get_timestamp(); + + if (buffer_subgradients_locally) + with_buffer_time.add(stop_time-start_time); + else + without_buffer_time.add(stop_time-start_time); + } mutable thread_pool tp; mutable mutex accum_mutex; + mutable timestamper ts; + mutable running_stats with_buffer_time; + mutable running_stats without_buffer_time; + mutable unsigned long num_iterations_executed; }; // ----------------------------------------------------------------------------------------