- Added operator+() for running_stats and running_scalar_covariance

- Simplified and optimized the running_stats implementation
 - Clarified the spec a little
This commit is contained in:
Davis King 2012-10-01 21:25:31 -04:00
parent 3f0ba96aa8
commit daecb0ec85
3 changed files with 128 additions and 11 deletions

View File

@ -53,11 +53,8 @@ namespace dlib
const T& val
)
{
const T div_n = 1/(n+1);
const T n_div_n = n*div_n;
sum = n_div_n*sum + val*div_n;
sum_sqr = n_div_n*sum_sqr + val*div_n*val;
sum += val;
sum_sqr += val*val;
if (val < min_value)
min_value = val;
@ -83,7 +80,10 @@ namespace dlib
T mean (
) const
{
return sum;
if (n != 0)
return sum/n;
else
return 0;
}
T max (
@ -122,8 +122,8 @@ namespace dlib
<< "\n\tthis: " << this
);
T temp = n/(n-1);
temp = temp*(sum_sqr - sum*sum);
T temp = 1/(n-1);
temp = temp*(sum_sqr - sum*sum/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
@ -158,6 +158,29 @@ namespace dlib
return (val-mean())/std::sqrt(variance());
}
running_stats operator+ (
const running_stats& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(max_n() == rhs.max_n(),
"\trunning_stats running_stats::operator+(rhs)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t max_n(): " << max_n()
<< "\n\t rhs.max_n(): " << rhs.max_n()
<< "\n\t this: " << this
);
running_stats temp(*this);
temp.sum += rhs.sum;
temp.sum_sqr += rhs.sum_sqr;
temp.n += rhs.n;
temp.min_value = std::min(rhs.min_value, min_value);
temp.max_value = std::max(rhs.max_value, max_value);
return temp;
}
template <typename U>
friend void serialize (
const running_stats<U>& item,
@ -373,6 +396,21 @@ namespace dlib
return std::sqrt(variance_y());
}
running_scalar_covariance operator+ (
const running_scalar_covariance& rhs
) const
{
running_scalar_covariance temp(rhs);
temp.sum_xy += sum_xy;
temp.sum_x += sum_x;
temp.sum_y += sum_y;
temp.sum_xx += sum_xx;
temp.sum_yy += sum_yy;
temp.n += n;
return temp;
}
private:
T sum_xy;

View File

@ -208,7 +208,7 @@ namespace dlib
requires
- current_n() > 1
ensures
- returns the variance of all the values presented to this
- returns the unbiased sample variance of all the values presented to this
object so far.
!*/
@ -218,8 +218,8 @@ namespace dlib
requires
- current_n() > 1
ensures
- returns the standard deviation of all the values presented to this
object so far.
- returns the unbiased sampled standard deviation of all the values
presented to this object so far.
!*/
T max (
@ -249,6 +249,19 @@ namespace dlib
ensures
- return (val-mean())/stddev();
!*/
running_stats operator+ (
const running_stats& rhs
) const;
/*!
requires
- max_n() == rhs.max_n()
ensures
- returns a new running_stats object that represents the combination of all
the values given to *this and rhs. That is, this function returns a
running_stats object, R, that is equivalent to what you would obtain if
all calls to this->add() and rhs.add() had instead been done to R.
!*/
};
template <typename T>
@ -400,6 +413,18 @@ namespace dlib
- returns the unbiased sample standard deviation of all y samples
presented to this object via add().
!*/
running_scalar_covariance operator+ (
const running_covariance& rhs
) const;
/*!
ensures
- returns a new running_scalar_covariance object that represents the
combination of all the values given to *this and rhs. That is, this
function returns a running_scalar_covariance object, R, that is
equivalent to what you would obtain if all calls to this->add() and
rhs.add() had instead been done to R.
!*/
};
// ----------------------------------------------------------------------------------------

View File

@ -286,6 +286,59 @@ namespace
}
}
void another_test()
{
std::vector<double> a;
running_stats<double> rs1, rs2;
for (int i = 0; i < 10; ++i)
{
rs1.add(i);
a.push_back(i);
}
DLIB_TEST(std::abs(variance(vector_to_matrix(a)) - rs1.variance()) < 1e-13);
DLIB_TEST(std::abs(mean(vector_to_matrix(a)) - rs1.mean()) < 1e-13);
for (int i = 10; i < 20; ++i)
{
rs2.add(i);
a.push_back(i);
}
DLIB_TEST(std::abs(variance(vector_to_matrix(a)) - (rs1+rs2).variance()) < 1e-13);
DLIB_TEST(std::abs(mean(vector_to_matrix(a)) - (rs1+rs2).mean()) < 1e-13);
DLIB_TEST((rs1+rs2).current_n() == 20);
running_scalar_covariance<double> rc1, rc2, rc3;
dlib::rand rnd;
for (double i = 0; i < 10; ++i)
{
const double a = i + rnd.get_random_gaussian();
const double b = i + rnd.get_random_gaussian();
rc1.add(a,b);
rc3.add(a,b);
}
for (double i = 11; i < 20; ++i)
{
const double a = i + rnd.get_random_gaussian();
const double b = i + rnd.get_random_gaussian();
rc2.add(a,b);
rc3.add(a,b);
}
DLIB_TEST(std::abs((rc1+rc2).mean_x() - rc3.mean_x()) < 1e-13);
DLIB_TEST(std::abs((rc1+rc2).mean_y() - rc3.mean_y()) < 1e-13);
DLIB_TEST_MSG(std::abs((rc1+rc2).variance_x() - rc3.variance_x()) < 1e-13, std::abs((rc1+rc2).variance_x() - rc3.variance_x()));
DLIB_TEST(std::abs((rc1+rc2).variance_y() - rc3.variance_y()) < 1e-13);
DLIB_TEST(std::abs((rc1+rc2).covariance() - rc3.covariance()) < 1e-13);
DLIB_TEST((rc1+rc2).current_n() == rc3.current_n());
rs1.set_max_n(50);
DLIB_TEST(rs1.max_n() == 50);
}
void perform_test (
)
{
@ -295,6 +348,7 @@ namespace
test_running_stats();
test_randomize_samples();
test_randomize_samples2();
another_test();
}
} a;