Added an option to solve the L2-loss version of the SVM objective function.

This commit is contained in:
Davis King 2016-02-17 10:37:19 -05:00
parent 1af517bcf0
commit 4203da2b16
3 changed files with 107 additions and 5 deletions

View File

@ -47,7 +47,8 @@ namespace dlib
verbose(false), verbose(false),
have_bias(true), have_bias(true),
last_weight_1(false), last_weight_1(false),
do_shrinking(true) do_shrinking(true),
do_svm_l2(false)
{ {
} }
@ -61,7 +62,8 @@ namespace dlib
verbose(false), verbose(false),
have_bias(true), have_bias(true),
last_weight_1(false), last_weight_1(false),
do_shrinking(true) do_shrinking(true),
do_svm_l2(false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(0 < C_, DLIB_ASSERT(0 < C_,
@ -104,6 +106,13 @@ namespace dlib
bool enabled bool enabled
) { do_shrinking = enabled; } ) { do_shrinking = enabled; }
bool solving_svm_l2_problem (
) const { return do_svm_l2; }
void solve_svm_l2_problem (
bool enabled
) { do_svm_l2 = enabled; }
void be_verbose ( void be_verbose (
) )
{ {
@ -497,11 +506,15 @@ namespace dlib
std::vector<long>& index = state.index; std::vector<long>& index = state.index;
const long dims = state.dims; const long dims = state.dims;
unsigned long active_size = index.size(); unsigned long active_size = index.size();
scalar_type PG_max_prev = std::numeric_limits<scalar_type>::infinity(); scalar_type PG_max_prev = std::numeric_limits<scalar_type>::infinity();
scalar_type PG_min_prev = -std::numeric_limits<scalar_type>::infinity(); scalar_type PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
const scalar_type Dii_pos = 1/(2*Cpos);
const scalar_type Dii_neg = 1/(2*Cneg);
// main loop // main loop
for (unsigned long iter = 0; iter < max_iterations; ++iter) for (unsigned long iter = 0; iter < max_iterations; ++iter)
{ {
@ -521,8 +534,16 @@ namespace dlib
{ {
const long i = index[ii]; const long i = index[ii];
const scalar_type G = y(i)*dot(w, x(i)) - 1; scalar_type G = y(i)*dot(w, x(i)) - 1;
if (do_svm_l2)
{
if (y(i) > 0)
G += Dii_pos*alpha[i];
else
G += Dii_neg*alpha[i];
}
const scalar_type C = (y(i) > 0) ? Cpos : Cneg; const scalar_type C = (y(i) > 0) ? Cpos : Cneg;
const scalar_type U = do_svm_l2 ? std::numeric_limits<scalar_type>::infinity() : C;
scalar_type PG = 0; scalar_type PG = 0;
if (alpha[i] == 0) if (alpha[i] == 0)
@ -539,7 +560,7 @@ namespace dlib
if (G < 0) if (G < 0)
PG = G; PG = G;
} }
else if (alpha[i] == C) else if (alpha[i] == U)
{ {
if (G < PG_min_prev) if (G < PG_min_prev)
{ {
@ -567,7 +588,7 @@ namespace dlib
if (std::abs(PG) > 1e-12) if (std::abs(PG) > 1e-12)
{ {
const scalar_type alpha_old = alpha[i]; const scalar_type alpha_old = alpha[i];
alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), C); alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), U);
const scalar_type delta = (alpha[i]-alpha_old)*y(i); const scalar_type delta = (alpha[i]-alpha_old)*y(i);
add_to(w, x(i), delta); add_to(w, x(i), delta);
if (have_bias && !last_weight_1) if (have_bias && !last_weight_1)
@ -660,6 +681,7 @@ namespace dlib
bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1. bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1.
bool last_weight_1; bool last_weight_1;
bool do_shrinking; bool do_shrinking;
bool do_svm_l2;
}; // end of class svm_c_linear_dcd_trainer }; // end of class svm_c_linear_dcd_trainer

View File

@ -67,6 +67,7 @@ namespace dlib
- #forces_last_weight_to_1() == false - #forces_last_weight_to_1() == false
- #includes_bias() == true - #includes_bias() == true
- #shrinking_enabled() == true - #shrinking_enabled() == true
- #solving_svm_l2_problem() == false
!*/ !*/
explicit svm_c_linear_dcd_trainer ( explicit svm_c_linear_dcd_trainer (
@ -86,6 +87,7 @@ namespace dlib
- #forces_last_weight_to_1() == false - #forces_last_weight_to_1() == false
- #includes_bias() == true - #includes_bias() == true
- #shrinking_enabled() == true - #shrinking_enabled() == true
- #solving_svm_l2_problem() == false
!*/ !*/
bool includes_bias ( bool includes_bias (
@ -140,6 +142,23 @@ namespace dlib
- #shrinking_enabled() == enabled - #shrinking_enabled() == enabled
!*/ !*/
bool solving_svm_l2_problem (
) const;
/*!
ensures
- returns true if this solver will solve the L2 version of the SVM
objective function. That is, if solving_svm_l2_problem()==true then this
object, rather than using the hinge loss, uses the squared hinge loss.
!*/
void solve_svm_l2_problem (
bool enabled
);
/*!
ensures
- #solving_svm_l2_problem() == enabled
!*/
void be_verbose ( void be_verbose (
); );
/*! /*!

View File

@ -440,6 +440,65 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void test_l2_version ()
{
typedef std::map<unsigned long,double> sample_type;
typedef sparse_linear_kernel<sample_type> kernel_type;
svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
linear_trainer.set_c(10);
linear_trainer.set_epsilon(1e-5);
std::vector<sample_type> samples;
std::vector<double> labels;
// make an instance of a sample vector so we can use it below
sample_type sample;
// Now let's go into a loop and randomly generate 10000 samples.
double label = +1;
for (int i = 0; i < 1000; ++i)
{
// flip this flag
label *= -1;
sample.clear();
// now make a random sparse sample with at most 10 non-zero elements
for (int j = 0; j < 10; ++j)
{
int idx = std::rand()%100;
double value = static_cast<double>(std::rand())/RAND_MAX;
sample[idx] = label*value;
}
// Also save the samples we are generating so we can let the svm_c_linear_trainer
// learn from them below.
samples.push_back(sample);
labels.push_back(label);
}
decision_function<kernel_type> df = linear_trainer.train(samples, labels);
sample.clear();
sample[4] = 0.3;
sample[10] = 0.9;
DLIB_TEST(df(sample) > 0);
sample.clear();
sample[83] = -0.3;
sample[26] = -0.9;
sample[58] = -0.7;
DLIB_TEST(df(sample) < 0);
sample.clear();
sample[0] = -0.2;
sample[9] = -0.8;
DLIB_TEST(df(sample) < 0);
}
class tester_svm_c_linear_dcd : public tester class tester_svm_c_linear_dcd : public tester
{ {
public: public:
@ -474,6 +533,8 @@ namespace
print_spinner(); print_spinner();
test_sparse_1_sample(-1); test_sparse_1_sample(-1);
print_spinner(); print_spinner();
test_l2_version();
} }
} a; } a;