mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added an option to solve the L2-loss version of the SVM objective function.
This commit is contained in:
parent
1af517bcf0
commit
4203da2b16
@ -47,7 +47,8 @@ namespace dlib
|
|||||||
verbose(false),
|
verbose(false),
|
||||||
have_bias(true),
|
have_bias(true),
|
||||||
last_weight_1(false),
|
last_weight_1(false),
|
||||||
do_shrinking(true)
|
do_shrinking(true),
|
||||||
|
do_svm_l2(false)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,7 +62,8 @@ namespace dlib
|
|||||||
verbose(false),
|
verbose(false),
|
||||||
have_bias(true),
|
have_bias(true),
|
||||||
last_weight_1(false),
|
last_weight_1(false),
|
||||||
do_shrinking(true)
|
do_shrinking(true),
|
||||||
|
do_svm_l2(false)
|
||||||
{
|
{
|
||||||
// make sure requires clause is not broken
|
// make sure requires clause is not broken
|
||||||
DLIB_ASSERT(0 < C_,
|
DLIB_ASSERT(0 < C_,
|
||||||
@ -104,6 +106,13 @@ namespace dlib
|
|||||||
bool enabled
|
bool enabled
|
||||||
) { do_shrinking = enabled; }
|
) { do_shrinking = enabled; }
|
||||||
|
|
||||||
|
bool solving_svm_l2_problem (
|
||||||
|
) const { return do_svm_l2; }
|
||||||
|
|
||||||
|
void solve_svm_l2_problem (
|
||||||
|
bool enabled
|
||||||
|
) { do_svm_l2 = enabled; }
|
||||||
|
|
||||||
void be_verbose (
|
void be_verbose (
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
@ -497,11 +506,15 @@ namespace dlib
|
|||||||
std::vector<long>& index = state.index;
|
std::vector<long>& index = state.index;
|
||||||
const long dims = state.dims;
|
const long dims = state.dims;
|
||||||
|
|
||||||
|
|
||||||
unsigned long active_size = index.size();
|
unsigned long active_size = index.size();
|
||||||
|
|
||||||
scalar_type PG_max_prev = std::numeric_limits<scalar_type>::infinity();
|
scalar_type PG_max_prev = std::numeric_limits<scalar_type>::infinity();
|
||||||
scalar_type PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
|
scalar_type PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
|
||||||
|
|
||||||
|
const scalar_type Dii_pos = 1/(2*Cpos);
|
||||||
|
const scalar_type Dii_neg = 1/(2*Cneg);
|
||||||
|
|
||||||
// main loop
|
// main loop
|
||||||
for (unsigned long iter = 0; iter < max_iterations; ++iter)
|
for (unsigned long iter = 0; iter < max_iterations; ++iter)
|
||||||
{
|
{
|
||||||
@ -521,8 +534,16 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
const long i = index[ii];
|
const long i = index[ii];
|
||||||
|
|
||||||
const scalar_type G = y(i)*dot(w, x(i)) - 1;
|
scalar_type G = y(i)*dot(w, x(i)) - 1;
|
||||||
|
if (do_svm_l2)
|
||||||
|
{
|
||||||
|
if (y(i) > 0)
|
||||||
|
G += Dii_pos*alpha[i];
|
||||||
|
else
|
||||||
|
G += Dii_neg*alpha[i];
|
||||||
|
}
|
||||||
const scalar_type C = (y(i) > 0) ? Cpos : Cneg;
|
const scalar_type C = (y(i) > 0) ? Cpos : Cneg;
|
||||||
|
const scalar_type U = do_svm_l2 ? std::numeric_limits<scalar_type>::infinity() : C;
|
||||||
|
|
||||||
scalar_type PG = 0;
|
scalar_type PG = 0;
|
||||||
if (alpha[i] == 0)
|
if (alpha[i] == 0)
|
||||||
@ -539,7 +560,7 @@ namespace dlib
|
|||||||
if (G < 0)
|
if (G < 0)
|
||||||
PG = G;
|
PG = G;
|
||||||
}
|
}
|
||||||
else if (alpha[i] == C)
|
else if (alpha[i] == U)
|
||||||
{
|
{
|
||||||
if (G < PG_min_prev)
|
if (G < PG_min_prev)
|
||||||
{
|
{
|
||||||
@ -567,7 +588,7 @@ namespace dlib
|
|||||||
if (std::abs(PG) > 1e-12)
|
if (std::abs(PG) > 1e-12)
|
||||||
{
|
{
|
||||||
const scalar_type alpha_old = alpha[i];
|
const scalar_type alpha_old = alpha[i];
|
||||||
alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), C);
|
alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), U);
|
||||||
const scalar_type delta = (alpha[i]-alpha_old)*y(i);
|
const scalar_type delta = (alpha[i]-alpha_old)*y(i);
|
||||||
add_to(w, x(i), delta);
|
add_to(w, x(i), delta);
|
||||||
if (have_bias && !last_weight_1)
|
if (have_bias && !last_weight_1)
|
||||||
@ -660,6 +681,7 @@ namespace dlib
|
|||||||
bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1.
|
bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1.
|
||||||
bool last_weight_1;
|
bool last_weight_1;
|
||||||
bool do_shrinking;
|
bool do_shrinking;
|
||||||
|
bool do_svm_l2;
|
||||||
|
|
||||||
}; // end of class svm_c_linear_dcd_trainer
|
}; // end of class svm_c_linear_dcd_trainer
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ namespace dlib
|
|||||||
- #forces_last_weight_to_1() == false
|
- #forces_last_weight_to_1() == false
|
||||||
- #includes_bias() == true
|
- #includes_bias() == true
|
||||||
- #shrinking_enabled() == true
|
- #shrinking_enabled() == true
|
||||||
|
- #solving_svm_l2_problem() == false
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
explicit svm_c_linear_dcd_trainer (
|
explicit svm_c_linear_dcd_trainer (
|
||||||
@ -86,6 +87,7 @@ namespace dlib
|
|||||||
- #forces_last_weight_to_1() == false
|
- #forces_last_weight_to_1() == false
|
||||||
- #includes_bias() == true
|
- #includes_bias() == true
|
||||||
- #shrinking_enabled() == true
|
- #shrinking_enabled() == true
|
||||||
|
- #solving_svm_l2_problem() == false
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
bool includes_bias (
|
bool includes_bias (
|
||||||
@ -140,6 +142,23 @@ namespace dlib
|
|||||||
- #shrinking_enabled() == enabled
|
- #shrinking_enabled() == enabled
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
bool solving_svm_l2_problem (
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns true if this solver will solve the L2 version of the SVM
|
||||||
|
objective function. That is, if solving_svm_l2_problem()==true then this
|
||||||
|
object, rather than using the hinge loss, uses the squared hinge loss.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void solve_svm_l2_problem (
|
||||||
|
bool enabled
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- #solving_svm_l2_problem() == enabled
|
||||||
|
!*/
|
||||||
|
|
||||||
void be_verbose (
|
void be_verbose (
|
||||||
);
|
);
|
||||||
/*!
|
/*!
|
||||||
|
@ -440,6 +440,65 @@ namespace
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void test_l2_version ()
|
||||||
|
{
|
||||||
|
typedef std::map<unsigned long,double> sample_type;
|
||||||
|
typedef sparse_linear_kernel<sample_type> kernel_type;
|
||||||
|
|
||||||
|
svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
|
||||||
|
linear_trainer.set_c(10);
|
||||||
|
linear_trainer.set_epsilon(1e-5);
|
||||||
|
|
||||||
|
std::vector<sample_type> samples;
|
||||||
|
std::vector<double> labels;
|
||||||
|
|
||||||
|
// make an instance of a sample vector so we can use it below
|
||||||
|
sample_type sample;
|
||||||
|
|
||||||
|
|
||||||
|
// Now let's go into a loop and randomly generate 10000 samples.
|
||||||
|
double label = +1;
|
||||||
|
for (int i = 0; i < 1000; ++i)
|
||||||
|
{
|
||||||
|
// flip this flag
|
||||||
|
label *= -1;
|
||||||
|
|
||||||
|
sample.clear();
|
||||||
|
|
||||||
|
// now make a random sparse sample with at most 10 non-zero elements
|
||||||
|
for (int j = 0; j < 10; ++j)
|
||||||
|
{
|
||||||
|
int idx = std::rand()%100;
|
||||||
|
double value = static_cast<double>(std::rand())/RAND_MAX;
|
||||||
|
|
||||||
|
sample[idx] = label*value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also save the samples we are generating so we can let the svm_c_linear_trainer
|
||||||
|
// learn from them below.
|
||||||
|
samples.push_back(sample);
|
||||||
|
labels.push_back(label);
|
||||||
|
}
|
||||||
|
|
||||||
|
decision_function<kernel_type> df = linear_trainer.train(samples, labels);
|
||||||
|
|
||||||
|
sample.clear();
|
||||||
|
sample[4] = 0.3;
|
||||||
|
sample[10] = 0.9;
|
||||||
|
DLIB_TEST(df(sample) > 0);
|
||||||
|
|
||||||
|
sample.clear();
|
||||||
|
sample[83] = -0.3;
|
||||||
|
sample[26] = -0.9;
|
||||||
|
sample[58] = -0.7;
|
||||||
|
DLIB_TEST(df(sample) < 0);
|
||||||
|
|
||||||
|
sample.clear();
|
||||||
|
sample[0] = -0.2;
|
||||||
|
sample[9] = -0.8;
|
||||||
|
DLIB_TEST(df(sample) < 0);
|
||||||
|
}
|
||||||
|
|
||||||
class tester_svm_c_linear_dcd : public tester
|
class tester_svm_c_linear_dcd : public tester
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -474,6 +533,8 @@ namespace
|
|||||||
print_spinner();
|
print_spinner();
|
||||||
test_sparse_1_sample(-1);
|
test_sparse_1_sample(-1);
|
||||||
print_spinner();
|
print_spinner();
|
||||||
|
|
||||||
|
test_l2_version();
|
||||||
}
|
}
|
||||||
} a;
|
} a;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user