Make clipped-relu inplace and fix docs for elu (#2345)

This commit is contained in:
Adrià Arrufat 2021-04-13 10:49:49 +09:00 committed by GitHub
parent 1b7c7a6411
commit 7f53d7feb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 13 deletions

View File

@ -3479,24 +3479,19 @@ namespace dlib
{ {
} }
template <typename SUBNET> void forward_inplace(const tensor& input, tensor& output)
void forward(
SUBNET& sub,
resizable_tensor& data_output
)
{ {
data_output.copy_size(sub.get_output()); tt::clipped_relu(output, input, ceiling);
tt::clipped_relu(data_output, sub.get_output(), ceiling);
} }
template <typename SUBNET> void backward_inplace(
void backward( const tensor& computed_output,
const tensor& gradient_input, const tensor& gradient_input,
SUBNET& sub, tensor& data_grad,
tensor& tensor&
) )
{ {
tt::clipped_relu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input, ceiling); tt::clipped_relu_gradient(data_grad, computed_output, gradient_input, ceiling);
} }
inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_input_to_output (const dpoint& p) const { return p; }

View File

@ -2537,8 +2537,8 @@ namespace dlib
- returns the alpha parameter of the elu - returns the alpha parameter of the elu
!*/ !*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
void forward_inplace(const tensor& input, tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& data_output);
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor&);
dpoint map_input_to_output(dpoint p) const; dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const; dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const; const tensor& get_layer_params() const;