Add sum method to running stats (#2728)

This commit is contained in:
Adrià Arrufat 2023-02-11 12:02:54 +09:00 committed by GitHub
parent 50b33753bb
commit 4519a7a6bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 15 deletions

View File

@ -35,7 +35,7 @@ namespace dlib
void clear()
{
sum = 0;
sum_ = 0;
sum_sqr = 0;
sum_cub = 0;
sum_four = 0;
@ -49,7 +49,7 @@ namespace dlib
const T& val
)
{
sum += val;
sum_ += val;
sum_sqr += val*val;
sum_cub += cubed(val);
sum_four += quaded(val);
@ -68,11 +68,17 @@ namespace dlib
return n;
}
T sum (
) const
{
return sum_;
}
T mean (
) const
{
if (n != 0)
return sum/n;
return sum_/n;
else
return 0;
}
@ -114,7 +120,7 @@ namespace dlib
);
T temp = 1/(n-1);
temp = temp*(sum_sqr - sum*sum/n);
temp = temp*(sum_sqr - sum_*sum_/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
@ -148,8 +154,8 @@ namespace dlib
T temp = 1/n;
T temp1 = std::sqrt(n*(n-1))/(n-2);
temp = temp1*temp*(sum_cub - 3*sum_sqr*sum*temp + 2*cubed(sum)*temp*temp)/
(std::sqrt(std::pow(temp*(sum_sqr-sum*sum*temp),3)));
temp = temp1*temp*(sum_cub - 3*sum_sqr*sum_*temp + 2*cubed(sum_)*temp*temp)/
(std::sqrt(std::pow(temp*(sum_sqr-sum_*sum_*temp),3)));
return temp;
}
@ -165,9 +171,9 @@ namespace dlib
);
T temp = 1/n;
T m4 = temp*(sum_four - 4*sum_cub*sum*temp+6*sum_sqr*sum*sum*temp*temp
-3*quaded(sum)*cubed(temp));
T m2 = temp*(sum_sqr-sum*sum*temp);
T m4 = temp*(sum_four - 4*sum_cub*sum_*temp+6*sum_sqr*sum_*sum_*temp*temp
-3*quaded(sum_)*cubed(temp));
T m2 = temp*(sum_sqr-sum_*sum_*temp);
temp = (n-1)*((n+1)*m4/(m2*m2)-3*(n-1))/((n-2)*(n-3));
return temp;
@ -192,7 +198,7 @@ namespace dlib
{
running_stats temp(*this);
temp.sum += rhs.sum;
temp.sum_ += rhs.sum_;
temp.sum_sqr += rhs.sum_sqr;
temp.sum_cub += rhs.sum_cub;
temp.sum_four += rhs.sum_four;
@ -215,7 +221,7 @@ namespace dlib
);
private:
T sum;
T sum_;
T sum_sqr;
T sum_cub;
T sum_four;
@ -236,7 +242,7 @@ namespace dlib
int version = 2;
serialize(version, out);
serialize(item.sum, out);
serialize(item.sum_, out);
serialize(item.sum_sqr, out);
serialize(item.sum_cub, out);
serialize(item.sum_four, out);
@ -256,7 +262,7 @@ namespace dlib
if (version != 2)
throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats object.");
deserialize(item.sum, in);
deserialize(item.sum_, in);
deserialize(item.sum_sqr, in);
deserialize(item.sum_cub, in);
deserialize(item.sum_four, in);

View File

@ -222,8 +222,9 @@ namespace dlib
);
/*!
ensures
- updates the mean, variance, skewness, and kurtosis stored in this object
so that the new value is factored into them.
- updates the sum, mean, variance, skewness, and kurtosis stored in this
object so that the new value is factored into them.
- #sum() == sum() + val
- #mean() == mean()*current_n()/(current_n()+1) + val/(current_n()+1).
(i.e. the updated mean value that takes the new value into account)
- #variance() == the updated variance that takes this new value into account.
@ -232,6 +233,14 @@ namespace dlib
- #current_n() == current_n() + 1
!*/
T sum (
) const;
/*!
ensures
- returns the sum of all the values presented to this object
so far.
!*/
T mean (
) const;
/*!