mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Changed code to avoid recreating thread_local cuda context objects.
This commit is contained in:
parent
e55afabd1a
commit
974743767f
@ -535,7 +535,15 @@ namespace dlib
|
||||
std::vector<tensor*> reference_params;
|
||||
visit_layer_parameters(devices[0]->net, [&](size_t, tensor& t) { reference_params.push_back(&t); });
|
||||
|
||||
thread_pool tp(devices.size());
|
||||
// We make separate thread pools with just one thread in them because we want
|
||||
// to make sure each device is always executed on the same thread. We care
|
||||
// about this because there are thread_local context variables for some cuda
|
||||
// components and they get regenerated when the current cuda device changes.
|
||||
// Recreating them over and over is somewhat expensive so we want to avoid
|
||||
// that.
|
||||
std::vector<std::shared_ptr<thread_pool>> tp;
|
||||
for (size_t i = 0; i < devices.size(); ++i)
|
||||
tp.push_back(std::make_shared<thread_pool>(1));
|
||||
|
||||
|
||||
size_t iteration = 0;
|
||||
@ -546,7 +554,7 @@ namespace dlib
|
||||
// right version for unsupervised or supervised training based on the type
|
||||
// of label_type.
|
||||
for (size_t i = 0; i < devices.size(); ++i)
|
||||
tp.add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
|
||||
tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
|
||||
// aggregate loss values from all the network computations.
|
||||
double theloss = 0;
|
||||
for (auto&& loss : losses)
|
||||
@ -597,9 +605,10 @@ namespace dlib
|
||||
|
||||
// Now apply all the updates to each device.
|
||||
for (size_t i = 0; i < devices.size(); ++i)
|
||||
tp.add_task_by_value([&,i](){ if (next_job.have_data[i]) update_parameters(i); });
|
||||
tp[i]->add_task_by_value([&,i](){ if (next_job.have_data[i]) update_parameters(i); });
|
||||
// and wait for the updates to all happen.
|
||||
tp.wait_for_all_tasks();
|
||||
for (size_t i = 0; i < devices.size(); ++i)
|
||||
tp[i]->wait_for_all_tasks();
|
||||
|
||||
|
||||
// Evey now and then force all the parameters to be the same just to make
|
||||
|
Loading…
Reference in New Issue
Block a user