From a1f158379e2f328e8697b63ad653926594c8a771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Sat, 10 Oct 2020 21:42:10 +0900 Subject: [PATCH] Do not use sqrt_2 in device code (fixes #2208) (#2210) * do not use sqrt_2 in device code * use CUDART_SQRT_2PI * better sort includes --- dlib/cuda/cpu_dlib.cpp | 2 +- dlib/cuda/cuda_dlib.cu | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index 252cb9413..b1ba41a83 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1711,7 +1711,7 @@ namespace dlib const tensor& gradient_input ) { - const float beta = 1.0f / std::sqrt(pi) / sqrt_2; + const float beta = 1.0f / std::sqrt(2.0f * pi); const auto compute_gradient = [beta](float x) { const float cdf = 0.5f*(1.0f + std::erf(x/sqrt_2)); diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 7ac740d63..9f4cc9287 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -4,6 +4,7 @@ #include "cuda_utils.h" #include "cuda_dlib.h" #include "cudnn_dlibapi.h" +#include namespace dlib @@ -1501,7 +1502,7 @@ namespace dlib __device__ float gelu_compute_gradient(float x) { - const float beta = 1.0f / std::sqrt(pi) / sqrt_2; + const float beta = 1.0f / CUDART_SQRT_2PI; const float cdf = normcdf(x); const float pdf = beta*std::exp(-0.5f*x*x); return cdf + x * pdf;