Add SiLU activation layer (#2584)

This commit is contained in:
Adrià Arrufat 2022-05-08 22:28:47 +09:00 committed by GitHub
parent 8af4226057
commit 06b826540c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 333 additions and 1 deletions

View File

@ -1998,6 +1998,8 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
void smelu (
tensor& dest,
const tensor& src,
@ -2053,6 +2055,46 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
void silu (
tensor& dest,
const tensor& src
)
{
const auto d = dest.host();
const auto s = src.host();
for (size_t i = 0; i < src.size(); ++i)
d[i] = s[i] * impl::sigmoid(s[i]);
}
void silu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
const auto g = grad.host();
const auto s = src.host();
const auto in = gradient_input.host();
if (is_same_object(grad, gradient_input))
{
for (size_t i = 0; i < src.size(); ++i)
{
const auto sig_s = impl::sigmoid(s[i]);
g[i] = in[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s)));
}
}
else
{
for (size_t i = 0; i < src.size(); ++i)
{
const auto sig_s = impl::sigmoid(s[i]);
g[i] += in[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s)));
}
}
}
// ----------------------------------------------------------------------------------------
void resize_bilinear (

View File

@ -434,6 +434,19 @@ namespace dlib
const float beta
);
// ----------------------------------------------------------------------------------------
void silu (
tensor& dest,
const tensor& src
);
void silu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ------------------------------------------------------------------------------------
void resize_bilinear (

View File

@ -1729,6 +1729,58 @@ namespace dlib
out, src.device(), gi, grad.size(), beta);
}
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_silu(const float* s, float* d, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
d[i] = s[i] / (1.0f + std::exp(-s[i]));
}
}
void silu (
tensor& dest,
const tensor& src
)
{
launch_kernel(_cuda_silu, max_jobs(dest.size()), src.device(), dest.device(), src.size());
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_silu_gradient_inplace(float* out, const float* s, const float* gi, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
const auto sig_s = 1.0f / (1.0f + std::exp(-s[i]));
out[i] = gi[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s)));
}
}
__global__ void _cuda_silu_gradient(float* out, const float* s, const float* gi, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
const auto sig_s = 1.0f / (1.0f + std::exp(-s[i]));
out[i] += gi[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s)));
}
}
void silu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
float* out = grad.device();
const float* gi = gradient_input.device();
if (out == gi)
launch_kernel(_cuda_silu_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size());
else
launch_kernel(_cuda_silu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
}
// ----------------------------------------------------------------------------------------

View File

@ -478,6 +478,19 @@ namespace dlib
const float beta
);
// ----------------------------------------------------------------------------------------
void silu (
tensor& dest,
const tensor& src
);
void silu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
);
// ------------------------------------------------------------------------------------
void resize_bilinear (

View File

@ -1118,6 +1118,34 @@ namespace dlib { namespace tt
cpu::smelu_gradient(grad, dest, gradient_input, beta);
#endif
}
// ----------------------------------------------------------------------------------------
void silu (
tensor& dest,
const tensor& src
)
{
#ifdef DLIB_USE_CUDA
cuda::silu(dest,src);
#else
cpu::silu(dest,src);
#endif
}
void silu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
#ifdef DLIB_USE_CUDA
cuda::silu_gradient(grad, src, gradient_input);
#else
cpu::silu_gradient(grad, src, gradient_input);
#endif
}
// ----------------------------------------------------------------------------------------
void resize_bilinear (

View File

@ -1790,6 +1790,41 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void silu (
tensor& dest,
const tensor& src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == src.host()[i] * sigmoid(src.host()[i])
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void silu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
ensures
- Recalling that dest is the output of silu(dest,src), let f(src) ==
dot(gradient_input,dest). Then this function computes the gradient of f() with respect
to src and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true
then the output is assigned to grad, replacing its previous contents. Otherwise the
output is added to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void resize_bilinear (

View File

@ -3689,7 +3689,7 @@ namespace dlib
friend std::ostream& operator<<(std::ostream& out, const smelu_& item)
{
out << "smelu\t("
out << "smelu\t ("
<< "beta=" << item.beta
<< ")";
return out;
@ -3709,6 +3709,77 @@ namespace dlib
template <typename SUBNET>
using smelu = add_layer<smelu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class silu_
{
public:
silu_(
)
{
}
template <typename SUBNET>
void setup(const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(
const SUBNET& sub,
resizable_tensor& data_ouput)
{
data_ouput.copy_size(sub.get_output());
tt::silu(data_ouput, sub.get_output());
}
template <typename SUBNET>
void backward(
const tensor& gradient_input,
SUBNET& sub,
tensor&
)
{
tt::silu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input);
}
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const silu_& /*item*/, std::ostream& out)
{
serialize("silu_", out);
}
friend void deserialize(silu_& /*item*/, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "silu_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::silu_.");
}
friend std::ostream& operator<<(std::ostream& out, const silu_& /*item*/)
{
out << "silu";
return out;
}
friend void to_xml(const silu_& /*item*/, std::ostream& out)
{
out << "<silu/>\n";
}
private:
resizable_tensor params;
};
template <typename SUBNET>
using silu = add_layer<silu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class softmax_

View File

@ -2691,6 +2691,44 @@ namespace dlib
template <typename SUBNET>
using smelu = add_layer<prelu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class silu_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines a silu layer. Therefore, it
passes its inputs through the function
f(x)= x * sigmoid(x) = x / (1 + exp(-x))
where f() is applied pointwise across the input tensor.
This is the layer type introduced in the paper:
Dan Hendrycks, Kevin Gimpel. "Gaussian Error Linear Units (GELUs)".
!*/
public:
silu_(
);
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& data_output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor&);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
interface. Note that this layer doesn't have any parameters, so the tensor
returned by get_layer_params() is always empty.
!*/
};
template <typename SUBNET>
using silu = add_layer<silu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class softmax_

View File

@ -870,6 +870,14 @@ namespace dlib
update(i);
}
template <typename U, typename E>
void operator()(size_t i, const add_layer<silu_, U, E>&)
{
start_node(i, "silu");
end_node();
update(i);
}
template <typename U, typename E>
void operator()(size_t i, const add_layer<softmax_, U, E>&)
{

View File

@ -398,6 +398,31 @@ namespace
#endif // DLIB_USE_CUDA
}
void test_silu()
{
#ifdef DLIB_USE_CUDA
using namespace dlib::tt;
print_spinner();
const long n = 4;
const long k = 5;
const long nr = 3;
const long nc = 3;
resizable_tensor src(n, k, nr, nc);
tt::tensor_rand rnd;
rnd.fill_gaussian(src);
resizable_tensor dest_cuda, dest_cpu;
dest_cuda.copy_size(src);
dest_cpu.copy_size(src);
// initialize to different values in order to make sure the output is actually changed
dest_cuda = 1;
dest_cpu = 2;
cuda::silu(dest_cuda, src);
cpu::silu(dest_cpu, src);
DLIB_TEST_MSG(max(abs(mat(dest_cuda) - mat(dest_cpu))) < 1e-6, max(abs(mat(dest_cuda) - mat(dest_cpu))));
#endif // DLIB_USE_CUDA
}
void test_batch_normalize()
{
using namespace dlib::tt;
@ -2135,6 +2160,12 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
silu_ l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
softmax_ l;
@ -4319,6 +4350,7 @@ namespace
test_elu();
test_gelu();
test_smelu();
test_silu();
test_batch_normalize();
test_batch_normalize_conv();
test_layer_normalize();