Updated find_max_factor_graph_nmplp() to use the version of the algorithm from

the 2011 paper Introduction to dual decomposition for inference by David
Sontag, Amir Globerson, and Tommi Jaakkola since the original 2008 paper had an
error in the algorithm that negatively effected its convergence.  Thanks to
James Gunning for pointing this out.
This commit is contained in:
Davis King 2013-11-29 17:05:38 -05:00
parent 251196c34a
commit 29381bcccb
3 changed files with 364 additions and 150 deletions

View File

@ -143,12 +143,19 @@ namespace dlib
/*
This function is an implementation of the NMPLP algorithm introduced in the
following paper:
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations
following papers:
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008)
by Amir Globerson and Tommi Jaakkola
In particular, see the pseudocode in Figure 1. The code in this function
follows what is described there.
Introduction to dual decomposition for inference (2011)
by David Sontag, Amir Globerson, and Tommi Jaakkola
In particular, this function implements the star MPLP update equations shown as
equation 1.20 from the paper Introduction to dual decomposition for inference
(the method was called NMPLP in the first paper). It should also be noted that
the original description of the NMPLP in the first paper had an error in the
equations and the second paper contains corrected equations, which is what this
function uses.
*/
typedef typename map_problem::node_iterator node_iterator;
@ -161,14 +168,15 @@ namespace dlib
return;
std::vector<double> gamma_elements;
gamma_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
std::vector<double> delta_elements;
delta_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
impl::simple_hash_map gamma_idx;
impl::simple_hash_map delta_idx;
// initialize gamma according to the initialization instructions at top of Figure 1
// Initialize delta to zero and fill up the hash table with the appropriate values
// so we can index into delta later on.
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
{
const unsigned long id_i = prob.node_id(i);
@ -176,106 +184,124 @@ namespace dlib
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{
const unsigned long id_j = prob.node_id(j);
gamma_idx.insert(id_i, id_j, gamma_elements.size());
delta_idx.insert(id_i, id_j, delta_elements.size());
const unsigned long num_states_xj = prob.num_states(j);
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
{
const unsigned long num_states_xi = prob.num_states(i);
double best_val = -std::numeric_limits<double>::infinity();
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
{
double val = prob.factor_value(i,j,xi,xj);
double sum_temp = 0;
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
{
if (j == k)
continue;
double max_val = -std::numeric_limits<double>::infinity();
for (unsigned long xk = 0; xk < prob.num_states(k); ++xk)
{
double temp = prob.factor_value(k,i,xk,xi);
if (temp > max_val)
max_val = temp;
}
sum_temp += max_val;
}
val += 0.5*sum_temp;
if (val > best_val)
best_val = val;
}
gamma_elements.push_back(best_val);
}
delta_elements.push_back(0);
}
}
std::vector<double> gamma_i;
std::vector<std::vector<double> > gamma_ji;
std::vector<std::vector<double> > delta_to_j_no_i;
// These arrays will end up with a length equal to the maximum number of neighbors
// of any node in the graph. So reserve a bigish number of slots so that we are
// very unlikely to need to preform an expensive reallocation during the
// optimization.
gamma_ji.reserve(10000);
delta_to_j_no_i.reserve(10000);
double max_change = eps + 1;
// Now do the main body of the optimization.
for (unsigned long iter = 0; iter < max_iter && max_change > eps; ++iter)
unsigned long iter;
for (iter = 0; iter < max_iter && max_change > eps; ++iter)
{
max_change = -std::numeric_limits<double>::infinity();
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
{
const unsigned long id_i = prob.node_id(i);
const unsigned long num_states_xi = prob.num_states(i);
gamma_i.assign(num_states_xi, 0);
double num_neighbors = 0;
unsigned int jcnt = 0;
// first we fill in the gamma vectors
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{
// Make sure these arrays are big enough to hold all the neighbor
// information.
if (jcnt >= gamma_ji.size())
{
gamma_ji.resize(gamma_ji.size()+1);
delta_to_j_no_i.resize(delta_to_j_no_i.size()+1);
}
++num_neighbors;
const unsigned long id_j = prob.node_id(j);
const unsigned long num_states_xj = prob.num_states(j);
gamma_ji[jcnt].assign(num_states_xi, -std::numeric_limits<double>::infinity());
delta_to_j_no_i[jcnt].assign(num_states_xj, 0);
// compute delta_j^{-i} and store it in delta_to_j_no_i[jcnt]
for (neighbor_iterator k = prob.begin(j); k != prob.end(j); ++k)
{
const unsigned long id_k = prob.node_id(k);
if (id_k==id_i)
continue;
const double* const delta_kj = &delta_elements[delta_idx(id_k,id_j)];
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
{
delta_to_j_no_i[jcnt][xj] += delta_kj[xj];
}
}
// now compute gamma values
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
{
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
{
gamma_ji[jcnt][xi] = std::max(gamma_ji[jcnt][xi], prob.factor_value(i,j,xi,xj) + delta_to_j_no_i[jcnt][xj]);
}
gamma_i[xi] += gamma_ji[jcnt][xi];
}
++jcnt;
}
// now update the delta values
jcnt = 0;
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{
const unsigned long id_j = prob.node_id(j);
double* const gamma_ji = &gamma_elements[gamma_idx(id_j,id_i)];
double* const gamma_ij = &gamma_elements[gamma_idx(id_i,id_j)];
const unsigned long num_states_xj = prob.num_states(j);
// messages from j to i
double* const delta_ji = &delta_elements[delta_idx(id_j,id_i)];
// messages from i to j
double* const delta_ij = &delta_elements[delta_idx(id_i,id_j)];
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
{
const unsigned long num_states_xi = prob.num_states(i);
double best_val = -std::numeric_limits<double>::infinity();
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
{
double val = prob.factor_value(i,j,xi,xj) - gamma_ji[xi];
double sum_temp = 0;
int num_neighbors = 0;
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
{
const unsigned long id_k = prob.node_id(k);
++num_neighbors;
const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
sum_temp += gamma_ki[xi];
}
val += 2.0/(num_neighbors + 1.0)*sum_temp;
double val = prob.factor_value(i,j,xi,xj) + 2/(num_neighbors+1)*gamma_i[xi] -gamma_ji[jcnt][xi];
if (val > best_val)
best_val = val;
}
best_val = -0.5*delta_to_j_no_i[jcnt][xj] + 0.5*best_val;
if (std::abs(delta_ij[xj] - best_val) > max_change)
max_change = std::abs(delta_ij[xj] - best_val);
if (std::abs(gamma_ij[xj] - best_val) > max_change)
max_change = std::abs(gamma_ij[xj] - best_val);
gamma_ij[xj] = best_val;
delta_ij[xj] = best_val;
}
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
{
double new_val = -1/(num_neighbors+1)*gamma_i[xi] + gamma_ji[jcnt][xi];
if (std::abs(delta_ji[xi] - new_val) > max_change)
max_change = std::abs(delta_ji[xi] - new_val);
delta_ji[xi] = new_val;
}
++jcnt;
}
}
}
@ -294,8 +320,8 @@ namespace dlib
for (unsigned long xi = 0; xi < b.size(); ++xi)
{
const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
b[xi] += gamma_ki[xi];
const double* const delta_ki = &delta_elements[delta_idx(id_k,id_i)];
b[xi] += delta_ki[xi];
}
}

