diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 337099d73..84b311ce3 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -131,9 +131,10 @@ namespace dlib } net_type& get_net ( - ) const + ) { wait_for_thread_to_pause(); + sync_to_disk(true); propagate_exception(); return net; } @@ -266,7 +267,6 @@ namespace dlib { DLIB_CASSERT(data.size() == labels.size() && data.size() > 0); - bool updated_the_network = false; // The reason these two loops don't initialize their counter variables but // instead use class members is so we can include the state of the loops in the // stuff written by sync_to_disk() @@ -297,7 +297,6 @@ namespace dlib send_job(data.begin()+epoch_pos, data.begin()+std::min(epoch_pos+mini_batch_size,data.size()), labels.begin()+epoch_pos); - updated_the_network = true; } epoch_pos = 0; @@ -313,7 +312,7 @@ namespace dlib } wait_for_thread_to_pause(); // if we modified the network at all then be sure to sync the final result. - sync_to_disk(updated_the_network); + sync_to_disk(true); } void train ( @@ -326,7 +325,6 @@ namespace dlib static_assert(has_unsupervised_loss, "You can only call this version of train() when using an unsupervised loss."); - bool updated_the_network = false; // The reason these two loops don't initialize their counter variables but // instead use class members is so we can include the state of the loops in the // stuff written by sync_to_disk() @@ -356,7 +354,6 @@ namespace dlib sync_to_disk(); send_job(data.begin()+epoch_pos, data.begin()+std::min(epoch_pos+mini_batch_size,data.size())); - updated_the_network = true; } epoch_pos = 0; @@ -372,7 +369,7 @@ namespace dlib } wait_for_thread_to_pause(); // if we modified the network at all then be sure to sync the final result. - sync_to_disk(updated_the_network); + sync_to_disk(true); } void set_synchronization_file ( @@ -588,6 +585,7 @@ namespace dlib main_iteration_counter = 0; while(job_pipe.dequeue(next_job)) { + updated_net_since_last_sync = true; ++main_iteration_counter; // Call compute_parameter_gradients() and update_parameters() but pick the // right version for unsupervised or supervised training based on the type @@ -746,6 +744,7 @@ namespace dlib prob_loss_increasing_thresh_default_value = 0.99; prob_loss_increasing_thresh_max_value = 0.99999; prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; + updated_net_since_last_sync = false; start(); } @@ -832,10 +831,15 @@ namespace dlib dlib::cuda::set_device(prev_dev); } } + void sync_to_disk ( bool do_it_now = false - ) + ) { + // don't sync anything if we haven't updated the network since the last sync + if (!updated_net_since_last_sync) + return; + // If the sync file isn't set then don't do anything. if (sync_filename.size() == 0) return; @@ -879,6 +883,7 @@ namespace dlib last_sync_time = std::chrono::system_clock::now(); main_iteration_counter_at_last_disk_sync = main_iteration_counter; + updated_net_since_last_sync = false; } } @@ -1069,6 +1074,7 @@ namespace dlib double prob_loss_increasing_thresh_default_value; double prob_loss_increasing_thresh_max_value; double prob_loss_increasing_thresh; + std::atomic updated_net_since_last_sync; }; diff --git a/dlib/dnn/trainer_abstract.h b/dlib/dnn/trainer_abstract.h index 1faf9548e..24c25fb1d 100644 --- a/dlib/dnn/trainer_abstract.h +++ b/dlib/dnn/trainer_abstract.h @@ -88,8 +88,8 @@ namespace dlib cuda_extra_devices. !*/ - net_type& get_net ( - ) const; + net_type& get_net ( + ); /*! ensures - returns the neural network object used by this trainer. This is the @@ -99,6 +99,8 @@ namespace dlib dnn_trainer's constructor. - This function blocks until all threads inside the dnn_trainer have stopped touching the net. + - This function will sync the trainer state to disk if the current state + hasn't already been synced to disk since the last network modification. !*/ const std::vector& get_solvers (