Added the option to learn non-negative weights to the svm_multiclass_linear_trainer.

This commit is contained in:
Davis King 2013-11-17 16:25:34 -05:00
parent 8e6b5a40c6
commit 83217d764a
2 changed files with 41 additions and 2 deletions

View File

@ -177,7 +177,8 @@ namespace dlib
num_threads(4),
C(1),
eps(0.001),
verbose(false)
verbose(false),
learn_nonnegative_weights(false)
{
}
@ -243,6 +244,16 @@ namespace dlib
return kernel_type();
}
bool learns_nonnegative_weights (
) const { return learn_nonnegative_weights; }
void set_learns_nonnegative_weights (
bool value
)
{
learn_nonnegative_weights = value;
}
void set_c (
scalar_type C_
)
@ -297,7 +308,13 @@ namespace dlib
problem.set_c(C);
problem.set_epsilon(eps);
svm_objective = solver(problem, weights);
unsigned long num_nonnegative = 0;
if (learn_nonnegative_weights)
{
num_nonnegative = problem.get_num_dimensions();
}
svm_objective = solver(problem, weights, num_nonnegative);
trained_function_type df;
@ -315,6 +332,7 @@ namespace dlib
scalar_type eps;
bool verbose;
oca solver;
bool learn_nonnegative_weights;
};
// ----------------------------------------------------------------------------------------

View File

@ -32,6 +32,7 @@ namespace dlib
INITIAL VALUE
- get_num_threads() == 4
- learns_nonnegative_weights() == false
- get_epsilon() == 0.001
- get_c() == 1
- this object will not be verbose unless be_verbose() is called
@ -155,6 +156,26 @@ namespace dlib
generalization.
!*/
bool learns_nonnegative_weights (
) const;
/*!
ensures
- The output of training is a set of weights and bias values that together
define the behavior of a multiclass_linear_decision_function object. If
learns_nonnegative_weights() == true then the resulting weights and bias
values will always have non-negative values. That is, if this function
returns true then all the numbers in the multiclass_linear_decision_function
objects output by train() will be non-negative.
!*/
void set_learns_nonnegative_weights (
bool value
);
/*!
ensures
- #learns_nonnegative_weights() == value
!*/
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels