worked on the rvm regression

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402457
This commit is contained in:
Davis King 2008-08-05 01:44:55 +00:00
parent 69d481ca34
commit 4360b88512

View File

@ -116,13 +116,14 @@ namespace dlib
{
// make sure requires clause is not broken
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\tdecision_function rvm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
);
// make a target vector where +1 examples have value 1 and -1 examples
@ -638,10 +639,10 @@ namespace dlib
>
const decision_function<kernel_type> train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
const in_scalar_vector_type& t
) const
{
return do_train(vector_to_matrix(x), vector_to_matrix(y));
return do_train(vector_to_matrix(x), vector_to_matrix(t));
}
void swap (
@ -665,30 +666,20 @@ namespace dlib
>
const decision_function<kernel_type> do_train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
const in_scalar_vector_type& t
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
"\tdecision_function rvm_regression_trainer::train(x,y)"
DLIB_ASSERT(x.nr() > 1 && x.nr() == t.nr() && x.nc() == 1 && t.nc() == 1,
"\tdecision_function rvm_regression_trainer::train(x,t)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t t.nr(): " << t.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t t.nc(): " << t.nc()
);
// make a target vector where +1 examples have value 1 and -1 examples
// have a value of 0.
scalar_vector_type t(y.size());
for (long i = 0; i < y.size(); ++i)
{
if (y(i) == 1)
t(i) = 1;
else
t(i) = 0;
}
/*! This is the convention for the active_bases variable in the function:
- if (active_bases(i) >= 0) then
@ -725,9 +716,6 @@ namespace dlib
matrix<scalar_type,1,0,mem_manager_type> tempv2, tempv3;
scalar_matrix_type tempm;
scalar_vector_type t_estimate;
scalar_vector_type beta;
Q.set_size(x.nr());
S.set_size(x.nr());
@ -740,8 +728,6 @@ namespace dlib
while (true)
{
// Compute optimal weights and sigma for current alpha using equation 6.
// compute the updated sigma matrix
sigma = trans(phi)*phi/var;
for (long r = 0; r < alpha.nr(); ++r)
sigma(r,r) += alpha(r);
@ -805,6 +791,8 @@ namespace dlib
}
}
// recompute the variance
var = length_squared(t - phi*weights)/(x.nr() - weights.size() + trans(alpha)*diag(sigma));
// next we update the selected alpha.
@ -873,12 +861,10 @@ namespace dlib
}
// recompute the variance
var = length_squared(t - phi*weights)/(x.nr() - weights.size() + trans(alpha)*diag(sigma));
} // end while(true). So we have converged on the final answer.
// now put everything into a decision_function object and return it
std_vector_c<sample_type> dictionary;
std_vector_c<scalar_type> final_weights;