Added test_one_step() to the dnn_trainer. This allows you to do automatic

early stopping based on observing the loss on held out data.
This commit is contained in:
Davis King 2017-01-22 10:25:06 -05:00
parent 5f5684a8fb
commit c20c11af90
2 changed files with 357 additions and 39 deletions

View File

@ -40,6 +40,7 @@ namespace dlib
std::vector<std::vector<training_label_type>> labels;
std::vector<resizable_tensor> t;
std::vector<int> have_data; // have_data[i] is true if there is data in labels[i] and t[i].
bool test_only = false;
};
template <typename training_label_type>
@ -48,6 +49,7 @@ namespace dlib
a.labels.swap(b.labels);
a.t.swap(b.t);
a.have_data.swap(b.have_data);
std::swap(a.test_only,b.test_only);
}
}
@ -205,22 +207,9 @@ namespace dlib
{
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
if (verbose)
{
using namespace std::chrono;
auto now_time = system_clock::now();
if (now_time-last_time > seconds(40))
{
last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
clear_average_loss();
}
}
print_periodic_verbose_status();
sync_to_disk();
send_job(dbegin, dend, lbegin);
send_job(false, dbegin, dend, lbegin);
++train_one_step_calls;
}
@ -241,22 +230,60 @@ namespace dlib
)
{
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
if (verbose)
{
using namespace std::chrono;
auto now_time = system_clock::now();
if (now_time-last_time > seconds(40))
{
last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
clear_average_loss();
}
}
print_periodic_verbose_status();
sync_to_disk();
send_job(dbegin, dend);
send_job(false, dbegin, dend);
++train_one_step_calls;
}
void test_one_step (
const std::vector<input_type>& data,
const std::vector<training_label_type>& labels
)
{
DLIB_CASSERT(data.size() == labels.size());
test_one_step(data.begin(), data.end(), labels.begin());
}
template <
typename data_iterator,
typename label_iterator
>
void test_one_step (
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
)
{
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
print_periodic_verbose_status();
sync_to_disk();
send_job(true, dbegin, dend, lbegin);
++train_one_step_calls;
}
void test_one_step (
const std::vector<input_type>& data
)
{
test_one_step(data.begin(), data.end());
}
template <
typename data_iterator
>
void test_one_step (
data_iterator dbegin,
data_iterator dend
)
{
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
print_periodic_verbose_status();
sync_to_disk();
send_job(true, dbegin, dend);
++train_one_step_calls;
}
@ -294,7 +321,7 @@ namespace dlib
}
sync_to_disk();
send_job(data.begin()+epoch_pos,
send_job(false, data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
labels.begin()+epoch_pos);
}
@ -352,7 +379,7 @@ namespace dlib
}
sync_to_disk();
send_job(data.begin()+epoch_pos,
send_job(false, data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
}
epoch_pos = 0;
@ -394,6 +421,16 @@ namespace dlib
return rs.mean();
}
double get_average_test_loss (
) const
{
wait_for_thread_to_pause();
running_stats<double> tmp;
for (auto& x : test_previous_loss_values)
tmp.add(x);
return tmp.mean();
}
void clear_average_loss (
)
{
@ -410,7 +447,9 @@ namespace dlib
if (learning_rate != lr)
{
steps_without_progress = 0;
test_steps_without_progress = 0;
previous_loss_values.clear();
test_previous_loss_values.clear();
}
learning_rate = lr;
lr_schedule.set_size(0);
@ -479,6 +518,27 @@ namespace dlib
return steps_without_progress;
}
void set_test_iterations_without_progress_threshold (
unsigned long thresh
)
{
wait_for_thread_to_pause();
lr_schedule.set_size(0);
test_iter_without_progress_thresh = thresh;
}
unsigned long get_test_iterations_without_progress_threshold (
) const
{
return test_iter_without_progress_thresh;
}
unsigned long get_test_steps_without_progress (
) const
{
return test_steps_without_progress;
}
void set_learning_rate_shrink_factor (
double shrink
)
@ -488,6 +548,7 @@ namespace dlib
lr_schedule.set_size(0);
learning_rate_shrink = shrink;
steps_without_progress = 0;
test_steps_without_progress = 0;
}
double get_learning_rate_shrink_factor (
@ -504,6 +565,14 @@ namespace dlib
private:
void record_test_loss(double loss)
{
test_previous_loss_values.push_back(loss);
// discard really old loss values.
while (test_previous_loss_values.size() > test_iter_without_progress_thresh)
test_previous_loss_values.pop_front();
}
void record_loss(double loss)
{
// Say that we will check if the gradient is bad 200 times during each
@ -527,7 +596,10 @@ namespace dlib
{
auto&& dev = *devices[device];
dlib::cuda::set_device(dev.device_id);
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
if (next_job.test_only)
return dev.net.compute_loss(next_job.t[device], next_job.labels[device].begin());
else
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
}
else
{
@ -542,7 +614,10 @@ namespace dlib
auto&& dev = *devices[device];
dlib::cuda::set_device(dev.device_id);
no_label_type pick_which_run_update;
return dev.net.compute_parameter_gradients(next_job.t[device]);
if (next_job.test_only)
return dev.net.compute_loss(next_job.t[device]);
else
return dev.net.compute_parameter_gradients(next_job.t[device]);
}
else
{
@ -585,6 +660,37 @@ namespace dlib
main_iteration_counter = 0;
while(job_pipe.dequeue(next_job))
{
if (next_job.test_only)
{
// compute the testing loss
for (size_t i = 0; i < devices.size(); ++i)
tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
// aggregate loss values from all the network computations.
double theloss = 0;
for (auto&& loss : losses)
theloss += loss.get();
record_test_loss(theloss/losses.size());
// Check if we should shrink the learning rate based on how the test
// error has been doing lately.
if (learning_rate_shrink != 1)
{
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
if (test_steps_without_progress >= test_iter_without_progress_thresh)
{
test_steps_without_progress = count_steps_without_decrease_robust(test_previous_loss_values);
if (test_steps_without_progress >= test_iter_without_progress_thresh)
{
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
test_steps_without_progress = 0;
test_previous_loss_values.clear();
}
}
}
continue;
}
updated_net_since_last_sync = true;
++main_iteration_counter;
// Call compute_parameter_gradients() and update_parameters() but pick the
@ -719,7 +825,7 @@ namespace dlib
job_pipe.wait_for_num_blocked_dequeues(1);
}
const static long string_pad = 10;
const static long string_pad = 11;
const static long epoch_string_pad = 4;
const static long lr_string_pad = 4;
@ -732,6 +838,9 @@ namespace dlib
min_learning_rate = 1e-5;
iter_without_progress_thresh = 2000;
steps_without_progress = 0;
test_iter_without_progress_thresh = 200;
test_steps_without_progress = 0;
learning_rate_shrink = 0.1;
epoch_iteration = 0;
epoch_pos = 0;
@ -755,7 +864,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 7;
int version = 8;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
@ -777,13 +886,17 @@ namespace dlib
serialize(item.train_one_step_calls, out);
serialize(item.lr_schedule, out);
serialize(item.lr_schedule_pos, out);
serialize(item.test_iter_without_progress_thresh.load(), out);
serialize(item.test_steps_without_progress.load(), out);
serialize(item.test_previous_loss_values, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 7)
if (version != 8)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
@ -815,6 +928,9 @@ namespace dlib
deserialize(item.train_one_step_calls, in);
deserialize(item.lr_schedule, in);
deserialize(item.lr_schedule_pos, in);
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
deserialize(ltemp, in); item.test_steps_without_progress = ltemp;
deserialize(item.test_previous_loss_values, in);
if (item.devices.size() > 1)
{
@ -966,6 +1082,7 @@ namespace dlib
typename label_iterator
>
void send_job (
bool test_only,
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
@ -977,6 +1094,7 @@ namespace dlib
job.t.resize(devs);
job.labels.resize(devs);
job.have_data.resize(devs);
job.test_only = test_only;
// chop the data into devs blocks, each of about block_size elements.
size_t block_size = (num+devs-1)/devs;
@ -1009,19 +1127,23 @@ namespace dlib
typename data_iterator
>
void send_job (
bool test_only,
data_iterator dbegin,
data_iterator dend
)
{
typename std::vector<training_label_type>::iterator nothing;
send_job(dbegin, dend, nothing);
send_job(test_only, dbegin, dend, nothing);
}
void print_progress()
{
if (lr_schedule.size() == 0)
{
std::cout << "steps without apparent progress: " << steps_without_progress;
if (test_previous_loss_values.size() == 0)
std::cout << "steps without apparent progress: " << steps_without_progress;
else
std::cout << "steps without apparent progress: train=" << steps_without_progress << ", test=" << test_steps_without_progress;
}
else
{
@ -1032,6 +1154,32 @@ namespace dlib
std::cout << std::endl;
}
void print_periodic_verbose_status()
{
if (verbose)
{
using namespace std::chrono;
auto now_time = system_clock::now();
if (now_time-last_time > seconds(40))
{
last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " ";
if (test_previous_loss_values.size() == 0)
{
std::cout << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
}
else
{
std::cout << "train loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
std::cout << "test loss: " << rpad(cast_to_string(get_average_test_loss()),string_pad) << " ";
}
print_progress();
clear_average_loss();
}
}
}
std::vector<std::shared_ptr<device_data>> devices;
dlib::pipe<job_t> job_pipe;
job_t job;
@ -1047,6 +1195,11 @@ namespace dlib
double min_learning_rate;
std::atomic<unsigned long> iter_without_progress_thresh;
std::atomic<unsigned long> steps_without_progress;
std::atomic<unsigned long> test_iter_without_progress_thresh;
std::atomic<unsigned long> test_steps_without_progress;
std::deque<double> test_previous_loss_values;
std::atomic<double> learning_rate_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
std::string sync_filename;

View File

@ -76,6 +76,7 @@ namespace dlib
- #get_learning_rate() == 1e-2
- #get_min_learning_rate() == 1e-5
- #get_iterations_without_progress_threshold() == 2000
- #get_test_iterations_without_progress_threshold() == 200
- #get_learning_rate_shrink_factor() == 0.1
- #get_learning_rate_schedule().size() == 0
- #get_train_one_step_calls() == 0
@ -536,6 +537,170 @@ namespace dlib
stopped touching the net.
!*/
// ----------------------
double get_average_test_loss (
) const;
/*!
ensures
- returns the average loss value observed during previous calls to
test_one_step().
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/
void test_one_step (
const std::vector<input_type>& data,
const std::vector<training_label_type>& labels
);
/*!
requires
- data.size() == labels.size()
- data.size() > 0
- net_type uses a supervised loss.
i.e. net_type::training_label_type != no_label_type.
ensures
- Runs the given data through the network and computes and records the loss.
- This call does not modify network parameters. The point of
test_one_step() is two fold, to allow you to observe the accuracy of the
network on hold out data during training, and to allow the trainer to
automatically adjust the learning rate when the test loss stops
improving. It should be noted that you are not required to use
test_one_step() at all, but if you want to do this kind of thing it is
available.
- You can observe the current average loss value by calling get_average_test_loss().
- The computation will happen in another thread. Therefore, after calling
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
!*/
template <
typename data_iterator,
typename label_iterator
>
void test_one_step (
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
);
/*!
requires
- std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
- std::distance(dbegin, dend) > 0
- net_type uses a supervised loss.
i.e. net_type::training_label_type != no_label_type.
ensures
- Runs the given data through the network and computes and records the loss.
- This call does not modify network parameters. The point of
test_one_step() is two fold, to allow you to observe the accuracy of the
network on hold out data during training, and to allow the trainer to
automatically adjust the learning rate when the test loss stops
improving. It should be noted that you are not required to use
test_one_step() at all, but if you want to do this kind of thing it is
available.
- You can observe the current average loss value by calling get_average_test_loss().
- The computation will happen in another thread. Therefore, after calling
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
!*/
void test_one_step (
const std::vector<input_type>& data
);
/*!
requires
- data.size() > 0
- net_type uses an unsupervised loss.
i.e. net_type::training_label_type == no_label_type.
ensures
- Runs the given data through the network and computes and records the loss.
- This call does not modify network parameters. The point of
test_one_step() is two fold, to allow you to observe the accuracy of the
network on hold out data during training, and to allow the trainer to
automatically adjust the learning rate when the test loss stops
improving. It should be noted that you are not required to use
test_one_step() at all, but if you want to do this kind of thing it is
available.
- You can observe the current average loss value by calling get_average_test_loss().
- The computation will happen in another thread. Therefore, after calling
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
!*/
template <
typename data_iterator
>
void test_one_step (
data_iterator dbegin,
data_iterator dend
);
/*!
requires
- std::distance(dbegin, dend) > 0
- net_type uses an unsupervised loss.
i.e. net_type::training_label_type == no_label_type.
ensures
- Runs the given data through the network and computes and records the loss.
- This call does not modify network parameters. The point of
test_one_step() is two fold, to allow you to observe the accuracy of the
network on hold out data during training, and to allow the trainer to
automatically adjust the learning rate when the test loss stops
improving. It should be noted that you are not required to use
test_one_step() at all, but if you want to do this kind of thing it is
available.
- You can observe the current average loss value by calling get_average_test_loss().
- The computation will happen in another thread. Therefore, after calling
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
!*/
void set_test_iterations_without_progress_threshold (
unsigned long thresh
);
/*!
ensures
- #get_test_iterations_without_progress_threshold() == thresh
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/
unsigned long get_test_iterations_without_progress_threshold (
) const;
/*!
ensures
- This object monitors the progress of training and estimates if the
testing error is being reduced. It does this by looking at the previous
get_test_iterations_without_progress_threshold() mini-batch results from
test_one_step() and applying the statistical test defined by the
running_gradient object to see if the testing error is getting smaller.
If it isn't being reduced then get_learning_rate() is made smaller by a
factor of get_learning_rate_shrink_factor().
Therefore, get_test_iterations_without_progress_threshold() should always be
set to something sensibly large so that this test can be done with
reasonably high confidence. Think of this test as saying "if the testing loss
hasn't decreased for the previous get_test_iterations_without_progress_threshold()
calls to test_one_step() then shrink the learning rate".
!*/
unsigned long get_test_steps_without_progress (
) const;
/*!
ensures
- if (get_learning_rate_shrink_factor() != 1) then
- returns an estimate of how many mini-batches have executed without us
observing a statistically significant decrease in the testing error
(i.e. the error on the data given to the trainer via test_one_step()
calls).
- else
- returns 0
!*/
};
// ----------------------------------------------------------------------------------------