add option to not zero out gradients and method to do it (#2477)

This commit is contained in:
Adrià Arrufat 2022-01-05 14:46:55 +01:00 committed by GitHub
parent a54cea44ae
commit 994df341a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 182 additions and 47 deletions

View File

@ -729,6 +729,14 @@ namespace dlib
};
}
// ----------------------------------------------------------------------------------------
enum class zero_gradients : uint8_t
{
no = 0,
yes = 1
};
// ----------------------------------------------------------------------------------------
template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void>
@ -1003,21 +1011,28 @@ namespace dlib
const tensor& get_final_data_gradient(
) const { return subnetwork->get_final_data_gradient(); }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
back_propagate_error(x, private_get_gradient_input());
back_propagate_error(x, private_get_gradient_input(), zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(),
gradient_input, wsub, static_cast<tensor&>(params_grad));
subnetwork->back_propagate_error(x);
subnetwork->back_propagate_error(x, zero_grads);
// zero out get_gradient_input()
gradient_input_is_stale = true;
gradient_input_is_stale = zero_grads == zero_gradients::yes;
}
template <typename solver_type>
@ -1057,6 +1072,12 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
gradient_input_is_stale = true;
subnetwork->set_gradient_inputs_to_zero();
}
void clean()
{
x_grad.clear();
@ -1374,11 +1395,18 @@ namespace dlib
const tensor& get_final_data_gradient(
) const { return grad_final; }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
back_propagate_error(x, private_get_gradient_input());
back_propagate_error(x, private_get_gradient_input(), zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
// make sure grad_final is initialized to 0
if (!have_same_dimensions(x, grad_final))
@ -1391,7 +1419,7 @@ namespace dlib
gradient_input, wsub, static_cast<tensor&>(params_grad));
// zero out get_gradient_input()
gradient_input_is_stale = true;
gradient_input_is_stale = zero_grads == zero_gradients::yes;
}
template <typename solver_type>
@ -1430,6 +1458,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
void set_gradient_inputs_to_zero()
{
gradient_input_is_stale = true;
}
void clean()
{
x_grad.clear();
@ -1642,13 +1675,20 @@ namespace dlib
const tensor& get_final_data_gradient(
) const { return subnetwork.get_final_data_gradient(); }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.back_propagate_error(x,gradient_input);
subnetwork.back_propagate_error(x,gradient_input, zero_grads);
}
template <typename solver_type>
@ -1677,6 +1717,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean()
{
subnetwork.clean();
@ -1934,28 +1979,35 @@ namespace dlib
tensor& get_parameter_gradient (
) { return details[0].get_parameter_gradient(); }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
back_propagate_error(x, private_get_gradient_input());
back_propagate_error(x, private_get_gradient_input(), zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
if (details.size() > 1)
{
details[0].back_propagate_error(details[1].get_output(), gradient_input);
details[0].back_propagate_error(details[1].get_output(), gradient_input, zero_grads);
for (size_t i = 1; i < details.size(); ++i)
{
if (i+1 < details.size())
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient());
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient(), zero_grads);
else
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient());
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient(), zero_grads);
}
}
else
{
details[0].back_propagate_error(subnetwork.get_output(), gradient_input);
details[0].back_propagate_error(subnetwork.get_output(), gradient_input, zero_grads);
}
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient());
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient(), zero_grads);
}
template <typename solver_type>
@ -1980,6 +2032,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean()
{
temp_tensor.clear();
@ -2191,11 +2248,19 @@ namespace dlib
return grad_final;
}
void back_propagate_error(const tensor& /*x*/)
void back_propagate_error(
const tensor& /*x*/,
zero_gradients zero_grads = zero_gradients::yes
)
{
// nothing to do
}
void back_propagate_error(const tensor& /*x*/, const tensor& /*gradient_input*/)
void back_propagate_error(
const tensor& /*x*/,
const tensor& /*gradient_input*/,
zero_gradients zero_grads = zero_gradients::yes
)
{
// nothing to do
}
@ -2218,6 +2283,11 @@ namespace dlib
const input_layer_type& input_layer() const { return input_layer_; }
input_layer_type& input_layer() { return input_layer_; }
void set_gradient_inputs_to_zero()
{
// nothing to do
}
void clean()
{
grad_final.clear();
@ -2518,14 +2588,21 @@ namespace dlib
return results;
}
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnet().back_propagate_error(x);
subnet().back_propagate_error(x, zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnet().back_propagate_error(x, gradient_input);
subnet().back_propagate_error(x, gradient_input, zero_grads);
}
const tensor& get_final_data_gradient(
@ -2604,43 +2681,47 @@ namespace dlib
template <typename label_iterator>
double compute_parameter_gradients (
const tensor& x,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.forward(x);
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub);
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
return l;
}
template <typename forward_iterator, typename label_iterator>
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
)
{
to_tensor(ibegin,iend,temp_tensor);
return compute_parameter_gradients(temp_tensor, lbegin);
return compute_parameter_gradients(temp_tensor, lbegin, zero_grads);
}
double compute_parameter_gradients (
const tensor& x
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.forward(x);
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
double l = loss.compute_loss_value_and_gradient(x, wsub);
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
return l;
}
template <typename forward_iterator>
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend
forward_iterator iend,
zero_gradients zero_grads = zero_gradients::yes
)
{
to_tensor(ibegin,iend,temp_tensor);
return compute_parameter_gradients(temp_tensor);
return compute_parameter_gradients(temp_tensor, zero_grads);
}
template <typename solver_type>
@ -2667,6 +2748,12 @@ namespace dlib
const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; }
void set_gradient_inputs_to_zero (
)
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean (
)
{
@ -3022,9 +3109,12 @@ namespace dlib
return subnetwork.get_final_data_gradient();
}
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
}
template <typename solver_type>
@ -3061,6 +3151,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean()
{
subnetwork.clean();

View File

@ -275,6 +275,14 @@ namespace dlib
- returns a sstack that sits on top of the given std::vector.
!*/
// ----------------------------------------------------------------------------------------
enum class zero_gradients : uint8_t
{
no,
yes
};
// ----------------------------------------------------------------------------------------
template <
@ -603,7 +611,8 @@ namespace dlib
!*/
void back_propagate_error(
const tensor& x
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -617,7 +626,7 @@ namespace dlib
network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient().
- All elements of #get_gradient_input() are set to 0.
- All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
- #get_final_data_gradient() contains the gradient of the network with
@ -626,7 +635,8 @@ namespace dlib
void back_propagate_error(
const tensor& x,
const tensor& gradient_input
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -643,7 +653,7 @@ namespace dlib
back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient.
- All elements of #get_gradient_input() are set to 0.
- All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
- #get_final_data_gradient() contains the gradient of the network with
@ -681,6 +691,20 @@ namespace dlib
Convenience method for calling update_parameters()
!*/
void set_gradient_inputs_to_zero(
);
/*!
ensures
- Sets all elements in all gradient inputs in the network to 0.
That is, for each layer, we will have:
- get_gradient_input() == 0
- Note that You only need to call this method if you manually called either
- back_propagate_error
- compute_parameter_gradients
with the zero_grads parameter set to zero_gradients::no.
- invokes subnet().set_gradient_inputs_to_zero()
!*/
void clean(
);
/*!
@ -1147,7 +1171,8 @@ namespace dlib
template <typename label_iterator>
double compute_parameter_gradients (
const tensor& x,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -1164,6 +1189,7 @@ namespace dlib
respect to the loss, via backpropagation. Specifically, this function
updates get_final_data_gradient() and also, for each layer, the tensor
returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- for all valid k:
- the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
- returns compute_loss(x,lbegin)
@ -1173,7 +1199,8 @@ namespace dlib
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -1187,13 +1214,15 @@ namespace dlib
gradients with respect to the loss, via backpropagation. Specifically,
this function updates get_final_data_gradient() and also, for each layer,
the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- for all valid k:
- the expected label of *(ibegin+k) is *(lbegin+k).
- returns compute_loss(ibegin,iend,lbegin)
!*/
double compute_parameter_gradients (
const tensor& x
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -1208,13 +1237,15 @@ namespace dlib
respect to the loss, via backpropagation. Specifically, this function
updates get_final_data_gradient() and also, for each layer, the tensor
returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- returns compute_loss(x)
!*/
template <typename forward_iterator>
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend
forward_iterator iend,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -1226,6 +1257,7 @@ namespace dlib
gradients with respect to the loss, via backpropagation. Specifically,
this function updates get_final_data_gradient() and also, for each layer,
the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- returns compute_loss(ibegin,iend)
!*/
@ -1262,6 +1294,7 @@ namespace dlib
void back_propagate_error(
const tensor& x
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
@ -1276,7 +1309,7 @@ namespace dlib
network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0.
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
@ -1301,7 +1334,7 @@ namespace dlib
back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient.
- All elements of #subnet.get_gradient_input() are set to 0.
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
@ -1319,6 +1352,13 @@ namespace dlib
not one per layer, since there is only one input to the entire network.
!*/
void set_gradient_inputs_to_zero(
);
/*!
ensures
- Sets all elements in all gradient inputs in the network to 0.
- invokes subnet().set_gradient_inputs_to_zero()
!*/
// -------------