Made get_net() sync to disk if the disk date is out of date. This way, when

using train_one_step(), you will get the behavior of automatic disk syncs at
the end of training.
This commit is contained in:
Davis King 2016-12-17 13:37:01 -05:00
parent cb198afc3f
commit 29047a2269
2 changed files with 18 additions and 10 deletions

View File

@ -131,9 +131,10 @@ namespace dlib
} }
net_type& get_net ( net_type& get_net (
) const )
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
sync_to_disk(true);
propagate_exception(); propagate_exception();
return net; return net;
} }
@ -266,7 +267,6 @@ namespace dlib
{ {
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0); DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
bool updated_the_network = false;
// The reason these two loops don't initialize their counter variables but // The reason these two loops don't initialize their counter variables but
// instead use class members is so we can include the state of the loops in the // instead use class members is so we can include the state of the loops in the
// stuff written by sync_to_disk() // stuff written by sync_to_disk()
@ -297,7 +297,6 @@ namespace dlib
send_job(data.begin()+epoch_pos, send_job(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()), data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
labels.begin()+epoch_pos); labels.begin()+epoch_pos);
updated_the_network = true;
} }
epoch_pos = 0; epoch_pos = 0;
@ -313,7 +312,7 @@ namespace dlib
} }
wait_for_thread_to_pause(); wait_for_thread_to_pause();
// if we modified the network at all then be sure to sync the final result. // if we modified the network at all then be sure to sync the final result.
sync_to_disk(updated_the_network); sync_to_disk(true);
} }
void train ( void train (
@ -326,7 +325,6 @@ namespace dlib
static_assert(has_unsupervised_loss, static_assert(has_unsupervised_loss,
"You can only call this version of train() when using an unsupervised loss."); "You can only call this version of train() when using an unsupervised loss.");
bool updated_the_network = false;
// The reason these two loops don't initialize their counter variables but // The reason these two loops don't initialize their counter variables but
// instead use class members is so we can include the state of the loops in the // instead use class members is so we can include the state of the loops in the
// stuff written by sync_to_disk() // stuff written by sync_to_disk()
@ -356,7 +354,6 @@ namespace dlib
sync_to_disk(); sync_to_disk();
send_job(data.begin()+epoch_pos, send_job(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size())); data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
updated_the_network = true;
} }
epoch_pos = 0; epoch_pos = 0;
@ -372,7 +369,7 @@ namespace dlib
} }
wait_for_thread_to_pause(); wait_for_thread_to_pause();
// if we modified the network at all then be sure to sync the final result. // if we modified the network at all then be sure to sync the final result.
sync_to_disk(updated_the_network); sync_to_disk(true);
} }
void set_synchronization_file ( void set_synchronization_file (
@ -588,6 +585,7 @@ namespace dlib
main_iteration_counter = 0; main_iteration_counter = 0;
while(job_pipe.dequeue(next_job)) while(job_pipe.dequeue(next_job))
{ {
updated_net_since_last_sync = true;
++main_iteration_counter; ++main_iteration_counter;
// Call compute_parameter_gradients() and update_parameters() but pick the // Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type // right version for unsupervised or supervised training based on the type
@ -746,6 +744,7 @@ namespace dlib
prob_loss_increasing_thresh_default_value = 0.99; prob_loss_increasing_thresh_default_value = 0.99;
prob_loss_increasing_thresh_max_value = 0.99999; prob_loss_increasing_thresh_max_value = 0.99999;
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync = false;
start(); start();
} }
@ -832,10 +831,15 @@ namespace dlib
dlib::cuda::set_device(prev_dev); dlib::cuda::set_device(prev_dev);
} }
} }
void sync_to_disk ( void sync_to_disk (
bool do_it_now = false bool do_it_now = false
) )
{ {
// don't sync anything if we haven't updated the network since the last sync
if (!updated_net_since_last_sync)
return;
// If the sync file isn't set then don't do anything. // If the sync file isn't set then don't do anything.
if (sync_filename.size() == 0) if (sync_filename.size() == 0)
return; return;
@ -879,6 +883,7 @@ namespace dlib
last_sync_time = std::chrono::system_clock::now(); last_sync_time = std::chrono::system_clock::now();
main_iteration_counter_at_last_disk_sync = main_iteration_counter; main_iteration_counter_at_last_disk_sync = main_iteration_counter;
updated_net_since_last_sync = false;
} }
} }
@ -1069,6 +1074,7 @@ namespace dlib
double prob_loss_increasing_thresh_default_value; double prob_loss_increasing_thresh_default_value;
double prob_loss_increasing_thresh_max_value; double prob_loss_increasing_thresh_max_value;
double prob_loss_increasing_thresh; double prob_loss_increasing_thresh;
std::atomic<bool> updated_net_since_last_sync;
}; };

View File

@ -88,8 +88,8 @@ namespace dlib
cuda_extra_devices. cuda_extra_devices.
!*/ !*/
net_type& get_net ( net_type& get_net (
) const; );
/*! /*!
ensures ensures
- returns the neural network object used by this trainer. This is the - returns the neural network object used by this trainer. This is the
@ -99,6 +99,8 @@ namespace dlib
dnn_trainer's constructor. dnn_trainer's constructor.
- This function blocks until all threads inside the dnn_trainer have - This function blocks until all threads inside the dnn_trainer have
stopped touching the net. stopped touching the net.
- This function will sync the trainer state to disk if the current state
hasn't already been synced to disk since the last network modification.
!*/ !*/
const std::vector<solver_type>& get_solvers ( const std::vector<solver_type>& get_solvers (