- Added operator+() for running_stats and running_scalar_covariance

- Simplified and optimized the running_stats implementation
 - Clarified the spec a little
This commit is contained in:
Davis King 2012-10-01 21:25:31 -04:00
parent 3f0ba96aa8
commit daecb0ec85
3 changed files with 128 additions and 11 deletions

View File

@ -53,11 +53,8 @@ namespace dlib
const T& val const T& val
) )
{ {
const T div_n = 1/(n+1); sum += val;
const T n_div_n = n*div_n; sum_sqr += val*val;
sum = n_div_n*sum + val*div_n;
sum_sqr = n_div_n*sum_sqr + val*div_n*val;
if (val < min_value) if (val < min_value)
min_value = val; min_value = val;
@ -83,7 +80,10 @@ namespace dlib
T mean ( T mean (
) const ) const
{ {
return sum; if (n != 0)
return sum/n;
else
return 0;
} }
T max ( T max (
@ -122,8 +122,8 @@ namespace dlib
<< "\n\tthis: " << this << "\n\tthis: " << this
); );
T temp = n/(n-1); T temp = 1/(n-1);
temp = temp*(sum_sqr - sum*sum); temp = temp*(sum_sqr - sum*sum/n);
// make sure the variance is never negative. This might // make sure the variance is never negative. This might
// happen due to numerical errors. // happen due to numerical errors.
if (temp >= 0) if (temp >= 0)
@ -158,6 +158,29 @@ namespace dlib
return (val-mean())/std::sqrt(variance()); return (val-mean())/std::sqrt(variance());
} }
running_stats operator+ (
const running_stats& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(max_n() == rhs.max_n(),
"\trunning_stats running_stats::operator+(rhs)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t max_n(): " << max_n()
<< "\n\t rhs.max_n(): " << rhs.max_n()
<< "\n\t this: " << this
);
running_stats temp(*this);
temp.sum += rhs.sum;
temp.sum_sqr += rhs.sum_sqr;
temp.n += rhs.n;
temp.min_value = std::min(rhs.min_value, min_value);
temp.max_value = std::max(rhs.max_value, max_value);
return temp;
}
template <typename U> template <typename U>
friend void serialize ( friend void serialize (
const running_stats<U>& item, const running_stats<U>& item,
@ -373,6 +396,21 @@ namespace dlib
return std::sqrt(variance_y()); return std::sqrt(variance_y());
} }
running_scalar_covariance operator+ (
const running_scalar_covariance& rhs
) const
{
running_scalar_covariance temp(rhs);
temp.sum_xy += sum_xy;
temp.sum_x += sum_x;
temp.sum_y += sum_y;
temp.sum_xx += sum_xx;
temp.sum_yy += sum_yy;
temp.n += n;
return temp;
}
private: private:
T sum_xy; T sum_xy;

View File

@ -208,7 +208,7 @@ namespace dlib
requires requires
- current_n() > 1 - current_n() > 1
ensures ensures
- returns the variance of all the values presented to this - returns the unbiased sample variance of all the values presented to this
object so far. object so far.
!*/ !*/
@ -218,8 +218,8 @@ namespace dlib
requires requires
- current_n() > 1 - current_n() > 1
ensures ensures
- returns the standard deviation of all the values presented to this - returns the unbiased sampled standard deviation of all the values
object so far. presented to this object so far.
!*/ !*/
T max ( T max (
@ -249,6 +249,19 @@ namespace dlib
ensures ensures
- return (val-mean())/stddev(); - return (val-mean())/stddev();
!*/ !*/
running_stats operator+ (
const running_stats& rhs
) const;
/*!
requires
- max_n() == rhs.max_n()
ensures
- returns a new running_stats object that represents the combination of all
the values given to *this and rhs. That is, this function returns a
running_stats object, R, that is equivalent to what you would obtain if
all calls to this->add() and rhs.add() had instead been done to R.
!*/
}; };
template <typename T> template <typename T>
@ -400,6 +413,18 @@ namespace dlib
- returns the unbiased sample standard deviation of all y samples - returns the unbiased sample standard deviation of all y samples
presented to this object via add(). presented to this object via add().
!*/ !*/
running_scalar_covariance operator+ (
const running_covariance& rhs
) const;
/*!
ensures
- returns a new running_scalar_covariance object that represents the
combination of all the values given to *this and rhs. That is, this
function returns a running_scalar_covariance object, R, that is
equivalent to what you would obtain if all calls to this->add() and
rhs.add() had instead been done to R.
!*/
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------

View File

@ -286,6 +286,59 @@ namespace
} }
} }
void another_test()
{
std::vector<double> a;
running_stats<double> rs1, rs2;
for (int i = 0; i < 10; ++i)
{
rs1.add(i);
a.push_back(i);
}
DLIB_TEST(std::abs(variance(vector_to_matrix(a)) - rs1.variance()) < 1e-13);
DLIB_TEST(std::abs(mean(vector_to_matrix(a)) - rs1.mean()) < 1e-13);
for (int i = 10; i < 20; ++i)
{
rs2.add(i);
a.push_back(i);
}
DLIB_TEST(std::abs(variance(vector_to_matrix(a)) - (rs1+rs2).variance()) < 1e-13);
DLIB_TEST(std::abs(mean(vector_to_matrix(a)) - (rs1+rs2).mean()) < 1e-13);
DLIB_TEST((rs1+rs2).current_n() == 20);
running_scalar_covariance<double> rc1, rc2, rc3;
dlib::rand rnd;
for (double i = 0; i < 10; ++i)
{
const double a = i + rnd.get_random_gaussian();
const double b = i + rnd.get_random_gaussian();
rc1.add(a,b);
rc3.add(a,b);
}
for (double i = 11; i < 20; ++i)
{
const double a = i + rnd.get_random_gaussian();
const double b = i + rnd.get_random_gaussian();
rc2.add(a,b);
rc3.add(a,b);
}
DLIB_TEST(std::abs((rc1+rc2).mean_x() - rc3.mean_x()) < 1e-13);
DLIB_TEST(std::abs((rc1+rc2).mean_y() - rc3.mean_y()) < 1e-13);
DLIB_TEST_MSG(std::abs((rc1+rc2).variance_x() - rc3.variance_x()) < 1e-13, std::abs((rc1+rc2).variance_x() - rc3.variance_x()));
DLIB_TEST(std::abs((rc1+rc2).variance_y() - rc3.variance_y()) < 1e-13);
DLIB_TEST(std::abs((rc1+rc2).covariance() - rc3.covariance()) < 1e-13);
DLIB_TEST((rc1+rc2).current_n() == rc3.current_n());
rs1.set_max_n(50);
DLIB_TEST(rs1.max_n() == 50);
}
void perform_test ( void perform_test (
) )
{ {
@ -295,6 +348,7 @@ namespace
test_running_stats(); test_running_stats();
test_randomize_samples(); test_randomize_samples();
test_randomize_samples2(); test_randomize_samples2();
another_test();
} }
} a; } a;