diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index b67491e72..18962de20 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -458,9 +458,10 @@ namespace dlib { steps_without_progress = 0; test_steps_without_progress = 0; + steps_since_last_learning_rate_shrink = 0; previous_loss_values.clear(); test_previous_loss_values.clear(); - additional_previous_loss_values_to_keep_until_disk_sync.clear(); + previous_loss_values_to_keep_until_disk_sync.clear(); } learning_rate = lr; lr_schedule.set_size(0); @@ -560,6 +561,7 @@ namespace dlib learning_rate_shrink = shrink; steps_without_progress = 0; test_steps_without_progress = 0; + steps_since_last_learning_rate_shrink = 0; } double get_learning_rate_shrink_factor ( @@ -602,11 +604,12 @@ namespace dlib previous_loss_values.push_back(loss); // discard really old loss values. while (previous_loss_values.size() > iter_without_progress_thresh) - { - if (!sync_filename.empty()) - additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front()); previous_loss_values.pop_front(); - } + + // separately keep another loss history until disk sync + // (but only if disk sync is enabled) + if (!sync_filename.empty()) + previous_loss_values_to_keep_until_disk_sync.push_back(loss); } template @@ -694,7 +697,7 @@ namespace dlib // Check if we should shrink the learning rate based on how the test // error has been doing lately. - if (learning_rate_shrink != 1) + if (learning_rate_shrink != 1 && steps_since_last_learning_rate_shrink > iter_without_progress_thresh) { test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values); if (test_steps_without_progress >= test_iter_without_progress_thresh) @@ -705,10 +708,13 @@ namespace dlib // optimization has flattened out, so drop the learning rate. learning_rate = learning_rate_shrink*learning_rate; test_steps_without_progress = 0; - // Empty out some of the previous loss values so that test_steps_without_progress - // will decrease below test_iter_without_progress_thresh. - for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt) - test_previous_loss_values.pop_front(); + + // Decrease steps_since_last_learning_rate_shrink, so that we + // do not get here again right away. + steps_since_last_learning_rate_shrink -= std::min( + test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10, + steps_since_last_learning_rate_shrink + ); } } } @@ -717,6 +723,7 @@ namespace dlib updated_net_since_last_sync = true; ++main_iteration_counter; + ++steps_since_last_learning_rate_shrink; // Call compute_parameter_gradients() and update_parameters() but pick the // right version for unsupervised or supervised training based on the type // of training_label_type. @@ -802,7 +809,9 @@ namespace dlib // have a "budget" that prevents us from calling // count_steps_without_decrease() every iteration. We do this because // it can be expensive to compute when previous_loss_values is large. - if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1) + if (gradient_check_budget > iter_without_progress_thresh + && learning_rate_shrink != 1 + && steps_since_last_learning_rate_shrink > iter_without_progress_thresh) { gradient_check_budget = 0; steps_without_progress = count_steps_without_decrease(previous_loss_values); @@ -825,14 +834,13 @@ namespace dlib // optimization has flattened out, so drop the learning rate. learning_rate = learning_rate_shrink*learning_rate; steps_without_progress = 0; - // Empty out some of the previous loss values so that steps_without_progress - // will decrease below iter_without_progress_thresh. - for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt) - { - if (!sync_filename.empty()) - additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front()); - previous_loss_values.pop_front(); - } + + // Decrease steps_since_last_learning_rate_shrink, so that we + // do not get here again right away. + steps_since_last_learning_rate_shrink -= std::min( + previous_loss_values_dump_amount + iter_without_progress_thresh / 10, + steps_since_last_learning_rate_shrink + ); } } } @@ -871,6 +879,7 @@ namespace dlib min_learning_rate = 1e-5; iter_without_progress_thresh = 2000; steps_without_progress = 0; + steps_since_last_learning_rate_shrink = 0; test_iter_without_progress_thresh = 500; test_steps_without_progress = 0; @@ -933,7 +942,8 @@ namespace dlib serialize(item.test_previous_loss_values, out); serialize(item.previous_loss_values_dump_amount, out); serialize(item.test_previous_loss_values_dump_amount, out); - serialize(item.additional_previous_loss_values_to_keep_until_disk_sync, out); + serialize(item.steps_since_last_learning_rate_shrink, out); + serialize(item.previous_loss_values_to_keep_until_disk_sync, out); } friend void deserialize(dnn_trainer& item, std::istream& in) { @@ -979,7 +989,8 @@ namespace dlib deserialize(item.test_previous_loss_values, in); deserialize(item.previous_loss_values_dump_amount, in); deserialize(item.test_previous_loss_values_dump_amount, in); - deserialize(item.additional_previous_loss_values_to_keep_until_disk_sync, in); + deserialize(item.steps_since_last_learning_rate_shrink, in); + deserialize(item.previous_loss_values_to_keep_until_disk_sync, in); if (item.devices.size() > 1) { @@ -1038,6 +1049,12 @@ namespace dlib { std::cout << "(and while at it, also shrinking the learning rate)" << std::endl; learning_rate = learning_rate_shrink * learning_rate; + + steps_without_progress = 0; + steps_since_last_learning_rate_shrink -= std::min( + previous_loss_values_dump_amount + iter_without_progress_thresh / 10, + steps_since_last_learning_rate_shrink + ); } } else @@ -1080,38 +1097,12 @@ namespace dlib if (gradient_updates_since_last_sync < 30) return false; - // how long a loss history do we actually need to maintain? - const size_t max_previous_loss_values_to_consider = 2 * gradient_updates_since_last_sync; - const auto total_previous_loss_values = [this]() - { - return additional_previous_loss_values_to_keep_until_disk_sync.size() - + previous_loss_values.size(); - }; - const auto can_drop_an_additional_previous_loss_value = [&]() { - return !additional_previous_loss_values_to_keep_until_disk_sync.empty() - && total_previous_loss_values() > max_previous_loss_values_to_consider; - }; - - while (can_drop_an_additional_previous_loss_value()) - additional_previous_loss_values_to_keep_until_disk_sync.pop_front(); - - // collect all previous loss values we want to consider - std::deque previous_loss_values_to_consider; - std::copy( - additional_previous_loss_values_to_keep_until_disk_sync.begin(), - additional_previous_loss_values_to_keep_until_disk_sync.end(), - std::back_inserter(previous_loss_values_to_consider) - ); - std::copy( - previous_loss_values.begin(), - previous_loss_values.end(), - std::back_inserter(previous_loss_values_to_consider) - ); - - while (previous_loss_values_to_consider.size() > max_previous_loss_values_to_consider) - previous_loss_values_to_consider.pop_front(); - - for (auto x : previous_loss_values_to_consider) + // Now look at the data since a little before the last disk sync. We will + // check if the loss is getting better or worse. + while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync) + previous_loss_values_to_keep_until_disk_sync.pop_front(); + + for (auto x : previous_loss_values_to_keep_until_disk_sync) { // If we get a NaN value of loss assume things have gone horribly wrong and // we should reload the state of the trainer. @@ -1119,11 +1110,9 @@ namespace dlib return true; } - // Now look at the data since a little before the last disk sync. We will - // check if the loss is getting better or worse. running_gradient g; - for (size_t i = 0; i < previous_loss_values_to_consider.size(); ++i) - g.add(previous_loss_values_to_consider[i]); + for (size_t i = 0; i < previous_loss_values_to_keep_until_disk_sync.size(); ++i) + g.add(previous_loss_values_to_keep_until_disk_sync[i]); // if the loss is very likely to be increasing then return true const double prob = g.probability_gradient_greater_than(0); @@ -1299,7 +1288,7 @@ namespace dlib std::atomic test_steps_without_progress; std::deque test_previous_loss_values; - std::deque additional_previous_loss_values_to_keep_until_disk_sync; + std::deque previous_loss_values_to_keep_until_disk_sync; std::atomic learning_rate_shrink; std::chrono::time_point last_sync_time; @@ -1313,6 +1302,7 @@ namespace dlib matrix lr_schedule; long lr_schedule_pos; unsigned long gradient_check_budget; + unsigned long steps_since_last_learning_rate_shrink; std::exception_ptr eptr = nullptr; mutable std::mutex eptr_mutex;