Added a function for generating Gaussian random numbers.

This commit is contained in:
Davis King 2011-12-18 16:51:56 -05:00
parent 83623fd800
commit b436d840fb

View File

@ -44,6 +44,10 @@ namespace dlib
max_val *= 0x1000000;
max_val += 0xFFFFFF;
max_val += 0.01;
has_gaussian = false;
next_gaussian = 0;
}
virtual ~rand(
@ -56,6 +60,9 @@ namespace dlib
mt.seed();
seed.clear();
has_gaussian = false;
next_gaussian = 0;
// prime the generator a bit
for (int i = 0; i < 10000; ++i)
mt();
@ -92,6 +99,10 @@ namespace dlib
// prime the generator a bit
for (int i = 0; i < 10000; ++i)
mt();
has_gaussian = false;
next_gaussian = 0;
}
unsigned char get_random_8bit_number (
@ -164,6 +175,35 @@ namespace dlib
}
}
double get_random_gaussian (
)
{
if (has_gaussian)
{
has_gaussian = false;
return next_gaussian;
}
double x1, x2, w;
const double rndmax = std::numeric_limits<dlib::uint32>::max();
// Generate a pair of Gaussian random numbers using the Box-Muller transformation.
do
{
const double rnd1 = get_random_32bit_number()/rndmax;
const double rnd2 = get_random_32bit_number()/rndmax;
x1 = 2.0 * rnd1 - 1.0;
x2 = 2.0 * rnd2 - 1.0;
w = x1 * x1 + x2 * x2;
} while ( w >= 1.0 );
w = std::sqrt( (-2.0 * std::log( w ) ) / w );
next_gaussian = x2 * w;
has_gaussian = true;
return x1 * w;
}
void swap (
rand& item
@ -171,6 +211,8 @@ namespace dlib
{
exchange(mt,item.mt);
exchange(seed, item.seed);
exchange(has_gaussian, item.has_gaussian);
exchange(next_gaussian, item.next_gaussian);
}
friend void serialize(
@ -190,6 +232,8 @@ namespace dlib
double max_val;
bool has_gaussian;
double next_gaussian;
};
@ -210,8 +254,13 @@ namespace dlib
std::ostream& out
)
{
int version = 1;
serialize(version, out);
serialize(item.mt, out);
serialize(item.seed, out);
serialize(item.has_gaussian, out);
serialize(item.next_gaussian, out);
}
inline void deserialize(
@ -219,8 +268,15 @@ namespace dlib
std::istream& in
)
{
int version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Error deserializing object of type rand: unexpected version.");
deserialize(item.mt, in);
deserialize(item.seed, in);
deserialize(item.has_gaussian, in);
deserialize(item.next_gaussian, in);
}
}