Make clipped-relu inplace and fix docs for elu (#2345)

This commit is contained in:
Adrià Arrufat 2021-04-13 10:49:49 +09:00 committed by GitHub
parent 1b7c7a6411
commit 7f53d7feb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 13 deletions

View File

@ -3479,24 +3479,19 @@ namespace dlib
{
}
template <typename SUBNET>
void forward(
SUBNET& sub,
resizable_tensor& data_output
)
void forward_inplace(const tensor& input, tensor& output)
{
data_output.copy_size(sub.get_output());
tt::clipped_relu(data_output, sub.get_output(), ceiling);
tt::clipped_relu(output, input, ceiling);
}
template <typename SUBNET>
void backward(
void backward_inplace(
const tensor& computed_output,
const tensor& gradient_input,
SUBNET& sub,
tensor& data_grad,
tensor&
)
{
tt::clipped_relu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input, ceiling);
tt::clipped_relu_gradient(data_grad, computed_output, gradient_input, ceiling);
}
inline dpoint map_input_to_output (const dpoint& p) const { return p; }

View File

@ -2537,8 +2537,8 @@ namespace dlib
- returns the alpha parameter of the elu
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
void forward_inplace(const tensor& input, tensor& output);
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& data_output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor&);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;