diff --git a/dlib/optimization/find_max_factor_graph_nmplp.h b/dlib/optimization/find_max_factor_graph_nmplp.h index f9e0cc278..12e611fb8 100644 --- a/dlib/optimization/find_max_factor_graph_nmplp.h +++ b/dlib/optimization/find_max_factor_graph_nmplp.h @@ -143,12 +143,19 @@ namespace dlib /* This function is an implementation of the NMPLP algorithm introduced in the - following paper: - Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations + following papers: + Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008) by Amir Globerson and Tommi Jaakkola - In particular, see the pseudocode in Figure 1. The code in this function - follows what is described there. + Introduction to dual decomposition for inference (2011) + by David Sontag, Amir Globerson, and Tommi Jaakkola + + In particular, this function implements the star MPLP update equations shown as + equation 1.20 from the paper Introduction to dual decomposition for inference + (the method was called NMPLP in the first paper). It should also be noted that + the original description of the NMPLP in the first paper had an error in the + equations and the second paper contains corrected equations, which is what this + function uses. */ typedef typename map_problem::node_iterator node_iterator; @@ -161,14 +168,15 @@ namespace dlib return; - std::vector gamma_elements; - gamma_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3); + std::vector delta_elements; + delta_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3); - impl::simple_hash_map gamma_idx; + impl::simple_hash_map delta_idx; - // initialize gamma according to the initialization instructions at top of Figure 1 + // Initialize delta to zero and fill up the hash table with the appropriate values + // so we can index into delta later on. for (node_iterator i = prob.begin(); i != prob.end(); ++i) { const unsigned long id_i = prob.node_id(i); @@ -176,106 +184,124 @@ namespace dlib for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) { const unsigned long id_j = prob.node_id(j); - - gamma_idx.insert(id_i, id_j, gamma_elements.size()); + delta_idx.insert(id_i, id_j, delta_elements.size()); const unsigned long num_states_xj = prob.num_states(j); - for (unsigned long xj = 0; xj < num_states_xj; ++xj) - { - const unsigned long num_states_xi = prob.num_states(i); - - double best_val = -std::numeric_limits::infinity(); - for (unsigned long xi = 0; xi < num_states_xi; ++xi) - { - double val = prob.factor_value(i,j,xi,xj); - - double sum_temp = 0; - - for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k) - { - if (j == k) - continue; - - double max_val = -std::numeric_limits::infinity(); - for (unsigned long xk = 0; xk < prob.num_states(k); ++xk) - { - double temp = prob.factor_value(k,i,xk,xi); - if (temp > max_val) - max_val = temp; - } - - sum_temp += max_val; - } - - - val += 0.5*sum_temp; - - if (val > best_val) - best_val = val; - } - - - gamma_elements.push_back(best_val); - } + delta_elements.push_back(0); } } + std::vector gamma_i; + std::vector > gamma_ji; + std::vector > delta_to_j_no_i; + // These arrays will end up with a length equal to the maximum number of neighbors + // of any node in the graph. So reserve a bigish number of slots so that we are + // very unlikely to need to preform an expensive reallocation during the + // optimization. + gamma_ji.reserve(10000); + delta_to_j_no_i.reserve(10000); double max_change = eps + 1; // Now do the main body of the optimization. - for (unsigned long iter = 0; iter < max_iter && max_change > eps; ++iter) + unsigned long iter; + for (iter = 0; iter < max_iter && max_change > eps; ++iter) { max_change = -std::numeric_limits::infinity(); for (node_iterator i = prob.begin(); i != prob.end(); ++i) { const unsigned long id_i = prob.node_id(i); + const unsigned long num_states_xi = prob.num_states(i); + gamma_i.assign(num_states_xi, 0); + double num_neighbors = 0; + + unsigned int jcnt = 0; + // first we fill in the gamma vectors + for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) + { + // Make sure these arrays are big enough to hold all the neighbor + // information. + if (jcnt >= gamma_ji.size()) + { + gamma_ji.resize(gamma_ji.size()+1); + delta_to_j_no_i.resize(delta_to_j_no_i.size()+1); + } + + ++num_neighbors; + const unsigned long id_j = prob.node_id(j); + const unsigned long num_states_xj = prob.num_states(j); + + gamma_ji[jcnt].assign(num_states_xi, -std::numeric_limits::infinity()); + delta_to_j_no_i[jcnt].assign(num_states_xj, 0); + + // compute delta_j^{-i} and store it in delta_to_j_no_i[jcnt] + for (neighbor_iterator k = prob.begin(j); k != prob.end(j); ++k) + { + const unsigned long id_k = prob.node_id(k); + if (id_k==id_i) + continue; + const double* const delta_kj = &delta_elements[delta_idx(id_k,id_j)]; + for (unsigned long xj = 0; xj < num_states_xj; ++xj) + { + delta_to_j_no_i[jcnt][xj] += delta_kj[xj]; + } + } + + // now compute gamma values + for (unsigned long xi = 0; xi < num_states_xi; ++xi) + { + for (unsigned long xj = 0; xj < num_states_xj; ++xj) + { + gamma_ji[jcnt][xi] = std::max(gamma_ji[jcnt][xi], prob.factor_value(i,j,xi,xj) + delta_to_j_no_i[jcnt][xj]); + } + gamma_i[xi] += gamma_ji[jcnt][xi]; + } + ++jcnt; + } + + // now update the delta values + jcnt = 0; for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) { const unsigned long id_j = prob.node_id(j); - double* const gamma_ji = &gamma_elements[gamma_idx(id_j,id_i)]; - double* const gamma_ij = &gamma_elements[gamma_idx(id_i,id_j)]; - const unsigned long num_states_xj = prob.num_states(j); + // messages from j to i + double* const delta_ji = &delta_elements[delta_idx(id_j,id_i)]; + + // messages from i to j + double* const delta_ij = &delta_elements[delta_idx(id_i,id_j)]; + for (unsigned long xj = 0; xj < num_states_xj; ++xj) { - const unsigned long num_states_xi = prob.num_states(i); - double best_val = -std::numeric_limits::infinity(); + for (unsigned long xi = 0; xi < num_states_xi; ++xi) { - double val = prob.factor_value(i,j,xi,xj) - gamma_ji[xi]; - - double sum_temp = 0; - - int num_neighbors = 0; - for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k) - { - const unsigned long id_k = prob.node_id(k); - ++num_neighbors; - - const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)]; - sum_temp += gamma_ki[xi]; - } - - - val += 2.0/(num_neighbors + 1.0)*sum_temp; - + double val = prob.factor_value(i,j,xi,xj) + 2/(num_neighbors+1)*gamma_i[xi] -gamma_ji[jcnt][xi]; if (val > best_val) best_val = val; } + best_val = -0.5*delta_to_j_no_i[jcnt][xj] + 0.5*best_val; + if (std::abs(delta_ij[xj] - best_val) > max_change) + max_change = std::abs(delta_ij[xj] - best_val); - if (std::abs(gamma_ij[xj] - best_val) > max_change) - max_change = std::abs(gamma_ij[xj] - best_val); - - gamma_ij[xj] = best_val; + delta_ij[xj] = best_val; } + + for (unsigned long xi = 0; xi < num_states_xi; ++xi) + { + double new_val = -1/(num_neighbors+1)*gamma_i[xi] + gamma_ji[jcnt][xi]; + if (std::abs(delta_ji[xi] - new_val) > max_change) + max_change = std::abs(delta_ji[xi] - new_val); + delta_ji[xi] = new_val; + } + ++jcnt; } } } @@ -294,8 +320,8 @@ namespace dlib for (unsigned long xi = 0; xi < b.size(); ++xi) { - const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)]; - b[xi] += gamma_ki[xi]; + const double* const delta_ki = &delta_elements[delta_idx(id_k,id_i)]; + b[xi] += delta_ki[xi]; } } diff --git a/dlib/optimization/find_max_factor_graph_nmplp_abstract.h b/dlib/optimization/find_max_factor_graph_nmplp_abstract.h index 96a7349f4..df2ec9c9b 100644 --- a/dlib/optimization/find_max_factor_graph_nmplp_abstract.h +++ b/dlib/optimization/find_max_factor_graph_nmplp_abstract.h @@ -347,10 +347,13 @@ namespace dlib to the MAP problem. However, for graphs with cycles, the solution may be approximate. - - This function is an implementation of the NMPLP algorithm introduced in the - following paper: - Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations + - This function is an implementation of the NMPLP algorithm introduced in the + following papers: + Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008) by Amir Globerson and Tommi Jaakkola + + Introduction to dual decomposition for inference (2011) + by David Sontag, Amir Globerson, and Tommi Jaakkola !*/ // ---------------------------------------------------------------------------------------- diff --git a/dlib/test/find_max_factor_graph_nmplp.cpp b/dlib/test/find_max_factor_graph_nmplp.cpp index 84f2ea15e..2260e92a1 100644 --- a/dlib/test/find_max_factor_graph_nmplp.cpp +++ b/dlib/test/find_max_factor_graph_nmplp.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "tester.h" @@ -19,14 +20,32 @@ namespace // ---------------------------------------------------------------------------------------- + dlib::rand rnd; + + template class map_problem { /* - This is a simple 8 node problem with two cycles in it. + This is a simple 8 node problem with two cycles in it unless fully_connected is true + and then it's a fully connected 8 note graph. */ public: + mutable std::map,std::map,double> > weights; + map_problem() + { + for (int i = 0; i < 8; ++i) + { + for (int j = i; j < 8; ++j) + { + weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian(); + } + } + } struct node_iterator { @@ -58,6 +77,14 @@ namespace unsigned long node_id () const { + if (fully_connected) + { + if (count < home_node) + return count; + else + return count+1; + } + if (home_node < 4) { if (count == 0) @@ -127,7 +154,7 @@ namespace ) const { neighbor_iterator temp; - temp.home_node = 8; + temp.home_node = 9; temp.count = 8; return temp; } @@ -137,7 +164,7 @@ namespace ) const { neighbor_iterator temp; - temp.home_node = 8; + temp.home_node = 9; temp.count = 8; return temp; } @@ -195,76 +222,221 @@ namespace swap(n1,n2); swap(s1,s2); } - - if (n1 == 0 && n2 == 1) - { - if (s1 && s2) return 100; - if (!s1 && s2) return 0; - if (s1 && !s2) return 0; - if (!s1 && !s2) return 0; - } - if (n1 == 1 && n2 == 2) - { - if (s1 && s2) return 5; - if (!s1 && s2) return 9; - if (s1 && !s2) return -100; - if (!s1 && !s2) return 9; - } - if (n1 == 2 && n2 == 3) - { - if (s1 && s2) return 10; - if (!s1 && s2) return 0; - if (s1 && !s2) return 10; - if (!s1 && !s2) return 10; - } - if (n1 == 0 && n2 == 3) - { - if (s1 && s2) return -5; - if (!s1 && s2) return 0; - if (s1 && !s2) return 0; - if (!s1 && !s2) return 0; - } - - - - if (n1 == 4 && n2 == 5) - { - if (s1 && s2) return -100; - if (!s1 && s2) return 0; - if (s1 && !s2) return 0; - if (!s1 && !s2) return 0; - } - if (n1 == 5 && n2 == 6) - { - if (s1 && s2) return -5; - if (!s1 && s2) return -9; - if (s1 && !s2) return 100; - if (!s1 && !s2) return -9; - } - if (n1 == 6 && n2 == 7) - { - if (s1 && s2) return -10; - if (!s1 && s2) return 0; - if (s1 && !s2) return -10; - if (!s1 && !s2) return -10; - } - if (n1 == 4 && n2 == 7) - { - if (s1 && s2) return 5; - if (!s1 && s2) return 0; - if (s1 && !s2) return 0; - if (!s1 && !s2) return 0; - } - - DLIB_CASSERT(false, "n1: "<< n1 << " n2: "<< n2 << " s1: "<< s1 << " s2: "<< s2); - return 0; + return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)]; + } + + }; + +// ---------------------------------------------------------------------------------------- + + class map_problem_chain + { + /* + This is a chain structured 8 node graph (so no cycles). + */ + + public: + + mutable std::map,std::map,double> > weights; + map_problem_chain() + { + for (int i = 0; i < 7; ++i) + { + weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian(); + } + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (count >= 2) + return 8; + return nid[count]; + } + + unsigned long nid[2]; + unsigned int count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 8; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 8; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + if (it.nid == 0) + { + temp.nid[0] = it.nid+1; + temp.nid[1] = 8; + } + else if (it.nid == 7) + { + temp.nid[0] = it.nid-1; + temp.nid[1] = 8; + } + else + { + temp.nid[0] = it.nid-1; + temp.nid[1] = it.nid+1; + } + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + const unsigned long nid = it.node_id(); + neighbor_iterator temp; + if (nid == 0) + { + temp.nid[0] = nid+1; + temp.nid[1] = 8; + } + else if (nid == 7) + { + temp.nid[0] = nid-1; + temp.nid[1] = 8; + } + else + { + temp.nid[0] = nid-1; + temp.nid[1] = nid+1; + } + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.nid[0] = 8; + temp.nid[1] = 8; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.nid[0] = 8; + temp.nid[1] = 8; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 2; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 2; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)]; } }; // ---------------------------------------------------------------------------------------- - dlib::rand rnd; class map_problem2 { @@ -451,8 +623,8 @@ namespace // basically ignore the other node in this factor. The node we // are ignoring is the center node of this star graph. So we basically - // let it always have a value of 0. - if (s2 == 0) + // let it always have a value of 1. + if (s2 == 1) return numbers(n1,s1) + 1; else return numbers(n1,s1); @@ -563,7 +735,7 @@ namespace map_assignment2[2] = index_of_max(rowm(prob.numbers,2)); map_assignment2[3] = index_of_max(rowm(prob.numbers,3)); map_assignment2[4] = index_of_max(rowm(prob.numbers,4)); - map_assignment2[5] = 0; + map_assignment2[5] = 1; const double score2 = find_total_score(prob, map_assignment2); dlog << LINFO << "score NMPLP: " << score1; @@ -588,8 +760,21 @@ namespace void perform_test ( ) { - do_test(); + rnd.clear(); + dlog << LINFO << "test on a chain structured graph"; + for (int i = 0; i < 30; ++i) + do_test(); + + dlog << LINFO << "test on a 2 cycle graph"; + for (int i = 0; i < 30; ++i) + do_test >(); + + dlog << LINFO << "test on a fully connected graph"; + for (int i = 0; i < 5; ++i) + do_test >(); + + dlog << LINFO << "test on a tree structured graph"; for (int i = 0; i < 10; ++i) do_test2(); }