Made rls run a bit faster, especially if the new mode that allows the

regularization to decay away is activated.
This commit is contained in:
Davis King 2016-11-08 14:40:19 -05:00
parent c8c1abb733
commit 28d76d011f
2 changed files with 59 additions and 11 deletions

View File

@ -20,7 +20,8 @@ namespace dlib
explicit rls(
double forget_factor_,
double C_ = 1000
double C_ = 1000,
bool apply_forget_factor_to_C_ = false
)
{
// make sure requires clause is not broken
@ -36,6 +37,7 @@ namespace dlib
C = C_;
forget_factor = forget_factor_;
apply_forget_factor_to_C = apply_forget_factor_to_C_;
}
rls(
@ -43,6 +45,7 @@ namespace dlib
{
C = 1000;
forget_factor = 1;
apply_forget_factor_to_C = false;
}
double get_c(
@ -57,6 +60,12 @@ namespace dlib
return forget_factor;
}
bool should_apply_forget_factor_to_C (
) const
{
return apply_forget_factor_to_C;
}
template <typename EXP>
void train (
const matrix_exp<EXP>& x,
@ -84,20 +93,25 @@ namespace dlib
// multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor;
const double temp = 1 + l*trans(x)*R*x;
matrix<double,0,1> tmp = R*x;
tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
add_eye_to_inv(R, (1-forget_factor)/C);
if (forget_factor != 1 && !apply_forget_factor_to_C)
add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm.
R = 0.5*(R + trans(R));
if (cnt%100 == 0)
R = 0.5*(R + trans(R));
++cnt;
w = w + R*x*(y - trans(x)*w);
}
const matrix<double,0,1>& get_w(
) const
{
@ -145,25 +159,37 @@ namespace dlib
friend inline void serialize(const rls& item, std::ostream& out)
{
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.w, out);
serialize(item.R, out);
serialize(item.C, out);
serialize(item.forget_factor, out);
serialize(item.cnt, out);
serialize(item.apply_forget_factor_to_C, out);
}
friend inline void deserialize(rls& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
if (!(1 <= version && version <= 2))
throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
deserialize(item.w, in);
deserialize(item.R, in);
deserialize(item.C, in);
deserialize(item.forget_factor, in);
if (version >= 1)
{
deserialize(item.w, in);
deserialize(item.R, in);
deserialize(item.C, in);
deserialize(item.forget_factor, in);
}
item.cnt = 0;
item.apply_forget_factor_to_C = false;
if (version >= 2)
{
deserialize(item.cnt, in);
deserialize(item.apply_forget_factor_to_C, in);
}
}
private:
@ -189,6 +215,13 @@ namespace dlib
matrix<double> R;
double C;
double forget_factor;
int cnt = 0;
bool apply_forget_factor_to_C;
// This object is here only to avoid reallocation during training. It don't
// logically contribute to the state of this object.
matrix<double,0,1> tmp;
};
// ----------------------------------------------------------------------------------------

View File

@ -37,7 +37,8 @@ namespace dlib
explicit rls(
double forget_factor,
double C = 1000
double C = 1000,
bool apply_forget_factor_to_C = false
);
/*!
requires
@ -47,6 +48,7 @@ namespace dlib
- #get_w().size() == 0
- #get_c() == C
- #get_forget_factor() == forget_factor
- #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
!*/
rls(
@ -56,6 +58,7 @@ namespace dlib
- #get_w().size() == 0
- #get_c() == 1000
- #get_forget_factor() == 1
- #should_apply_forget_factor_to_C() == false
!*/
double get_c(
@ -80,6 +83,18 @@ namespace dlib
zero the faster old examples are forgotten.
!*/
bool should_apply_forget_factor_to_C (
) const;
/*!
ensures
- If this function returns false then it means we are optimizing the
objective function discussed in the WHAT THIS OBJECT REPRESENTS section
above. However, if it returns true then we will allow the forget factor
(get_forget_factor()) to be applied to the C value which causes the
algorithm to slowly increase C and convert into a textbook version of RLS
without regularization. The main reason you might want to do this is
because it can make the algorithm run significantly faster.
!*/
template <typename EXP>
void train (