From 4519a7a6bdf4e5fb455d6b1509d0b2575abf270d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Sat, 11 Feb 2023 12:02:54 +0900 Subject: [PATCH] Add sum method to running stats (#2728) --- dlib/statistics/statistics.h | 32 ++++++++++++++++----------- dlib/statistics/statistics_abstract.h | 13 +++++++++-- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/dlib/statistics/statistics.h b/dlib/statistics/statistics.h index 9dee7006b..492ddf8e7 100644 --- a/dlib/statistics/statistics.h +++ b/dlib/statistics/statistics.h @@ -35,7 +35,7 @@ namespace dlib void clear() { - sum = 0; + sum_ = 0; sum_sqr = 0; sum_cub = 0; sum_four = 0; @@ -49,7 +49,7 @@ namespace dlib const T& val ) { - sum += val; + sum_ += val; sum_sqr += val*val; sum_cub += cubed(val); sum_four += quaded(val); @@ -68,11 +68,17 @@ namespace dlib return n; } + T sum ( + ) const + { + return sum_; + } + T mean ( ) const { if (n != 0) - return sum/n; + return sum_/n; else return 0; } @@ -114,7 +120,7 @@ namespace dlib ); T temp = 1/(n-1); - temp = temp*(sum_sqr - sum*sum/n); + temp = temp*(sum_sqr - sum_*sum_/n); // make sure the variance is never negative. This might // happen due to numerical errors. if (temp >= 0) @@ -148,8 +154,8 @@ namespace dlib T temp = 1/n; T temp1 = std::sqrt(n*(n-1))/(n-2); - temp = temp1*temp*(sum_cub - 3*sum_sqr*sum*temp + 2*cubed(sum)*temp*temp)/ - (std::sqrt(std::pow(temp*(sum_sqr-sum*sum*temp),3))); + temp = temp1*temp*(sum_cub - 3*sum_sqr*sum_*temp + 2*cubed(sum_)*temp*temp)/ + (std::sqrt(std::pow(temp*(sum_sqr-sum_*sum_*temp),3))); return temp; } @@ -165,9 +171,9 @@ namespace dlib ); T temp = 1/n; - T m4 = temp*(sum_four - 4*sum_cub*sum*temp+6*sum_sqr*sum*sum*temp*temp - -3*quaded(sum)*cubed(temp)); - T m2 = temp*(sum_sqr-sum*sum*temp); + T m4 = temp*(sum_four - 4*sum_cub*sum_*temp+6*sum_sqr*sum_*sum_*temp*temp + -3*quaded(sum_)*cubed(temp)); + T m2 = temp*(sum_sqr-sum_*sum_*temp); temp = (n-1)*((n+1)*m4/(m2*m2)-3*(n-1))/((n-2)*(n-3)); return temp; @@ -192,7 +198,7 @@ namespace dlib { running_stats temp(*this); - temp.sum += rhs.sum; + temp.sum_ += rhs.sum_; temp.sum_sqr += rhs.sum_sqr; temp.sum_cub += rhs.sum_cub; temp.sum_four += rhs.sum_four; @@ -215,7 +221,7 @@ namespace dlib ); private: - T sum; + T sum_; T sum_sqr; T sum_cub; T sum_four; @@ -236,7 +242,7 @@ namespace dlib int version = 2; serialize(version, out); - serialize(item.sum, out); + serialize(item.sum_, out); serialize(item.sum_sqr, out); serialize(item.sum_cub, out); serialize(item.sum_four, out); @@ -256,7 +262,7 @@ namespace dlib if (version != 2) throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats object."); - deserialize(item.sum, in); + deserialize(item.sum_, in); deserialize(item.sum_sqr, in); deserialize(item.sum_cub, in); deserialize(item.sum_four, in); diff --git a/dlib/statistics/statistics_abstract.h b/dlib/statistics/statistics_abstract.h index 77db856f8..b5738196d 100644 --- a/dlib/statistics/statistics_abstract.h +++ b/dlib/statistics/statistics_abstract.h @@ -222,8 +222,9 @@ namespace dlib ); /*! ensures - - updates the mean, variance, skewness, and kurtosis stored in this object - so that the new value is factored into them. + - updates the sum, mean, variance, skewness, and kurtosis stored in this + object so that the new value is factored into them. + - #sum() == sum() + val - #mean() == mean()*current_n()/(current_n()+1) + val/(current_n()+1). (i.e. the updated mean value that takes the new value into account) - #variance() == the updated variance that takes this new value into account. @@ -232,6 +233,14 @@ namespace dlib - #current_n() == current_n() + 1 !*/ + T sum ( + ) const; + /*! + ensures + - returns the sum of all the values presented to this object + so far. + !*/ + T mean ( ) const; /*!