mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added test_one_step() to the dnn_trainer. This allows you to do automatic
early stopping based on observing the loss on held out data.
This commit is contained in:
parent
5f5684a8fb
commit
c20c11af90
@ -40,6 +40,7 @@ namespace dlib
|
|||||||
std::vector<std::vector<training_label_type>> labels;
|
std::vector<std::vector<training_label_type>> labels;
|
||||||
std::vector<resizable_tensor> t;
|
std::vector<resizable_tensor> t;
|
||||||
std::vector<int> have_data; // have_data[i] is true if there is data in labels[i] and t[i].
|
std::vector<int> have_data; // have_data[i] is true if there is data in labels[i] and t[i].
|
||||||
|
bool test_only = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename training_label_type>
|
template <typename training_label_type>
|
||||||
@ -48,6 +49,7 @@ namespace dlib
|
|||||||
a.labels.swap(b.labels);
|
a.labels.swap(b.labels);
|
||||||
a.t.swap(b.t);
|
a.t.swap(b.t);
|
||||||
a.have_data.swap(b.have_data);
|
a.have_data.swap(b.have_data);
|
||||||
|
std::swap(a.test_only,b.test_only);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,22 +207,9 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||||
|
|
||||||
if (verbose)
|
print_periodic_verbose_status();
|
||||||
{
|
|
||||||
using namespace std::chrono;
|
|
||||||
auto now_time = system_clock::now();
|
|
||||||
if (now_time-last_time > seconds(40))
|
|
||||||
{
|
|
||||||
last_time = now_time;
|
|
||||||
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
|
|
||||||
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
|
|
||||||
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
|
||||||
print_progress();
|
|
||||||
clear_average_loss();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(dbegin, dend, lbegin);
|
send_job(false, dbegin, dend, lbegin);
|
||||||
|
|
||||||
++train_one_step_calls;
|
++train_one_step_calls;
|
||||||
}
|
}
|
||||||
@ -241,22 +230,60 @@ namespace dlib
|
|||||||
)
|
)
|
||||||
{
|
{
|
||||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||||
if (verbose)
|
print_periodic_verbose_status();
|
||||||
{
|
|
||||||
using namespace std::chrono;
|
|
||||||
auto now_time = system_clock::now();
|
|
||||||
if (now_time-last_time > seconds(40))
|
|
||||||
{
|
|
||||||
last_time = now_time;
|
|
||||||
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
|
|
||||||
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
|
|
||||||
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
|
||||||
print_progress();
|
|
||||||
clear_average_loss();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(dbegin, dend);
|
send_job(false, dbegin, dend);
|
||||||
|
++train_one_step_calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_one_step (
|
||||||
|
const std::vector<input_type>& data,
|
||||||
|
const std::vector<training_label_type>& labels
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(data.size() == labels.size());
|
||||||
|
|
||||||
|
test_one_step(data.begin(), data.end(), labels.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator,
|
||||||
|
typename label_iterator
|
||||||
|
>
|
||||||
|
void test_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend,
|
||||||
|
label_iterator lbegin
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||||
|
|
||||||
|
print_periodic_verbose_status();
|
||||||
|
sync_to_disk();
|
||||||
|
send_job(true, dbegin, dend, lbegin);
|
||||||
|
|
||||||
|
++train_one_step_calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_one_step (
|
||||||
|
const std::vector<input_type>& data
|
||||||
|
)
|
||||||
|
{
|
||||||
|
test_one_step(data.begin(), data.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator
|
||||||
|
>
|
||||||
|
void test_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||||
|
print_periodic_verbose_status();
|
||||||
|
sync_to_disk();
|
||||||
|
send_job(true, dbegin, dend);
|
||||||
++train_one_step_calls;
|
++train_one_step_calls;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,7 +321,7 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
|
|
||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(data.begin()+epoch_pos,
|
send_job(false, data.begin()+epoch_pos,
|
||||||
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
|
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
|
||||||
labels.begin()+epoch_pos);
|
labels.begin()+epoch_pos);
|
||||||
}
|
}
|
||||||
@ -352,7 +379,7 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
|
|
||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(data.begin()+epoch_pos,
|
send_job(false, data.begin()+epoch_pos,
|
||||||
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
|
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
|
||||||
}
|
}
|
||||||
epoch_pos = 0;
|
epoch_pos = 0;
|
||||||
@ -394,6 +421,16 @@ namespace dlib
|
|||||||
return rs.mean();
|
return rs.mean();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double get_average_test_loss (
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
wait_for_thread_to_pause();
|
||||||
|
running_stats<double> tmp;
|
||||||
|
for (auto& x : test_previous_loss_values)
|
||||||
|
tmp.add(x);
|
||||||
|
return tmp.mean();
|
||||||
|
}
|
||||||
|
|
||||||
void clear_average_loss (
|
void clear_average_loss (
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
@ -410,7 +447,9 @@ namespace dlib
|
|||||||
if (learning_rate != lr)
|
if (learning_rate != lr)
|
||||||
{
|
{
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
|
test_steps_without_progress = 0;
|
||||||
previous_loss_values.clear();
|
previous_loss_values.clear();
|
||||||
|
test_previous_loss_values.clear();
|
||||||
}
|
}
|
||||||
learning_rate = lr;
|
learning_rate = lr;
|
||||||
lr_schedule.set_size(0);
|
lr_schedule.set_size(0);
|
||||||
@ -479,6 +518,27 @@ namespace dlib
|
|||||||
return steps_without_progress;
|
return steps_without_progress;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_test_iterations_without_progress_threshold (
|
||||||
|
unsigned long thresh
|
||||||
|
)
|
||||||
|
{
|
||||||
|
wait_for_thread_to_pause();
|
||||||
|
lr_schedule.set_size(0);
|
||||||
|
test_iter_without_progress_thresh = thresh;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned long get_test_iterations_without_progress_threshold (
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
return test_iter_without_progress_thresh;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned long get_test_steps_without_progress (
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
return test_steps_without_progress;
|
||||||
|
}
|
||||||
|
|
||||||
void set_learning_rate_shrink_factor (
|
void set_learning_rate_shrink_factor (
|
||||||
double shrink
|
double shrink
|
||||||
)
|
)
|
||||||
@ -488,6 +548,7 @@ namespace dlib
|
|||||||
lr_schedule.set_size(0);
|
lr_schedule.set_size(0);
|
||||||
learning_rate_shrink = shrink;
|
learning_rate_shrink = shrink;
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
|
test_steps_without_progress = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
double get_learning_rate_shrink_factor (
|
double get_learning_rate_shrink_factor (
|
||||||
@ -504,6 +565,14 @@ namespace dlib
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
void record_test_loss(double loss)
|
||||||
|
{
|
||||||
|
test_previous_loss_values.push_back(loss);
|
||||||
|
// discard really old loss values.
|
||||||
|
while (test_previous_loss_values.size() > test_iter_without_progress_thresh)
|
||||||
|
test_previous_loss_values.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
void record_loss(double loss)
|
void record_loss(double loss)
|
||||||
{
|
{
|
||||||
// Say that we will check if the gradient is bad 200 times during each
|
// Say that we will check if the gradient is bad 200 times during each
|
||||||
@ -527,7 +596,10 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
auto&& dev = *devices[device];
|
auto&& dev = *devices[device];
|
||||||
dlib::cuda::set_device(dev.device_id);
|
dlib::cuda::set_device(dev.device_id);
|
||||||
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
|
if (next_job.test_only)
|
||||||
|
return dev.net.compute_loss(next_job.t[device], next_job.labels[device].begin());
|
||||||
|
else
|
||||||
|
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -542,7 +614,10 @@ namespace dlib
|
|||||||
auto&& dev = *devices[device];
|
auto&& dev = *devices[device];
|
||||||
dlib::cuda::set_device(dev.device_id);
|
dlib::cuda::set_device(dev.device_id);
|
||||||
no_label_type pick_which_run_update;
|
no_label_type pick_which_run_update;
|
||||||
return dev.net.compute_parameter_gradients(next_job.t[device]);
|
if (next_job.test_only)
|
||||||
|
return dev.net.compute_loss(next_job.t[device]);
|
||||||
|
else
|
||||||
|
return dev.net.compute_parameter_gradients(next_job.t[device]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -585,6 +660,37 @@ namespace dlib
|
|||||||
main_iteration_counter = 0;
|
main_iteration_counter = 0;
|
||||||
while(job_pipe.dequeue(next_job))
|
while(job_pipe.dequeue(next_job))
|
||||||
{
|
{
|
||||||
|
if (next_job.test_only)
|
||||||
|
{
|
||||||
|
// compute the testing loss
|
||||||
|
for (size_t i = 0; i < devices.size(); ++i)
|
||||||
|
tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
|
||||||
|
// aggregate loss values from all the network computations.
|
||||||
|
double theloss = 0;
|
||||||
|
for (auto&& loss : losses)
|
||||||
|
theloss += loss.get();
|
||||||
|
record_test_loss(theloss/losses.size());
|
||||||
|
|
||||||
|
// Check if we should shrink the learning rate based on how the test
|
||||||
|
// error has been doing lately.
|
||||||
|
if (learning_rate_shrink != 1)
|
||||||
|
{
|
||||||
|
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
|
||||||
|
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
||||||
|
{
|
||||||
|
test_steps_without_progress = count_steps_without_decrease_robust(test_previous_loss_values);
|
||||||
|
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
||||||
|
{
|
||||||
|
// optimization has flattened out, so drop the learning rate.
|
||||||
|
learning_rate = learning_rate_shrink*learning_rate;
|
||||||
|
test_steps_without_progress = 0;
|
||||||
|
test_previous_loss_values.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
updated_net_since_last_sync = true;
|
updated_net_since_last_sync = true;
|
||||||
++main_iteration_counter;
|
++main_iteration_counter;
|
||||||
// Call compute_parameter_gradients() and update_parameters() but pick the
|
// Call compute_parameter_gradients() and update_parameters() but pick the
|
||||||
@ -719,7 +825,7 @@ namespace dlib
|
|||||||
job_pipe.wait_for_num_blocked_dequeues(1);
|
job_pipe.wait_for_num_blocked_dequeues(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const static long string_pad = 10;
|
const static long string_pad = 11;
|
||||||
const static long epoch_string_pad = 4;
|
const static long epoch_string_pad = 4;
|
||||||
const static long lr_string_pad = 4;
|
const static long lr_string_pad = 4;
|
||||||
|
|
||||||
@ -732,6 +838,9 @@ namespace dlib
|
|||||||
min_learning_rate = 1e-5;
|
min_learning_rate = 1e-5;
|
||||||
iter_without_progress_thresh = 2000;
|
iter_without_progress_thresh = 2000;
|
||||||
steps_without_progress = 0;
|
steps_without_progress = 0;
|
||||||
|
test_iter_without_progress_thresh = 200;
|
||||||
|
test_steps_without_progress = 0;
|
||||||
|
|
||||||
learning_rate_shrink = 0.1;
|
learning_rate_shrink = 0.1;
|
||||||
epoch_iteration = 0;
|
epoch_iteration = 0;
|
||||||
epoch_pos = 0;
|
epoch_pos = 0;
|
||||||
@ -755,7 +864,7 @@ namespace dlib
|
|||||||
friend void serialize(const dnn_trainer& item, std::ostream& out)
|
friend void serialize(const dnn_trainer& item, std::ostream& out)
|
||||||
{
|
{
|
||||||
item.wait_for_thread_to_pause();
|
item.wait_for_thread_to_pause();
|
||||||
int version = 7;
|
int version = 8;
|
||||||
serialize(version, out);
|
serialize(version, out);
|
||||||
|
|
||||||
size_t nl = dnn_trainer::num_layers;
|
size_t nl = dnn_trainer::num_layers;
|
||||||
@ -777,13 +886,17 @@ namespace dlib
|
|||||||
serialize(item.train_one_step_calls, out);
|
serialize(item.train_one_step_calls, out);
|
||||||
serialize(item.lr_schedule, out);
|
serialize(item.lr_schedule, out);
|
||||||
serialize(item.lr_schedule_pos, out);
|
serialize(item.lr_schedule_pos, out);
|
||||||
|
serialize(item.test_iter_without_progress_thresh.load(), out);
|
||||||
|
serialize(item.test_steps_without_progress.load(), out);
|
||||||
|
serialize(item.test_previous_loss_values, out);
|
||||||
|
|
||||||
}
|
}
|
||||||
friend void deserialize(dnn_trainer& item, std::istream& in)
|
friend void deserialize(dnn_trainer& item, std::istream& in)
|
||||||
{
|
{
|
||||||
item.wait_for_thread_to_pause();
|
item.wait_for_thread_to_pause();
|
||||||
int version = 0;
|
int version = 0;
|
||||||
deserialize(version, in);
|
deserialize(version, in);
|
||||||
if (version != 7)
|
if (version != 8)
|
||||||
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
|
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
|
||||||
|
|
||||||
size_t num_layers = 0;
|
size_t num_layers = 0;
|
||||||
@ -815,6 +928,9 @@ namespace dlib
|
|||||||
deserialize(item.train_one_step_calls, in);
|
deserialize(item.train_one_step_calls, in);
|
||||||
deserialize(item.lr_schedule, in);
|
deserialize(item.lr_schedule, in);
|
||||||
deserialize(item.lr_schedule_pos, in);
|
deserialize(item.lr_schedule_pos, in);
|
||||||
|
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
|
||||||
|
deserialize(ltemp, in); item.test_steps_without_progress = ltemp;
|
||||||
|
deserialize(item.test_previous_loss_values, in);
|
||||||
|
|
||||||
if (item.devices.size() > 1)
|
if (item.devices.size() > 1)
|
||||||
{
|
{
|
||||||
@ -966,6 +1082,7 @@ namespace dlib
|
|||||||
typename label_iterator
|
typename label_iterator
|
||||||
>
|
>
|
||||||
void send_job (
|
void send_job (
|
||||||
|
bool test_only,
|
||||||
data_iterator dbegin,
|
data_iterator dbegin,
|
||||||
data_iterator dend,
|
data_iterator dend,
|
||||||
label_iterator lbegin
|
label_iterator lbegin
|
||||||
@ -977,6 +1094,7 @@ namespace dlib
|
|||||||
job.t.resize(devs);
|
job.t.resize(devs);
|
||||||
job.labels.resize(devs);
|
job.labels.resize(devs);
|
||||||
job.have_data.resize(devs);
|
job.have_data.resize(devs);
|
||||||
|
job.test_only = test_only;
|
||||||
|
|
||||||
// chop the data into devs blocks, each of about block_size elements.
|
// chop the data into devs blocks, each of about block_size elements.
|
||||||
size_t block_size = (num+devs-1)/devs;
|
size_t block_size = (num+devs-1)/devs;
|
||||||
@ -1009,19 +1127,23 @@ namespace dlib
|
|||||||
typename data_iterator
|
typename data_iterator
|
||||||
>
|
>
|
||||||
void send_job (
|
void send_job (
|
||||||
|
bool test_only,
|
||||||
data_iterator dbegin,
|
data_iterator dbegin,
|
||||||
data_iterator dend
|
data_iterator dend
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
typename std::vector<training_label_type>::iterator nothing;
|
typename std::vector<training_label_type>::iterator nothing;
|
||||||
send_job(dbegin, dend, nothing);
|
send_job(test_only, dbegin, dend, nothing);
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_progress()
|
void print_progress()
|
||||||
{
|
{
|
||||||
if (lr_schedule.size() == 0)
|
if (lr_schedule.size() == 0)
|
||||||
{
|
{
|
||||||
std::cout << "steps without apparent progress: " << steps_without_progress;
|
if (test_previous_loss_values.size() == 0)
|
||||||
|
std::cout << "steps without apparent progress: " << steps_without_progress;
|
||||||
|
else
|
||||||
|
std::cout << "steps without apparent progress: train=" << steps_without_progress << ", test=" << test_steps_without_progress;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -1032,6 +1154,32 @@ namespace dlib
|
|||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_periodic_verbose_status()
|
||||||
|
{
|
||||||
|
if (verbose)
|
||||||
|
{
|
||||||
|
using namespace std::chrono;
|
||||||
|
auto now_time = system_clock::now();
|
||||||
|
if (now_time-last_time > seconds(40))
|
||||||
|
{
|
||||||
|
last_time = now_time;
|
||||||
|
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
|
||||||
|
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " ";
|
||||||
|
if (test_previous_loss_values.size() == 0)
|
||||||
|
{
|
||||||
|
std::cout << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::cout << "train loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
||||||
|
std::cout << "test loss: " << rpad(cast_to_string(get_average_test_loss()),string_pad) << " ";
|
||||||
|
}
|
||||||
|
print_progress();
|
||||||
|
clear_average_loss();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<device_data>> devices;
|
std::vector<std::shared_ptr<device_data>> devices;
|
||||||
dlib::pipe<job_t> job_pipe;
|
dlib::pipe<job_t> job_pipe;
|
||||||
job_t job;
|
job_t job;
|
||||||
@ -1047,6 +1195,11 @@ namespace dlib
|
|||||||
double min_learning_rate;
|
double min_learning_rate;
|
||||||
std::atomic<unsigned long> iter_without_progress_thresh;
|
std::atomic<unsigned long> iter_without_progress_thresh;
|
||||||
std::atomic<unsigned long> steps_without_progress;
|
std::atomic<unsigned long> steps_without_progress;
|
||||||
|
|
||||||
|
std::atomic<unsigned long> test_iter_without_progress_thresh;
|
||||||
|
std::atomic<unsigned long> test_steps_without_progress;
|
||||||
|
std::deque<double> test_previous_loss_values;
|
||||||
|
|
||||||
std::atomic<double> learning_rate_shrink;
|
std::atomic<double> learning_rate_shrink;
|
||||||
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
||||||
std::string sync_filename;
|
std::string sync_filename;
|
||||||
|
@ -76,6 +76,7 @@ namespace dlib
|
|||||||
- #get_learning_rate() == 1e-2
|
- #get_learning_rate() == 1e-2
|
||||||
- #get_min_learning_rate() == 1e-5
|
- #get_min_learning_rate() == 1e-5
|
||||||
- #get_iterations_without_progress_threshold() == 2000
|
- #get_iterations_without_progress_threshold() == 2000
|
||||||
|
- #get_test_iterations_without_progress_threshold() == 200
|
||||||
- #get_learning_rate_shrink_factor() == 0.1
|
- #get_learning_rate_shrink_factor() == 0.1
|
||||||
- #get_learning_rate_schedule().size() == 0
|
- #get_learning_rate_schedule().size() == 0
|
||||||
- #get_train_one_step_calls() == 0
|
- #get_train_one_step_calls() == 0
|
||||||
@ -536,6 +537,170 @@ namespace dlib
|
|||||||
stopped touching the net.
|
stopped touching the net.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
// ----------------------
|
||||||
|
|
||||||
|
double get_average_test_loss (
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns the average loss value observed during previous calls to
|
||||||
|
test_one_step().
|
||||||
|
- This function blocks until all threads inside the dnn_trainer have
|
||||||
|
stopped touching the net.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void test_one_step (
|
||||||
|
const std::vector<input_type>& data,
|
||||||
|
const std::vector<training_label_type>& labels
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- data.size() == labels.size()
|
||||||
|
- data.size() > 0
|
||||||
|
- net_type uses a supervised loss.
|
||||||
|
i.e. net_type::training_label_type != no_label_type.
|
||||||
|
ensures
|
||||||
|
- Runs the given data through the network and computes and records the loss.
|
||||||
|
- This call does not modify network parameters. The point of
|
||||||
|
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||||
|
network on hold out data during training, and to allow the trainer to
|
||||||
|
automatically adjust the learning rate when the test loss stops
|
||||||
|
improving. It should be noted that you are not required to use
|
||||||
|
test_one_step() at all, but if you want to do this kind of thing it is
|
||||||
|
available.
|
||||||
|
- You can observe the current average loss value by calling get_average_test_loss().
|
||||||
|
- The computation will happen in another thread. Therefore, after calling
|
||||||
|
this function you should call get_net() before you touch the net object
|
||||||
|
from the calling thread to ensure no other threads are still accessing
|
||||||
|
the network.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator,
|
||||||
|
typename label_iterator
|
||||||
|
>
|
||||||
|
void test_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend,
|
||||||
|
label_iterator lbegin
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
|
||||||
|
- std::distance(dbegin, dend) > 0
|
||||||
|
- net_type uses a supervised loss.
|
||||||
|
i.e. net_type::training_label_type != no_label_type.
|
||||||
|
ensures
|
||||||
|
- Runs the given data through the network and computes and records the loss.
|
||||||
|
- This call does not modify network parameters. The point of
|
||||||
|
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||||
|
network on hold out data during training, and to allow the trainer to
|
||||||
|
automatically adjust the learning rate when the test loss stops
|
||||||
|
improving. It should be noted that you are not required to use
|
||||||
|
test_one_step() at all, but if you want to do this kind of thing it is
|
||||||
|
available.
|
||||||
|
- You can observe the current average loss value by calling get_average_test_loss().
|
||||||
|
- The computation will happen in another thread. Therefore, after calling
|
||||||
|
this function you should call get_net() before you touch the net object
|
||||||
|
from the calling thread to ensure no other threads are still accessing
|
||||||
|
the network.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void test_one_step (
|
||||||
|
const std::vector<input_type>& data
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- data.size() > 0
|
||||||
|
- net_type uses an unsupervised loss.
|
||||||
|
i.e. net_type::training_label_type == no_label_type.
|
||||||
|
ensures
|
||||||
|
- Runs the given data through the network and computes and records the loss.
|
||||||
|
- This call does not modify network parameters. The point of
|
||||||
|
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||||
|
network on hold out data during training, and to allow the trainer to
|
||||||
|
automatically adjust the learning rate when the test loss stops
|
||||||
|
improving. It should be noted that you are not required to use
|
||||||
|
test_one_step() at all, but if you want to do this kind of thing it is
|
||||||
|
available.
|
||||||
|
- You can observe the current average loss value by calling get_average_test_loss().
|
||||||
|
- The computation will happen in another thread. Therefore, after calling
|
||||||
|
this function you should call get_net() before you touch the net object
|
||||||
|
from the calling thread to ensure no other threads are still accessing
|
||||||
|
the network.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator
|
||||||
|
>
|
||||||
|
void test_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- std::distance(dbegin, dend) > 0
|
||||||
|
- net_type uses an unsupervised loss.
|
||||||
|
i.e. net_type::training_label_type == no_label_type.
|
||||||
|
ensures
|
||||||
|
- Runs the given data through the network and computes and records the loss.
|
||||||
|
- This call does not modify network parameters. The point of
|
||||||
|
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||||
|
network on hold out data during training, and to allow the trainer to
|
||||||
|
automatically adjust the learning rate when the test loss stops
|
||||||
|
improving. It should be noted that you are not required to use
|
||||||
|
test_one_step() at all, but if you want to do this kind of thing it is
|
||||||
|
available.
|
||||||
|
- You can observe the current average loss value by calling get_average_test_loss().
|
||||||
|
- The computation will happen in another thread. Therefore, after calling
|
||||||
|
this function you should call get_net() before you touch the net object
|
||||||
|
from the calling thread to ensure no other threads are still accessing
|
||||||
|
the network.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void set_test_iterations_without_progress_threshold (
|
||||||
|
unsigned long thresh
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- #get_test_iterations_without_progress_threshold() == thresh
|
||||||
|
- #get_learning_rate_schedule().size() == 0
|
||||||
|
- This function blocks until all threads inside the dnn_trainer have
|
||||||
|
stopped touching the net.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
unsigned long get_test_iterations_without_progress_threshold (
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- This object monitors the progress of training and estimates if the
|
||||||
|
testing error is being reduced. It does this by looking at the previous
|
||||||
|
get_test_iterations_without_progress_threshold() mini-batch results from
|
||||||
|
test_one_step() and applying the statistical test defined by the
|
||||||
|
running_gradient object to see if the testing error is getting smaller.
|
||||||
|
If it isn't being reduced then get_learning_rate() is made smaller by a
|
||||||
|
factor of get_learning_rate_shrink_factor().
|
||||||
|
|
||||||
|
Therefore, get_test_iterations_without_progress_threshold() should always be
|
||||||
|
set to something sensibly large so that this test can be done with
|
||||||
|
reasonably high confidence. Think of this test as saying "if the testing loss
|
||||||
|
hasn't decreased for the previous get_test_iterations_without_progress_threshold()
|
||||||
|
calls to test_one_step() then shrink the learning rate".
|
||||||
|
!*/
|
||||||
|
|
||||||
|
unsigned long get_test_steps_without_progress (
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- if (get_learning_rate_shrink_factor() != 1) then
|
||||||
|
- returns an estimate of how many mini-batches have executed without us
|
||||||
|
observing a statistically significant decrease in the testing error
|
||||||
|
(i.e. the error on the data given to the trainer via test_one_step()
|
||||||
|
calls).
|
||||||
|
- else
|
||||||
|
- returns 0
|
||||||
|
!*/
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
Loading…
Reference in New Issue
Block a user