This commit is contained in:
Davis King 2016-11-29 11:05:55 -05:00
commit ca3a3aadb6
2 changed files with 199 additions and 5 deletions

View File

@ -631,6 +631,106 @@ namespace dlib
T forget; T forget;
}; };
// ----------------------------------------------------------------------------------------
template <
typename T
>
class running_stats_decayed
{
public:
explicit running_stats_decayed(
T decay_halflife = 1000
)
{
DLIB_ASSERT(decay_halflife > 0);
sum_x = 0;
sum_xx = 0;
forget = std::pow(0.5, 1/decay_halflife);
n = 0;
COMPILE_TIME_ASSERT ((
is_same_type<float,T>::value ||
is_same_type<double,T>::value ||
is_same_type<long double,T>::value
));
}
T forget_factor (
) const
{
return forget;
}
void add (
const T& x
)
{
sum_xx = sum_xx*forget + x*x;
sum_x = sum_x*forget + x;
n = n*forget + forget;
}
T current_n (
) const
{
return n;
}
T mean (
) const
{
if (n != 0)
return sum_x/n;
else
return 0;
}
T variance (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_stats_decayed::variance()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/n * (sum_xx - sum_x*sum_x/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
return temp;
else
return 0;
}
T stddev (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_stats_decayed::stddev()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return std::sqrt(variance());
}
private:
T sum_x;
T sum_xx;
T n;
T forget;
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <

View File

@ -435,7 +435,7 @@ namespace dlib
a stream of real number pairs. It is essentially the same as a stream of real number pairs. It is essentially the same as
running_scalar_covariance except that it forgets about data it has seen running_scalar_covariance except that it forgets about data it has seen
after a certain period of time. It does this by exponentially decaying old after a certain period of time. It does this by exponentially decaying old
statistic. statistics.
!*/ !*/
public: public:
@ -523,7 +523,7 @@ namespace dlib
requires requires
- current_n() > 0 - current_n() > 0
ensures ensures
- returns the unbiased sample variance value of all x samples presented - returns the sample variance value of all x samples presented
to this object via add(). to this object via add().
!*/ !*/
@ -533,7 +533,7 @@ namespace dlib
requires requires
- current_n() > 0 - current_n() > 0
ensures ensures
- returns the unbiased sample variance value of all y samples presented - returns the sample variance value of all y samples presented
to this object via add(). to this object via add().
!*/ !*/
@ -543,7 +543,7 @@ namespace dlib
requires requires
- current_n() > 0 - current_n() > 0
ensures ensures
- returns the unbiased sample standard deviation of all x samples - returns the sample standard deviation of all x samples
presented to this object via add(). presented to this object via add().
!*/ !*/
@ -553,11 +553,105 @@ namespace dlib
requires requires
- current_n() > 0 - current_n() > 0
ensures ensures
- returns the unbiased sample standard deviation of all y samples - returns the sample standard deviation of all y samples
presented to this object via add(). presented to this object via add().
!*/ !*/
}; };
// ----------------------------------------------------------------------------------------
template <
typename T
>
class running_stats_decayed
{
/*!
REQUIREMENTS ON T
- T must be a float, double, or long double type
INITIAL VALUE
- mean() == 0
- current_n() == 0
WHAT THIS OBJECT REPRESENTS
This object represents something that can compute the running mean and
variance of a stream of real numbers. It is similar to running_stats
except that it forgets about data it has seen after a certain period of
time. It does this by exponentially decaying old statistics.
!*/
public:
running_stats_decayed(
T decay_halflife = 1000
);
/*!
requires
- decay_halflife > 0
ensures
- #forget_factor() == std::pow(0.5, 1/decay_halflife);
(i.e. after decay_halflife calls to add() the data given to the first add
will be down weighted by 0.5 in the statistics stored in this object).
!*/
T forget_factor (
) const;
/*!
ensures
- returns the exponential forget factor used to forget old statistics when
add() is called.
!*/
void add (
const T& x
);
/*!
ensures
- updates the statistics stored in this object so that x is factored into
them.
- #current_n() == current_n()*forget_factor() + forget_factor()
- Down weights old statistics by a factor of forget_factor().
!*/
T current_n (
) const;
/*!
ensures
- returns the effective number of points given to this object. As add()
is called this value will converge to a constant, the value of which is
based on the decay_halflife supplied to the constructor.
!*/
T mean (
) const;
/*!
ensures
- returns the mean value of all x samples presented to this object
via add().
!*/
T variance (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the sample variance value of all x samples presented to this
object via add().
!*/
T stddev (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the sample standard deviation of all x samples presented to this
object via add().
!*/
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <