mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Use CUDA in LayerNorm gradient computation
I don't know how I could miss this.
This commit is contained in:
parent
3a267db577
commit
49314c12d9
@ -687,7 +687,11 @@ namespace dlib { namespace tt
|
|||||||
tensor& beta_grad
|
tensor& beta_grad
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
#ifdef DLIB_USE_CUDA
|
||||||
|
cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
|
||||||
|
#else
|
||||||
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
|
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
Loading…
Reference in New Issue
Block a user