mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add SmeLU activation
This commit is contained in:
parent
66f9b2b5bc
commit
8fa65eb7b2
@ -1998,6 +1998,61 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& dest,
|
||||||
|
const tensor& src,
|
||||||
|
const float beta
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const float* s = src.host();
|
||||||
|
float* d = dest.host();
|
||||||
|
for (size_t i = 0; i < dest.size(); ++i)
|
||||||
|
{
|
||||||
|
if (s[i] >= beta)
|
||||||
|
d[i] = s[i];
|
||||||
|
else if (s[i] <= -beta)
|
||||||
|
d[i] = 0;
|
||||||
|
else
|
||||||
|
d[i] = (s[i] + beta) * (s[i] + beta) / (4 * beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void smelu_gradient (
|
||||||
|
tensor& grad,
|
||||||
|
const tensor& dest,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
const float beta
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const float* gi = gradient_input.host();
|
||||||
|
const float* in = dest.host();
|
||||||
|
float* out = grad.host();
|
||||||
|
if (is_same_object(grad, gradient_input))
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < dest.size(); ++i)
|
||||||
|
{
|
||||||
|
if (in[i] >= beta)
|
||||||
|
out[i] = gi[i];
|
||||||
|
else if (in[i] == 0)
|
||||||
|
out[i] = 0;
|
||||||
|
else
|
||||||
|
out[i] = std::sqrt(beta * in[i]) / beta * gi[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < dest.size(); ++i)
|
||||||
|
{
|
||||||
|
if (in[i] >= beta)
|
||||||
|
out[i] += gi[i];
|
||||||
|
else if (in[i] == 0)
|
||||||
|
continue;
|
||||||
|
else
|
||||||
|
out[i] += std::sqrt(beta * in[i]) / beta * gi[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
void resize_bilinear (
|
void resize_bilinear (
|
||||||
|
@ -421,6 +421,21 @@ namespace dlib
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& dest,
|
||||||
|
const tensor& src,
|
||||||
|
const float beta
|
||||||
|
);
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& grad,
|
||||||
|
const tensor& dest,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
const float beta
|
||||||
|
);
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------------------
|
||||||
|
|
||||||
void resize_bilinear (
|
void resize_bilinear (
|
||||||
tensor& dest,
|
tensor& dest,
|
||||||
long long dest_row_stride,
|
long long dest_row_stride,
|
||||||
|
@ -1366,7 +1366,7 @@ namespace dlib
|
|||||||
|
|
||||||
void leaky_relu(
|
void leaky_relu(
|
||||||
tensor& dest,
|
tensor& dest,
|
||||||
const tensor &src,
|
const tensor& src,
|
||||||
const float alpha
|
const float alpha
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
@ -1657,6 +1657,79 @@ namespace dlib
|
|||||||
launch_kernel(_cuda_gelu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
|
launch_kernel(_cuda_gelu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
__global__ void _cuda_smelu (const float* s, float* d, size_t n, const float beta)
|
||||||
|
{
|
||||||
|
for (auto i : grid_stride_range(0, n))
|
||||||
|
{
|
||||||
|
if (s[i] >= beta)
|
||||||
|
d[i] = s[i];
|
||||||
|
else if (s[i] <= -beta)
|
||||||
|
d[i] = 0;
|
||||||
|
else
|
||||||
|
d[i] = (s[i] + beta) * (s[i] + beta) / (4 * beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& dest,
|
||||||
|
const tensor& src,
|
||||||
|
const float beta
|
||||||
|
)
|
||||||
|
{
|
||||||
|
launch_kernel(_cuda_smelu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
__global__ void _cuda_smelu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float beta)
|
||||||
|
{
|
||||||
|
for (auto i : grid_stride_range(0, n))
|
||||||
|
{
|
||||||
|
if (s[i] >= beta)
|
||||||
|
out[i] = gi[i];
|
||||||
|
else if (s[i] == 0)
|
||||||
|
out[i] = 0;
|
||||||
|
else
|
||||||
|
out[i] = std::sqrt(beta * s[i]) / beta * gi[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void _cuda_smelu_gradient(float* out, const float* s, const float* gi, size_t n, const float beta)
|
||||||
|
{
|
||||||
|
for (auto i : grid_stride_range(0, n))
|
||||||
|
{
|
||||||
|
if (s[i] >= beta)
|
||||||
|
out[i] += gi[i];
|
||||||
|
else if (s[i] == 0)
|
||||||
|
continue;
|
||||||
|
else
|
||||||
|
out[i] += std::sqrt(beta * s[i]) / beta * gi[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void smelu_gradient (
|
||||||
|
tensor& grad,
|
||||||
|
const tensor& src,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
const float beta
|
||||||
|
)
|
||||||
|
{
|
||||||
|
float* out = grad.device();
|
||||||
|
const float* gi = gradient_input.device();
|
||||||
|
if (out == gi)
|
||||||
|
{
|
||||||
|
launch_kernel(_cuda_smelu_gradient_inplace, max_jobs(grad.size()),
|
||||||
|
out, src.device(), gi, grad.size(), beta);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
launch_kernel(_cuda_smelu_gradient, max_jobs(grad.size()),
|
||||||
|
out, src.device(), gi, grad.size(), beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
__global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d,
|
__global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d,
|
||||||
|
@ -465,6 +465,21 @@ namespace dlib
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& dest,
|
||||||
|
const tensor& src,
|
||||||
|
const float beta
|
||||||
|
);
|
||||||
|
|
||||||
|
void smelu_gradient (
|
||||||
|
tensor& grad,
|
||||||
|
const tensor& dest,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
const float beta
|
||||||
|
);
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------------------
|
||||||
|
|
||||||
void resize_bilinear (
|
void resize_bilinear (
|
||||||
tensor& dest,
|
tensor& dest,
|
||||||
long long dest_row_stride,
|
long long dest_row_stride,
|
||||||
|
@ -1084,6 +1084,36 @@ namespace dlib { namespace tt
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& dest,
|
||||||
|
const tensor& src,
|
||||||
|
const float beta
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(beta > 0);
|
||||||
|
#ifdef DLIB_USE_CUDA
|
||||||
|
cuda::smelu(dest, src, beta);
|
||||||
|
#else
|
||||||
|
cpu::smelu(dest, src, beta);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void smelu_gradient (
|
||||||
|
tensor& grad,
|
||||||
|
const tensor& dest,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
const float beta
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(beta > 0);
|
||||||
|
#ifdef DLIB_USE_CUDA
|
||||||
|
cuda::smelu_gradient(grad, dest, gradient_input, beta);
|
||||||
|
#else
|
||||||
|
cpu::smelu_gradient(grad, dest, gradient_input, beta);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
void resize_bilinear (
|
void resize_bilinear (
|
||||||
|
@ -1747,6 +1747,49 @@ namespace dlib { namespace tt
|
|||||||
is_same_object(grad, gradient_input)==true
|
is_same_object(grad, gradient_input)==true
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void smelu (
|
||||||
|
tensor& dest,
|
||||||
|
const tensor& src,
|
||||||
|
const float beta
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- have_same_dimensions(dest, src) == true
|
||||||
|
- beta > 0
|
||||||
|
ensures
|
||||||
|
- for all valid i:
|
||||||
|
- if (src.host()[i] > beta) then
|
||||||
|
- #dest.host()[i] == src.host()[i]
|
||||||
|
- else if (src.host()[i] < -beta) then
|
||||||
|
- #dest.host()[i] == 0
|
||||||
|
- else
|
||||||
|
- #dest.host()[i] == std::pow(src.host()[i] + beta), 2) / (4 * beta)
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void smelu_gradient (
|
||||||
|
tensor& grad,
|
||||||
|
const tensor& dest,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
const float beta
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- have_same_dimensions(dest,gradient_input) == true
|
||||||
|
- have_same_dimensions(dest,grad) == true
|
||||||
|
- beta > 0
|
||||||
|
ensures
|
||||||
|
- Recalling that dest is the output of smelu(dest,SRC) for some SRC tensor,
|
||||||
|
let f(SRC) == dot(gradient_input,dest). Then this function computes the
|
||||||
|
gradient of f() with respect to SRC and stores it to grad. Moreover, if
|
||||||
|
is_same_object(grad,gradient_input)==true then the output is assigned to
|
||||||
|
grad, replacing its previous contents. Otherwise the output is added to
|
||||||
|
grad.
|
||||||
|
- This function supports in-place operation, i.e. having
|
||||||
|
is_same_object(grad, gradient_input)==true
|
||||||
|
!*/
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
void resize_bilinear (
|
void resize_bilinear (
|
||||||
|
@ -3113,6 +3113,7 @@ namespace dlib
|
|||||||
using prelu = add_layer<prelu_, SUBNET>;
|
using prelu = add_layer<prelu_, SUBNET>;
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
class leaky_relu_
|
class leaky_relu_
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -3629,6 +3630,85 @@ namespace dlib
|
|||||||
template <typename SUBNET>
|
template <typename SUBNET>
|
||||||
using gelu = add_layer<gelu_, SUBNET>;
|
using gelu = add_layer<gelu_, SUBNET>;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class smelu_
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit smelu_(
|
||||||
|
float beta_ = 1
|
||||||
|
) : beta(beta_)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
float get_beta(
|
||||||
|
) const {
|
||||||
|
return beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SUBNET>
|
||||||
|
void setup(const SUBNET& /*sub*/)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward_inplace(const tensor& input, tensor& output)
|
||||||
|
{
|
||||||
|
tt::smelu(output, input, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward_inplace(
|
||||||
|
const tensor& computed_output,
|
||||||
|
const tensor& gradient_input,
|
||||||
|
tensor& data_grad,
|
||||||
|
tensor&
|
||||||
|
)
|
||||||
|
{
|
||||||
|
tt::smelu_gradient(data_grad, computed_output, gradient_input, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
|
||||||
|
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
|
||||||
|
|
||||||
|
const tensor& get_layer_params() const { return params; }
|
||||||
|
tensor& get_layer_params() { return params; }
|
||||||
|
|
||||||
|
friend void serialize(const smelu_& item, std::ostream& out)
|
||||||
|
{
|
||||||
|
serialize("smelu_", out);
|
||||||
|
serialize(item.beta, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend void deserialize(smelu_& item, std::istream& in)
|
||||||
|
{
|
||||||
|
std::string version;
|
||||||
|
deserialize(version, in);
|
||||||
|
if (version != "smelu_")
|
||||||
|
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::smelu_.");
|
||||||
|
deserialize(item.beta, in);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& out, const smelu_& item)
|
||||||
|
{
|
||||||
|
out << "smelu\t("
|
||||||
|
<< "beta=" << item.beta
|
||||||
|
<< ")";
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend void to_xml(const smelu_& item, std::ostream& out)
|
||||||
|
{
|
||||||
|
out << "<smelu beta='"<< item.beta << "'>\n";
|
||||||
|
out << "<smelu/>\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
resizable_tensor params;
|
||||||
|
float beta;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SUBNET>
|
||||||
|
using smelu = add_layer<smelu_, SUBNET>;
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
class softmax_
|
class softmax_
|
||||||
|
@ -2638,6 +2638,59 @@ namespace dlib
|
|||||||
template <typename SUBNET>
|
template <typename SUBNET>
|
||||||
using gelu = add_layer<gelu_, SUBNET>;
|
using gelu = add_layer<gelu_, SUBNET>;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class smelu_
|
||||||
|
{
|
||||||
|
/*!
|
||||||
|
WHAT THIS OBJECT REPRESENTS
|
||||||
|
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
|
||||||
|
defined above. In particular, it defines a smooth rectified linear
|
||||||
|
layer. Therefore, it passes its inputs through the function f(x):
|
||||||
|
- if (x > beta) 1
|
||||||
|
- if (x < -beta) 0
|
||||||
|
- else std::pow(x + beta, 2) / (4 * beta)
|
||||||
|
where f() is applied pointwise across the input tensor and beta is a
|
||||||
|
non-learned scalar.
|
||||||
|
|
||||||
|
This is the layer type introduced in the paper:
|
||||||
|
"Smooth activations and reproducibility in deep networks" by
|
||||||
|
Gil I. Shamir, Dong Lin, Lorenzo Coviello (https://arxiv.org/abs/2010.09931)
|
||||||
|
!*/
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit smelu_(
|
||||||
|
float beta = 1
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- the beta parameter will be initialized with the beta value
|
||||||
|
!*/
|
||||||
|
|
||||||
|
float get_beta(
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns the beta parameter of the smelu
|
||||||
|
!*/
|
||||||
|
|
||||||
|
template <typename SUBNET> void setup(const SUBNET& sub);
|
||||||
|
void forward_inplace(const tensor& input, tensor& output);
|
||||||
|
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||||
|
dpoint map_input_to_output(dpoint p) const;
|
||||||
|
dpoint map_output_to_input(dpoint p) const;
|
||||||
|
const tensor& get_layer_params() const;
|
||||||
|
tensor& get_layer_params();
|
||||||
|
/*!
|
||||||
|
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
|
||||||
|
interface. Note that this layer doesn't have any parameters, so the tensor
|
||||||
|
returned by get_layer_params() is always empty.
|
||||||
|
!*/
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SUBNET>
|
||||||
|
using smelu = add_layer<prelu_, SUBNET>;
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
class softmax_
|
class softmax_
|
||||||
|
@ -223,13 +223,13 @@ namespace
|
|||||||
// make sure that cuda::mish and cpu::mish return the same results
|
// make sure that cuda::mish and cpu::mish return the same results
|
||||||
using namespace dlib::tt;
|
using namespace dlib::tt;
|
||||||
print_spinner();
|
print_spinner();
|
||||||
const long n = 5;
|
const long n = 4;
|
||||||
const long k = 5;
|
const long k = 5;
|
||||||
const long nr = 3;
|
const long nr = 3;
|
||||||
const long nc = 3;
|
const long nc = 3;
|
||||||
resizable_tensor src(n,k,nr,nc);
|
resizable_tensor src(n,k,nr,nc);
|
||||||
tt::tensor_rand rnd;
|
tt::tensor_rand rnd;
|
||||||
rnd.fill_uniform(src);
|
rnd.fill_gaussian(src);
|
||||||
|
|
||||||
resizable_tensor dest1, dest2;
|
resizable_tensor dest1, dest2;
|
||||||
dest1.copy_size(src);
|
dest1.copy_size(src);
|
||||||
@ -239,7 +239,7 @@ namespace
|
|||||||
dest2 = 2;
|
dest2 = 2;
|
||||||
cuda::mish(dest1, src);
|
cuda::mish(dest1, src);
|
||||||
cpu::mish(dest2, src);
|
cpu::mish(dest2, src);
|
||||||
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2))));
|
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-6, max(abs(mat(dest1) - mat(dest2))));
|
||||||
#endif // DLIB_USE_CUDA
|
#endif // DLIB_USE_CUDA
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,14 +248,14 @@ namespace
|
|||||||
#ifdef DLIB_USE_CUDA
|
#ifdef DLIB_USE_CUDA
|
||||||
using namespace dlib::tt;
|
using namespace dlib::tt;
|
||||||
print_spinner();
|
print_spinner();
|
||||||
const long n = 5;
|
const long n = 4;
|
||||||
const long k = 5;
|
const long k = 5;
|
||||||
const long nr = 3;
|
const long nr = 3;
|
||||||
const long nc = 3;
|
const long nc = 3;
|
||||||
const float alpha = 0.01;
|
const float alpha = 0.01;
|
||||||
resizable_tensor src(n, k, nr, nc);
|
resizable_tensor src(n, k, nr, nc);
|
||||||
tt::tensor_rand rnd;
|
tt::tensor_rand rnd;
|
||||||
rnd.fill_uniform(src);
|
rnd.fill_gaussian(src);
|
||||||
resizable_tensor dest_cuda, dest_cpu;
|
resizable_tensor dest_cuda, dest_cpu;
|
||||||
dest_cuda.copy_size(src);
|
dest_cuda.copy_size(src);
|
||||||
dest_cpu.copy_size(src);
|
dest_cpu.copy_size(src);
|
||||||
@ -352,13 +352,13 @@ namespace
|
|||||||
// make sure that cuda::gelu and cpu::gelu return the same results
|
// make sure that cuda::gelu and cpu::gelu return the same results
|
||||||
using namespace dlib::tt;
|
using namespace dlib::tt;
|
||||||
print_spinner();
|
print_spinner();
|
||||||
const long n = 5;
|
const long n = 4;
|
||||||
const long k = 5;
|
const long k = 5;
|
||||||
const long nr = 3;
|
const long nr = 3;
|
||||||
const long nc = 3;
|
const long nc = 3;
|
||||||
resizable_tensor src(n,k,nr,nc);
|
resizable_tensor src(n,k,nr,nc);
|
||||||
tt::tensor_rand rnd;
|
tt::tensor_rand rnd;
|
||||||
rnd.fill_uniform(src);
|
rnd.fill_gaussian(src);
|
||||||
|
|
||||||
resizable_tensor dest1, dest2;
|
resizable_tensor dest1, dest2;
|
||||||
dest1.copy_size(src);
|
dest1.copy_size(src);
|
||||||
@ -368,7 +368,33 @@ namespace
|
|||||||
dest2 = 2;
|
dest2 = 2;
|
||||||
cuda::gelu(dest1, src);
|
cuda::gelu(dest1, src);
|
||||||
cpu::gelu(dest2, src);
|
cpu::gelu(dest2, src);
|
||||||
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2))));
|
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-6, max(abs(mat(dest1) - mat(dest2))));
|
||||||
|
#endif // DLIB_USE_CUDA
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_smelu()
|
||||||
|
{
|
||||||
|
#ifdef DLIB_USE_CUDA
|
||||||
|
using namespace dlib::tt;
|
||||||
|
print_spinner();
|
||||||
|
const long n = 4;
|
||||||
|
const long k = 5;
|
||||||
|
const long nr = 3;
|
||||||
|
const long nc = 3;
|
||||||
|
const float beta = 1;
|
||||||
|
resizable_tensor src(n, k, nr, nc);
|
||||||
|
tt::tensor_rand rnd;
|
||||||
|
rnd.fill_gaussian(src);
|
||||||
|
resizable_tensor dest_cuda, dest_cpu;
|
||||||
|
dest_cuda.copy_size(src);
|
||||||
|
dest_cpu.copy_size(src);
|
||||||
|
// initialize to different values in order to make sure the output is actually changed
|
||||||
|
dest_cuda = 1;
|
||||||
|
dest_cpu = 2;
|
||||||
|
cuda::smelu(dest_cuda, src, beta);
|
||||||
|
cpu::smelu(dest_cpu, src, beta);
|
||||||
|
|
||||||
|
DLIB_TEST_MSG(max(abs(mat(dest_cuda) - mat(dest_cpu))) < 1e-7, max(abs(mat(dest_cuda) - mat(dest_cpu))));
|
||||||
#endif // DLIB_USE_CUDA
|
#endif // DLIB_USE_CUDA
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2103,6 +2129,12 @@ namespace
|
|||||||
auto res = test_layer(l);
|
auto res = test_layer(l);
|
||||||
DLIB_TEST_MSG(res, res);
|
DLIB_TEST_MSG(res, res);
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
print_spinner();
|
||||||
|
smelu_ l;
|
||||||
|
auto res = test_layer(l);
|
||||||
|
DLIB_TEST_MSG(res, res);
|
||||||
|
}
|
||||||
{
|
{
|
||||||
print_spinner();
|
print_spinner();
|
||||||
softmax_ l;
|
softmax_ l;
|
||||||
@ -4286,6 +4318,7 @@ namespace
|
|||||||
test_clipped_relu();
|
test_clipped_relu();
|
||||||
test_elu();
|
test_elu();
|
||||||
test_gelu();
|
test_gelu();
|
||||||
|
test_smelu();
|
||||||
test_batch_normalize();
|
test_batch_normalize();
|
||||||
test_batch_normalize_conv();
|
test_batch_normalize_conv();
|
||||||
test_layer_normalize();
|
test_layer_normalize();
|
||||||
|
Loading…
Reference in New Issue
Block a user