diff --git a/dlib/all/source.cpp b/dlib/all/source.cpp index 8716974d8..4b8d250af 100644 --- a/dlib/all/source.cpp +++ b/dlib/all/source.cpp @@ -23,6 +23,7 @@ // include this first so that it can disable the older version // of the winsock API when compiled in windows. #include "../sockets/sockets_kernel_1.cpp" +#include "../bsp/bsp.cpp" #include "../dir_nav/dir_nav_kernel_1.cpp" #include "../dir_nav/dir_nav_kernel_2.cpp" @@ -56,7 +57,6 @@ #endif #ifndef DLIB_NO_GUI_SUPPORT - #include "../gui_widgets/fonts.cpp" #include "../gui_widgets/widgets.cpp" #include "../gui_widgets/drawable.cpp" @@ -65,7 +65,6 @@ #include "../gui_widgets/base_widgets.cpp" #include "../gui_core/gui_core_kernel_1.cpp" #include "../gui_core/gui_core_kernel_2.cpp" - #endif // DLIB_NO_GUI_SUPPORT #endif // DLIB_ISO_CPP_ONLY diff --git a/dlib/bsp.h b/dlib/bsp.h new file mode 100644 index 000000000..ec971641d --- /dev/null +++ b/dlib/bsp.h @@ -0,0 +1,12 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSP__ +#define DLIB_BSP__ + + +#include "bsp/bsp.h" + +#endif // DLIB_BSP__ + + + diff --git a/dlib/bsp/bsp.cpp b/dlib/bsp/bsp.cpp new file mode 100644 index 000000000..69f104cf4 --- /dev/null +++ b/dlib/bsp/bsp.cpp @@ -0,0 +1,488 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "bsp.h" +#include "../ref.h" + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + namespace impl + { + + struct hostinfo + { + hostinfo() {} + hostinfo ( + const std::string& ip_, + unsigned short port_, + unsigned long node_id_ + ) : + ip(ip_), + port(port_), + node_id(node_id_) + { + } + + std::string ip; + unsigned short port; + unsigned long node_id; + }; + + void connect_all ( + map_id_to_con& cons, + const std::vector >& hosts, + unsigned long node_id + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + scoped_ptr con(new bsp_con(hosts[i])); + serialize(node_id, con->stream); // tell the other end our node_id + unsigned long id = i+1; + cons.add(id, con); + } + } + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + scoped_ptr con(new bsp_con(make_pair(hosts[i].ip,hosts[i].port))); + serialize(node_id, con->stream); // tell the other end our node_id + con->stream.flush(); + unsigned long id = hosts[i].node_id; + cons.add(id, con); + } + } + + + void serialize ( + const hostinfo& item, + std::ostream& out + ) + { + dlib::serialize(item.ip, out); + dlib::serialize(item.port, out); + dlib::serialize(item.node_id, out); + } + + void deserialize ( + hostinfo& item, + std::istream& in + ) + { + dlib::deserialize(item.ip, in); + dlib::deserialize(item.port, in); + dlib::deserialize(item.node_id, in); + } + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector >& hosts + ) + { + // tell everyone their node ids + cons.reset(); + while (cons.move_next()) + { + dlib::serialize(cons.element().key(), cons.element().value()->stream); + } + + // now tell them who to connect to + std::vector targets; + for (unsigned long i = 0; i < hosts.size(); ++i) + { + hostinfo info(hosts[i].first, hosts[i].second, i+1); + + dlib::serialize(targets, cons[info.node_id]->stream); + targets.push_back(info); + + // let the other host know how many incoming connections to expect + const unsigned long num = hosts.size()-targets.size(); + dlib::serialize(num, cons[info.node_id]->stream); + cons[info.node_id]->stream.flush(); + } + } + + // ------------------------------------------------------------------------------------ + + // These control bytes are sent before each message nodes send to each other. + const static char MESSAGE_HEADER = 0; + const static char WAITING_ON_RECEIVE = 1; + const static char NOT_WAITING_ON_RECEIVE = 2; + const static char ALL_NODES_WAITING = 3; + const static char SENT_MESSAGE = 4; + const static char GOT_MESSAGE = 5; + + // ------------------------------------------------------------------------------------ + + void listen_and_connect_all( + unsigned long& node_id, + map_id_to_con& cons, + unsigned short port + ) + { + cons.clear(); + scoped_ptr list; + const int status = create_listener(list, port); + if (status == PORTINUSE) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) + + ". The port is already in use"); + } + else if (status != 0) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) ); + } + + scoped_ptr con; + if (list->accept(con)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + scoped_ptr temp(new bsp_con(con)); + + unsigned long remote_node_id; + dlib::deserialize(remote_node_id, temp->stream); + dlib::deserialize(node_id, temp->stream); + std::vector targets; + dlib::deserialize(targets, temp->stream); + unsigned long num_incoming_connections; + dlib::deserialize(num_incoming_connections, temp->stream); + + cons.add(remote_node_id,temp); + + // make a thread that will connect to all the targets + map_id_to_con cons2; + thread_function thread(impl::connect_all_hostinfo, ref(cons2), ref(targets), node_id); + + // accept any incoming connections + for (unsigned long i = 0; i < num_incoming_connections; ++i) + { + // If it takes more than 10 seconds for the other nodes to connect to us + // then something has gone horribly wrong and it almost certainly will + // never connect at all. So just give up if that happens. + const unsigned long timeout_milliseconds = 10000; + if (list->accept(con, timeout_milliseconds)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + temp.reset(new bsp_con(con)); + + dlib::deserialize(remote_node_id, temp->stream); + cons.add(remote_node_id,temp); + } + + + // put all the connections created by the thread into cons + thread.wait(); + while (cons2.size() > 0) + { + unsigned long id; + scoped_ptr temp; + cons2.remove_any(id,temp); + cons.add(id,temp); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION OF bsp OBJECT MEMBERS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bsp:: + ~bsp() + { + _cons.reset(); + while (_cons.move_next()) + { + _cons.element().value()->con->shutdown(); + } + + + // this will wait for all the threads to terminate + threads.clear(); + } + +// ---------------------------------------------------------------------------------------- + + bsp:: + bsp( + unsigned long node_id_, + impl::map_id_to_con& cons_ + ) : + read_thread_terminated(false), + outstanding_messages(0), + num_waiting_nodes(0), + buf_not_empty(class_mutex), + _cons(cons_), + _node_id(node_id_) + { + // spawn a bunch of read threads, one for each connection + member_function_pointer::kernel_1a_c mfp; + mfp.set(*this, &bsp::read_thread); + _cons.reset(); + while (_cons.move_next()) + { + scoped_ptr ptr(new thread_function(mfp, + _cons.element().value().get(), + _cons.element().key())); + threads.push_back(ptr); + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bsp:: + receive_data ( + shared_ptr& item, + unsigned long& sending_node_id + ) + { + using namespace impl; + // If there aren't any other nodes then you will never receive anything. + if (_cons.size() == 0) + return false; + + { + auto_mutex lock(class_mutex); + if (msg_buffer.size() == 0) + { + send_to_master_node(WAITING_ON_RECEIVE); + while (msg_buffer.size() == 0 && !read_thread_terminated) + { + buf_not_empty.wait(); + } + if (read_thread_terminated) + { + throw dlib::socket_error("A connection between processing nodes has been lost."); + } + send_to_master_node(NOT_WAITING_ON_RECEIVE); + } + + sending_node_id = msg_sender_id.front(); + msg_sender_id.pop_front(); + item = msg_buffer.front(); + msg_buffer.pop_front(); + } + + // if this is a message from another node rather than the + // "everyone is blocked on receive() message". + if (item) + { + send_to_master_node(GOT_MESSAGE); + return true; + } + else + { + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp:: + send_to_master_node ( + char msg + ) + { + using namespace impl; + // if we aren't the special controlling node then send the + // controller a message. + if (_cons.is_in_domain(0)) + { + serialize(msg, _cons[0]->stream); + _cons[0]->stream.flush(); + } + else if (_node_id == 0) // if this is the master node + { + // since we are the master node we will just modify our state directly + auto_mutex lock(class_mutex); + switch(msg) + { + case WAITING_ON_RECEIVE: { + ++num_waiting_nodes; + notify_everyone_if_all_blocked(); + } break; + + case NOT_WAITING_ON_RECEIVE: { + --num_waiting_nodes; + } break; + + case SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case GOT_MESSAGE: { + --outstanding_messages; + } break; + + default: + DLIB_CASSERT(false,"this should not happen"); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp:: + notify_everyone_if_all_blocked( + ) + { + using namespace impl; + // if all the nodes are blocked on receive() and there aren't any + // messages in flight. + if (_node_id == 0 && num_waiting_nodes == number_of_nodes() && outstanding_messages == 0) + { + // send notifications + _cons.reset(); + while (_cons.move_next()) + { + try + { + serialize(ALL_NODES_WAITING, _cons.element().value()->stream); + _cons.element().value()->stream.flush(); + if (!_cons.element().value()->stream) + throw dlib::error("Error writing data to TCP connection"); + } + catch (std::exception& e) + { + const connection* const con = _cons.element().value()->con.get(); + std::ostringstream sout; + sout << "An exception occurred in the controlling node while it was trying to communicate with a listening node.\n"; + sout << " Listening processing node address: " << con->get_foreign_ip() << ":" << con->get_foreign_port() << std::endl; + sout << " Controlling processing node address: " << con->get_local_ip() << ":" << con->get_local_port() << std::endl; + sout << " Error message in the exception: " << e.what() << std::endl; + error_message = sout.str(); + } + } + + // unblock the control node itself + shared_ptr msg; + msg_buffer.push_back(msg); + msg_sender_id.push_back(0); + buf_not_empty.signal(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp:: + read_thread ( + impl::bsp_con* con, + unsigned long sender_id + ) + { + try + { + using namespace impl; + while (con->stream.peek() != EOF) + { + char header; + deserialize(header, con->stream); + switch (header) + { + case MESSAGE_HEADER: { + shared_ptr msg(new std::string); + deserialize(*msg, con->stream); + + auto_mutex lock(class_mutex); + msg_buffer.push_back(msg); + msg_sender_id.push_back(sender_id); + buf_not_empty.signal(); + } break; + + case WAITING_ON_RECEIVE: { + auto_mutex lock(class_mutex); + ++num_waiting_nodes; + notify_everyone_if_all_blocked(); + } break; + + case NOT_WAITING_ON_RECEIVE: { + auto_mutex lock(class_mutex); + --num_waiting_nodes; + } break; + + case ALL_NODES_WAITING: { + // put something into the message buffer that lets + // receive() know to return false. We do this using + // a null msg pointer. + auto_mutex lock(class_mutex); + shared_ptr msg; + msg_buffer.push_back(msg); + msg_sender_id.push_back(sender_id); + buf_not_empty.signal(); + } break; + + case SENT_MESSAGE: { + auto_mutex lock(class_mutex); + ++outstanding_messages; + } break; + + case GOT_MESSAGE: { + auto_mutex lock(class_mutex); + --outstanding_messages; + } break; + } + } + } + catch (std::exception& e) + { + std::ostringstream sout; + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Error message in the exception: " << e.what() << std::endl; + auto_mutex lock(class_mutex); + error_message = sout.str(); + } + + auto_mutex lock(class_mutex); + read_thread_terminated = true; + buf_not_empty.signal(); + } + +// ---------------------------------------------------------------------------------------- + + void bsp:: + check_for_errors() + { + auto_mutex lock(class_mutex); + if (error_message.size() != 0) + throw dlib::socket_error(error_message); + } + +// ---------------------------------------------------------------------------------------- + + void bsp:: + send_data( + const std::string& item, + unsigned long target_node_id + ) + { + using namespace impl; + serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); + serialize(item, _cons[target_node_id]->stream); + _cons[target_node_id]->stream.flush(); + send_to_master_node(SENT_MESSAGE); + } + +// ---------------------------------------------------------------------------------------- + +} + diff --git a/dlib/bsp/bsp.h b/dlib/bsp/bsp.h new file mode 100644 index 000000000..11b8917c9 --- /dev/null +++ b/dlib/bsp/bsp.h @@ -0,0 +1,318 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BsP_H__ +#define DLIB_BsP_H__ + +#include "bsp_abstract.h" +#include "../sockets.h" +#include "../array.h" +#include "../smart_pointers.h" +#include "../sockstreambuf.h" +#include "../string.h" +#include "../serialize.h" +#include "../map.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct bsp_con + { + bsp_con( + const std::pair& dest + ) : + con(connect(dest.first,dest.second)), + buf(con), + stream(&buf) + {} + + bsp_con( + scoped_ptr& conptr + ) : + buf(conptr), + stream(&buf) + { + // make sure we own the connection + conptr.swap(con); + } + + scoped_ptr con; + sockstreambuf::kernel_2a buf; + std::iostream stream; + }; + + typedef dlib::map >::kernel_1a_c map_id_to_con; + + void connect_all ( + map_id_to_con& cons, + const std::vector >& hosts, + unsigned long node_id + ); + /*! + ensures + - creates connections to all the given hosts and stores them into cons + !*/ + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector >& hosts + ); + + void listen_and_connect_all( + unsigned long& node_id, + map_id_to_con& cons, + unsigned short port + ); + } + +// ---------------------------------------------------------------------------------------- + + class bsp : noncopyable + { + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ) + /*! + requires + - item is serializable + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + !*/ + { + std::ostringstream sout; + serialize(item, sout); + send_data(sout.str(), target_node_id); + } + + template + void broadcast ( + const T& item + ) + /*! + ensures + - sends a copy of item to all other processing nodes. + !*/ + { + std::ostringstream sout; + serialize(item, sout); + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + if (i == node_id()) + continue; + send_data(sout.str(), i); + } + } + + unsigned long node_id ( + ) const { return _node_id; } + /*! + ensures + - Returns the id of the current processing node. That is, + returns a number N such that: + - N < number_of_nodes() + - N == the node id of the processing node that called + node_id(). + !*/ + + unsigned long number_of_nodes ( + ) const { return _cons.size()+1; } + /*! + ensures + - returns the number of processing nodes participating in the + BSP computation. + !*/ + + template + bool receive ( + T& item + ) + /*! + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - else + - There were no other messages to receive and all other processing + nodes are blocked on calls to receive(). + !*/ + { + unsigned long sending_node_id; + return receive(item, sending_node_id); + } + + template + bool receive ( + T& item, + unsigned long& sending_node_id + ) + /*! + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - else + - There were no other messages to receive and all other processing + nodes are blocked on calls to receive(). + !*/ + { + shared_ptr temp; + if (receive_data(temp, sending_node_id)) + { + std::istringstream sin(*temp); + deserialize(item, sin); + return true; + } + else + { + return false; + } + } + + ~bsp(); + + private: + + bsp(); + + bsp( + unsigned long node_id_, + impl::map_id_to_con& cons_ + ); + + bool receive_data ( + shared_ptr& item, + unsigned long& sending_node_id + ); + + void send_to_master_node ( + char msg + ); + + void notify_everyone_if_all_blocked( + ); + /*! + requires + - class_mutex is locked + ensures + - sends out notifications to all the nodes if we are all blocked on receive. This + will cause all receive calls to unblock and return false. + !*/ + + void read_thread ( + impl::bsp_con* con, + unsigned long sender_id + ); + + + void check_for_errors(); + + void send_data( + const std::string& item, + unsigned long target_node_id + ); + /*! + requires + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + !*/ + + + + rmutex class_mutex; // used to lock any class members touched from more than one thread. + std::string error_message; + bool read_thread_terminated; // true if any of our connections goes down. + unsigned long outstanding_messages; + unsigned long num_waiting_nodes; + rsignaler buf_not_empty; // used to signal when msg_buffer isn't empty + std::deque > msg_buffer; + std::deque msg_sender_id; + + impl::map_id_to_con& _cons; + const unsigned long _node_id; + array > threads; + + template < + typename funct_type + > + friend void bsp_connect ( + funct_type& funct, + const std::vector >& hosts + ); + + template < + typename funct_type + > + friend void bsp_listen ( + funct_type& funct, + unsigned short listening_port + ); + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + funct_type& funct, + const std::vector >& hosts + ) + { + impl::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + + send_out_connection_orders(cons, hosts); + + bsp obj(node_id, cons); + funct(obj); + + obj.check_for_errors(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + funct_type& funct, + unsigned short listening_port + ) + { + impl::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port); + + bsp obj(node_id, cons); + funct(obj); + + obj.check_for_errors(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bsp.cpp" +#endif + +#endif // DLIB_BsP_H__ + diff --git a/dlib/bsp/bsp_abstract.h b/dlib/bsp/bsp_abstract.h new file mode 100644 index 000000000..139597f9c --- /dev/null +++ b/dlib/bsp/bsp_abstract.h @@ -0,0 +1,2 @@ + +