Made get_net() sync to disk if the disk date is out of date. This way, when

using train_one_step(), you will get the behavior of automatic disk syncs at
the end of training.
pull/374/head
Davis King 8 years ago
parent cb198afc3f
commit 29047a2269

@ -131,9 +131,10 @@ namespace dlib
}
net_type& get_net (
) const
)
{
wait_for_thread_to_pause();
sync_to_disk(true);
propagate_exception();
return net;
}
@ -266,7 +267,6 @@ namespace dlib
{
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
bool updated_the_network = false;
// The reason these two loops don't initialize their counter variables but
// instead use class members is so we can include the state of the loops in the
// stuff written by sync_to_disk()
@ -297,7 +297,6 @@ namespace dlib
send_job(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
labels.begin()+epoch_pos);
updated_the_network = true;
}
epoch_pos = 0;
@ -313,7 +312,7 @@ namespace dlib
}
wait_for_thread_to_pause();
// if we modified the network at all then be sure to sync the final result.
sync_to_disk(updated_the_network);
sync_to_disk(true);
}
void train (
@ -326,7 +325,6 @@ namespace dlib
static_assert(has_unsupervised_loss,
"You can only call this version of train() when using an unsupervised loss.");
bool updated_the_network = false;
// The reason these two loops don't initialize their counter variables but
// instead use class members is so we can include the state of the loops in the
// stuff written by sync_to_disk()
@ -356,7 +354,6 @@ namespace dlib
sync_to_disk();
send_job(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
updated_the_network = true;
}
epoch_pos = 0;
@ -372,7 +369,7 @@ namespace dlib
}
wait_for_thread_to_pause();
// if we modified the network at all then be sure to sync the final result.
sync_to_disk(updated_the_network);
sync_to_disk(true);
}
void set_synchronization_file (
@ -588,6 +585,7 @@ namespace dlib
main_iteration_counter = 0;
while(job_pipe.dequeue(next_job))
{
updated_net_since_last_sync = true;
++main_iteration_counter;
// Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type
@ -746,6 +744,7 @@ namespace dlib
prob_loss_increasing_thresh_default_value = 0.99;
prob_loss_increasing_thresh_max_value = 0.99999;
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync = false;
start();
}
@ -832,10 +831,15 @@ namespace dlib
dlib::cuda::set_device(prev_dev);
}
}
void sync_to_disk (
bool do_it_now = false
)
)
{
// don't sync anything if we haven't updated the network since the last sync
if (!updated_net_since_last_sync)
return;
// If the sync file isn't set then don't do anything.
if (sync_filename.size() == 0)
return;
@ -879,6 +883,7 @@ namespace dlib
last_sync_time = std::chrono::system_clock::now();
main_iteration_counter_at_last_disk_sync = main_iteration_counter;
updated_net_since_last_sync = false;
}
}
@ -1069,6 +1074,7 @@ namespace dlib
double prob_loss_increasing_thresh_default_value;
double prob_loss_increasing_thresh_max_value;
double prob_loss_increasing_thresh;
std::atomic<bool> updated_net_since_last_sync;
};

@ -88,8 +88,8 @@ namespace dlib
cuda_extra_devices.
!*/
net_type& get_net (
) const;
net_type& get_net (
);
/*!
ensures
- returns the neural network object used by this trainer. This is the
@ -99,6 +99,8 @@ namespace dlib
dnn_trainer's constructor.
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
- This function will sync the trainer state to disk if the current state
hasn't already been synced to disk since the last network modification.
!*/
const std::vector<solver_type>& get_solvers (

Loading…
Cancel
Save