Do not use sqrt_2 in device code (fixes #2208) (#2210)

* do not use sqrt_2 in device code

* use CUDART_SQRT_2PI

* better sort includes
pull/2217/head
Adrià Arrufat 4 years ago committed by GitHub
parent 3ba004f875
commit a1f158379e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1711,7 +1711,7 @@ namespace dlib
const tensor& gradient_input
)
{
const float beta = 1.0f / std::sqrt(pi) / sqrt_2;
const float beta = 1.0f / std::sqrt(2.0f * pi);
const auto compute_gradient = [beta](float x)
{
const float cdf = 0.5f*(1.0f + std::erf(x/sqrt_2));

@ -4,6 +4,7 @@
#include "cuda_utils.h"
#include "cuda_dlib.h"
#include "cudnn_dlibapi.h"
#include <math_constants.h>
namespace dlib
@ -1501,7 +1502,7 @@ namespace dlib
__device__ float gelu_compute_gradient(float x)
{
const float beta = 1.0f / std::sqrt(pi) / sqrt_2;
const float beta = 1.0f / CUDART_SQRT_2PI;
const float cdf = normcdf(x);
const float pdf = beta*std::exp(-0.5f*x*x);
return cdf + x * pdf;

Loading…
Cancel
Save