diff --git a/dlib/test/thread_pool.cpp b/dlib/test/thread_pool.cpp index 2b5054275..73ccb346e 100644 --- a/dlib/test/thread_pool.cpp +++ b/dlib/test/thread_pool.cpp @@ -369,6 +369,39 @@ namespace DLIB_TEST(d == 4); } + + tp.wait_for_all_tasks(); + + // make sure exception propagation from tasks works correctly. + auto f_throws = []() { throw dlib::error("test exception");}; + bool got_exception = false; + try + { + tp.add_task_by_value(f_throws); + tp.wait_for_all_tasks(); + } + catch(dlib::error& e) + { + DLIB_TEST(e.info == "test exception"); + got_exception = true; + } + DLIB_TEST(got_exception); + + dlib::future aa; + auto f_throws2 = [](int& a) { a = 1; throw dlib::error("test exception");}; + got_exception = false; + try + { + tp.add_task(f_throws2, aa); + aa.get(); + } + catch(dlib::error& e) + { + DLIB_TEST(e.info == "test exception"); + got_exception = true; + } + DLIB_TEST(got_exception); + } } diff --git a/dlib/threads/parallel_for_extension_abstract.h b/dlib/threads/parallel_for_extension_abstract.h index fd6f4645f..ffd2e0c44 100644 --- a/dlib/threads/parallel_for_extension_abstract.h +++ b/dlib/threads/parallel_for_extension_abstract.h @@ -24,7 +24,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This is a convenience function for submitting a block of jobs to a thread_pool. In particular, given the half open range [begin, end), this function will @@ -61,7 +60,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following block of code: thread_pool tp(num_threads); @@ -82,7 +80,6 @@ namespace dlib requires - chunks_per_thread > 0 - begin <= end - - funct does not throw any exceptions ensures - This is a convenience function for submitting a block of jobs to a thread_pool. In particular, given the range [begin, end), this function will @@ -117,7 +114,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following block of code: thread_pool tp(num_threads); @@ -137,7 +133,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following block of code: parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread); @@ -159,7 +154,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following function call: parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub) @@ -189,7 +183,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following block of code: thread_pool tp(num_threads); @@ -210,7 +203,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following function call: parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub) @@ -238,7 +230,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following block of code: thread_pool tp(num_threads); @@ -258,7 +249,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is equivalent to the following block of code: parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread); @@ -280,7 +270,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for() routine defined above except that it will print messages to cout showing the progress in executing the @@ -302,7 +291,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for() routine defined above except that it will print messages to cout showing the progress in executing the @@ -323,7 +311,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for() routine defined above except that it will print messages to cout showing the progress in executing the @@ -344,7 +331,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for() routine defined above except that it will print messages to cout showing the progress in executing the @@ -364,7 +350,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for() routine defined above except that it will print messages to cout showing the progress in executing the @@ -388,7 +373,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for_blocked() routine defined above except that it will print messages to cout showing the progress in @@ -410,7 +394,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for_blocked() routine defined above except that it will print messages to cout showing the progress in @@ -431,7 +414,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for_blocked() routine defined above except that it will print messages to cout showing the progress in @@ -452,7 +434,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for_blocked() routine defined above except that it will print messages to cout showing the progress in @@ -472,7 +453,6 @@ namespace dlib requires - begin <= end - chunks_per_thread > 0 - - funct does not throw any exceptions ensures - This function is identical to the parallel_for_blocked() routine defined above except that it will print messages to cout showing the progress in diff --git a/dlib/threads/thread_pool_extension.cpp b/dlib/threads/thread_pool_extension.cpp index 7fdf47840..585390a36 100644 --- a/dlib/threads/thread_pool_extension.cpp +++ b/dlib/threads/thread_pool_extension.cpp @@ -61,6 +61,11 @@ namespace dlib } wait(); + + // Throw any unhandled exceptions. Since shutdown_pool() is only called in the + // destructor this will kill the program. + for (auto&& task : tasks) + task.propagate_exception(); } // ---------------------------------------------------------------------------------------- @@ -94,6 +99,9 @@ namespace dlib const unsigned long idx = task_id_to_index(task_id); while (tasks[idx].task_id == task_id) task_done_signaler.wait(); + + for (auto&& task : tasks) + task.propagate_exception(); } } @@ -124,6 +132,10 @@ namespace dlib if (found_task) task_done_signaler.wait(); } + + // throw any exceptions generated by the tasks + for (auto&& task : tasks) + task.propagate_exception(); } // ---------------------------------------------------------------------------------------- @@ -177,15 +189,23 @@ namespace dlib task = tasks[idx]; } - // now do the task - if (task.bfp) - task.bfp(); - else if (task.mfp0) - task.mfp0(); - else if (task.mfp1) - task.mfp1(task.arg1); - else if (task.mfp2) - task.mfp2(task.arg1, task.arg2); + std::exception_ptr eptr; + try + { + // now do the task + if (task.bfp) + task.bfp(); + else if (task.mfp0) + task.mfp0(); + else if (task.mfp1) + task.mfp1(task.arg1); + else if (task.mfp2) + task.mfp2(task.arg1, task.arg2); + } + catch(...) + { + eptr = std::current_exception(); + } // Now let others know that we finished the task. We do this // by clearing out the state of this task @@ -198,6 +218,7 @@ namespace dlib tasks[idx].mfp2.clear(); tasks[idx].arg1 = 0; tasks[idx].arg2 = 0; + tasks[idx].eptr = eptr; task_done_signaler.broadcast(); } @@ -210,6 +231,9 @@ namespace dlib find_empty_task_slot ( ) const { + for (auto&& task : tasks) + task.propagate_exception(); + for (unsigned long i = 0; i < tasks.size(); ++i) { if (tasks[i].is_empty()) diff --git a/dlib/threads/thread_pool_extension.h b/dlib/threads/thread_pool_extension.h index 13f7b8777..e928cd946 100644 --- a/dlib/threads/thread_pool_extension.h +++ b/dlib/threads/thread_pool_extension.h @@ -13,6 +13,7 @@ #include "../array.h" #include "../smart_pointers_thread_safe.h" #include "../smart_pointers.h" +#include namespace dlib { @@ -451,6 +452,17 @@ namespace dlib bfp_type bfp; shared_ptr function_copy; + mutable std::exception_ptr eptr; // non-null if the task threw an exception + + void propagate_exception() const + { + if (eptr) + { + auto tmp = eptr; + eptr = nullptr; + std::rethrow_exception(tmp); + } + } }; diff --git a/dlib/threads/thread_pool_extension_abstract.h b/dlib/threads/thread_pool_extension_abstract.h index f2fb7efa8..e9dbd996d 100644 --- a/dlib/threads/thread_pool_extension_abstract.h +++ b/dlib/threads/thread_pool_extension_abstract.h @@ -225,9 +225,11 @@ namespace dlib such as mutex objects. EXCEPTIONS - Note that if an exception is thrown inside a task thread and - is not caught then the normal rule for uncaught exceptions in - threads applies. That is, the application will be terminated. + Note that if an exception is thrown inside a task thread and is not caught + then the exception will be trapped inside the thread pool and rethrown at a + later time when someone calls one of the add task or wait member functions + of the thread pool. This allows exceptions to propagate out of task threads + and into the calling code where they can be handled. !*/ public: diff --git a/examples/thread_pool_ex.cpp b/examples/thread_pool_ex.cpp index 23c6be6cf..e0a566ef6 100644 --- a/examples/thread_pool_ex.cpp +++ b/examples/thread_pool_ex.cpp @@ -5,8 +5,8 @@ object from the dlib C++ Library. - This is a very simple example. It creates a thread pool with 3 - threads and then sends a few simple tasks to the pool. + In this example we will crate a thread pool with 3 threads and then show a + few different ways to send tasks to the pool. */ @@ -17,18 +17,19 @@ using namespace dlib; -// We will be using the dlib logger object to print out messages in this example +// We will be using the dlib logger object to print messages in this example // because its output is timestamped and labeled with the thread that the log -// message came from. So this will make it easier to see what is going on in -// this example. Here we make an instance of the logger. See the logger +// message came from. This will make it easier to see what is going on in this +// example. Here we make an instance of the logger. See the logger // documentation and examples for detailed information regarding its use. logger dlog("main"); -// Here we make an instance of the thread pool object +// Here we make an instance of the thread pool object. You could also use the +// global dlib::default_thread_pool(), which automatically selects the number of +// threads based on your hardware. But here let's make our own. thread_pool tp(3); - // ---------------------------------------------------------------------------------------- class test @@ -37,27 +38,27 @@ class test The thread_pool accepts "tasks" from the user and schedules them for execution in one of its threads when one becomes available. Each task is just a request to call a function. So here we create a class called - test with a few member functions which we will have the thread pool call + test with a few member functions, which we will have the thread pool call as tasks. */ public: - void task() + void mytask() { - dlog << LINFO << "task start"; + dlog << LINFO << "mytask start"; - future var; + dlib::future var; var = 1; // Here we ask the thread pool to call this->subtask() and this->subtask2(). // Note that calls to add_task() will return immediately if there is an - // available thread to hand the task off to. However, if there isn't a - // thread ready then add_task() blocks until there is such a thread. - // Also note that since task() is executed within the thread pool (see main() below) - // calls to add_task() will execute the requested task within the calling thread - // in cases where the thread pool is full. This means it is always safe to - // spawn subtasks from within another task, which is what we are doing here. + // available thread. However, if there isn't a thread ready then + // add_task() blocks until there is such a thread. Also, note that if + // mytask() is executed within the thread pool then calls to add_task() + // will execute the requested task within the calling thread in cases + // where the thread pool is full. This means it is always safe to spawn + // subtasks from within another task, which is what we are doing here. tp.add_task(*this,&test::subtask,var); // schedule call to this->subtask(var) tp.add_task(*this,&test::subtask2); // schedule call to this->subtask2() @@ -66,17 +67,16 @@ public: // return the integer it contains. In this case result will be assigned // the value 2 since var was incremented by subtask(). int result = var; - // print out the result dlog << LINFO << "var = " << result; // Wait for all the tasks we have started to finish. Note that - // wait_for_all_tasks() only waits for tasks which were started - // by the calling thread. So you don't have to worry about other - // unrelated parts of your application interfering. In this case - // it just waits for subtask2() to finish. + // wait_for_all_tasks() only waits for tasks which were started by the + // calling thread. So you don't have to worry about other unrelated + // parts of your application interfering. In this case it just waits + // for subtask2() to finish. tp.wait_for_all_tasks(); - dlog << LINFO << "task end" ; + dlog << LINFO << "mytask end" ; } void subtask(int& a) @@ -96,23 +96,7 @@ public: // ---------------------------------------------------------------------------------------- -class add_value -{ -public: - add_value(int value):val(value) { } - - void operator()( int& a ) - { - a += val; - } - -private: - int val; -}; - -// ---------------------------------------------------------------------------------------- - -int main() +int main() try { // tell the logger to print out everything dlog.set_level(LALL); @@ -120,84 +104,80 @@ int main() dlog << LINFO << "schedule a few tasks"; - test mytask; - // Schedule the thread pool to call mytask.task(). Note that all forms of add_task() - // pass in the task object by reference. This means you must make sure, in this case, - // that mytask isn't destructed until after the task has finished executing. - tp.add_task(mytask, &test::task); - - // You can also pass task objects to a thread pool by value. So in this case we don't - // have to worry about keeping our own instance of the task. Here we construct a temporary - // add_value object and pass it right in and everything works like it should. - future num = 3; - tp.add_task_by_value(add_value(7), num); // adds 7 to num - int result = num.get(); - dlog << LINFO << "result = " << result; // prints result = 10 - - + test taskobj; + // Schedule the thread pool to call taskobj.mytask(). Note that all forms of + // add_task() pass in the task object by reference. This means you must make sure, + // in this case, that taskobj isn't destructed until after the task has finished + // executing. + tp.add_task(taskobj, &test::mytask); + // This behavior of add_task() enables it to guarantee that no memory allocations + // occur after the thread_pool has been constructed, so long as the user doesn't + // call any of the add_task_by_value() routines. The future object also doesn't + // perform any memory allocations or contain any system resources such as mutex + // objects. If you don't care about memory allocations then you will likely find + // the add_task_by_value() interface more convenient to use, which is shown below. -// uncomment this line if your compiler supports the new C++0x lambda functions -//#define COMPILER_SUPPORTS_CPP0X_LAMBDA_FUNCTIONS -#ifdef COMPILER_SUPPORTS_CPP0X_LAMBDA_FUNCTIONS - // In the above examples we had to explicitly create task objects which is - // inconvenient. If you have a compiler which supports C++0x lambda functions - // then you can use the following simpler method. + // If we call add_task_by_value() we pass task objects to a thread pool by value. + // So in this case we don't have to worry about keeping our own instance of the + // task. Here we create a lambda function and pass it right in and everything + // works like it should. + dlib::future num = 3; + tp.add_task_by_value([](int& val){val += 7;}, num); // adds 7 to num + int result = num.get(); + dlog << LINFO << "result = " << result; // prints result = 10 - // make a task which will just log a message - tp.add_task_by_value([](){ - dlog << LINFO << "A message from a lambda function running in another thread."; - }); - // Here we make 10 different tasks, each assigns a different value into - // the elements of the vector vect. - std::vector vect(10); + // dlib also contains dlib::async(), which is essentially identical to std::async() + // except that it launches tasks to a dlib::thread_pool (using add_task_by_value) + // rather than starting an unbounded number of threads. As an example, here we + // make 10 different tasks, each assigns a different value into the elements of the + // vector vect. + std::vector> vect(10); for (unsigned long i = 0; i < vect.size(); ++i) - { - // Make a lambda function which takes vect by reference and i by value. So what - // will happen is each assignment statement will run in a thread in the thread_pool. - tp.add_task_by_value([&vect,i](){ - vect[i] = i; - }); - } - // Wait for all tasks which were requested by the main thread to complete. - tp.wait_for_all_tasks(); + vect[i] = dlib::async(tp, [i]() { return i*i; }); + // Print the results for (unsigned long i = 0; i < vect.size(); ++i) - { - dlog << LINFO << "vect["<