mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Made rls run a bit faster, especially if the new mode that allows the
regularization to decay away is activated.
This commit is contained in:
parent
c8c1abb733
commit
28d76d011f
@ -20,7 +20,8 @@ namespace dlib
|
||||
|
||||
explicit rls(
|
||||
double forget_factor_,
|
||||
double C_ = 1000
|
||||
double C_ = 1000,
|
||||
bool apply_forget_factor_to_C_ = false
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
@ -36,6 +37,7 @@ namespace dlib
|
||||
|
||||
C = C_;
|
||||
forget_factor = forget_factor_;
|
||||
apply_forget_factor_to_C = apply_forget_factor_to_C_;
|
||||
}
|
||||
|
||||
rls(
|
||||
@ -43,6 +45,7 @@ namespace dlib
|
||||
{
|
||||
C = 1000;
|
||||
forget_factor = 1;
|
||||
apply_forget_factor_to_C = false;
|
||||
}
|
||||
|
||||
double get_c(
|
||||
@ -57,6 +60,12 @@ namespace dlib
|
||||
return forget_factor;
|
||||
}
|
||||
|
||||
bool should_apply_forget_factor_to_C (
|
||||
) const
|
||||
{
|
||||
return apply_forget_factor_to_C;
|
||||
}
|
||||
|
||||
template <typename EXP>
|
||||
void train (
|
||||
const matrix_exp<EXP>& x,
|
||||
@ -84,20 +93,25 @@ namespace dlib
|
||||
// multiply by forget factor and incorporate x*trans(x) into R.
|
||||
const double l = 1.0/forget_factor;
|
||||
const double temp = 1 + l*trans(x)*R*x;
|
||||
matrix<double,0,1> tmp = R*x;
|
||||
tmp = R*x;
|
||||
R = l*R - l*l*(tmp*trans(tmp))/temp;
|
||||
|
||||
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
|
||||
// identity matrix back in to keep the regularization alive.
|
||||
add_eye_to_inv(R, (1-forget_factor)/C);
|
||||
if (forget_factor != 1 && !apply_forget_factor_to_C)
|
||||
add_eye_to_inv(R, (1-forget_factor)/C);
|
||||
|
||||
// R should always be symmetric. This line improves numeric stability of this algorithm.
|
||||
R = 0.5*(R + trans(R));
|
||||
if (cnt%100 == 0)
|
||||
R = 0.5*(R + trans(R));
|
||||
++cnt;
|
||||
|
||||
w = w + R*x*(y - trans(x)*w);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
const matrix<double,0,1>& get_w(
|
||||
) const
|
||||
{
|
||||
@ -145,25 +159,37 @@ namespace dlib
|
||||
|
||||
friend inline void serialize(const rls& item, std::ostream& out)
|
||||
{
|
||||
int version = 1;
|
||||
int version = 2;
|
||||
serialize(version, out);
|
||||
serialize(item.w, out);
|
||||
serialize(item.R, out);
|
||||
serialize(item.C, out);
|
||||
serialize(item.forget_factor, out);
|
||||
serialize(item.cnt, out);
|
||||
serialize(item.apply_forget_factor_to_C, out);
|
||||
}
|
||||
|
||||
friend inline void deserialize(rls& item, std::istream& in)
|
||||
{
|
||||
int version = 0;
|
||||
deserialize(version, in);
|
||||
if (version != 1)
|
||||
if (!(1 <= version && version <= 2))
|
||||
throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
|
||||
|
||||
deserialize(item.w, in);
|
||||
deserialize(item.R, in);
|
||||
deserialize(item.C, in);
|
||||
deserialize(item.forget_factor, in);
|
||||
if (version >= 1)
|
||||
{
|
||||
deserialize(item.w, in);
|
||||
deserialize(item.R, in);
|
||||
deserialize(item.C, in);
|
||||
deserialize(item.forget_factor, in);
|
||||
}
|
||||
item.cnt = 0;
|
||||
item.apply_forget_factor_to_C = false;
|
||||
if (version >= 2)
|
||||
{
|
||||
deserialize(item.cnt, in);
|
||||
deserialize(item.apply_forget_factor_to_C, in);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -189,6 +215,13 @@ namespace dlib
|
||||
matrix<double> R;
|
||||
double C;
|
||||
double forget_factor;
|
||||
int cnt = 0;
|
||||
bool apply_forget_factor_to_C;
|
||||
|
||||
|
||||
// This object is here only to avoid reallocation during training. It don't
|
||||
// logically contribute to the state of this object.
|
||||
matrix<double,0,1> tmp;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -37,7 +37,8 @@ namespace dlib
|
||||
|
||||
explicit rls(
|
||||
double forget_factor,
|
||||
double C = 1000
|
||||
double C = 1000,
|
||||
bool apply_forget_factor_to_C = false
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -47,6 +48,7 @@ namespace dlib
|
||||
- #get_w().size() == 0
|
||||
- #get_c() == C
|
||||
- #get_forget_factor() == forget_factor
|
||||
- #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
|
||||
!*/
|
||||
|
||||
rls(
|
||||
@ -56,6 +58,7 @@ namespace dlib
|
||||
- #get_w().size() == 0
|
||||
- #get_c() == 1000
|
||||
- #get_forget_factor() == 1
|
||||
- #should_apply_forget_factor_to_C() == false
|
||||
!*/
|
||||
|
||||
double get_c(
|
||||
@ -80,6 +83,18 @@ namespace dlib
|
||||
zero the faster old examples are forgotten.
|
||||
!*/
|
||||
|
||||
bool should_apply_forget_factor_to_C (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- If this function returns false then it means we are optimizing the
|
||||
objective function discussed in the WHAT THIS OBJECT REPRESENTS section
|
||||
above. However, if it returns true then we will allow the forget factor
|
||||
(get_forget_factor()) to be applied to the C value which causes the
|
||||
algorithm to slowly increase C and convert into a textbook version of RLS
|
||||
without regularization. The main reason you might want to do this is
|
||||
because it can make the algorithm run significantly faster.
|
||||
!*/
|
||||
|
||||
template <typename EXP>
|
||||
void train (
|
||||
|
Loading…
Reference in New Issue
Block a user