mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
add leaky_relu activation layer (#2033)
* add leaky_relu activation layer * add inplace case for leaky_relu and test_layer * make clear that alpha is not learned by leaky_relu * remove branch from cuda kernel
This commit is contained in:
parent
74123841bb
commit
d610e56c2a
@ -1607,6 +1607,57 @@ namespace dlib
|
||||
params_grad.host()[0] = pgrad;
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void leaky_relu (
|
||||
tensor& dest,
|
||||
const tensor& src,
|
||||
const float alpha
|
||||
)
|
||||
{
|
||||
const float* s = src.host();
|
||||
float* d = dest.host();
|
||||
for (size_t i = 0; i < dest.size(); ++i)
|
||||
{
|
||||
if (s[i] > 0)
|
||||
d[i] = s[i];
|
||||
else
|
||||
d[i] = alpha * s[i];
|
||||
}
|
||||
}
|
||||
|
||||
void leaky_relu_gradient (
|
||||
tensor& grad,
|
||||
const tensor& dest,
|
||||
const tensor& gradient_input,
|
||||
const float alpha
|
||||
)
|
||||
{
|
||||
const float* gi = gradient_input.host();
|
||||
const float* in = dest.host();
|
||||
float* out = grad.host();
|
||||
if (is_same_object(grad, gradient_input))
|
||||
{
|
||||
for (size_t i = 0; i < dest.size(); ++i)
|
||||
{
|
||||
if (in[i] > 0)
|
||||
out[i] = gi[i];
|
||||
else
|
||||
out[i] = alpha * gi[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < dest.size(); ++i)
|
||||
{
|
||||
if (in[i] > 0)
|
||||
out[i] += gi[i];
|
||||
else
|
||||
out[i] += alpha * gi[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void tanh (
|
||||
|
@ -323,6 +323,21 @@ namespace dlib
|
||||
tensor& params_grad
|
||||
);
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void leaky_relu (
|
||||
tensor& dest,
|
||||
const tensor& src,
|
||||
const float alpha
|
||||
);
|
||||
|
||||
void leaky_relu_gradient (
|
||||
tensor& grad,
|
||||
const tensor& dest,
|
||||
const tensor& gradient_input,
|
||||
const float alpha
|
||||
);
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
void tanh (
|
||||
|
@ -1350,6 +1350,73 @@ namespace dlib
|
||||
grad.device(), src.device(), gradient_input.device(), grad.size(),
|
||||
param.device(), params_grad.device());
|
||||
}
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_leaky_relu(const float* s, float* d, size_t n, const float alpha)
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
{
|
||||
if (s[i] > 0)
|
||||
d[i] = s[i];
|
||||
else
|
||||
d[i] = alpha * s[i];
|
||||
}
|
||||
}
|
||||
|
||||
void leaky_relu(
|
||||
tensor& dest,
|
||||
const tensor &src,
|
||||
const float alpha
|
||||
)
|
||||
{
|
||||
launch_kernel(_cuda_leaky_relu, max_jobs(dest.size()),
|
||||
src.device(), dest.device(), src.size(), alpha);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_leaky_relu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float alpha)
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
{
|
||||
if (s[i] > 0)
|
||||
out[i] = gi[i];
|
||||
else
|
||||
out[i] = alpha * gi[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _cuda_leaky_relu_gradient(float* out, const float* s, const float* gi, size_t n, const float alpha)
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
{
|
||||
if (s[i] > 0)
|
||||
out[i] += gi[i];
|
||||
else
|
||||
out[i] += alpha * gi[i];
|
||||
}
|
||||
}
|
||||
|
||||
void leaky_relu_gradient (
|
||||
tensor& grad,
|
||||
const tensor& src,
|
||||
const tensor& gradient_input,
|
||||
const float alpha
|
||||
)
|
||||
{
|
||||
float* out = grad.device();
|
||||
const float *gi = gradient_input.device();
|
||||
if (out == gi)
|
||||
{
|
||||
launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()),
|
||||
out, src.device(), gi, grad.size(), alpha);
|
||||
}
|
||||
else
|
||||
{
|
||||
launch_kernel(_cuda_leaky_relu_gradient, max_jobs(grad.size()),
|
||||
out, src.device(), gi, grad.size(), alpha);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
@ -1408,6 +1475,7 @@ namespace dlib
|
||||
{
|
||||
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), grad.device(), src.device(), gradient_input.device(), grad.size());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d,
|
||||
|
@ -367,6 +367,21 @@ namespace dlib
|
||||
tensor& params_grad
|
||||
);
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void leaky_relu (
|
||||
tensor& dest,
|
||||
const tensor& src,
|
||||
const float alpha
|
||||
);
|
||||
|
||||
void leaky_relu_gradient (
|
||||
tensor& grad,
|
||||
const tensor& src,
|
||||
const tensor& gradient_input,
|
||||
const float alpha
|
||||
);
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void mish (
|
||||
|
@ -909,6 +909,35 @@ namespace dlib { namespace tt
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void leaky_relu (
|
||||
tensor& dest,
|
||||
const tensor& src,
|
||||
const float alpha
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::leaky_relu(dest, src, alpha);
|
||||
#else
|
||||
cpu::leaky_relu(dest, src, alpha);
|
||||
#endif
|
||||
}
|
||||
|
||||
void leaky_relu_gradient (
|
||||
tensor& grad,
|
||||
const tensor& dest,
|
||||
const tensor& gradient_input,
|
||||
const float alpha
|
||||
)
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
cuda::leaky_relu_gradient(grad, dest, gradient_input, alpha);
|
||||
#else
|
||||
cpu::leaky_relu_gradient(grad, dest, gradient_input, alpha);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void tanh (
|
||||
|
@ -1443,6 +1443,45 @@ namespace dlib { namespace tt
|
||||
adds the gradient with respect to src to #grad.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void leaky_relu (
|
||||
tensor& dest,
|
||||
const tensor& src,
|
||||
const float alpha
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- have_same_dimensions(dest, src) == true
|
||||
ensures
|
||||
- for all valid i:
|
||||
- if (src.host()[i] > 0) then
|
||||
- #dest.host()[i] == src.host()[i]
|
||||
- else
|
||||
- #dest.host()[i] == src.host()[i] * alpha
|
||||
!*/
|
||||
|
||||
void leaky_relu_gradient (
|
||||
tensor& grad,
|
||||
const tensor& dest,
|
||||
const tensor& gradient_input,
|
||||
const float alpha
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- have_same_dimensions(dest,gradient_input) == true
|
||||
- have_same_dimensions(dest,grad) == true
|
||||
ensures
|
||||
- Recalling that dest is the output of leaky_relu(dest,SRC) for some SRC tensor,
|
||||
let f(SRC) == dot(gradient_input,dest). Then this function computes the
|
||||
gradient of f() with respect to SRC and stores it to grad. Moreover, if
|
||||
is_same_object(grad,gradient_input)==true then the output is assigned to
|
||||
grad, replacing its previous contents. Otherwise the output is added to
|
||||
grad.
|
||||
- This function supports in-place operation, i.e. having
|
||||
is_same_object(grad, gradient_input)==true
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void tanh (
|
||||
|
@ -2761,6 +2761,84 @@ namespace dlib
|
||||
template <typename SUBNET>
|
||||
using prelu = add_layer<prelu_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
class leaky_relu_
|
||||
{
|
||||
public:
|
||||
explicit leaky_relu_(
|
||||
float alpha_ = 0.01f
|
||||
) : alpha(alpha_)
|
||||
{
|
||||
}
|
||||
|
||||
float get_alpha(
|
||||
) const {
|
||||
return alpha;
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void setup(const SUBNET& /*sub*/)
|
||||
{
|
||||
}
|
||||
|
||||
void forward_inplace(const tensor& input, tensor& output)
|
||||
{
|
||||
tt::leaky_relu(output, input, alpha);
|
||||
}
|
||||
|
||||
void backward_inplace(
|
||||
const tensor& computed_output,
|
||||
const tensor& gradient_input,
|
||||
tensor& data_grad,
|
||||
tensor&
|
||||
)
|
||||
{
|
||||
tt::leaky_relu_gradient(data_grad, computed_output, gradient_input, alpha);
|
||||
}
|
||||
|
||||
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
|
||||
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
|
||||
|
||||
const tensor& get_layer_params() const { return params; }
|
||||
tensor& get_layer_params() { return params; }
|
||||
|
||||
friend void serialize(const leaky_relu_& item, std::ostream& out)
|
||||
{
|
||||
serialize("leaky_relu_", out);
|
||||
serialize(item.alpha, out);
|
||||
}
|
||||
|
||||
friend void deserialize(leaky_relu_& item, std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "leaky_relu_")
|
||||
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::leaky_relu_.");
|
||||
deserialize(item.alpha, in);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const leaky_relu_& item)
|
||||
{
|
||||
out << "leaky_relu\t("
|
||||
<< "alpha=" << item.alpha
|
||||
<< ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const leaky_relu_& item, std::ostream& out)
|
||||
{
|
||||
out << "<leaky_relu alpha='"<< item.alpha << "'>\n";
|
||||
out << "<leaky_relu/>\n";
|
||||
}
|
||||
|
||||
private:
|
||||
resizable_tensor params;
|
||||
float alpha;
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using leaky_relu = add_layer<leaky_relu_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class sig_
|
||||
|
@ -2090,6 +2090,55 @@ namespace dlib
|
||||
template <typename SUBNET>
|
||||
using prelu = add_layer<prelu_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class leaky_relu_
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
|
||||
defined above. In particular, it defines a leaky rectified linear
|
||||
layer. Therefore, it passes its inputs through the function
|
||||
f(x) = x>0 ? x : alpha*x
|
||||
where f() is applied pointwise across the input tensor and alpha is a
|
||||
non-learned scalar.
|
||||
|
||||
This is the layer type introduced in the paper:
|
||||
A. L. Maas, A. Y. Hannun, and A. Y. Ng. "Rectifier nonlinearities improve
|
||||
neural network acoustic models". In ICML, 2013.
|
||||
!*/
|
||||
|
||||
public:
|
||||
explicit leaky_relu_(
|
||||
float alpha = 0.01f
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- the alpha parameter will be initialized with the alpha value
|
||||
!*/
|
||||
|
||||
float get_alpha(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the alpha parameter of the leaky_relu
|
||||
!*/
|
||||
|
||||
template <typename SUBNET> void setup(const SUBNET& sub);
|
||||
void forward_inplace(const tensor& input, tensor& output);
|
||||
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||
dpoint map_input_to_output(dpoint p) const;
|
||||
dpoint map_output_to_input(dpoint p) const;
|
||||
const tensor& get_layer_params() const;
|
||||
tensor& get_layer_params();
|
||||
/*!
|
||||
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
|
||||
!*/
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using leaky_relu = add_layer<prelu_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class sig_
|
||||
|
@ -243,6 +243,32 @@ namespace
|
||||
#endif // DLIB_USE_CUDA
|
||||
}
|
||||
|
||||
void test_leaky_relu()
|
||||
{
|
||||
#ifdef DLIB_USE_CUDA
|
||||
using namespace dlib::tt;
|
||||
print_spinner();
|
||||
const long n = 5;
|
||||
const long k = 5;
|
||||
const long nr = 3;
|
||||
const long nc = 3;
|
||||
const float alpha = 0.01;
|
||||
resizable_tensor src(n, k, nr, nc);
|
||||
tt::tensor_rand rnd;
|
||||
rnd.fill_uniform(src);
|
||||
resizable_tensor dest1, dest2;
|
||||
dest1.copy_size(src);
|
||||
dest2.copy_size(src);
|
||||
// initialize to different values in order to make sure the output is actually changed
|
||||
dest1 = 1;
|
||||
dest2 = 2;
|
||||
cuda::leaky_relu(dest1, src, alpha);
|
||||
cpu::leaky_relu(dest2, src, alpha);
|
||||
|
||||
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2))));
|
||||
#endif // DLIB_USE_CUDA
|
||||
}
|
||||
|
||||
void test_batch_normalize()
|
||||
{
|
||||
using namespace dlib::tt;
|
||||
@ -1866,6 +1892,12 @@ namespace
|
||||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
leaky_relu_ l;
|
||||
auto res = test_layer(l);
|
||||
DLIB_TEST_MSG(res, res);
|
||||
}
|
||||
{
|
||||
print_spinner();
|
||||
sig_ l;
|
||||
@ -3703,6 +3735,7 @@ namespace
|
||||
test_softmax_all();
|
||||
test_sigmoid();
|
||||
test_mish();
|
||||
test_leaky_relu();
|
||||
test_batch_normalize();
|
||||
test_batch_normalize_conv();
|
||||
test_basic_tensor_ops();
|
||||
|
Loading…
Reference in New Issue
Block a user