mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
worked on the rvm regression
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402457
This commit is contained in:
parent
69d481ca34
commit
4360b88512
@ -116,13 +116,14 @@ namespace dlib
|
||||
{
|
||||
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
|
||||
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
|
||||
"\tdecision_function rvm_trainer::train(x,y)"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t x.nr(): " << x.nr()
|
||||
<< "\n\t y.nr(): " << y.nr()
|
||||
<< "\n\t x.nc(): " << x.nc()
|
||||
<< "\n\t y.nc(): " << y.nc()
|
||||
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
|
||||
);
|
||||
|
||||
// make a target vector where +1 examples have value 1 and -1 examples
|
||||
@ -638,10 +639,10 @@ namespace dlib
|
||||
>
|
||||
const decision_function<kernel_type> train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y
|
||||
const in_scalar_vector_type& t
|
||||
) const
|
||||
{
|
||||
return do_train(vector_to_matrix(x), vector_to_matrix(y));
|
||||
return do_train(vector_to_matrix(x), vector_to_matrix(t));
|
||||
}
|
||||
|
||||
void swap (
|
||||
@ -665,30 +666,20 @@ namespace dlib
|
||||
>
|
||||
const decision_function<kernel_type> do_train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y
|
||||
const in_scalar_vector_type& t
|
||||
) const
|
||||
{
|
||||
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
|
||||
"\tdecision_function rvm_regression_trainer::train(x,y)"
|
||||
DLIB_ASSERT(x.nr() > 1 && x.nr() == t.nr() && x.nc() == 1 && t.nc() == 1,
|
||||
"\tdecision_function rvm_regression_trainer::train(x,t)"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t x.nr(): " << x.nr()
|
||||
<< "\n\t y.nr(): " << y.nr()
|
||||
<< "\n\t t.nr(): " << t.nr()
|
||||
<< "\n\t x.nc(): " << x.nc()
|
||||
<< "\n\t y.nc(): " << y.nc()
|
||||
<< "\n\t t.nc(): " << t.nc()
|
||||
);
|
||||
|
||||
// make a target vector where +1 examples have value 1 and -1 examples
|
||||
// have a value of 0.
|
||||
scalar_vector_type t(y.size());
|
||||
for (long i = 0; i < y.size(); ++i)
|
||||
{
|
||||
if (y(i) == 1)
|
||||
t(i) = 1;
|
||||
else
|
||||
t(i) = 0;
|
||||
}
|
||||
|
||||
/*! This is the convention for the active_bases variable in the function:
|
||||
- if (active_bases(i) >= 0) then
|
||||
@ -725,9 +716,6 @@ namespace dlib
|
||||
matrix<scalar_type,1,0,mem_manager_type> tempv2, tempv3;
|
||||
scalar_matrix_type tempm;
|
||||
|
||||
scalar_vector_type t_estimate;
|
||||
scalar_vector_type beta;
|
||||
|
||||
|
||||
Q.set_size(x.nr());
|
||||
S.set_size(x.nr());
|
||||
@ -740,8 +728,6 @@ namespace dlib
|
||||
while (true)
|
||||
{
|
||||
// Compute optimal weights and sigma for current alpha using equation 6.
|
||||
|
||||
// compute the updated sigma matrix
|
||||
sigma = trans(phi)*phi/var;
|
||||
for (long r = 0; r < alpha.nr(); ++r)
|
||||
sigma(r,r) += alpha(r);
|
||||
@ -805,6 +791,8 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
|
||||
// recompute the variance
|
||||
var = length_squared(t - phi*weights)/(x.nr() - weights.size() + trans(alpha)*diag(sigma));
|
||||
|
||||
// next we update the selected alpha.
|
||||
|
||||
@ -873,8 +861,6 @@ namespace dlib
|
||||
}
|
||||
|
||||
|
||||
// recompute the variance
|
||||
var = length_squared(t - phi*weights)/(x.nr() - weights.size() + trans(alpha)*diag(sigma));
|
||||
|
||||
} // end while(true). So we have converged on the final answer.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user