mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Simplify code as per review comments
This commit is contained in:
parent
d48b406a11
commit
3746b3c1d3
@ -458,9 +458,10 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
test_steps_without_progress = 0;
|
test_steps_without_progress = 0;
|
||||||
|
steps_since_last_learning_rate_shrink = 0;
|
||||||
previous_loss_values.clear();
|
previous_loss_values.clear();
|
||||||
test_previous_loss_values.clear();
|
test_previous_loss_values.clear();
|
||||||
additional_previous_loss_values_to_keep_until_disk_sync.clear();
|
previous_loss_values_to_keep_until_disk_sync.clear();
|
||||||
}
|
}
|
||||||
learning_rate = lr;
|
learning_rate = lr;
|
||||||
lr_schedule.set_size(0);
|
lr_schedule.set_size(0);
|
||||||
@ -560,6 +561,7 @@ namespace dlib
|
|||||||
learning_rate_shrink = shrink;
|
learning_rate_shrink = shrink;
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
test_steps_without_progress = 0;
|
test_steps_without_progress = 0;
|
||||||
|
steps_since_last_learning_rate_shrink = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
double get_learning_rate_shrink_factor (
|
double get_learning_rate_shrink_factor (
|
||||||
@ -602,11 +604,12 @@ namespace dlib
|
|||||||
previous_loss_values.push_back(loss);
|
previous_loss_values.push_back(loss);
|
||||||
// discard really old loss values.
|
// discard really old loss values.
|
||||||
while (previous_loss_values.size() > iter_without_progress_thresh)
|
while (previous_loss_values.size() > iter_without_progress_thresh)
|
||||||
{
|
|
||||||
if (!sync_filename.empty())
|
|
||||||
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
|
|
||||||
previous_loss_values.pop_front();
|
previous_loss_values.pop_front();
|
||||||
}
|
|
||||||
|
// separately keep another loss history until disk sync
|
||||||
|
// (but only if disk sync is enabled)
|
||||||
|
if (!sync_filename.empty())
|
||||||
|
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -694,7 +697,7 @@ namespace dlib
|
|||||||
|
|
||||||
// Check if we should shrink the learning rate based on how the test
|
// Check if we should shrink the learning rate based on how the test
|
||||||
// error has been doing lately.
|
// error has been doing lately.
|
||||||
if (learning_rate_shrink != 1)
|
if (learning_rate_shrink != 1 && steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
|
||||||
{
|
{
|
||||||
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
|
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
|
||||||
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
||||||
@ -705,10 +708,13 @@ namespace dlib
|
|||||||
// optimization has flattened out, so drop the learning rate.
|
// optimization has flattened out, so drop the learning rate.
|
||||||
learning_rate = learning_rate_shrink*learning_rate;
|
learning_rate = learning_rate_shrink*learning_rate;
|
||||||
test_steps_without_progress = 0;
|
test_steps_without_progress = 0;
|
||||||
// Empty out some of the previous loss values so that test_steps_without_progress
|
|
||||||
// will decrease below test_iter_without_progress_thresh.
|
// Decrease steps_since_last_learning_rate_shrink, so that we
|
||||||
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
|
// do not get here again right away.
|
||||||
test_previous_loss_values.pop_front();
|
steps_since_last_learning_rate_shrink -= std::min(
|
||||||
|
test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10,
|
||||||
|
steps_since_last_learning_rate_shrink
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -717,6 +723,7 @@ namespace dlib
|
|||||||
|
|
||||||
updated_net_since_last_sync = true;
|
updated_net_since_last_sync = true;
|
||||||
++main_iteration_counter;
|
++main_iteration_counter;
|
||||||
|
++steps_since_last_learning_rate_shrink;
|
||||||
// Call compute_parameter_gradients() and update_parameters() but pick the
|
// Call compute_parameter_gradients() and update_parameters() but pick the
|
||||||
// right version for unsupervised or supervised training based on the type
|
// right version for unsupervised or supervised training based on the type
|
||||||
// of training_label_type.
|
// of training_label_type.
|
||||||
@ -802,7 +809,9 @@ namespace dlib
|
|||||||
// have a "budget" that prevents us from calling
|
// have a "budget" that prevents us from calling
|
||||||
// count_steps_without_decrease() every iteration. We do this because
|
// count_steps_without_decrease() every iteration. We do this because
|
||||||
// it can be expensive to compute when previous_loss_values is large.
|
// it can be expensive to compute when previous_loss_values is large.
|
||||||
if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1)
|
if (gradient_check_budget > iter_without_progress_thresh
|
||||||
|
&& learning_rate_shrink != 1
|
||||||
|
&& steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
|
||||||
{
|
{
|
||||||
gradient_check_budget = 0;
|
gradient_check_budget = 0;
|
||||||
steps_without_progress = count_steps_without_decrease(previous_loss_values);
|
steps_without_progress = count_steps_without_decrease(previous_loss_values);
|
||||||
@ -825,14 +834,13 @@ namespace dlib
|
|||||||
// optimization has flattened out, so drop the learning rate.
|
// optimization has flattened out, so drop the learning rate.
|
||||||
learning_rate = learning_rate_shrink*learning_rate;
|
learning_rate = learning_rate_shrink*learning_rate;
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
// Empty out some of the previous loss values so that steps_without_progress
|
|
||||||
// will decrease below iter_without_progress_thresh.
|
// Decrease steps_since_last_learning_rate_shrink, so that we
|
||||||
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
|
// do not get here again right away.
|
||||||
{
|
steps_since_last_learning_rate_shrink -= std::min(
|
||||||
if (!sync_filename.empty())
|
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
|
||||||
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
|
steps_since_last_learning_rate_shrink
|
||||||
previous_loss_values.pop_front();
|
);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -871,6 +879,7 @@ namespace dlib
|
|||||||
min_learning_rate = 1e-5;
|
min_learning_rate = 1e-5;
|
||||||
iter_without_progress_thresh = 2000;
|
iter_without_progress_thresh = 2000;
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
|
steps_since_last_learning_rate_shrink = 0;
|
||||||
test_iter_without_progress_thresh = 500;
|
test_iter_without_progress_thresh = 500;
|
||||||
test_steps_without_progress = 0;
|
test_steps_without_progress = 0;
|
||||||
|
|
||||||
@ -933,7 +942,8 @@ namespace dlib
|
|||||||
serialize(item.test_previous_loss_values, out);
|
serialize(item.test_previous_loss_values, out);
|
||||||
serialize(item.previous_loss_values_dump_amount, out);
|
serialize(item.previous_loss_values_dump_amount, out);
|
||||||
serialize(item.test_previous_loss_values_dump_amount, out);
|
serialize(item.test_previous_loss_values_dump_amount, out);
|
||||||
serialize(item.additional_previous_loss_values_to_keep_until_disk_sync, out);
|
serialize(item.steps_since_last_learning_rate_shrink, out);
|
||||||
|
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
|
||||||
}
|
}
|
||||||
friend void deserialize(dnn_trainer& item, std::istream& in)
|
friend void deserialize(dnn_trainer& item, std::istream& in)
|
||||||
{
|
{
|
||||||
@ -979,7 +989,8 @@ namespace dlib
|
|||||||
deserialize(item.test_previous_loss_values, in);
|
deserialize(item.test_previous_loss_values, in);
|
||||||
deserialize(item.previous_loss_values_dump_amount, in);
|
deserialize(item.previous_loss_values_dump_amount, in);
|
||||||
deserialize(item.test_previous_loss_values_dump_amount, in);
|
deserialize(item.test_previous_loss_values_dump_amount, in);
|
||||||
deserialize(item.additional_previous_loss_values_to_keep_until_disk_sync, in);
|
deserialize(item.steps_since_last_learning_rate_shrink, in);
|
||||||
|
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
|
||||||
|
|
||||||
if (item.devices.size() > 1)
|
if (item.devices.size() > 1)
|
||||||
{
|
{
|
||||||
@ -1038,6 +1049,12 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
|
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
|
||||||
learning_rate = learning_rate_shrink * learning_rate;
|
learning_rate = learning_rate_shrink * learning_rate;
|
||||||
|
|
||||||
|
steps_without_progress = 0;
|
||||||
|
steps_since_last_learning_rate_shrink -= std::min(
|
||||||
|
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
|
||||||
|
steps_since_last_learning_rate_shrink
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -1080,38 +1097,12 @@ namespace dlib
|
|||||||
if (gradient_updates_since_last_sync < 30)
|
if (gradient_updates_since_last_sync < 30)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// how long a loss history do we actually need to maintain?
|
// Now look at the data since a little before the last disk sync. We will
|
||||||
const size_t max_previous_loss_values_to_consider = 2 * gradient_updates_since_last_sync;
|
// check if the loss is getting better or worse.
|
||||||
const auto total_previous_loss_values = [this]()
|
while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
|
||||||
{
|
previous_loss_values_to_keep_until_disk_sync.pop_front();
|
||||||
return additional_previous_loss_values_to_keep_until_disk_sync.size()
|
|
||||||
+ previous_loss_values.size();
|
|
||||||
};
|
|
||||||
const auto can_drop_an_additional_previous_loss_value = [&]() {
|
|
||||||
return !additional_previous_loss_values_to_keep_until_disk_sync.empty()
|
|
||||||
&& total_previous_loss_values() > max_previous_loss_values_to_consider;
|
|
||||||
};
|
|
||||||
|
|
||||||
while (can_drop_an_additional_previous_loss_value())
|
for (auto x : previous_loss_values_to_keep_until_disk_sync)
|
||||||
additional_previous_loss_values_to_keep_until_disk_sync.pop_front();
|
|
||||||
|
|
||||||
// collect all previous loss values we want to consider
|
|
||||||
std::deque<double> previous_loss_values_to_consider;
|
|
||||||
std::copy(
|
|
||||||
additional_previous_loss_values_to_keep_until_disk_sync.begin(),
|
|
||||||
additional_previous_loss_values_to_keep_until_disk_sync.end(),
|
|
||||||
std::back_inserter(previous_loss_values_to_consider)
|
|
||||||
);
|
|
||||||
std::copy(
|
|
||||||
previous_loss_values.begin(),
|
|
||||||
previous_loss_values.end(),
|
|
||||||
std::back_inserter(previous_loss_values_to_consider)
|
|
||||||
);
|
|
||||||
|
|
||||||
while (previous_loss_values_to_consider.size() > max_previous_loss_values_to_consider)
|
|
||||||
previous_loss_values_to_consider.pop_front();
|
|
||||||
|
|
||||||
for (auto x : previous_loss_values_to_consider)
|
|
||||||
{
|
{
|
||||||
// If we get a NaN value of loss assume things have gone horribly wrong and
|
// If we get a NaN value of loss assume things have gone horribly wrong and
|
||||||
// we should reload the state of the trainer.
|
// we should reload the state of the trainer.
|
||||||
@ -1119,11 +1110,9 @@ namespace dlib
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now look at the data since a little before the last disk sync. We will
|
|
||||||
// check if the loss is getting better or worse.
|
|
||||||
running_gradient g;
|
running_gradient g;
|
||||||
for (size_t i = 0; i < previous_loss_values_to_consider.size(); ++i)
|
for (size_t i = 0; i < previous_loss_values_to_keep_until_disk_sync.size(); ++i)
|
||||||
g.add(previous_loss_values_to_consider[i]);
|
g.add(previous_loss_values_to_keep_until_disk_sync[i]);
|
||||||
|
|
||||||
// if the loss is very likely to be increasing then return true
|
// if the loss is very likely to be increasing then return true
|
||||||
const double prob = g.probability_gradient_greater_than(0);
|
const double prob = g.probability_gradient_greater_than(0);
|
||||||
@ -1299,7 +1288,7 @@ namespace dlib
|
|||||||
std::atomic<unsigned long> test_steps_without_progress;
|
std::atomic<unsigned long> test_steps_without_progress;
|
||||||
std::deque<double> test_previous_loss_values;
|
std::deque<double> test_previous_loss_values;
|
||||||
|
|
||||||
std::deque<double> additional_previous_loss_values_to_keep_until_disk_sync;
|
std::deque<double> previous_loss_values_to_keep_until_disk_sync;
|
||||||
|
|
||||||
std::atomic<double> learning_rate_shrink;
|
std::atomic<double> learning_rate_shrink;
|
||||||
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
||||||
@ -1313,6 +1302,7 @@ namespace dlib
|
|||||||
matrix<double,0,1> lr_schedule;
|
matrix<double,0,1> lr_schedule;
|
||||||
long lr_schedule_pos;
|
long lr_schedule_pos;
|
||||||
unsigned long gradient_check_budget;
|
unsigned long gradient_check_budget;
|
||||||
|
unsigned long steps_since_last_learning_rate_shrink;
|
||||||
|
|
||||||
std::exception_ptr eptr = nullptr;
|
std::exception_ptr eptr = nullptr;
|
||||||
mutable std::mutex eptr_mutex;
|
mutable std::mutex eptr_mutex;
|
||||||
|
Loading…
Reference in New Issue
Block a user