mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Switched this code to use the oca object's ability to force a weight to 1
instead of rolling its own implementation.
This commit is contained in:
parent
277b47ae58
commit
b445ddbd8d
@ -37,15 +37,13 @@ namespace dlib
|
|||||||
const std::vector<ranking_pair<sample_type> >& samples_,
|
const std::vector<ranking_pair<sample_type> >& samples_,
|
||||||
const bool be_verbose_,
|
const bool be_verbose_,
|
||||||
const scalar_type eps_,
|
const scalar_type eps_,
|
||||||
const unsigned long max_iter,
|
const unsigned long max_iter
|
||||||
const bool last_weight_1_
|
|
||||||
) :
|
) :
|
||||||
samples(samples_),
|
samples(samples_),
|
||||||
C(C_),
|
C(C_),
|
||||||
be_verbose(be_verbose_),
|
be_verbose(be_verbose_),
|
||||||
eps(eps_),
|
eps(eps_),
|
||||||
max_iterations(max_iter),
|
max_iterations(max_iter)
|
||||||
last_weight_1(last_weight_1_)
|
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,8 +111,6 @@ namespace dlib
|
|||||||
// rank flips. So a risk of 0.1 would mean that rank flips happen < 10% of the
|
// rank flips. So a risk of 0.1 would mean that rank flips happen < 10% of the
|
||||||
// time.
|
// time.
|
||||||
|
|
||||||
if(last_weight_1)
|
|
||||||
w(w.size()-1) = 1;
|
|
||||||
|
|
||||||
std::vector<double> rel_scores;
|
std::vector<double> rel_scores;
|
||||||
std::vector<double> nonrel_scores;
|
std::vector<double> nonrel_scores;
|
||||||
@ -163,12 +159,6 @@ namespace dlib
|
|||||||
|
|
||||||
risk *= scale;
|
risk *= scale;
|
||||||
subgradient = scale*subgradient;
|
subgradient = scale*subgradient;
|
||||||
|
|
||||||
if(last_weight_1)
|
|
||||||
{
|
|
||||||
w(w.size()-1) = 0;
|
|
||||||
subgradient(w.size()-1) = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -183,7 +173,6 @@ namespace dlib
|
|||||||
const bool be_verbose;
|
const bool be_verbose;
|
||||||
const scalar_type eps;
|
const scalar_type eps;
|
||||||
const unsigned long max_iterations;
|
const unsigned long max_iterations;
|
||||||
const bool last_weight_1;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
@ -198,12 +187,11 @@ namespace dlib
|
|||||||
const std::vector<ranking_pair<sample_type> >& samples,
|
const std::vector<ranking_pair<sample_type> >& samples,
|
||||||
const bool be_verbose,
|
const bool be_verbose,
|
||||||
const scalar_type eps,
|
const scalar_type eps,
|
||||||
const unsigned long max_iterations,
|
const unsigned long max_iterations
|
||||||
const bool last_weight_1
|
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
return oca_problem_ranking_svm<matrix_type, sample_type>(
|
return oca_problem_ranking_svm<matrix_type, sample_type>(
|
||||||
C, samples, be_verbose, eps, max_iterations, last_weight_1);
|
C, samples, be_verbose, eps, max_iterations);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
@ -385,12 +373,17 @@ namespace dlib
|
|||||||
num_nonnegative = num_dims;
|
num_nonnegative = num_dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, last_weight_1),
|
unsigned long force_weight_1_idx = std::numeric_limits<unsigned long>::max();
|
||||||
w,
|
if (last_weight_1)
|
||||||
num_nonnegative);
|
{
|
||||||
|
force_weight_1_idx = num_dims-1;
|
||||||
|
}
|
||||||
|
|
||||||
|
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations),
|
||||||
|
w,
|
||||||
|
num_nonnegative,
|
||||||
|
force_weight_1_idx);
|
||||||
|
|
||||||
if(last_weight_1)
|
|
||||||
w(w.size()-1) = 1;
|
|
||||||
|
|
||||||
// put the solution into a decision function and then return it
|
// put the solution into a decision function and then return it
|
||||||
decision_function<kernel_type> df;
|
decision_function<kernel_type> df;
|
||||||
|
Loading…
Reference in New Issue
Block a user