mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
- Added operator+() for running_stats and running_scalar_covariance
- Simplified and optimized the running_stats implementation - Clarified the spec a little
This commit is contained in:
parent
3f0ba96aa8
commit
daecb0ec85
@ -53,11 +53,8 @@ namespace dlib
|
|||||||
const T& val
|
const T& val
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
const T div_n = 1/(n+1);
|
sum += val;
|
||||||
const T n_div_n = n*div_n;
|
sum_sqr += val*val;
|
||||||
|
|
||||||
sum = n_div_n*sum + val*div_n;
|
|
||||||
sum_sqr = n_div_n*sum_sqr + val*div_n*val;
|
|
||||||
|
|
||||||
if (val < min_value)
|
if (val < min_value)
|
||||||
min_value = val;
|
min_value = val;
|
||||||
@ -83,7 +80,10 @@ namespace dlib
|
|||||||
T mean (
|
T mean (
|
||||||
) const
|
) const
|
||||||
{
|
{
|
||||||
return sum;
|
if (n != 0)
|
||||||
|
return sum/n;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
T max (
|
T max (
|
||||||
@ -122,8 +122,8 @@ namespace dlib
|
|||||||
<< "\n\tthis: " << this
|
<< "\n\tthis: " << this
|
||||||
);
|
);
|
||||||
|
|
||||||
T temp = n/(n-1);
|
T temp = 1/(n-1);
|
||||||
temp = temp*(sum_sqr - sum*sum);
|
temp = temp*(sum_sqr - sum*sum/n);
|
||||||
// make sure the variance is never negative. This might
|
// make sure the variance is never negative. This might
|
||||||
// happen due to numerical errors.
|
// happen due to numerical errors.
|
||||||
if (temp >= 0)
|
if (temp >= 0)
|
||||||
@ -158,6 +158,29 @@ namespace dlib
|
|||||||
return (val-mean())/std::sqrt(variance());
|
return (val-mean())/std::sqrt(variance());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
running_stats operator+ (
|
||||||
|
const running_stats& rhs
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(max_n() == rhs.max_n(),
|
||||||
|
"\trunning_stats running_stats::operator+(rhs)"
|
||||||
|
<< "\n\t invalid inputs were given to this function"
|
||||||
|
<< "\n\t max_n(): " << max_n()
|
||||||
|
<< "\n\t rhs.max_n(): " << rhs.max_n()
|
||||||
|
<< "\n\t this: " << this
|
||||||
|
);
|
||||||
|
|
||||||
|
running_stats temp(*this);
|
||||||
|
|
||||||
|
temp.sum += rhs.sum;
|
||||||
|
temp.sum_sqr += rhs.sum_sqr;
|
||||||
|
temp.n += rhs.n;
|
||||||
|
temp.min_value = std::min(rhs.min_value, min_value);
|
||||||
|
temp.max_value = std::max(rhs.max_value, max_value);
|
||||||
|
return temp;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
friend void serialize (
|
friend void serialize (
|
||||||
const running_stats<U>& item,
|
const running_stats<U>& item,
|
||||||
@ -373,6 +396,21 @@ namespace dlib
|
|||||||
return std::sqrt(variance_y());
|
return std::sqrt(variance_y());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
running_scalar_covariance operator+ (
|
||||||
|
const running_scalar_covariance& rhs
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
running_scalar_covariance temp(rhs);
|
||||||
|
|
||||||
|
temp.sum_xy += sum_xy;
|
||||||
|
temp.sum_x += sum_x;
|
||||||
|
temp.sum_y += sum_y;
|
||||||
|
temp.sum_xx += sum_xx;
|
||||||
|
temp.sum_yy += sum_yy;
|
||||||
|
temp.n += n;
|
||||||
|
return temp;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
T sum_xy;
|
T sum_xy;
|
||||||
|
@ -208,7 +208,7 @@ namespace dlib
|
|||||||
requires
|
requires
|
||||||
- current_n() > 1
|
- current_n() > 1
|
||||||
ensures
|
ensures
|
||||||
- returns the variance of all the values presented to this
|
- returns the unbiased sample variance of all the values presented to this
|
||||||
object so far.
|
object so far.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
@ -218,8 +218,8 @@ namespace dlib
|
|||||||
requires
|
requires
|
||||||
- current_n() > 1
|
- current_n() > 1
|
||||||
ensures
|
ensures
|
||||||
- returns the standard deviation of all the values presented to this
|
- returns the unbiased sampled standard deviation of all the values
|
||||||
object so far.
|
presented to this object so far.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
T max (
|
T max (
|
||||||
@ -249,6 +249,19 @@ namespace dlib
|
|||||||
ensures
|
ensures
|
||||||
- return (val-mean())/stddev();
|
- return (val-mean())/stddev();
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
running_stats operator+ (
|
||||||
|
const running_stats& rhs
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- max_n() == rhs.max_n()
|
||||||
|
ensures
|
||||||
|
- returns a new running_stats object that represents the combination of all
|
||||||
|
the values given to *this and rhs. That is, this function returns a
|
||||||
|
running_stats object, R, that is equivalent to what you would obtain if
|
||||||
|
all calls to this->add() and rhs.add() had instead been done to R.
|
||||||
|
!*/
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -400,6 +413,18 @@ namespace dlib
|
|||||||
- returns the unbiased sample standard deviation of all y samples
|
- returns the unbiased sample standard deviation of all y samples
|
||||||
presented to this object via add().
|
presented to this object via add().
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
running_scalar_covariance operator+ (
|
||||||
|
const running_covariance& rhs
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns a new running_scalar_covariance object that represents the
|
||||||
|
combination of all the values given to *this and rhs. That is, this
|
||||||
|
function returns a running_scalar_covariance object, R, that is
|
||||||
|
equivalent to what you would obtain if all calls to this->add() and
|
||||||
|
rhs.add() had instead been done to R.
|
||||||
|
!*/
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
@ -286,6 +286,59 @@ namespace
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void another_test()
|
||||||
|
{
|
||||||
|
std::vector<double> a;
|
||||||
|
|
||||||
|
running_stats<double> rs1, rs2;
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; ++i)
|
||||||
|
{
|
||||||
|
rs1.add(i);
|
||||||
|
a.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
DLIB_TEST(std::abs(variance(vector_to_matrix(a)) - rs1.variance()) < 1e-13);
|
||||||
|
DLIB_TEST(std::abs(mean(vector_to_matrix(a)) - rs1.mean()) < 1e-13);
|
||||||
|
|
||||||
|
for (int i = 10; i < 20; ++i)
|
||||||
|
{
|
||||||
|
rs2.add(i);
|
||||||
|
a.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
DLIB_TEST(std::abs(variance(vector_to_matrix(a)) - (rs1+rs2).variance()) < 1e-13);
|
||||||
|
DLIB_TEST(std::abs(mean(vector_to_matrix(a)) - (rs1+rs2).mean()) < 1e-13);
|
||||||
|
DLIB_TEST((rs1+rs2).current_n() == 20);
|
||||||
|
|
||||||
|
running_scalar_covariance<double> rc1, rc2, rc3;
|
||||||
|
dlib::rand rnd;
|
||||||
|
for (double i = 0; i < 10; ++i)
|
||||||
|
{
|
||||||
|
const double a = i + rnd.get_random_gaussian();
|
||||||
|
const double b = i + rnd.get_random_gaussian();
|
||||||
|
rc1.add(a,b);
|
||||||
|
rc3.add(a,b);
|
||||||
|
}
|
||||||
|
for (double i = 11; i < 20; ++i)
|
||||||
|
{
|
||||||
|
const double a = i + rnd.get_random_gaussian();
|
||||||
|
const double b = i + rnd.get_random_gaussian();
|
||||||
|
rc2.add(a,b);
|
||||||
|
rc3.add(a,b);
|
||||||
|
}
|
||||||
|
|
||||||
|
DLIB_TEST(std::abs((rc1+rc2).mean_x() - rc3.mean_x()) < 1e-13);
|
||||||
|
DLIB_TEST(std::abs((rc1+rc2).mean_y() - rc3.mean_y()) < 1e-13);
|
||||||
|
DLIB_TEST_MSG(std::abs((rc1+rc2).variance_x() - rc3.variance_x()) < 1e-13, std::abs((rc1+rc2).variance_x() - rc3.variance_x()));
|
||||||
|
DLIB_TEST(std::abs((rc1+rc2).variance_y() - rc3.variance_y()) < 1e-13);
|
||||||
|
DLIB_TEST(std::abs((rc1+rc2).covariance() - rc3.covariance()) < 1e-13);
|
||||||
|
DLIB_TEST((rc1+rc2).current_n() == rc3.current_n());
|
||||||
|
|
||||||
|
rs1.set_max_n(50);
|
||||||
|
DLIB_TEST(rs1.max_n() == 50);
|
||||||
|
}
|
||||||
|
|
||||||
void perform_test (
|
void perform_test (
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
@ -295,6 +348,7 @@ namespace
|
|||||||
test_running_stats();
|
test_running_stats();
|
||||||
test_randomize_samples();
|
test_randomize_samples();
|
||||||
test_randomize_samples2();
|
test_randomize_samples2();
|
||||||
|
another_test();
|
||||||
}
|
}
|
||||||
} a;
|
} a;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user