Simplify code as per review comments

This commit is contained in:
Juha Reunanen 2019-11-12 14:13:48 +02:00
parent d48b406a11
commit 3746b3c1d3

View File

@ -458,9 +458,10 @@ namespace dlib
{ {
steps_without_progress = 0; steps_without_progress = 0;
test_steps_without_progress = 0; test_steps_without_progress = 0;
steps_since_last_learning_rate_shrink = 0;
previous_loss_values.clear(); previous_loss_values.clear();
test_previous_loss_values.clear(); test_previous_loss_values.clear();
additional_previous_loss_values_to_keep_until_disk_sync.clear(); previous_loss_values_to_keep_until_disk_sync.clear();
} }
learning_rate = lr; learning_rate = lr;
lr_schedule.set_size(0); lr_schedule.set_size(0);
@ -560,6 +561,7 @@ namespace dlib
learning_rate_shrink = shrink; learning_rate_shrink = shrink;
steps_without_progress = 0; steps_without_progress = 0;
test_steps_without_progress = 0; test_steps_without_progress = 0;
steps_since_last_learning_rate_shrink = 0;
} }
double get_learning_rate_shrink_factor ( double get_learning_rate_shrink_factor (
@ -602,11 +604,12 @@ namespace dlib
previous_loss_values.push_back(loss); previous_loss_values.push_back(loss);
// discard really old loss values. // discard really old loss values.
while (previous_loss_values.size() > iter_without_progress_thresh) while (previous_loss_values.size() > iter_without_progress_thresh)
{
if (!sync_filename.empty())
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
previous_loss_values.pop_front(); previous_loss_values.pop_front();
}
// separately keep another loss history until disk sync
// (but only if disk sync is enabled)
if (!sync_filename.empty())
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
} }
template <typename T> template <typename T>
@ -694,7 +697,7 @@ namespace dlib
// Check if we should shrink the learning rate based on how the test // Check if we should shrink the learning rate based on how the test
// error has been doing lately. // error has been doing lately.
if (learning_rate_shrink != 1) if (learning_rate_shrink != 1 && steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
{ {
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values); test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
if (test_steps_without_progress >= test_iter_without_progress_thresh) if (test_steps_without_progress >= test_iter_without_progress_thresh)
@ -705,10 +708,13 @@ namespace dlib
// optimization has flattened out, so drop the learning rate. // optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate; learning_rate = learning_rate_shrink*learning_rate;
test_steps_without_progress = 0; test_steps_without_progress = 0;
// Empty out some of the previous loss values so that test_steps_without_progress
// will decrease below test_iter_without_progress_thresh. // Decrease steps_since_last_learning_rate_shrink, so that we
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt) // do not get here again right away.
test_previous_loss_values.pop_front(); steps_since_last_learning_rate_shrink -= std::min(
test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10,
steps_since_last_learning_rate_shrink
);
} }
} }
} }
@ -717,6 +723,7 @@ namespace dlib
updated_net_since_last_sync = true; updated_net_since_last_sync = true;
++main_iteration_counter; ++main_iteration_counter;
++steps_since_last_learning_rate_shrink;
// Call compute_parameter_gradients() and update_parameters() but pick the // Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type // right version for unsupervised or supervised training based on the type
// of training_label_type. // of training_label_type.
@ -802,7 +809,9 @@ namespace dlib
// have a "budget" that prevents us from calling // have a "budget" that prevents us from calling
// count_steps_without_decrease() every iteration. We do this because // count_steps_without_decrease() every iteration. We do this because
// it can be expensive to compute when previous_loss_values is large. // it can be expensive to compute when previous_loss_values is large.
if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1) if (gradient_check_budget > iter_without_progress_thresh
&& learning_rate_shrink != 1
&& steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
{ {
gradient_check_budget = 0; gradient_check_budget = 0;
steps_without_progress = count_steps_without_decrease(previous_loss_values); steps_without_progress = count_steps_without_decrease(previous_loss_values);
@ -825,14 +834,13 @@ namespace dlib
// optimization has flattened out, so drop the learning rate. // optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate; learning_rate = learning_rate_shrink*learning_rate;
steps_without_progress = 0; steps_without_progress = 0;
// Empty out some of the previous loss values so that steps_without_progress
// will decrease below iter_without_progress_thresh. // Decrease steps_since_last_learning_rate_shrink, so that we
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt) // do not get here again right away.
{ steps_since_last_learning_rate_shrink -= std::min(
if (!sync_filename.empty()) previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front()); steps_since_last_learning_rate_shrink
previous_loss_values.pop_front(); );
}
} }
} }
} }
@ -871,6 +879,7 @@ namespace dlib
min_learning_rate = 1e-5; min_learning_rate = 1e-5;
iter_without_progress_thresh = 2000; iter_without_progress_thresh = 2000;
steps_without_progress = 0; steps_without_progress = 0;
steps_since_last_learning_rate_shrink = 0;
test_iter_without_progress_thresh = 500; test_iter_without_progress_thresh = 500;
test_steps_without_progress = 0; test_steps_without_progress = 0;
@ -933,7 +942,8 @@ namespace dlib
serialize(item.test_previous_loss_values, out); serialize(item.test_previous_loss_values, out);
serialize(item.previous_loss_values_dump_amount, out); serialize(item.previous_loss_values_dump_amount, out);
serialize(item.test_previous_loss_values_dump_amount, out); serialize(item.test_previous_loss_values_dump_amount, out);
serialize(item.additional_previous_loss_values_to_keep_until_disk_sync, out); serialize(item.steps_since_last_learning_rate_shrink, out);
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
} }
friend void deserialize(dnn_trainer& item, std::istream& in) friend void deserialize(dnn_trainer& item, std::istream& in)
{ {
@ -979,7 +989,8 @@ namespace dlib
deserialize(item.test_previous_loss_values, in); deserialize(item.test_previous_loss_values, in);
deserialize(item.previous_loss_values_dump_amount, in); deserialize(item.previous_loss_values_dump_amount, in);
deserialize(item.test_previous_loss_values_dump_amount, in); deserialize(item.test_previous_loss_values_dump_amount, in);
deserialize(item.additional_previous_loss_values_to_keep_until_disk_sync, in); deserialize(item.steps_since_last_learning_rate_shrink, in);
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
if (item.devices.size() > 1) if (item.devices.size() > 1)
{ {
@ -1038,6 +1049,12 @@ namespace dlib
{ {
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl; std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
learning_rate = learning_rate_shrink * learning_rate; learning_rate = learning_rate_shrink * learning_rate;
steps_without_progress = 0;
steps_since_last_learning_rate_shrink -= std::min(
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
steps_since_last_learning_rate_shrink
);
} }
} }
else else
@ -1080,38 +1097,12 @@ namespace dlib
if (gradient_updates_since_last_sync < 30) if (gradient_updates_since_last_sync < 30)
return false; return false;
// how long a loss history do we actually need to maintain? // Now look at the data since a little before the last disk sync. We will
const size_t max_previous_loss_values_to_consider = 2 * gradient_updates_since_last_sync; // check if the loss is getting better or worse.
const auto total_previous_loss_values = [this]() while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
{ previous_loss_values_to_keep_until_disk_sync.pop_front();
return additional_previous_loss_values_to_keep_until_disk_sync.size()
+ previous_loss_values.size();
};
const auto can_drop_an_additional_previous_loss_value = [&]() {
return !additional_previous_loss_values_to_keep_until_disk_sync.empty()
&& total_previous_loss_values() > max_previous_loss_values_to_consider;
};
while (can_drop_an_additional_previous_loss_value()) for (auto x : previous_loss_values_to_keep_until_disk_sync)
additional_previous_loss_values_to_keep_until_disk_sync.pop_front();
// collect all previous loss values we want to consider
std::deque<double> previous_loss_values_to_consider;
std::copy(
additional_previous_loss_values_to_keep_until_disk_sync.begin(),
additional_previous_loss_values_to_keep_until_disk_sync.end(),
std::back_inserter(previous_loss_values_to_consider)
);
std::copy(
previous_loss_values.begin(),
previous_loss_values.end(),
std::back_inserter(previous_loss_values_to_consider)
);
while (previous_loss_values_to_consider.size() > max_previous_loss_values_to_consider)
previous_loss_values_to_consider.pop_front();
for (auto x : previous_loss_values_to_consider)
{ {
// If we get a NaN value of loss assume things have gone horribly wrong and // If we get a NaN value of loss assume things have gone horribly wrong and
// we should reload the state of the trainer. // we should reload the state of the trainer.
@ -1119,11 +1110,9 @@ namespace dlib
return true; return true;
} }
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting better or worse.
running_gradient g; running_gradient g;
for (size_t i = 0; i < previous_loss_values_to_consider.size(); ++i) for (size_t i = 0; i < previous_loss_values_to_keep_until_disk_sync.size(); ++i)
g.add(previous_loss_values_to_consider[i]); g.add(previous_loss_values_to_keep_until_disk_sync[i]);
// if the loss is very likely to be increasing then return true // if the loss is very likely to be increasing then return true
const double prob = g.probability_gradient_greater_than(0); const double prob = g.probability_gradient_greater_than(0);
@ -1299,7 +1288,7 @@ namespace dlib
std::atomic<unsigned long> test_steps_without_progress; std::atomic<unsigned long> test_steps_without_progress;
std::deque<double> test_previous_loss_values; std::deque<double> test_previous_loss_values;
std::deque<double> additional_previous_loss_values_to_keep_until_disk_sync; std::deque<double> previous_loss_values_to_keep_until_disk_sync;
std::atomic<double> learning_rate_shrink; std::atomic<double> learning_rate_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time; std::chrono::time_point<std::chrono::system_clock> last_sync_time;
@ -1313,6 +1302,7 @@ namespace dlib
matrix<double,0,1> lr_schedule; matrix<double,0,1> lr_schedule;
long lr_schedule_pos; long lr_schedule_pos;
unsigned long gradient_check_budget; unsigned long gradient_check_budget;
unsigned long steps_since_last_learning_rate_shrink;
std::exception_ptr eptr = nullptr; std::exception_ptr eptr = nullptr;
mutable std::mutex eptr_mutex; mutable std::mutex eptr_mutex;