mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Simplify code as per review comments
This commit is contained in:
parent
d48b406a11
commit
3746b3c1d3
@ -458,9 +458,10 @@ namespace dlib
|
||||
{
|
||||
steps_without_progress = 0;
|
||||
test_steps_without_progress = 0;
|
||||
steps_since_last_learning_rate_shrink = 0;
|
||||
previous_loss_values.clear();
|
||||
test_previous_loss_values.clear();
|
||||
additional_previous_loss_values_to_keep_until_disk_sync.clear();
|
||||
previous_loss_values_to_keep_until_disk_sync.clear();
|
||||
}
|
||||
learning_rate = lr;
|
||||
lr_schedule.set_size(0);
|
||||
@ -560,6 +561,7 @@ namespace dlib
|
||||
learning_rate_shrink = shrink;
|
||||
steps_without_progress = 0;
|
||||
test_steps_without_progress = 0;
|
||||
steps_since_last_learning_rate_shrink = 0;
|
||||
}
|
||||
|
||||
double get_learning_rate_shrink_factor (
|
||||
@ -602,11 +604,12 @@ namespace dlib
|
||||
previous_loss_values.push_back(loss);
|
||||
// discard really old loss values.
|
||||
while (previous_loss_values.size() > iter_without_progress_thresh)
|
||||
{
|
||||
if (!sync_filename.empty())
|
||||
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
|
||||
previous_loss_values.pop_front();
|
||||
}
|
||||
|
||||
// separately keep another loss history until disk sync
|
||||
// (but only if disk sync is enabled)
|
||||
if (!sync_filename.empty())
|
||||
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -694,7 +697,7 @@ namespace dlib
|
||||
|
||||
// Check if we should shrink the learning rate based on how the test
|
||||
// error has been doing lately.
|
||||
if (learning_rate_shrink != 1)
|
||||
if (learning_rate_shrink != 1 && steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
|
||||
{
|
||||
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
|
||||
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
||||
@ -705,10 +708,13 @@ namespace dlib
|
||||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
test_steps_without_progress = 0;
|
||||
// Empty out some of the previous loss values so that test_steps_without_progress
|
||||
// will decrease below test_iter_without_progress_thresh.
|
||||
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
|
||||
test_previous_loss_values.pop_front();
|
||||
|
||||
// Decrease steps_since_last_learning_rate_shrink, so that we
|
||||
// do not get here again right away.
|
||||
steps_since_last_learning_rate_shrink -= std::min(
|
||||
test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10,
|
||||
steps_since_last_learning_rate_shrink
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -717,6 +723,7 @@ namespace dlib
|
||||
|
||||
updated_net_since_last_sync = true;
|
||||
++main_iteration_counter;
|
||||
++steps_since_last_learning_rate_shrink;
|
||||
// Call compute_parameter_gradients() and update_parameters() but pick the
|
||||
// right version for unsupervised or supervised training based on the type
|
||||
// of training_label_type.
|
||||
@ -802,7 +809,9 @@ namespace dlib
|
||||
// have a "budget" that prevents us from calling
|
||||
// count_steps_without_decrease() every iteration. We do this because
|
||||
// it can be expensive to compute when previous_loss_values is large.
|
||||
if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1)
|
||||
if (gradient_check_budget > iter_without_progress_thresh
|
||||
&& learning_rate_shrink != 1
|
||||
&& steps_since_last_learning_rate_shrink > iter_without_progress_thresh)
|
||||
{
|
||||
gradient_check_budget = 0;
|
||||
steps_without_progress = count_steps_without_decrease(previous_loss_values);
|
||||
@ -825,14 +834,13 @@ namespace dlib
|
||||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
steps_without_progress = 0;
|
||||
// Empty out some of the previous loss values so that steps_without_progress
|
||||
// will decrease below iter_without_progress_thresh.
|
||||
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
|
||||
{
|
||||
if (!sync_filename.empty())
|
||||
additional_previous_loss_values_to_keep_until_disk_sync.push_back(previous_loss_values.front());
|
||||
previous_loss_values.pop_front();
|
||||
}
|
||||
|
||||
// Decrease steps_since_last_learning_rate_shrink, so that we
|
||||
// do not get here again right away.
|
||||
steps_since_last_learning_rate_shrink -= std::min(
|
||||
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
|
||||
steps_since_last_learning_rate_shrink
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -871,6 +879,7 @@ namespace dlib
|
||||
min_learning_rate = 1e-5;
|
||||
iter_without_progress_thresh = 2000;
|
||||
steps_without_progress = 0;
|
||||
steps_since_last_learning_rate_shrink = 0;
|
||||
test_iter_without_progress_thresh = 500;
|
||||
test_steps_without_progress = 0;
|
||||
|
||||
@ -933,7 +942,8 @@ namespace dlib
|
||||
serialize(item.test_previous_loss_values, out);
|
||||
serialize(item.previous_loss_values_dump_amount, out);
|
||||
serialize(item.test_previous_loss_values_dump_amount, out);
|
||||
serialize(item.additional_previous_loss_values_to_keep_until_disk_sync, out);
|
||||
serialize(item.steps_since_last_learning_rate_shrink, out);
|
||||
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
|
||||
}
|
||||
friend void deserialize(dnn_trainer& item, std::istream& in)
|
||||
{
|
||||
@ -979,7 +989,8 @@ namespace dlib
|
||||
deserialize(item.test_previous_loss_values, in);
|
||||
deserialize(item.previous_loss_values_dump_amount, in);
|
||||
deserialize(item.test_previous_loss_values_dump_amount, in);
|
||||
deserialize(item.additional_previous_loss_values_to_keep_until_disk_sync, in);
|
||||
deserialize(item.steps_since_last_learning_rate_shrink, in);
|
||||
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
|
||||
|
||||
if (item.devices.size() > 1)
|
||||
{
|
||||
@ -1038,6 +1049,12 @@ namespace dlib
|
||||
{
|
||||
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
|
||||
learning_rate = learning_rate_shrink * learning_rate;
|
||||
|
||||
steps_without_progress = 0;
|
||||
steps_since_last_learning_rate_shrink -= std::min(
|
||||
previous_loss_values_dump_amount + iter_without_progress_thresh / 10,
|
||||
steps_since_last_learning_rate_shrink
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -1080,38 +1097,12 @@ namespace dlib
|
||||
if (gradient_updates_since_last_sync < 30)
|
||||
return false;
|
||||
|
||||
// how long a loss history do we actually need to maintain?
|
||||
const size_t max_previous_loss_values_to_consider = 2 * gradient_updates_since_last_sync;
|
||||
const auto total_previous_loss_values = [this]()
|
||||
{
|
||||
return additional_previous_loss_values_to_keep_until_disk_sync.size()
|
||||
+ previous_loss_values.size();
|
||||
};
|
||||
const auto can_drop_an_additional_previous_loss_value = [&]() {
|
||||
return !additional_previous_loss_values_to_keep_until_disk_sync.empty()
|
||||
&& total_previous_loss_values() > max_previous_loss_values_to_consider;
|
||||
};
|
||||
// Now look at the data since a little before the last disk sync. We will
|
||||
// check if the loss is getting better or worse.
|
||||
while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
|
||||
previous_loss_values_to_keep_until_disk_sync.pop_front();
|
||||
|
||||
while (can_drop_an_additional_previous_loss_value())
|
||||
additional_previous_loss_values_to_keep_until_disk_sync.pop_front();
|
||||
|
||||
// collect all previous loss values we want to consider
|
||||
std::deque<double> previous_loss_values_to_consider;
|
||||
std::copy(
|
||||
additional_previous_loss_values_to_keep_until_disk_sync.begin(),
|
||||
additional_previous_loss_values_to_keep_until_disk_sync.end(),
|
||||
std::back_inserter(previous_loss_values_to_consider)
|
||||
);
|
||||
std::copy(
|
||||
previous_loss_values.begin(),
|
||||
previous_loss_values.end(),
|
||||
std::back_inserter(previous_loss_values_to_consider)
|
||||
);
|
||||
|
||||
while (previous_loss_values_to_consider.size() > max_previous_loss_values_to_consider)
|
||||
previous_loss_values_to_consider.pop_front();
|
||||
|
||||
for (auto x : previous_loss_values_to_consider)
|
||||
for (auto x : previous_loss_values_to_keep_until_disk_sync)
|
||||
{
|
||||
// If we get a NaN value of loss assume things have gone horribly wrong and
|
||||
// we should reload the state of the trainer.
|
||||
@ -1119,11 +1110,9 @@ namespace dlib
|
||||
return true;
|
||||
}
|
||||
|
||||
// Now look at the data since a little before the last disk sync. We will
|
||||
// check if the loss is getting better or worse.
|
||||
running_gradient g;
|
||||
for (size_t i = 0; i < previous_loss_values_to_consider.size(); ++i)
|
||||
g.add(previous_loss_values_to_consider[i]);
|
||||
for (size_t i = 0; i < previous_loss_values_to_keep_until_disk_sync.size(); ++i)
|
||||
g.add(previous_loss_values_to_keep_until_disk_sync[i]);
|
||||
|
||||
// if the loss is very likely to be increasing then return true
|
||||
const double prob = g.probability_gradient_greater_than(0);
|
||||
@ -1299,7 +1288,7 @@ namespace dlib
|
||||
std::atomic<unsigned long> test_steps_without_progress;
|
||||
std::deque<double> test_previous_loss_values;
|
||||
|
||||
std::deque<double> additional_previous_loss_values_to_keep_until_disk_sync;
|
||||
std::deque<double> previous_loss_values_to_keep_until_disk_sync;
|
||||
|
||||
std::atomic<double> learning_rate_shrink;
|
||||
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
||||
@ -1313,6 +1302,7 @@ namespace dlib
|
||||
matrix<double,0,1> lr_schedule;
|
||||
long lr_schedule_pos;
|
||||
unsigned long gradient_check_budget;
|
||||
unsigned long steps_since_last_learning_rate_shrink;
|
||||
|
||||
std::exception_ptr eptr = nullptr;
|
||||
mutable std::mutex eptr_mutex;
|
||||
|
Loading…
Reference in New Issue
Block a user