View File

@ -347,10 +347,13 @@ namespace dlib
to the MAP problem. However, for graphs with cycles, the solution may be approximate.
- This function is an implementation of the NMPLP algorithm introduced in the
following paper:
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations
- This function is an implementation of the NMPLP algorithm introduced in the
following papers:
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008)
by Amir Globerson and Tommi Jaakkola
Introduction to dual decomposition for inference (2011)
by David Sontag, Amir Globerson, and Tommi Jaakkola
!*/
// ----------------------------------------------------------------------------------------

View File

@ -5,6 +5,7 @@
#include <cstdlib>
#include <ctime>
#include <dlib/optimization.h>
#include <dlib/unordered_pair.h>
#include <dlib/rand.h>
#include "tester.h"
@ -19,14 +20,32 @@ namespace
// ----------------------------------------------------------------------------------------
dlib::rand rnd;
template <bool fully_connected>
class map_problem
{
/*
This is a simple 8 node problem with two cycles in it.
This is a simple 8 node problem with two cycles in it unless fully_connected is true
and then it's a fully connected 8 note graph.
*/
public:
mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
map_problem()
{
for (int i = 0; i < 8; ++i)
{
for (int j = i; j < 8; ++j)
{
weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian();
}
}
}
struct node_iterator
{
@ -58,6 +77,14 @@ namespace
unsigned long node_id () const
{
if (fully_connected)
{
if (count < home_node)
return count;
else
return count+1;
}
if (home_node < 4)
{
if (count == 0)
@ -127,7 +154,7 @@ namespace
) const
{
neighbor_iterator temp;
temp.home_node = 8;
temp.home_node = 9;
temp.count = 8;
return temp;
}
@ -137,7 +164,7 @@ namespace
) const
{
neighbor_iterator temp;
temp.home_node = 8;
temp.home_node = 9;
temp.count = 8;
return temp;
}
@ -195,76 +222,221 @@ namespace
swap(n1,n2);
swap(s1,s2);
}
if (n1 == 0 && n2 == 1)
{
if (s1 && s2) return 100;
if (!s1 && s2) return 0;
if (s1 && !s2) return 0;
if (!s1 && !s2) return 0;
}
if (n1 == 1 && n2 == 2)
{
if (s1 && s2) return 5;
if (!s1 && s2) return 9;
if (s1 && !s2) return -100;
if (!s1 && !s2) return 9;
}
if (n1 == 2 && n2 == 3)
{
if (s1 && s2) return 10;
if (!s1 && s2) return 0;
if (s1 && !s2) return 10;
if (!s1 && !s2) return 10;
}
if (n1 == 0 && n2 == 3)
{
if (s1 && s2) return -5;
if (!s1 && s2) return 0;
if (s1 && !s2) return 0;
if (!s1 && !s2) return 0;
}
if (n1 == 4 && n2 == 5)
{
if (s1 && s2) return -100;
if (!s1 && s2) return 0;
if (s1 && !s2) return 0;
if (!s1 && !s2) return 0;
}
if (n1 == 5 && n2 == 6)
{
if (s1 && s2) return -5;
if (!s1 && s2) return -9;
if (s1 && !s2) return 100;
if (!s1 && !s2) return -9;
}
if (n1 == 6 && n2 == 7)
{
if (s1 && s2) return -10;
if (!s1 && s2) return 0;
if (s1 && !s2) return -10;
if (!s1 && !s2) return -10;
}
if (n1 == 4 && n2 == 7)
{
if (s1 && s2) return 5;
if (!s1 && s2) return 0;
if (s1 && !s2) return 0;
if (!s1 && !s2) return 0;
}
DLIB_CASSERT(false, "n1: "<< n1 << " n2: "<< n2 << " s1: "<< s1 << " s2: "<< s2);
return 0;
return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
}
};
// ----------------------------------------------------------------------------------------
class map_problem_chain
{
/*
This is a chain structured 8 node graph (so no cycles).
*/
public:
mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
map_problem_chain()
{
for (int i = 0; i < 7; ++i)
{
weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian();
}
}
struct node_iterator
{
node_iterator() {}
node_iterator(unsigned long nid_): nid(nid_) {}
bool operator== (const node_iterator& item) const { return item.nid == nid; }
bool operator!= (const node_iterator& item) const { return item.nid != nid; }
node_iterator& operator++()
{
++nid;
return *this;
}
unsigned long nid;
};
struct neighbor_iterator
{
neighbor_iterator() : count(0) {}
bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
neighbor_iterator& operator++()
{
++count;
return *this;
}
unsigned long node_id () const
{
if (count >= 2)
return 8;
return nid[count];
}
unsigned long nid[2];
unsigned int count;
};
unsigned long number_of_nodes (
) const
{
return 8;
}
node_iterator begin(
) const
{
node_iterator temp;
temp.nid = 0;
return temp;
}
node_iterator end(
) const
{
node_iterator temp;
temp.nid = 8;
return temp;
}
neighbor_iterator begin(
const node_iterator& it
) const
{
neighbor_iterator temp;
if (it.nid == 0)
{
temp.nid[0] = it.nid+1;
temp.nid[1] = 8;
}
else if (it.nid == 7)
{
temp.nid[0] = it.nid-1;
temp.nid[1] = 8;
}
else
{
temp.nid[0] = it.nid-1;
temp.nid[1] = it.nid+1;
}
return temp;
}
neighbor_iterator begin(
const neighbor_iterator& it
) const
{
const unsigned long nid = it.node_id();
neighbor_iterator temp;
if (nid == 0)
{
temp.nid[0] = nid+1;
temp.nid[1] = 8;
}
else if (nid == 7)
{
temp.nid[0] = nid-1;
temp.nid[1] = 8;
}
else
{
temp.nid[0] = nid-1;
temp.nid[1] = nid+1;
}
return temp;
}
neighbor_iterator end(
const node_iterator&
) const
{
neighbor_iterator temp;
temp.nid[0] = 8;
temp.nid[1] = 8;
return temp;
}
neighbor_iterator end(
const neighbor_iterator&
) const
{
neighbor_iterator temp;
temp.nid[0] = 8;
temp.nid[1] = 8;
return temp;
}
unsigned long node_id (
const node_iterator& it
) const
{
return it.nid;
}
unsigned long node_id (
const neighbor_iterator& it
) const
{
return it.node_id();
}
unsigned long num_states (
const node_iterator&
) const
{
return 2;
}
unsigned long num_states (
const neighbor_iterator&
) const
{
return 2;
}
double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.nid, s1, s2); }
double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
private:
double basic_factor_value (
unsigned long n1,
unsigned long n2,
unsigned long s1,
unsigned long s2
) const
{
if (n1 > n2)
{
swap(n1,n2);
swap(s1,s2);
}
return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
}
};
// ----------------------------------------------------------------------------------------
dlib::rand rnd;
class map_problem2
{
@ -451,8 +623,8 @@ namespace
// basically ignore the other node in this factor. The node we
// are ignoring is the center node of this star graph. So we basically
// let it always have a value of 0.
if (s2 == 0)
// let it always have a value of 1.
if (s2 == 1)
return numbers(n1,s1) + 1;
else
return numbers(n1,s1);
@ -563,7 +735,7 @@ namespace
map_assignment2[2] = index_of_max(rowm(prob.numbers,2));
map_assignment2[3] = index_of_max(rowm(prob.numbers,3));
map_assignment2[4] = index_of_max(rowm(prob.numbers,4));
map_assignment2[5] = 0;
map_assignment2[5] = 1;
const double score2 = find_total_score(prob, map_assignment2);
dlog << LINFO << "score NMPLP: " << score1;
@ -588,8 +760,21 @@ namespace
void perform_test (
)
{
do_test<map_problem>();
rnd.clear();
dlog << LINFO << "test on a chain structured graph";
for (int i = 0; i < 30; ++i)
do_test<map_problem_chain>();
dlog << LINFO << "test on a 2 cycle graph";
for (int i = 0; i < 30; ++i)
do_test<map_problem<false> >();
dlog << LINFO << "test on a fully connected graph";
for (int i = 0; i < 5; ++i)
do_test<map_problem<true> >();
dlog << LINFO << "test on a tree structured graph";
for (int i = 0; i < 10; ++i)
do_test2<map_problem2>();
}