|
|
|
@ -131,9 +131,10 @@ namespace dlib
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
net_type& get_net (
|
|
|
|
|
) const
|
|
|
|
|
)
|
|
|
|
|
{
|
|
|
|
|
wait_for_thread_to_pause();
|
|
|
|
|
sync_to_disk(true);
|
|
|
|
|
propagate_exception();
|
|
|
|
|
return net;
|
|
|
|
|
}
|
|
|
|
@ -266,7 +267,6 @@ namespace dlib
|
|
|
|
|
{
|
|
|
|
|
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
|
|
|
|
|
|
|
|
|
|
bool updated_the_network = false;
|
|
|
|
|
// The reason these two loops don't initialize their counter variables but
|
|
|
|
|
// instead use class members is so we can include the state of the loops in the
|
|
|
|
|
// stuff written by sync_to_disk()
|
|
|
|
@ -297,7 +297,6 @@ namespace dlib
|
|
|
|
|
send_job(data.begin()+epoch_pos,
|
|
|
|
|
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
|
|
|
|
|
labels.begin()+epoch_pos);
|
|
|
|
|
updated_the_network = true;
|
|
|
|
|
}
|
|
|
|
|
epoch_pos = 0;
|
|
|
|
|
|
|
|
|
@ -313,7 +312,7 @@ namespace dlib
|
|
|
|
|
}
|
|
|
|
|
wait_for_thread_to_pause();
|
|
|
|
|
// if we modified the network at all then be sure to sync the final result.
|
|
|
|
|
sync_to_disk(updated_the_network);
|
|
|
|
|
sync_to_disk(true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void train (
|
|
|
|
@ -326,7 +325,6 @@ namespace dlib
|
|
|
|
|
static_assert(has_unsupervised_loss,
|
|
|
|
|
"You can only call this version of train() when using an unsupervised loss.");
|
|
|
|
|
|
|
|
|
|
bool updated_the_network = false;
|
|
|
|
|
// The reason these two loops don't initialize their counter variables but
|
|
|
|
|
// instead use class members is so we can include the state of the loops in the
|
|
|
|
|
// stuff written by sync_to_disk()
|
|
|
|
@ -356,7 +354,6 @@ namespace dlib
|
|
|
|
|
sync_to_disk();
|
|
|
|
|
send_job(data.begin()+epoch_pos,
|
|
|
|
|
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
|
|
|
|
|
updated_the_network = true;
|
|
|
|
|
}
|
|
|
|
|
epoch_pos = 0;
|
|
|
|
|
|
|
|
|
@ -372,7 +369,7 @@ namespace dlib
|
|
|
|
|
}
|
|
|
|
|
wait_for_thread_to_pause();
|
|
|
|
|
// if we modified the network at all then be sure to sync the final result.
|
|
|
|
|
sync_to_disk(updated_the_network);
|
|
|
|
|
sync_to_disk(true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void set_synchronization_file (
|
|
|
|
@ -588,6 +585,7 @@ namespace dlib
|
|
|
|
|
main_iteration_counter = 0;
|
|
|
|
|
while(job_pipe.dequeue(next_job))
|
|
|
|
|
{
|
|
|
|
|
updated_net_since_last_sync = true;
|
|
|
|
|
++main_iteration_counter;
|
|
|
|
|
// Call compute_parameter_gradients() and update_parameters() but pick the
|
|
|
|
|
// right version for unsupervised or supervised training based on the type
|
|
|
|
@ -746,6 +744,7 @@ namespace dlib
|
|
|
|
|
prob_loss_increasing_thresh_default_value = 0.99;
|
|
|
|
|
prob_loss_increasing_thresh_max_value = 0.99999;
|
|
|
|
|
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
|
|
|
|
|
updated_net_since_last_sync = false;
|
|
|
|
|
start();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -832,10 +831,15 @@ namespace dlib
|
|
|
|
|
dlib::cuda::set_device(prev_dev);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void sync_to_disk (
|
|
|
|
|
bool do_it_now = false
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
{
|
|
|
|
|
// don't sync anything if we haven't updated the network since the last sync
|
|
|
|
|
if (!updated_net_since_last_sync)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
// If the sync file isn't set then don't do anything.
|
|
|
|
|
if (sync_filename.size() == 0)
|
|
|
|
|
return;
|
|
|
|
@ -879,6 +883,7 @@ namespace dlib
|
|
|
|
|
|
|
|
|
|
last_sync_time = std::chrono::system_clock::now();
|
|
|
|
|
main_iteration_counter_at_last_disk_sync = main_iteration_counter;
|
|
|
|
|
updated_net_since_last_sync = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1069,6 +1074,7 @@ namespace dlib
|
|
|
|
|
double prob_loss_increasing_thresh_default_value;
|
|
|
|
|
double prob_loss_increasing_thresh_max_value;
|
|
|
|
|
double prob_loss_increasing_thresh;
|
|
|
|
|
std::atomic<bool> updated_net_since_last_sync;
|
|
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|