mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
add option to not zero out gradients and method to do it (#2477)
This commit is contained in:
parent
a54cea44ae
commit
994df341a2
169
dlib/dnn/core.h
169
dlib/dnn/core.h
@ -729,6 +729,14 @@ namespace dlib
|
||||
};
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
enum class zero_gradients : uint8_t
|
||||
{
|
||||
no = 0,
|
||||
yes = 1
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void>
|
||||
@ -1003,21 +1011,28 @@ namespace dlib
|
||||
const tensor& get_final_data_gradient(
|
||||
) const { return subnetwork->get_final_data_gradient(); }
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
back_propagate_error(x, private_get_gradient_input());
|
||||
back_propagate_error(x, private_get_gradient_input(), zero_grads);
|
||||
}
|
||||
void back_propagate_error(const tensor& x, const tensor& gradient_input)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
|
||||
params_grad.copy_size(details.get_layer_params());
|
||||
impl::call_layer_backward(details, private_get_output(),
|
||||
gradient_input, wsub, static_cast<tensor&>(params_grad));
|
||||
|
||||
subnetwork->back_propagate_error(x);
|
||||
subnetwork->back_propagate_error(x, zero_grads);
|
||||
|
||||
// zero out get_gradient_input()
|
||||
gradient_input_is_stale = true;
|
||||
gradient_input_is_stale = zero_grads == zero_gradients::yes;
|
||||
}
|
||||
|
||||
template <typename solver_type>
|
||||
@ -1057,6 +1072,12 @@ namespace dlib
|
||||
|
||||
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
|
||||
|
||||
void set_gradient_inputs_to_zero()
|
||||
{
|
||||
gradient_input_is_stale = true;
|
||||
subnetwork->set_gradient_inputs_to_zero();
|
||||
}
|
||||
|
||||
void clean()
|
||||
{
|
||||
x_grad.clear();
|
||||
@ -1374,11 +1395,18 @@ namespace dlib
|
||||
const tensor& get_final_data_gradient(
|
||||
) const { return grad_final; }
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
back_propagate_error(x, private_get_gradient_input());
|
||||
back_propagate_error(x, private_get_gradient_input(), zero_grads);
|
||||
}
|
||||
void back_propagate_error(const tensor& x, const tensor& gradient_input)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
// make sure grad_final is initialized to 0
|
||||
if (!have_same_dimensions(x, grad_final))
|
||||
@ -1391,7 +1419,7 @@ namespace dlib
|
||||
gradient_input, wsub, static_cast<tensor&>(params_grad));
|
||||
|
||||
// zero out get_gradient_input()
|
||||
gradient_input_is_stale = true;
|
||||
gradient_input_is_stale = zero_grads == zero_gradients::yes;
|
||||
}
|
||||
|
||||
template <typename solver_type>
|
||||
@ -1430,6 +1458,11 @@ namespace dlib
|
||||
|
||||
unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
|
||||
|
||||
void set_gradient_inputs_to_zero()
|
||||
{
|
||||
gradient_input_is_stale = true;
|
||||
}
|
||||
|
||||
void clean()
|
||||
{
|
||||
x_grad.clear();
|
||||
@ -1642,13 +1675,20 @@ namespace dlib
|
||||
const tensor& get_final_data_gradient(
|
||||
) const { return subnetwork.get_final_data_gradient(); }
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnetwork.back_propagate_error(x);
|
||||
subnetwork.back_propagate_error(x, zero_grads);
|
||||
}
|
||||
void back_propagate_error(const tensor& x, const tensor& gradient_input)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnetwork.back_propagate_error(x,gradient_input);
|
||||
subnetwork.back_propagate_error(x,gradient_input, zero_grads);
|
||||
}
|
||||
|
||||
template <typename solver_type>
|
||||
@ -1677,6 +1717,11 @@ namespace dlib
|
||||
|
||||
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
|
||||
|
||||
void set_gradient_inputs_to_zero()
|
||||
{
|
||||
subnetwork.set_gradient_inputs_to_zero();
|
||||
}
|
||||
|
||||
void clean()
|
||||
{
|
||||
subnetwork.clean();
|
||||
@ -1934,28 +1979,35 @@ namespace dlib
|
||||
tensor& get_parameter_gradient (
|
||||
) { return details[0].get_parameter_gradient(); }
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
back_propagate_error(x, private_get_gradient_input());
|
||||
back_propagate_error(x, private_get_gradient_input(), zero_grads);
|
||||
}
|
||||
void back_propagate_error(const tensor& x, const tensor& gradient_input)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
if (details.size() > 1)
|
||||
{
|
||||
details[0].back_propagate_error(details[1].get_output(), gradient_input);
|
||||
details[0].back_propagate_error(details[1].get_output(), gradient_input, zero_grads);
|
||||
for (size_t i = 1; i < details.size(); ++i)
|
||||
{
|
||||
if (i+1 < details.size())
|
||||
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient());
|
||||
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient(), zero_grads);
|
||||
else
|
||||
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient());
|
||||
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient(), zero_grads);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
details[0].back_propagate_error(subnetwork.get_output(), gradient_input);
|
||||
details[0].back_propagate_error(subnetwork.get_output(), gradient_input, zero_grads);
|
||||
}
|
||||
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient());
|
||||
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient(), zero_grads);
|
||||
}
|
||||
|
||||
template <typename solver_type>
|
||||
@ -1980,6 +2032,11 @@ namespace dlib
|
||||
|
||||
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
|
||||
|
||||
void set_gradient_inputs_to_zero()
|
||||
{
|
||||
subnetwork.set_gradient_inputs_to_zero();
|
||||
}
|
||||
|
||||
void clean()
|
||||
{
|
||||
temp_tensor.clear();
|
||||
@ -2191,11 +2248,19 @@ namespace dlib
|
||||
return grad_final;
|
||||
}
|
||||
|
||||
void back_propagate_error(const tensor& /*x*/)
|
||||
|
||||
void back_propagate_error(
|
||||
const tensor& /*x*/,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
// nothing to do
|
||||
}
|
||||
void back_propagate_error(const tensor& /*x*/, const tensor& /*gradient_input*/)
|
||||
void back_propagate_error(
|
||||
const tensor& /*x*/,
|
||||
const tensor& /*gradient_input*/,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
// nothing to do
|
||||
}
|
||||
@ -2218,6 +2283,11 @@ namespace dlib
|
||||
const input_layer_type& input_layer() const { return input_layer_; }
|
||||
input_layer_type& input_layer() { return input_layer_; }
|
||||
|
||||
void set_gradient_inputs_to_zero()
|
||||
{
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
void clean()
|
||||
{
|
||||
grad_final.clear();
|
||||
@ -2518,14 +2588,21 @@ namespace dlib
|
||||
return results;
|
||||
}
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnet().back_propagate_error(x);
|
||||
subnet().back_propagate_error(x, zero_grads);
|
||||
}
|
||||
|
||||
void back_propagate_error(const tensor& x, const tensor& gradient_input)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnet().back_propagate_error(x, gradient_input);
|
||||
subnet().back_propagate_error(x, gradient_input, zero_grads);
|
||||
}
|
||||
|
||||
const tensor& get_final_data_gradient(
|
||||
@ -2604,43 +2681,47 @@ namespace dlib
|
||||
template <typename label_iterator>
|
||||
double compute_parameter_gradients (
|
||||
const tensor& x,
|
||||
label_iterator lbegin
|
||||
label_iterator lbegin,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnetwork.forward(x);
|
||||
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
|
||||
double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub);
|
||||
subnetwork.back_propagate_error(x);
|
||||
subnetwork.back_propagate_error(x, zero_grads);
|
||||
return l;
|
||||
}
|
||||
template <typename forward_iterator, typename label_iterator>
|
||||
double compute_parameter_gradients (
|
||||
forward_iterator ibegin,
|
||||
forward_iterator iend,
|
||||
label_iterator lbegin
|
||||
label_iterator lbegin,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
to_tensor(ibegin,iend,temp_tensor);
|
||||
return compute_parameter_gradients(temp_tensor, lbegin);
|
||||
return compute_parameter_gradients(temp_tensor, lbegin, zero_grads);
|
||||
}
|
||||
double compute_parameter_gradients (
|
||||
const tensor& x
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnetwork.forward(x);
|
||||
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
|
||||
double l = loss.compute_loss_value_and_gradient(x, wsub);
|
||||
subnetwork.back_propagate_error(x);
|
||||
subnetwork.back_propagate_error(x, zero_grads);
|
||||
return l;
|
||||
}
|
||||
template <typename forward_iterator>
|
||||
double compute_parameter_gradients (
|
||||
forward_iterator ibegin,
|
||||
forward_iterator iend
|
||||
forward_iterator iend,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
to_tensor(ibegin,iend,temp_tensor);
|
||||
return compute_parameter_gradients(temp_tensor);
|
||||
return compute_parameter_gradients(temp_tensor, zero_grads);
|
||||
}
|
||||
|
||||
template <typename solver_type>
|
||||
@ -2667,6 +2748,12 @@ namespace dlib
|
||||
const loss_details_type& loss_details() const { return loss; }
|
||||
loss_details_type& loss_details() { return loss; }
|
||||
|
||||
void set_gradient_inputs_to_zero (
|
||||
)
|
||||
{
|
||||
subnetwork.set_gradient_inputs_to_zero();
|
||||
}
|
||||
|
||||
void clean (
|
||||
)
|
||||
{
|
||||
@ -3022,9 +3109,12 @@ namespace dlib
|
||||
return subnetwork.get_final_data_gradient();
|
||||
}
|
||||
|
||||
void back_propagate_error(const tensor& x)
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
)
|
||||
{
|
||||
subnetwork.back_propagate_error(x);
|
||||
subnetwork.back_propagate_error(x, zero_grads);
|
||||
}
|
||||
|
||||
template <typename solver_type>
|
||||
@ -3061,6 +3151,11 @@ namespace dlib
|
||||
|
||||
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
|
||||
|
||||
void set_gradient_inputs_to_zero()
|
||||
{
|
||||
subnetwork.set_gradient_inputs_to_zero();
|
||||
}
|
||||
|
||||
void clean()
|
||||
{
|
||||
subnetwork.clean();
|
||||
|
@ -275,6 +275,14 @@ namespace dlib
|
||||
- returns a sstack that sits on top of the given std::vector.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
enum class zero_gradients : uint8_t
|
||||
{
|
||||
no,
|
||||
yes
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
@ -603,7 +611,8 @@ namespace dlib
|
||||
!*/
|
||||
|
||||
void back_propagate_error(
|
||||
const tensor& x
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -617,7 +626,7 @@ namespace dlib
|
||||
network and computes parameter and data gradients, via backpropagation.
|
||||
Specifically, this function populates get_final_data_gradient() and also,
|
||||
for each layer, the tensor returned by get_parameter_gradient().
|
||||
- All elements of #get_gradient_input() are set to 0.
|
||||
- All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- have_same_dimensions(#get_final_data_gradient(), x) == true.
|
||||
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
|
||||
- #get_final_data_gradient() contains the gradient of the network with
|
||||
@ -626,7 +635,8 @@ namespace dlib
|
||||
|
||||
void back_propagate_error(
|
||||
const tensor& x,
|
||||
const tensor& gradient_input
|
||||
const tensor& gradient_input,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -643,7 +653,7 @@ namespace dlib
|
||||
back_propagate_error(x);
|
||||
Except that calling back_propagate_error(x,gradient_input) avoids the
|
||||
copy and is therefore slightly more efficient.
|
||||
- All elements of #get_gradient_input() are set to 0.
|
||||
- All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- have_same_dimensions(#get_final_data_gradient(), x) == true.
|
||||
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
|
||||
- #get_final_data_gradient() contains the gradient of the network with
|
||||
@ -681,6 +691,20 @@ namespace dlib
|
||||
Convenience method for calling update_parameters()
|
||||
!*/
|
||||
|
||||
void set_gradient_inputs_to_zero(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- Sets all elements in all gradient inputs in the network to 0.
|
||||
That is, for each layer, we will have:
|
||||
- get_gradient_input() == 0
|
||||
- Note that You only need to call this method if you manually called either
|
||||
- back_propagate_error
|
||||
- compute_parameter_gradients
|
||||
with the zero_grads parameter set to zero_gradients::no.
|
||||
- invokes subnet().set_gradient_inputs_to_zero()
|
||||
!*/
|
||||
|
||||
void clean(
|
||||
);
|
||||
/*!
|
||||
@ -1147,7 +1171,8 @@ namespace dlib
|
||||
template <typename label_iterator>
|
||||
double compute_parameter_gradients (
|
||||
const tensor& x,
|
||||
label_iterator lbegin
|
||||
label_iterator lbegin,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -1164,6 +1189,7 @@ namespace dlib
|
||||
respect to the loss, via backpropagation. Specifically, this function
|
||||
updates get_final_data_gradient() and also, for each layer, the tensor
|
||||
returned by get_parameter_gradient().
|
||||
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- for all valid k:
|
||||
- the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
|
||||
- returns compute_loss(x,lbegin)
|
||||
@ -1173,7 +1199,8 @@ namespace dlib
|
||||
double compute_parameter_gradients (
|
||||
forward_iterator ibegin,
|
||||
forward_iterator iend,
|
||||
label_iterator lbegin
|
||||
label_iterator lbegin,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -1187,13 +1214,15 @@ namespace dlib
|
||||
gradients with respect to the loss, via backpropagation. Specifically,
|
||||
this function updates get_final_data_gradient() and also, for each layer,
|
||||
the tensor returned by get_parameter_gradient().
|
||||
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- for all valid k:
|
||||
- the expected label of *(ibegin+k) is *(lbegin+k).
|
||||
- returns compute_loss(ibegin,iend,lbegin)
|
||||
!*/
|
||||
|
||||
double compute_parameter_gradients (
|
||||
const tensor& x
|
||||
const tensor& x,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -1208,13 +1237,15 @@ namespace dlib
|
||||
respect to the loss, via backpropagation. Specifically, this function
|
||||
updates get_final_data_gradient() and also, for each layer, the tensor
|
||||
returned by get_parameter_gradient().
|
||||
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- returns compute_loss(x)
|
||||
!*/
|
||||
|
||||
template <typename forward_iterator>
|
||||
double compute_parameter_gradients (
|
||||
forward_iterator ibegin,
|
||||
forward_iterator iend
|
||||
forward_iterator iend,
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -1226,6 +1257,7 @@ namespace dlib
|
||||
gradients with respect to the loss, via backpropagation. Specifically,
|
||||
this function updates get_final_data_gradient() and also, for each layer,
|
||||
the tensor returned by get_parameter_gradient().
|
||||
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- returns compute_loss(ibegin,iend)
|
||||
!*/
|
||||
|
||||
@ -1262,6 +1294,7 @@ namespace dlib
|
||||
|
||||
void back_propagate_error(
|
||||
const tensor& x
|
||||
zero_gradients zero_grads = zero_gradients::yes
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -1276,7 +1309,7 @@ namespace dlib
|
||||
network and computes parameter and data gradients, via backpropagation.
|
||||
Specifically, this function populates get_final_data_gradient() and also,
|
||||
for each layer, the tensor returned by get_parameter_gradient().
|
||||
- All elements of #subnet().get_gradient_input() are set to 0.
|
||||
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- have_same_dimensions(#get_final_data_gradient(), x) == true.
|
||||
- #get_final_data_gradient() contains the gradient of the network with
|
||||
respect to x.
|
||||
@ -1301,7 +1334,7 @@ namespace dlib
|
||||
back_propagate_error(x);
|
||||
Except that calling back_propagate_error(x,gradient_input) avoids the
|
||||
copy and is therefore slightly more efficient.
|
||||
- All elements of #subnet.get_gradient_input() are set to 0.
|
||||
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
|
||||
- have_same_dimensions(#get_final_data_gradient(), x) == true.
|
||||
- #get_final_data_gradient() contains the gradient of the network with
|
||||
respect to x.
|
||||
@ -1319,6 +1352,13 @@ namespace dlib
|
||||
not one per layer, since there is only one input to the entire network.
|
||||
!*/
|
||||
|
||||
void set_gradient_inputs_to_zero(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- Sets all elements in all gradient inputs in the network to 0.
|
||||
- invokes subnet().set_gradient_inputs_to_zero()
|
||||
!*/
|
||||
|
||||
// -------------
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user