mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Made the svm_multiclass_linear_trainer threaded. This also means you have to
#include dlib/svm_threaded.h instead of dlib/svm.h to get it now.
This commit is contained in:
parent
ef41b7f61c
commit
f9530dcdf1
@ -45,7 +45,6 @@
|
||||
#include "svm/one_vs_all_trainer.h"
|
||||
|
||||
#include "svm/structural_svm_problem.h"
|
||||
#include "svm/svm_multiclass_linear_trainer.h"
|
||||
#include "svm/sequence_labeler.h"
|
||||
#include "svm/assignment_function.h"
|
||||
#include "svm/active_learning.h"
|
||||
|
@ -4,7 +4,7 @@
|
||||
#define DLIB_SVm_MULTICLASS_LINEAR_TRAINER_H__
|
||||
|
||||
#include "svm_multiclass_linear_trainer_abstract.h"
|
||||
#include "structural_svm_problem.h"
|
||||
#include "structural_svm_problem_threaded.h"
|
||||
#include <vector>
|
||||
#include "../optimization/optimization_oca.h"
|
||||
#include "../matrix.h"
|
||||
@ -21,7 +21,7 @@ namespace dlib
|
||||
typename sample_type,
|
||||
typename label_type
|
||||
>
|
||||
class multiclass_svm_problem : public structural_svm_problem<matrix_type,
|
||||
class multiclass_svm_problem : public structural_svm_problem_threaded<matrix_type,
|
||||
std::vector<std::pair<unsigned long,typename matrix_type::type> > >
|
||||
{
|
||||
/*!
|
||||
@ -45,8 +45,10 @@ namespace dlib
|
||||
|
||||
multiclass_svm_problem (
|
||||
const std::vector<sample_type>& samples_,
|
||||
const std::vector<label_type>& labels_
|
||||
const std::vector<label_type>& labels_,
|
||||
const unsigned long num_threads
|
||||
) :
|
||||
structural_svm_problem_threaded<matrix_type, std::vector<std::pair<unsigned long,typename matrix_type::type> > >(num_threads),
|
||||
samples(samples_),
|
||||
labels(labels_),
|
||||
distinct_labels(select_all_distinct_labels(labels_)),
|
||||
@ -172,12 +174,26 @@ namespace dlib
|
||||
|
||||
svm_multiclass_linear_trainer (
|
||||
) :
|
||||
num_threads(4),
|
||||
C(1),
|
||||
eps(0.001),
|
||||
verbose(false)
|
||||
{
|
||||
}
|
||||
|
||||
void set_num_threads (
|
||||
unsigned long num
|
||||
)
|
||||
{
|
||||
num_threads = num;
|
||||
}
|
||||
|
||||
unsigned long get_num_threads (
|
||||
) const
|
||||
{
|
||||
return num_threads;
|
||||
}
|
||||
|
||||
void set_epsilon (
|
||||
scalar_type eps_
|
||||
)
|
||||
@ -273,7 +289,7 @@ namespace dlib
|
||||
|
||||
typedef matrix<scalar_type,0,1> w_type;
|
||||
w_type weights;
|
||||
multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels);
|
||||
multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels, num_threads);
|
||||
if (verbose)
|
||||
problem.be_verbose();
|
||||
|
||||
@ -293,6 +309,8 @@ namespace dlib
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
unsigned long num_threads;
|
||||
scalar_type C;
|
||||
scalar_type eps;
|
||||
bool verbose;
|
||||
|
@ -31,6 +31,7 @@ namespace dlib
|
||||
using operator<<.
|
||||
|
||||
INITIAL VALUE
|
||||
- get_num_threads() == 4
|
||||
- get_epsilon() == 0.001
|
||||
- get_c() == 1
|
||||
- this object will not be verbose unless be_verbose() is called
|
||||
@ -106,6 +107,23 @@ namespace dlib
|
||||
- returns a copy of the optimizer used to solve the SVM problem.
|
||||
!*/
|
||||
|
||||
void set_num_threads (
|
||||
unsigned long num
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_num_threads() == num
|
||||
!*/
|
||||
|
||||
unsigned long get_num_threads (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the number of threads used during training. You should
|
||||
usually set this equal to the number of processing cores on your
|
||||
machine.
|
||||
!*/
|
||||
|
||||
const kernel_type get_kernel (
|
||||
) const;
|
||||
/*!
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "svm/structural_svm_graph_labeling_problem.h"
|
||||
#include "svm/structural_graph_labeling_trainer.h"
|
||||
#include "svm/cross_validate_graph_labeling_trainer.h"
|
||||
#include "svm/svm_multiclass_linear_trainer.h"
|
||||
|
||||
#endif // DLIB_SVm_THREADED_HEADER
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
|
||||
#include "tester.h"
|
||||
#include <dlib/svm.h>
|
||||
#include <dlib/svm_threaded.h>
|
||||
#include <dlib/data_io.h>
|
||||
#include "create_iris_datafile.h"
|
||||
#include <vector>
|
||||
|
@ -446,7 +446,7 @@ namespace
|
||||
|
||||
typedef matrix<scalar_type,0,1> w_type;
|
||||
w_type weights;
|
||||
multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels);
|
||||
multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels,4);
|
||||
problem.set_max_cache_size(3);
|
||||
|
||||
problem.set_c(C);
|
||||
|
Loading…
Reference in New Issue
Block a user