mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Made get_net() sync to disk if the disk date is out of date. This way, when
using train_one_step(), you will get the behavior of automatic disk syncs at the end of training.
This commit is contained in:
parent
cb198afc3f
commit
29047a2269
@ -131,9 +131,10 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
|
|
||||||
net_type& get_net (
|
net_type& get_net (
|
||||||
) const
|
)
|
||||||
{
|
{
|
||||||
wait_for_thread_to_pause();
|
wait_for_thread_to_pause();
|
||||||
|
sync_to_disk(true);
|
||||||
propagate_exception();
|
propagate_exception();
|
||||||
return net;
|
return net;
|
||||||
}
|
}
|
||||||
@ -266,7 +267,6 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
|
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
|
||||||
|
|
||||||
bool updated_the_network = false;
|
|
||||||
// The reason these two loops don't initialize their counter variables but
|
// The reason these two loops don't initialize their counter variables but
|
||||||
// instead use class members is so we can include the state of the loops in the
|
// instead use class members is so we can include the state of the loops in the
|
||||||
// stuff written by sync_to_disk()
|
// stuff written by sync_to_disk()
|
||||||
@ -297,7 +297,6 @@ namespace dlib
|
|||||||
send_job(data.begin()+epoch_pos,
|
send_job(data.begin()+epoch_pos,
|
||||||
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
|
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
|
||||||
labels.begin()+epoch_pos);
|
labels.begin()+epoch_pos);
|
||||||
updated_the_network = true;
|
|
||||||
}
|
}
|
||||||
epoch_pos = 0;
|
epoch_pos = 0;
|
||||||
|
|
||||||
@ -313,7 +312,7 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
wait_for_thread_to_pause();
|
wait_for_thread_to_pause();
|
||||||
// if we modified the network at all then be sure to sync the final result.
|
// if we modified the network at all then be sure to sync the final result.
|
||||||
sync_to_disk(updated_the_network);
|
sync_to_disk(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void train (
|
void train (
|
||||||
@ -326,7 +325,6 @@ namespace dlib
|
|||||||
static_assert(has_unsupervised_loss,
|
static_assert(has_unsupervised_loss,
|
||||||
"You can only call this version of train() when using an unsupervised loss.");
|
"You can only call this version of train() when using an unsupervised loss.");
|
||||||
|
|
||||||
bool updated_the_network = false;
|
|
||||||
// The reason these two loops don't initialize their counter variables but
|
// The reason these two loops don't initialize their counter variables but
|
||||||
// instead use class members is so we can include the state of the loops in the
|
// instead use class members is so we can include the state of the loops in the
|
||||||
// stuff written by sync_to_disk()
|
// stuff written by sync_to_disk()
|
||||||
@ -356,7 +354,6 @@ namespace dlib
|
|||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(data.begin()+epoch_pos,
|
send_job(data.begin()+epoch_pos,
|
||||||
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
|
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
|
||||||
updated_the_network = true;
|
|
||||||
}
|
}
|
||||||
epoch_pos = 0;
|
epoch_pos = 0;
|
||||||
|
|
||||||
@ -372,7 +369,7 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
wait_for_thread_to_pause();
|
wait_for_thread_to_pause();
|
||||||
// if we modified the network at all then be sure to sync the final result.
|
// if we modified the network at all then be sure to sync the final result.
|
||||||
sync_to_disk(updated_the_network);
|
sync_to_disk(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_synchronization_file (
|
void set_synchronization_file (
|
||||||
@ -588,6 +585,7 @@ namespace dlib
|
|||||||
main_iteration_counter = 0;
|
main_iteration_counter = 0;
|
||||||
while(job_pipe.dequeue(next_job))
|
while(job_pipe.dequeue(next_job))
|
||||||
{
|
{
|
||||||
|
updated_net_since_last_sync = true;
|
||||||
++main_iteration_counter;
|
++main_iteration_counter;
|
||||||
// Call compute_parameter_gradients() and update_parameters() but pick the
|
// Call compute_parameter_gradients() and update_parameters() but pick the
|
||||||
// right version for unsupervised or supervised training based on the type
|
// right version for unsupervised or supervised training based on the type
|
||||||
@ -746,6 +744,7 @@ namespace dlib
|
|||||||
prob_loss_increasing_thresh_default_value = 0.99;
|
prob_loss_increasing_thresh_default_value = 0.99;
|
||||||
prob_loss_increasing_thresh_max_value = 0.99999;
|
prob_loss_increasing_thresh_max_value = 0.99999;
|
||||||
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
|
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
|
||||||
|
updated_net_since_last_sync = false;
|
||||||
start();
|
start();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -832,10 +831,15 @@ namespace dlib
|
|||||||
dlib::cuda::set_device(prev_dev);
|
dlib::cuda::set_device(prev_dev);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void sync_to_disk (
|
void sync_to_disk (
|
||||||
bool do_it_now = false
|
bool do_it_now = false
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
// don't sync anything if we haven't updated the network since the last sync
|
||||||
|
if (!updated_net_since_last_sync)
|
||||||
|
return;
|
||||||
|
|
||||||
// If the sync file isn't set then don't do anything.
|
// If the sync file isn't set then don't do anything.
|
||||||
if (sync_filename.size() == 0)
|
if (sync_filename.size() == 0)
|
||||||
return;
|
return;
|
||||||
@ -879,6 +883,7 @@ namespace dlib
|
|||||||
|
|
||||||
last_sync_time = std::chrono::system_clock::now();
|
last_sync_time = std::chrono::system_clock::now();
|
||||||
main_iteration_counter_at_last_disk_sync = main_iteration_counter;
|
main_iteration_counter_at_last_disk_sync = main_iteration_counter;
|
||||||
|
updated_net_since_last_sync = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1069,6 +1074,7 @@ namespace dlib
|
|||||||
double prob_loss_increasing_thresh_default_value;
|
double prob_loss_increasing_thresh_default_value;
|
||||||
double prob_loss_increasing_thresh_max_value;
|
double prob_loss_increasing_thresh_max_value;
|
||||||
double prob_loss_increasing_thresh;
|
double prob_loss_increasing_thresh;
|
||||||
|
std::atomic<bool> updated_net_since_last_sync;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -88,8 +88,8 @@ namespace dlib
|
|||||||
cuda_extra_devices.
|
cuda_extra_devices.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
net_type& get_net (
|
net_type& get_net (
|
||||||
) const;
|
);
|
||||||
/*!
|
/*!
|
||||||
ensures
|
ensures
|
||||||
- returns the neural network object used by this trainer. This is the
|
- returns the neural network object used by this trainer. This is the
|
||||||
@ -99,6 +99,8 @@ namespace dlib
|
|||||||
dnn_trainer's constructor.
|
dnn_trainer's constructor.
|
||||||
- This function blocks until all threads inside the dnn_trainer have
|
- This function blocks until all threads inside the dnn_trainer have
|
||||||
stopped touching the net.
|
stopped touching the net.
|
||||||
|
- This function will sync the trainer state to disk if the current state
|
||||||
|
hasn't already been synced to disk since the last network modification.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
const std::vector<solver_type>& get_solvers (
|
const std::vector<solver_type>& get_solvers (
|
||||||
|
Loading…
Reference in New Issue
Block a user