mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Make clipped-relu inplace and fix docs for elu (#2345)
This commit is contained in:
parent
1b7c7a6411
commit
7f53d7feb6
@ -3479,24 +3479,19 @@ namespace dlib
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void forward(
|
||||
SUBNET& sub,
|
||||
resizable_tensor& data_output
|
||||
)
|
||||
void forward_inplace(const tensor& input, tensor& output)
|
||||
{
|
||||
data_output.copy_size(sub.get_output());
|
||||
tt::clipped_relu(data_output, sub.get_output(), ceiling);
|
||||
tt::clipped_relu(output, input, ceiling);
|
||||
}
|
||||
|
||||
template <typename SUBNET>
|
||||
void backward(
|
||||
void backward_inplace(
|
||||
const tensor& computed_output,
|
||||
const tensor& gradient_input,
|
||||
SUBNET& sub,
|
||||
tensor& data_grad,
|
||||
tensor&
|
||||
)
|
||||
{
|
||||
tt::clipped_relu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input, ceiling);
|
||||
tt::clipped_relu_gradient(data_grad, computed_output, gradient_input, ceiling);
|
||||
}
|
||||
|
||||
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
|
||||
|
@ -2537,8 +2537,8 @@ namespace dlib
|
||||
- returns the alpha parameter of the elu
|
||||
!*/
|
||||
template <typename SUBNET> void setup (const SUBNET& sub);
|
||||
void forward_inplace(const tensor& input, tensor& output);
|
||||
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
|
||||
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& data_output);
|
||||
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor&);
|
||||
dpoint map_input_to_output(dpoint p) const;
|
||||
dpoint map_output_to_input(dpoint p) const;
|
||||
const tensor& get_layer_params() const;
|
||||
|
Loading…
Reference in New Issue
Block a user