Add sum method to running stats (#2728)

This commit is contained in:
Adrià Arrufat 2023-02-11 12:02:54 +09:00 committed by GitHub
parent 50b33753bb
commit 4519a7a6bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 15 deletions

View File

@ -35,7 +35,7 @@ namespace dlib
void clear() void clear()
{ {
sum = 0; sum_ = 0;
sum_sqr = 0; sum_sqr = 0;
sum_cub = 0; sum_cub = 0;
sum_four = 0; sum_four = 0;
@ -49,7 +49,7 @@ namespace dlib
const T& val const T& val
) )
{ {
sum += val; sum_ += val;
sum_sqr += val*val; sum_sqr += val*val;
sum_cub += cubed(val); sum_cub += cubed(val);
sum_four += quaded(val); sum_four += quaded(val);
@ -68,11 +68,17 @@ namespace dlib
return n; return n;
} }
T sum (
) const
{
return sum_;
}
T mean ( T mean (
) const ) const
{ {
if (n != 0) if (n != 0)
return sum/n; return sum_/n;
else else
return 0; return 0;
} }
@ -114,7 +120,7 @@ namespace dlib
); );
T temp = 1/(n-1); T temp = 1/(n-1);
temp = temp*(sum_sqr - sum*sum/n); temp = temp*(sum_sqr - sum_*sum_/n);
// make sure the variance is never negative. This might // make sure the variance is never negative. This might
// happen due to numerical errors. // happen due to numerical errors.
if (temp >= 0) if (temp >= 0)
@ -148,8 +154,8 @@ namespace dlib
T temp = 1/n; T temp = 1/n;
T temp1 = std::sqrt(n*(n-1))/(n-2); T temp1 = std::sqrt(n*(n-1))/(n-2);
temp = temp1*temp*(sum_cub - 3*sum_sqr*sum*temp + 2*cubed(sum)*temp*temp)/ temp = temp1*temp*(sum_cub - 3*sum_sqr*sum_*temp + 2*cubed(sum_)*temp*temp)/
(std::sqrt(std::pow(temp*(sum_sqr-sum*sum*temp),3))); (std::sqrt(std::pow(temp*(sum_sqr-sum_*sum_*temp),3)));
return temp; return temp;
} }
@ -165,9 +171,9 @@ namespace dlib
); );
T temp = 1/n; T temp = 1/n;
T m4 = temp*(sum_four - 4*sum_cub*sum*temp+6*sum_sqr*sum*sum*temp*temp T m4 = temp*(sum_four - 4*sum_cub*sum_*temp+6*sum_sqr*sum_*sum_*temp*temp
-3*quaded(sum)*cubed(temp)); -3*quaded(sum_)*cubed(temp));
T m2 = temp*(sum_sqr-sum*sum*temp); T m2 = temp*(sum_sqr-sum_*sum_*temp);
temp = (n-1)*((n+1)*m4/(m2*m2)-3*(n-1))/((n-2)*(n-3)); temp = (n-1)*((n+1)*m4/(m2*m2)-3*(n-1))/((n-2)*(n-3));
return temp; return temp;
@ -192,7 +198,7 @@ namespace dlib
{ {
running_stats temp(*this); running_stats temp(*this);
temp.sum += rhs.sum; temp.sum_ += rhs.sum_;
temp.sum_sqr += rhs.sum_sqr; temp.sum_sqr += rhs.sum_sqr;
temp.sum_cub += rhs.sum_cub; temp.sum_cub += rhs.sum_cub;
temp.sum_four += rhs.sum_four; temp.sum_four += rhs.sum_four;
@ -215,7 +221,7 @@ namespace dlib
); );
private: private:
T sum; T sum_;
T sum_sqr; T sum_sqr;
T sum_cub; T sum_cub;
T sum_four; T sum_four;
@ -236,7 +242,7 @@ namespace dlib
int version = 2; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.sum, out); serialize(item.sum_, out);
serialize(item.sum_sqr, out); serialize(item.sum_sqr, out);
serialize(item.sum_cub, out); serialize(item.sum_cub, out);
serialize(item.sum_four, out); serialize(item.sum_four, out);
@ -256,7 +262,7 @@ namespace dlib
if (version != 2) if (version != 2)
throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats object."); throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats object.");
deserialize(item.sum, in); deserialize(item.sum_, in);
deserialize(item.sum_sqr, in); deserialize(item.sum_sqr, in);
deserialize(item.sum_cub, in); deserialize(item.sum_cub, in);
deserialize(item.sum_four, in); deserialize(item.sum_four, in);

View File

@ -222,8 +222,9 @@ namespace dlib
); );
/*! /*!
ensures ensures
- updates the mean, variance, skewness, and kurtosis stored in this object - updates the sum, mean, variance, skewness, and kurtosis stored in this
so that the new value is factored into them. object so that the new value is factored into them.
- #sum() == sum() + val
- #mean() == mean()*current_n()/(current_n()+1) + val/(current_n()+1). - #mean() == mean()*current_n()/(current_n()+1) + val/(current_n()+1).
(i.e. the updated mean value that takes the new value into account) (i.e. the updated mean value that takes the new value into account)
- #variance() == the updated variance that takes this new value into account. - #variance() == the updated variance that takes this new value into account.
@ -232,6 +233,14 @@ namespace dlib
- #current_n() == current_n() + 1 - #current_n() == current_n() + 1
!*/ !*/
T sum (
) const;
/*!
ensures
- returns the sum of all the values presented to this object
so far.
!*/
T mean ( T mean (
) const; ) const;
/*! /*!