diff --git a/dlib/test/CMakeLists.txt b/dlib/test/CMakeLists.txt index 1283facbc..51258bde3 100644 --- a/dlib/test/CMakeLists.txt +++ b/dlib/test/CMakeLists.txt @@ -9,7 +9,6 @@ cmake_minimum_required(VERSION 2.6) # into the regression test suite. set (tests example.cpp - example_args.cpp any.cpp any_function.cpp array2d.cpp @@ -23,6 +22,7 @@ set (tests binary_search_tree_mm1.cpp binary_search_tree_mm2.cpp bridge.cpp + bsp.cpp byte_orderer.cpp cmd_line_parser.cpp cmd_line_parser_wchar_t.cpp @@ -40,6 +40,7 @@ set (tests empirical_kernel_map.cpp entropy_coder.cpp entropy_encoder_model.cpp + example_args.cpp filtering.cpp find_max_factor_graph_nmplp.cpp find_max_factor_graph_viterbi.cpp diff --git a/dlib/test/bsp.cpp b/dlib/test/bsp.cpp new file mode 100644 index 000000000..378f3bf2a --- /dev/null +++ b/dlib/test/bsp.cpp @@ -0,0 +1,285 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.bsp"); + + + template + struct callfunct_helper + { + callfunct_helper ( + funct f_, + int port_, + bool& error_occurred_ + ) :f(f_), port(port_), error_occurred(error_occurred_) {} + + funct f; + int port; + bool& error_occurred; + + void operator() ( + ) const + { + try + { + bsp_listen(port, f); + } + catch (exception& e) + { + dlog << LERROR << "error calling bsp_listen(): " << e.what(); + error_occurred = true; + } + } + }; + + template + callfunct_helper callfunct(funct f, int port, bool& error_occurred) + { + return callfunct_helper(f,port,error_occurred); + + } + +// ---------------------------------------------------------------------------------------- + + void sum_array_driver ( + bsp_context& obj, + const std::vector& v, + int& result + ) + { + obj.broadcast(v); + + result = 0; + int val; + while(obj.receive(val)) + result += val; + } + + void sum_array_other ( + bsp_context& obj + ) + { + std::vector v; + obj.receive(v); + + int sum = 0; + for (unsigned long i = 0; i < v.size(); ++i) + sum += v[i]; + + obj.send(sum, 0); + + + } + + + void dotest1() + { + dlog << LINFO << "start dotest1()"; + print_spinner(); + bool error_occurred = false; + { + thread_function t1(callfunct(sum_array_other, 12345, error_occurred)); + thread_function t2(callfunct(sum_array_other, 12346, error_occurred)); + thread_function t3(callfunct(sum_array_other, 12347, error_occurred)); + std::vector v; + int true_value = 0; + for (int i = 0; i < 10; ++i) + { + v.push_back(i); + true_value += i; + } + + // wait a little bit for the threads to start up + dlib::sleep(200); + + try + { + int result; + std::vector > hosts; + hosts.push_back(make_pair("127.0.0.1",12345)); + hosts.push_back(make_pair("127.0.0.1",12346)); + hosts.push_back(make_pair("127.0.0.1",12347)); + bsp_connect(hosts, sum_array_driver, dlib::ref(v), dlib::ref(result)); + + dlog << LINFO << "result: "<< result; + dlog << LINFO << "should be: "<< 3*true_value; + DLIB_TEST(result == 3*true_value); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + template + void test2_job(bsp_context& obj) + { + if (obj.node_id() == id) + dlib::sleep(100); + } + + template + void dotest2() + { + dlog << LINFO << "start dotest2()"; + print_spinner(); + bool error_occurred = false; + { + thread_function t1(callfunct(test2_job, 12345, error_occurred)); + thread_function t2(callfunct(test2_job, 12346, error_occurred)); + thread_function t3(callfunct(test2_job, 12347, error_occurred)); + + // wait a little bit for the threads to start up + dlib::sleep(200); + + try + { + std::vector > hosts; + hosts.push_back(make_pair("127.0.0.1",12345)); + hosts.push_back(make_pair("127.0.0.1",12346)); + hosts.push_back(make_pair("127.0.0.1",12347)); + bsp_connect(hosts, test2_job); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + void test3_job_driver(bsp_context& obj, int& result) + { + + obj.broadcast(obj.node_id()); + + int accum = 0; + int temp = 0; + while(obj.receive(temp)) + accum += temp; + + // send to node 1 so it can sum everything + if (obj.node_id() != 1) + obj.send(accum, 1); + + while(obj.receive(temp)) + accum += temp; + + // Now hop the accum values along the nodes until the value from node 1 gets to + // node 0. + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + DLIB_TEST(obj.receive(accum)); + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + DLIB_TEST(obj.receive(accum)); + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + DLIB_TEST(obj.receive(accum)); + + // this whole block is a noop since it doesn't end up doing anything. + for (int k = 0; k < 100; ++k) + { + dlog << LINFO << "k: " << k; + for (int i = 0; i < 4; ++i) + { + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + DLIB_TEST(obj.receive(accum)); + } + } + + + dlog << LINFO << "TERMINATE"; + if (obj.node_id() == 0) + result = accum; + } + + + void test3_job(bsp_context& obj) + { + int junk; + test3_job_driver(obj, junk); + } + + + void dotest3() + { + dlog << LINFO << "start dotest3()"; + print_spinner(); + bool error_occurred = false; + { + thread_function t1(callfunct(test3_job, 12345, error_occurred)); + thread_function t2(callfunct(test3_job, 12346, error_occurred)); + thread_function t3(callfunct(test3_job, 12347, error_occurred)); + + // wait a little bit for the threads to start up + dlib::sleep(200); + + try + { + std::vector > hosts; + hosts.push_back(make_pair("127.0.0.1",12345)); + hosts.push_back(make_pair("127.0.0.1",12346)); + hosts.push_back(make_pair("127.0.0.1",12347)); + int result = 0; + const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2; + bsp_connect(hosts, test3_job_driver, dlib::ref(result)); + + dlog << LINFO << "result: " << result; + dlog << LINFO << "should be: " << expected; + DLIB_TEST(result == expected); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + class bsp_tester : public tester + { + + public: + bsp_tester ( + ) : + tester ("test_bsp", + "Runs tests on the BSP components.") + {} + + void perform_test ( + ) + { + dotest1(); + dotest2<0>(); + dotest2<1>(); + dotest2<2>(); + dotest3(); + } + } a; + +} + diff --git a/dlib/test/makefile b/dlib/test/makefile index e5dbc6d1b..aeeb346ea 100644 --- a/dlib/test/makefile +++ b/dlib/test/makefile @@ -38,6 +38,7 @@ SRC += binary_search_tree_kernel_2a.cpp SRC += binary_search_tree_mm1.cpp SRC += binary_search_tree_mm2.cpp SRC += bridge.cpp +SRC += bsp.cpp SRC += byte_orderer.cpp SRC += cmd_line_parser.cpp SRC += cmd_line_parser_wchar_t.cpp