add option to not zero out gradients and method to do it (#2477)

This commit is contained in:
Adrià Arrufat 2022-01-05 14:46:55 +01:00 committed by GitHub
parent a54cea44ae
commit 994df341a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 182 additions and 47 deletions

View File

@ -729,6 +729,14 @@ namespace dlib
}; };
} }
// ----------------------------------------------------------------------------------------
enum class zero_gradients : uint8_t
{
no = 0,
yes = 1
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void> template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void>
@ -1003,21 +1011,28 @@ namespace dlib
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
) const { return subnetwork->get_final_data_gradient(); } ) const { return subnetwork->get_final_data_gradient(); }
void back_propagate_error(const tensor& x) void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
back_propagate_error(x, private_get_gradient_input()); back_propagate_error(x, private_get_gradient_input(), zero_grads);
} }
void back_propagate_error(const tensor& x, const tensor& gradient_input) void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork); dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params()); params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(), impl::call_layer_backward(details, private_get_output(),
gradient_input, wsub, static_cast<tensor&>(params_grad)); gradient_input, wsub, static_cast<tensor&>(params_grad));
subnetwork->back_propagate_error(x); subnetwork->back_propagate_error(x, zero_grads);
// zero out get_gradient_input() // zero out get_gradient_input()
gradient_input_is_stale = true; gradient_input_is_stale = zero_grads == zero_gradients::yes;
} }
template <typename solver_type> template <typename solver_type>
@ -1057,6 +1072,12 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
gradient_input_is_stale = true;
subnetwork->set_gradient_inputs_to_zero();
}
void clean() void clean()
{ {
x_grad.clear(); x_grad.clear();
@ -1374,11 +1395,18 @@ namespace dlib
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
) const { return grad_final; } ) const { return grad_final; }
void back_propagate_error(const tensor& x) void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
back_propagate_error(x, private_get_gradient_input()); back_propagate_error(x, private_get_gradient_input(), zero_grads);
} }
void back_propagate_error(const tensor& x, const tensor& gradient_input) void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
// make sure grad_final is initialized to 0 // make sure grad_final is initialized to 0
if (!have_same_dimensions(x, grad_final)) if (!have_same_dimensions(x, grad_final))
@ -1391,7 +1419,7 @@ namespace dlib
gradient_input, wsub, static_cast<tensor&>(params_grad)); gradient_input, wsub, static_cast<tensor&>(params_grad));
// zero out get_gradient_input() // zero out get_gradient_input()
gradient_input_is_stale = true; gradient_input_is_stale = zero_grads == zero_gradients::yes;
} }
template <typename solver_type> template <typename solver_type>
@ -1430,6 +1458,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
void set_gradient_inputs_to_zero()
{
gradient_input_is_stale = true;
}
void clean() void clean()
{ {
x_grad.clear(); x_grad.clear();
@ -1642,13 +1675,20 @@ namespace dlib
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
) const { return subnetwork.get_final_data_gradient(); } ) const { return subnetwork.get_final_data_gradient(); }
void back_propagate_error(const tensor& x) void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
subnetwork.back_propagate_error(x); subnetwork.back_propagate_error(x, zero_grads);
} }
void back_propagate_error(const tensor& x, const tensor& gradient_input) void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
subnetwork.back_propagate_error(x,gradient_input); subnetwork.back_propagate_error(x,gradient_input, zero_grads);
} }
template <typename solver_type> template <typename solver_type>
@ -1677,6 +1717,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean() void clean()
{ {
subnetwork.clean(); subnetwork.clean();
@ -1934,28 +1979,35 @@ namespace dlib
tensor& get_parameter_gradient ( tensor& get_parameter_gradient (
) { return details[0].get_parameter_gradient(); } ) { return details[0].get_parameter_gradient(); }
void back_propagate_error(const tensor& x) void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
back_propagate_error(x, private_get_gradient_input()); back_propagate_error(x, private_get_gradient_input(), zero_grads);
} }
void back_propagate_error(const tensor& x, const tensor& gradient_input) void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
if (details.size() > 1) if (details.size() > 1)
{ {
details[0].back_propagate_error(details[1].get_output(), gradient_input); details[0].back_propagate_error(details[1].get_output(), gradient_input, zero_grads);
for (size_t i = 1; i < details.size(); ++i) for (size_t i = 1; i < details.size(); ++i)
{ {
if (i+1 < details.size()) if (i+1 < details.size())
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient()); details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient(), zero_grads);
else else
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient()); details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient(), zero_grads);
} }
} }
else else
{ {
details[0].back_propagate_error(subnetwork.get_output(), gradient_input); details[0].back_propagate_error(subnetwork.get_output(), gradient_input, zero_grads);
} }
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient()); subnetwork.back_propagate_error(x, details.back().get_final_data_gradient(), zero_grads);
} }
template <typename solver_type> template <typename solver_type>
@ -1980,6 +2032,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean() void clean()
{ {
temp_tensor.clear(); temp_tensor.clear();
@ -2191,11 +2248,19 @@ namespace dlib
return grad_final; return grad_final;
} }
void back_propagate_error(const tensor& /*x*/)
void back_propagate_error(
const tensor& /*x*/,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
// nothing to do // nothing to do
} }
void back_propagate_error(const tensor& /*x*/, const tensor& /*gradient_input*/) void back_propagate_error(
const tensor& /*x*/,
const tensor& /*gradient_input*/,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
// nothing to do // nothing to do
} }
@ -2218,6 +2283,11 @@ namespace dlib
const input_layer_type& input_layer() const { return input_layer_; } const input_layer_type& input_layer() const { return input_layer_; }
input_layer_type& input_layer() { return input_layer_; } input_layer_type& input_layer() { return input_layer_; }
void set_gradient_inputs_to_zero()
{
// nothing to do
}
void clean() void clean()
{ {
grad_final.clear(); grad_final.clear();
@ -2518,14 +2588,21 @@ namespace dlib
return results; return results;
} }
void back_propagate_error(const tensor& x) void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
subnet().back_propagate_error(x); subnet().back_propagate_error(x, zero_grads);
} }
void back_propagate_error(const tensor& x, const tensor& gradient_input) void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
subnet().back_propagate_error(x, gradient_input); subnet().back_propagate_error(x, gradient_input, zero_grads);
} }
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
@ -2604,43 +2681,47 @@ namespace dlib
template <typename label_iterator> template <typename label_iterator>
double compute_parameter_gradients ( double compute_parameter_gradients (
const tensor& x, const tensor& x,
label_iterator lbegin label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
) )
{ {
subnetwork.forward(x); subnetwork.forward(x);
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork); dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub); double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub);
subnetwork.back_propagate_error(x); subnetwork.back_propagate_error(x, zero_grads);
return l; return l;
} }
template <typename forward_iterator, typename label_iterator> template <typename forward_iterator, typename label_iterator>
double compute_parameter_gradients ( double compute_parameter_gradients (
forward_iterator ibegin, forward_iterator ibegin,
forward_iterator iend, forward_iterator iend,
label_iterator lbegin label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
) )
{ {
to_tensor(ibegin,iend,temp_tensor); to_tensor(ibegin,iend,temp_tensor);
return compute_parameter_gradients(temp_tensor, lbegin); return compute_parameter_gradients(temp_tensor, lbegin, zero_grads);
} }
double compute_parameter_gradients ( double compute_parameter_gradients (
const tensor& x const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
) )
{ {
subnetwork.forward(x); subnetwork.forward(x);
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork); dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
double l = loss.compute_loss_value_and_gradient(x, wsub); double l = loss.compute_loss_value_and_gradient(x, wsub);
subnetwork.back_propagate_error(x); subnetwork.back_propagate_error(x, zero_grads);
return l; return l;
} }
template <typename forward_iterator> template <typename forward_iterator>
double compute_parameter_gradients ( double compute_parameter_gradients (
forward_iterator ibegin, forward_iterator ibegin,
forward_iterator iend forward_iterator iend,
zero_gradients zero_grads = zero_gradients::yes
) )
{ {
to_tensor(ibegin,iend,temp_tensor); to_tensor(ibegin,iend,temp_tensor);
return compute_parameter_gradients(temp_tensor); return compute_parameter_gradients(temp_tensor, zero_grads);
} }
template <typename solver_type> template <typename solver_type>
@ -2667,6 +2748,12 @@ namespace dlib
const loss_details_type& loss_details() const { return loss; } const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; } loss_details_type& loss_details() { return loss; }
void set_gradient_inputs_to_zero (
)
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean ( void clean (
) )
{ {
@ -3022,9 +3109,12 @@ namespace dlib
return subnetwork.get_final_data_gradient(); return subnetwork.get_final_data_gradient();
} }
void back_propagate_error(const tensor& x) void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{ {
subnetwork.back_propagate_error(x); subnetwork.back_propagate_error(x, zero_grads);
} }
template <typename solver_type> template <typename solver_type>
@ -3061,6 +3151,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean() void clean()
{ {
subnetwork.clean(); subnetwork.clean();

View File

@ -275,6 +275,14 @@ namespace dlib
- returns a sstack that sits on top of the given std::vector. - returns a sstack that sits on top of the given std::vector.
!*/ !*/
// ----------------------------------------------------------------------------------------
enum class zero_gradients : uint8_t
{
no,
yes
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
@ -603,7 +611,8 @@ namespace dlib
!*/ !*/
void back_propagate_error( void back_propagate_error(
const tensor& x const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -617,7 +626,7 @@ namespace dlib
network and computes parameter and data gradients, via backpropagation. network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also, Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient(). for each layer, the tensor returned by get_parameter_gradient().
- All elements of #get_gradient_input() are set to 0. - All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true. - have_same_dimensions(#get_final_data_gradient(), x) == true.
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true. - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
- #get_final_data_gradient() contains the gradient of the network with - #get_final_data_gradient() contains the gradient of the network with
@ -626,7 +635,8 @@ namespace dlib
void back_propagate_error( void back_propagate_error(
const tensor& x, const tensor& x,
const tensor& gradient_input const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -643,7 +653,7 @@ namespace dlib
back_propagate_error(x); back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient. copy and is therefore slightly more efficient.
- All elements of #get_gradient_input() are set to 0. - All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true. - have_same_dimensions(#get_final_data_gradient(), x) == true.
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true. - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
- #get_final_data_gradient() contains the gradient of the network with - #get_final_data_gradient() contains the gradient of the network with
@ -681,6 +691,20 @@ namespace dlib
Convenience method for calling update_parameters() Convenience method for calling update_parameters()
!*/ !*/
void set_gradient_inputs_to_zero(
);
/*!
ensures
- Sets all elements in all gradient inputs in the network to 0.
That is, for each layer, we will have:
- get_gradient_input() == 0
- Note that You only need to call this method if you manually called either
- back_propagate_error
- compute_parameter_gradients
with the zero_grads parameter set to zero_gradients::no.
- invokes subnet().set_gradient_inputs_to_zero()
!*/
void clean( void clean(
); );
/*! /*!
@ -1147,7 +1171,8 @@ namespace dlib
template <typename label_iterator> template <typename label_iterator>
double compute_parameter_gradients ( double compute_parameter_gradients (
const tensor& x, const tensor& x,
label_iterator lbegin label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -1164,6 +1189,7 @@ namespace dlib
respect to the loss, via backpropagation. Specifically, this function respect to the loss, via backpropagation. Specifically, this function
updates get_final_data_gradient() and also, for each layer, the tensor updates get_final_data_gradient() and also, for each layer, the tensor
returned by get_parameter_gradient(). returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- for all valid k: - for all valid k:
- the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()). - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
- returns compute_loss(x,lbegin) - returns compute_loss(x,lbegin)
@ -1173,7 +1199,8 @@ namespace dlib
double compute_parameter_gradients ( double compute_parameter_gradients (
forward_iterator ibegin, forward_iterator ibegin,
forward_iterator iend, forward_iterator iend,
label_iterator lbegin label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -1187,13 +1214,15 @@ namespace dlib
gradients with respect to the loss, via backpropagation. Specifically, gradients with respect to the loss, via backpropagation. Specifically,
this function updates get_final_data_gradient() and also, for each layer, this function updates get_final_data_gradient() and also, for each layer,
the tensor returned by get_parameter_gradient(). the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- for all valid k: - for all valid k:
- the expected label of *(ibegin+k) is *(lbegin+k). - the expected label of *(ibegin+k) is *(lbegin+k).
- returns compute_loss(ibegin,iend,lbegin) - returns compute_loss(ibegin,iend,lbegin)
!*/ !*/
double compute_parameter_gradients ( double compute_parameter_gradients (
const tensor& x const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -1208,13 +1237,15 @@ namespace dlib
respect to the loss, via backpropagation. Specifically, this function respect to the loss, via backpropagation. Specifically, this function
updates get_final_data_gradient() and also, for each layer, the tensor updates get_final_data_gradient() and also, for each layer, the tensor
returned by get_parameter_gradient(). returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- returns compute_loss(x) - returns compute_loss(x)
!*/ !*/
template <typename forward_iterator> template <typename forward_iterator>
double compute_parameter_gradients ( double compute_parameter_gradients (
forward_iterator ibegin, forward_iterator ibegin,
forward_iterator iend forward_iterator iend,
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -1226,6 +1257,7 @@ namespace dlib
gradients with respect to the loss, via backpropagation. Specifically, gradients with respect to the loss, via backpropagation. Specifically,
this function updates get_final_data_gradient() and also, for each layer, this function updates get_final_data_gradient() and also, for each layer,
the tensor returned by get_parameter_gradient(). the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- returns compute_loss(ibegin,iend) - returns compute_loss(ibegin,iend)
!*/ !*/
@ -1262,6 +1294,7 @@ namespace dlib
void back_propagate_error( void back_propagate_error(
const tensor& x const tensor& x
zero_gradients zero_grads = zero_gradients::yes
); );
/*! /*!
requires requires
@ -1276,7 +1309,7 @@ namespace dlib
network and computes parameter and data gradients, via backpropagation. network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also, Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient(). for each layer, the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0. - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true. - have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with - #get_final_data_gradient() contains the gradient of the network with
respect to x. respect to x.
@ -1301,7 +1334,7 @@ namespace dlib
back_propagate_error(x); back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient. copy and is therefore slightly more efficient.
- All elements of #subnet.get_gradient_input() are set to 0. - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true. - have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with - #get_final_data_gradient() contains the gradient of the network with
respect to x. respect to x.
@ -1319,6 +1352,13 @@ namespace dlib
not one per layer, since there is only one input to the entire network. not one per layer, since there is only one input to the entire network.
!*/ !*/
void set_gradient_inputs_to_zero(
);
/*!
ensures
- Sets all elements in all gradient inputs in the network to 0.
- invokes subnet().set_gradient_inputs_to_zero()
!*/
// ------------- // -------------