From 8fa65eb7b268851beb50a62eec4351216fc0291a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Thu, 7 Apr 2022 01:45:30 +0900 Subject: [PATCH] Add SmeLU activation --- dlib/cuda/cpu_dlib.cpp | 55 ++++++++++++++++++++++++++ dlib/cuda/cpu_dlib.h | 15 +++++++ dlib/cuda/cuda_dlib.cu | 75 ++++++++++++++++++++++++++++++++++- dlib/cuda/cuda_dlib.h | 15 +++++++ dlib/cuda/tensor_tools.cpp | 30 ++++++++++++++ dlib/cuda/tensor_tools.h | 43 ++++++++++++++++++++ dlib/dnn/layers.h | 80 ++++++++++++++++++++++++++++++++++++++ dlib/dnn/layers_abstract.h | 53 +++++++++++++++++++++++++ dlib/test/dnn.cpp | 49 +++++++++++++++++++---- 9 files changed, 406 insertions(+), 9 deletions(-) diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index 2d8e1e453..7c4de9d3f 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1998,6 +1998,61 @@ namespace dlib } } + void smelu ( + tensor& dest, + const tensor& src, + const float beta + ) + { + const float* s = src.host(); + float* d = dest.host(); + for (size_t i = 0; i < dest.size(); ++i) + { + if (s[i] >= beta) + d[i] = s[i]; + else if (s[i] <= -beta) + d[i] = 0; + else + d[i] = (s[i] + beta) * (s[i] + beta) / (4 * beta); + } + } + + void smelu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input, + const float beta + ) + { + const float* gi = gradient_input.host(); + const float* in = dest.host(); + float* out = grad.host(); + if (is_same_object(grad, gradient_input)) + { + for (size_t i = 0; i < dest.size(); ++i) + { + if (in[i] >= beta) + out[i] = gi[i]; + else if (in[i] == 0) + out[i] = 0; + else + out[i] = std::sqrt(beta * in[i]) / beta * gi[i]; + } + } + else + { + for (size_t i = 0; i < dest.size(); ++i) + { + if (in[i] >= beta) + out[i] += gi[i]; + else if (in[i] == 0) + continue; + else + out[i] += std::sqrt(beta * in[i]) / beta * gi[i]; + } + } + } + // ---------------------------------------------------------------------------------------- void resize_bilinear ( diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index 3e9e3d3a8..410609378 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -421,6 +421,21 @@ namespace dlib // ---------------------------------------------------------------------------------------- + void smelu ( + tensor& dest, + const tensor& src, + const float beta + ); + + void smelu ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input, + const float beta + ); + + // ------------------------------------------------------------------------------------ + void resize_bilinear ( tensor& dest, long long dest_row_stride, diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 02c02ae84..3cfef9545 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -1366,7 +1366,7 @@ namespace dlib void leaky_relu( tensor& dest, - const tensor &src, + const tensor& src, const float alpha ) { @@ -1657,6 +1657,79 @@ namespace dlib launch_kernel(_cuda_gelu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size()); } + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_smelu (const float* s, float* d, size_t n, const float beta) + { + for (auto i : grid_stride_range(0, n)) + { + if (s[i] >= beta) + d[i] = s[i]; + else if (s[i] <= -beta) + d[i] = 0; + else + d[i] = (s[i] + beta) * (s[i] + beta) / (4 * beta); + } + } + + void smelu ( + tensor& dest, + const tensor& src, + const float beta + ) + { + launch_kernel(_cuda_smelu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), beta); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_smelu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float beta) + { + for (auto i : grid_stride_range(0, n)) + { + if (s[i] >= beta) + out[i] = gi[i]; + else if (s[i] == 0) + out[i] = 0; + else + out[i] = std::sqrt(beta * s[i]) / beta * gi[i]; + } + } + + __global__ void _cuda_smelu_gradient(float* out, const float* s, const float* gi, size_t n, const float beta) + { + for (auto i : grid_stride_range(0, n)) + { + if (s[i] >= beta) + out[i] += gi[i]; + else if (s[i] == 0) + continue; + else + out[i] += std::sqrt(beta * s[i]) / beta * gi[i]; + } + } + + void smelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const float beta + ) + { + float* out = grad.device(); + const float* gi = gradient_input.device(); + if (out == gi) + { + launch_kernel(_cuda_smelu_gradient_inplace, max_jobs(grad.size()), + out, src.device(), gi, grad.size(), beta); + } + else + { + launch_kernel(_cuda_smelu_gradient, max_jobs(grad.size()), + out, src.device(), gi, grad.size(), beta); + } + } + // ---------------------------------------------------------------------------------------- __global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d, diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index 8f00ceb48..d4f0cfbd4 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -465,6 +465,21 @@ namespace dlib // ---------------------------------------------------------------------------------------- + void smelu ( + tensor& dest, + const tensor& src, + const float beta + ); + + void smelu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input, + const float beta + ); + + // ------------------------------------------------------------------------------------ + void resize_bilinear ( tensor& dest, long long dest_row_stride, diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index d0f44ab6a..5c4f3bed4 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -1084,6 +1084,36 @@ namespace dlib { namespace tt #endif } +// ---------------------------------------------------------------------------------------- + + void smelu ( + tensor& dest, + const tensor& src, + const float beta + ) + { + DLIB_CASSERT(beta > 0); +#ifdef DLIB_USE_CUDA + cuda::smelu(dest, src, beta); +#else + cpu::smelu(dest, src, beta); +#endif + } + + void smelu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input, + const float beta + ) + { + DLIB_CASSERT(beta > 0); +#ifdef DLIB_USE_CUDA + cuda::smelu_gradient(grad, dest, gradient_input, beta); +#else + cpu::smelu_gradient(grad, dest, gradient_input, beta); +#endif + } // ---------------------------------------------------------------------------------------- void resize_bilinear ( diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 6d30d5168..07b7098cc 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -1747,6 +1747,49 @@ namespace dlib { namespace tt is_same_object(grad, gradient_input)==true !*/ +// ---------------------------------------------------------------------------------------- + + void smelu ( + tensor& dest, + const tensor& src, + const float beta + ); + /*! + requires + - have_same_dimensions(dest, src) == true + - beta > 0 + ensures + - for all valid i: + - if (src.host()[i] > beta) then + - #dest.host()[i] == src.host()[i] + - else if (src.host()[i] < -beta) then + - #dest.host()[i] == 0 + - else + - #dest.host()[i] == std::pow(src.host()[i] + beta), 2) / (4 * beta) + !*/ + + void smelu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input, + const float beta + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + - beta > 0 + ensures + - Recalling that dest is the output of smelu(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest). Then this function computes the + gradient of f() with respect to SRC and stores it to grad. Moreover, if + is_same_object(grad,gradient_input)==true then the output is assigned to + grad, replacing its previous contents. Otherwise the output is added to + grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + // ---------------------------------------------------------------------------------------- void resize_bilinear ( diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index df9f3d3a6..c1e4fc73a 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -3113,6 +3113,7 @@ namespace dlib using prelu = add_layer; // ---------------------------------------------------------------------------------------- + class leaky_relu_ { public: @@ -3629,6 +3630,85 @@ namespace dlib template using gelu = add_layer; +// ---------------------------------------------------------------------------------------- + + class smelu_ + { + public: + explicit smelu_( + float beta_ = 1 + ) : beta(beta_) + { + } + + float get_beta( + ) const { + return beta; + } + + template + void setup(const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::smelu(output, input, beta); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& + ) + { + tt::smelu_gradient(data_grad, computed_output, gradient_input, beta); + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const smelu_& item, std::ostream& out) + { + serialize("smelu_", out); + serialize(item.beta, out); + } + + friend void deserialize(smelu_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "smelu_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::smelu_."); + deserialize(item.beta, in); + } + + friend std::ostream& operator<<(std::ostream& out, const smelu_& item) + { + out << "smelu\t(" + << "beta=" << item.beta + << ")"; + return out; + } + + friend void to_xml(const smelu_& item, std::ostream& out) + { + out << "\n"; + out << "\n"; + } + + private: + resizable_tensor params; + float beta; + }; + + template + using smelu = add_layer; + // ---------------------------------------------------------------------------------------- class softmax_ diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 2dd88ba16..545b0175e 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -2638,6 +2638,59 @@ namespace dlib template using gelu = add_layer; +// ---------------------------------------------------------------------------------------- + + class smelu_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a smooth rectified linear + layer. Therefore, it passes its inputs through the function f(x): + - if (x > beta) 1 + - if (x < -beta) 0 + - else std::pow(x + beta, 2) / (4 * beta) + where f() is applied pointwise across the input tensor and beta is a + non-learned scalar. + + This is the layer type introduced in the paper: + "Smooth activations and reproducibility in deep networks" by + Gil I. Shamir, Dong Lin, Lorenzo Coviello (https://arxiv.org/abs/2010.09931) + !*/ + + public: + explicit smelu_( + float beta = 1 + ); + /*! + ensures + - the beta parameter will be initialized with the beta value + !*/ + + float get_beta( + ) const; + /*! + ensures + - returns the beta parameter of the smelu + !*/ + + template void setup(const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template + using smelu = add_layer; + // ---------------------------------------------------------------------------------------- class softmax_ diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 4bd5dd5b5..1345a2f9e 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -223,13 +223,13 @@ namespace // make sure that cuda::mish and cpu::mish return the same results using namespace dlib::tt; print_spinner(); - const long n = 5; + const long n = 4; const long k = 5; const long nr = 3; const long nc = 3; resizable_tensor src(n,k,nr,nc); tt::tensor_rand rnd; - rnd.fill_uniform(src); + rnd.fill_gaussian(src); resizable_tensor dest1, dest2; dest1.copy_size(src); @@ -239,7 +239,7 @@ namespace dest2 = 2; cuda::mish(dest1, src); cpu::mish(dest2, src); - DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2)))); + DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-6, max(abs(mat(dest1) - mat(dest2)))); #endif // DLIB_USE_CUDA } @@ -248,14 +248,14 @@ namespace #ifdef DLIB_USE_CUDA using namespace dlib::tt; print_spinner(); - const long n = 5; + const long n = 4; const long k = 5; const long nr = 3; const long nc = 3; const float alpha = 0.01; resizable_tensor src(n, k, nr, nc); tt::tensor_rand rnd; - rnd.fill_uniform(src); + rnd.fill_gaussian(src); resizable_tensor dest_cuda, dest_cpu; dest_cuda.copy_size(src); dest_cpu.copy_size(src); @@ -352,13 +352,13 @@ namespace // make sure that cuda::gelu and cpu::gelu return the same results using namespace dlib::tt; print_spinner(); - const long n = 5; + const long n = 4; const long k = 5; const long nr = 3; const long nc = 3; resizable_tensor src(n,k,nr,nc); tt::tensor_rand rnd; - rnd.fill_uniform(src); + rnd.fill_gaussian(src); resizable_tensor dest1, dest2; dest1.copy_size(src); @@ -368,7 +368,33 @@ namespace dest2 = 2; cuda::gelu(dest1, src); cpu::gelu(dest2, src); - DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2)))); + DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-6, max(abs(mat(dest1) - mat(dest2)))); +#endif // DLIB_USE_CUDA + } + + void test_smelu() + { +#ifdef DLIB_USE_CUDA + using namespace dlib::tt; + print_spinner(); + const long n = 4; + const long k = 5; + const long nr = 3; + const long nc = 3; + const float beta = 1; + resizable_tensor src(n, k, nr, nc); + tt::tensor_rand rnd; + rnd.fill_gaussian(src); + resizable_tensor dest_cuda, dest_cpu; + dest_cuda.copy_size(src); + dest_cpu.copy_size(src); + // initialize to different values in order to make sure the output is actually changed + dest_cuda = 1; + dest_cpu = 2; + cuda::smelu(dest_cuda, src, beta); + cpu::smelu(dest_cpu, src, beta); + + DLIB_TEST_MSG(max(abs(mat(dest_cuda) - mat(dest_cpu))) < 1e-7, max(abs(mat(dest_cuda) - mat(dest_cpu)))); #endif // DLIB_USE_CUDA } @@ -2103,6 +2129,12 @@ namespace auto res = test_layer(l); DLIB_TEST_MSG(res, res); } + { + print_spinner(); + smelu_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } { print_spinner(); softmax_ l; @@ -4286,6 +4318,7 @@ namespace test_clipped_relu(); test_elu(); test_gelu(); + test_smelu(); test_batch_normalize(); test_batch_normalize_conv(); test_layer_normalize();