mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added an option to solve the L2-loss version of the SVM objective function.
This commit is contained in:
parent
1af517bcf0
commit
4203da2b16
@ -47,7 +47,8 @@ namespace dlib
|
||||
verbose(false),
|
||||
have_bias(true),
|
||||
last_weight_1(false),
|
||||
do_shrinking(true)
|
||||
do_shrinking(true),
|
||||
do_svm_l2(false)
|
||||
{
|
||||
}
|
||||
|
||||
@ -61,7 +62,8 @@ namespace dlib
|
||||
verbose(false),
|
||||
have_bias(true),
|
||||
last_weight_1(false),
|
||||
do_shrinking(true)
|
||||
do_shrinking(true),
|
||||
do_svm_l2(false)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(0 < C_,
|
||||
@ -104,6 +106,13 @@ namespace dlib
|
||||
bool enabled
|
||||
) { do_shrinking = enabled; }
|
||||
|
||||
bool solving_svm_l2_problem (
|
||||
) const { return do_svm_l2; }
|
||||
|
||||
void solve_svm_l2_problem (
|
||||
bool enabled
|
||||
) { do_svm_l2 = enabled; }
|
||||
|
||||
void be_verbose (
|
||||
)
|
||||
{
|
||||
@ -497,11 +506,15 @@ namespace dlib
|
||||
std::vector<long>& index = state.index;
|
||||
const long dims = state.dims;
|
||||
|
||||
|
||||
unsigned long active_size = index.size();
|
||||
|
||||
scalar_type PG_max_prev = std::numeric_limits<scalar_type>::infinity();
|
||||
scalar_type PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
|
||||
|
||||
const scalar_type Dii_pos = 1/(2*Cpos);
|
||||
const scalar_type Dii_neg = 1/(2*Cneg);
|
||||
|
||||
// main loop
|
||||
for (unsigned long iter = 0; iter < max_iterations; ++iter)
|
||||
{
|
||||
@ -521,8 +534,16 @@ namespace dlib
|
||||
{
|
||||
const long i = index[ii];
|
||||
|
||||
const scalar_type G = y(i)*dot(w, x(i)) - 1;
|
||||
scalar_type G = y(i)*dot(w, x(i)) - 1;
|
||||
if (do_svm_l2)
|
||||
{
|
||||
if (y(i) > 0)
|
||||
G += Dii_pos*alpha[i];
|
||||
else
|
||||
G += Dii_neg*alpha[i];
|
||||
}
|
||||
const scalar_type C = (y(i) > 0) ? Cpos : Cneg;
|
||||
const scalar_type U = do_svm_l2 ? std::numeric_limits<scalar_type>::infinity() : C;
|
||||
|
||||
scalar_type PG = 0;
|
||||
if (alpha[i] == 0)
|
||||
@ -539,7 +560,7 @@ namespace dlib
|
||||
if (G < 0)
|
||||
PG = G;
|
||||
}
|
||||
else if (alpha[i] == C)
|
||||
else if (alpha[i] == U)
|
||||
{
|
||||
if (G < PG_min_prev)
|
||||
{
|
||||
@ -567,7 +588,7 @@ namespace dlib
|
||||
if (std::abs(PG) > 1e-12)
|
||||
{
|
||||
const scalar_type alpha_old = alpha[i];
|
||||
alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), C);
|
||||
alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), U);
|
||||
const scalar_type delta = (alpha[i]-alpha_old)*y(i);
|
||||
add_to(w, x(i), delta);
|
||||
if (have_bias && !last_weight_1)
|
||||
@ -660,6 +681,7 @@ namespace dlib
|
||||
bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1.
|
||||
bool last_weight_1;
|
||||
bool do_shrinking;
|
||||
bool do_svm_l2;
|
||||
|
||||
}; // end of class svm_c_linear_dcd_trainer
|
||||
|
||||
|
@ -67,6 +67,7 @@ namespace dlib
|
||||
- #forces_last_weight_to_1() == false
|
||||
- #includes_bias() == true
|
||||
- #shrinking_enabled() == true
|
||||
- #solving_svm_l2_problem() == false
|
||||
!*/
|
||||
|
||||
explicit svm_c_linear_dcd_trainer (
|
||||
@ -86,6 +87,7 @@ namespace dlib
|
||||
- #forces_last_weight_to_1() == false
|
||||
- #includes_bias() == true
|
||||
- #shrinking_enabled() == true
|
||||
- #solving_svm_l2_problem() == false
|
||||
!*/
|
||||
|
||||
bool includes_bias (
|
||||
@ -140,6 +142,23 @@ namespace dlib
|
||||
- #shrinking_enabled() == enabled
|
||||
!*/
|
||||
|
||||
bool solving_svm_l2_problem (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns true if this solver will solve the L2 version of the SVM
|
||||
objective function. That is, if solving_svm_l2_problem()==true then this
|
||||
object, rather than using the hinge loss, uses the squared hinge loss.
|
||||
!*/
|
||||
|
||||
void solve_svm_l2_problem (
|
||||
bool enabled
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #solving_svm_l2_problem() == enabled
|
||||
!*/
|
||||
|
||||
void be_verbose (
|
||||
);
|
||||
/*!
|
||||
|
@ -440,6 +440,65 @@ namespace
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_l2_version ()
|
||||
{
|
||||
typedef std::map<unsigned long,double> sample_type;
|
||||
typedef sparse_linear_kernel<sample_type> kernel_type;
|
||||
|
||||
svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
|
||||
linear_trainer.set_c(10);
|
||||
linear_trainer.set_epsilon(1e-5);
|
||||
|
||||
std::vector<sample_type> samples;
|
||||
std::vector<double> labels;
|
||||
|
||||
// make an instance of a sample vector so we can use it below
|
||||
sample_type sample;
|
||||
|
||||
|
||||
// Now let's go into a loop and randomly generate 10000 samples.
|
||||
double label = +1;
|
||||
for (int i = 0; i < 1000; ++i)
|
||||
{
|
||||
// flip this flag
|
||||
label *= -1;
|
||||
|
||||
sample.clear();
|
||||
|
||||
// now make a random sparse sample with at most 10 non-zero elements
|
||||
for (int j = 0; j < 10; ++j)
|
||||
{
|
||||
int idx = std::rand()%100;
|
||||
double value = static_cast<double>(std::rand())/RAND_MAX;
|
||||
|
||||
sample[idx] = label*value;
|
||||
}
|
||||
|
||||
// Also save the samples we are generating so we can let the svm_c_linear_trainer
|
||||
// learn from them below.
|
||||
samples.push_back(sample);
|
||||
labels.push_back(label);
|
||||
}
|
||||
|
||||
decision_function<kernel_type> df = linear_trainer.train(samples, labels);
|
||||
|
||||
sample.clear();
|
||||
sample[4] = 0.3;
|
||||
sample[10] = 0.9;
|
||||
DLIB_TEST(df(sample) > 0);
|
||||
|
||||
sample.clear();
|
||||
sample[83] = -0.3;
|
||||
sample[26] = -0.9;
|
||||
sample[58] = -0.7;
|
||||
DLIB_TEST(df(sample) < 0);
|
||||
|
||||
sample.clear();
|
||||
sample[0] = -0.2;
|
||||
sample[9] = -0.8;
|
||||
DLIB_TEST(df(sample) < 0);
|
||||
}
|
||||
|
||||
class tester_svm_c_linear_dcd : public tester
|
||||
{
|
||||
public:
|
||||
@ -474,6 +533,8 @@ namespace
|
||||
print_spinner();
|
||||
test_sparse_1_sample(-1);
|
||||
print_spinner();
|
||||
|
||||
test_l2_version();
|
||||
}
|
||||
} a;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user