Fixed a bug which triggered when the last weight was forced to 1.

This commit is contained in:
Davis King 2012-12-17 20:33:34 -05:00
parent 034eb0acc0
commit f9805be856

View File

@ -301,7 +301,7 @@ namespace dlib
for (long i = new_idx; i < x.size(); ++i)
{
Q.push_back(dlib::dot(x(i),x(i)));
Q.push_back(length_squared(x(i)));
if (have_bias)
{
@ -318,6 +318,44 @@ namespace dlib
w(dims-1) = 1;
}
template <typename T>
typename enable_if<is_matrix<T>,scalar_type>::type length_squared (const T& x) const
{
if (!last_weight_1)
{
return dlib::dot(x,x);
}
else
{
// skip the last dimension
return dlib::dot(colm(x,0,x.size()-1),
colm(x,0,x.size()-1));
}
}
template <typename T>
typename disable_if<is_matrix<T>,scalar_type>::type length_squared (const T& x) const
{
if (!last_weight_1)
{
return dlib::dot(x,x);
}
else
{
scalar_type temp = 0;
typename T::const_iterator i;
for (i = x.begin(); i != x.end(); ++i)
{
// skip the last dimension
if (static_cast<long>(i->first) < dims-1)
temp += i->second*i->second;
}
return temp;
}
}
bool did_init;
bool have_bias;
bool last_weight_1;