Simplify code as per review comments

This commit is contained in:
Juha Reunanen 2019-11-12 14:13:48 +02:00
parent d48b406a11
commit 3746b3c1d3

View File

@ -458,9 +458,10 @@ namespace dlib
{
steps_without_progress = 0;
test_steps_without_progress = 0;
steps_since_last_learning_rate_shrink = 0;
previous_loss_values.clear();
test_previous_loss_values.clear();
additional_previous_loss_values_to_keep_until_disk_sync.clear();
previous_loss_values_to_keep_until_disk_sync.clear();
}
learning_rate = lr;
lr_schedule.set_size(0);
@ -560,6 +561,7 @@ namespace dlib
learning_rate_shrink = shrink;
steps_without_progress = 0;
test_steps_without_progress = 0;
steps_since_last_learning_rate_shrink = 0;
}
double get_learning_rate_shrink_factor (
@ -602,11 +604,12 @@ namespace dlib
previous_loss_values.push_back(loss);
// discard really old loss values.
while (previous_loss_values.size() > iter_without_progress_thresh)
{
if (!sync_filename.empty())
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
previous_loss_values.pop_front();
}
// separately keep another loss history until disk sync
// (but only if disk sync is enabled)
if (!sync_filename.empty())
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
}
template <typename T>
@ -694,7 +697,7 @@ namespace dlib
// Check if we should shrink the learning rate based on how the test
// error has been doing lately.
if (learning_rate_shrink != 1)
if (learning_rate_shrink != 1 && steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
{
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
if (test_steps_without_progress >= test_iter_without_progress_thresh)
@ -705,10 +708,13 @@ namespace dlib
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
test_steps_without_progress = 0;
// Empty out some of the previous loss values so that test_steps_without_progress
// will decrease below test_iter_without_progress_thresh.
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
test_previous_loss_values.pop_front();
// Decrease steps_since_last_learning_rate_shrink, so that we
// do not get here again right away.
steps_since_last_learning_rate_shrink -= std::min(
test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10,
steps_since_last_learning_rate_shrink
);
}
}
}
@ -717,6 +723,7 @@ namespace dlib
updated_net_since_last_sync = true;
++main_iteration_counter;
++steps_since_last_learning_rate_shrink;
// Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type
// of training_label_type.
@ -802,7 +809,9 @@ namespace dlib
// have a "budget" that prevents us from calling
// count_steps_without_decrease() every iteration. We do this because
// it can be expensive to compute when previous_loss_values is large.
if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1)
if (gradient_check_budget > iter_without_progress_thresh
&& learning_rate_shrink != 1
&& steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
{
gradient_check_budget = 0;
steps_without_progress = count_steps_without_decrease(previous_loss_values);
@ -825,14 +834,13 @@ namespace dlib
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
steps_without_progress = 0;
// Empty out some of the previous loss values so that steps_without_progress
// will decrease below iter_without_progress_thresh.
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
{
if (!sync_filename.empty())
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
previous_loss_values.pop_front();
}
// Decrease steps_since_last_learning_rate_shrink, so that we
// do not get here again right away.
steps_since_last_learning_rate_shrink -= std::min(
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
steps_since_last_learning_rate_shrink
);
}
}
}
@ -871,6 +879,7 @@ namespace dlib
min_learning_rate = 1e-5;
iter_without_progress_thresh = 2000;
steps_without_progress = 0;
steps_since_last_learning_rate_shrink = 0;
test_iter_without_progress_thresh = 500;
test_steps_without_progress = 0;
@ -933,7 +942,8 @@ namespace dlib
serialize(item.test_previous_loss_values, out);
serialize(item.previous_loss_values_dump_amount, out);
serialize(item.test_previous_loss_values_dump_amount, out);
serialize(item.additional_previous_loss_values_to_keep_until_disk_sync, out);
serialize(item.steps_since_last_learning_rate_shrink, out);
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
@ -979,7 +989,8 @@ namespace dlib
deserialize(item.test_previous_loss_values, in);
deserialize(item.previous_loss_values_dump_amount, in);
deserialize(item.test_previous_loss_values_dump_amount, in);
deserialize(item.additional_previous_loss_values_to_keep_until_disk_sync, in);
deserialize(item.steps_since_last_learning_rate_shrink, in);
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
if (item.devices.size() > 1)
{
@ -1038,6 +1049,12 @@ namespace dlib
{
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
learning_rate = learning_rate_shrink * learning_rate;
steps_without_progress = 0;
steps_since_last_learning_rate_shrink -= std::min(
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
steps_since_last_learning_rate_shrink
);
}
}
else
@ -1080,38 +1097,12 @@ namespace dlib
if (gradient_updates_since_last_sync < 30)
return false;
// how long a loss history do we actually need to maintain?
const size_t max_previous_loss_values_to_consider = 2 * gradient_updates_since_last_sync;
const auto total_previous_loss_values = [this]()
{
return additional_previous_loss_values_to_keep_until_disk_sync.size()
+ previous_loss_values.size();
};
const auto can_drop_an_additional_previous_loss_value = [&]() {
return !additional_previous_loss_values_to_keep_until_disk_sync.empty()
&& total_previous_loss_values() > max_previous_loss_values_to_consider;
};
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting better or worse.
while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
previous_loss_values_to_keep_until_disk_sync.pop_front();
while (can_drop_an_additional_previous_loss_value())
additional_previous_loss_values_to_keep_until_disk_sync.pop_front();
// collect all previous loss values we want to consider
std::deque<double> previous_loss_values_to_consider;
std::copy(
additional_previous_loss_values_to_keep_until_disk_sync.begin(),
additional_previous_loss_values_to_keep_until_disk_sync.end(),
std::back_inserter(previous_loss_values_to_consider)
);
std::copy(
previous_loss_values.begin(),
previous_loss_values.end(),
std::back_inserter(previous_loss_values_to_consider)
);
while (previous_loss_values_to_consider.size() > max_previous_loss_values_to_consider)
previous_loss_values_to_consider.pop_front();
for (auto x : previous_loss_values_to_consider)
for (auto x : previous_loss_values_to_keep_until_disk_sync)
{
// If we get a NaN value of loss assume things have gone horribly wrong and
// we should reload the state of the trainer.
@ -1119,11 +1110,9 @@ namespace dlib
return true;
}
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting better or worse.
running_gradient g;
for (size_t i = 0; i < previous_loss_values_to_consider.size(); ++i)
g.add(previous_loss_values_to_consider[i]);
for (size_t i = 0; i < previous_loss_values_to_keep_until_disk_sync.size(); ++i)
g.add(previous_loss_values_to_keep_until_disk_sync[i]);
// if the loss is very likely to be increasing then return true
const double prob = g.probability_gradient_greater_than(0);
@ -1299,7 +1288,7 @@ namespace dlib
std::atomic<unsigned long> test_steps_without_progress;
std::deque<double> test_previous_loss_values;
std::deque<double> additional_previous_loss_values_to_keep_until_disk_sync;
std::deque<double> previous_loss_values_to_keep_until_disk_sync;
std::atomic<double> learning_rate_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
@ -1313,6 +1302,7 @@ namespace dlib
matrix<double,0,1> lr_schedule;
long lr_schedule_pos;
unsigned long gradient_check_budget;
unsigned long steps_since_last_learning_rate_shrink;
std::exception_ptr eptr = nullptr;
mutable std::mutex eptr_mutex;