mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Updated find_max_factor_graph_nmplp() to use the version of the algorithm from
the 2011 paper Introduction to dual decomposition for inference by David Sontag, Amir Globerson, and Tommi Jaakkola since the original 2008 paper had an error in the algorithm that negatively effected its convergence. Thanks to James Gunning for pointing this out.
This commit is contained in:
parent
251196c34a
commit
29381bcccb
@ -143,12 +143,19 @@ namespace dlib
|
||||
|
||||
/*
|
||||
This function is an implementation of the NMPLP algorithm introduced in the
|
||||
following paper:
|
||||
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations
|
||||
following papers:
|
||||
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008)
|
||||
by Amir Globerson and Tommi Jaakkola
|
||||
|
||||
In particular, see the pseudocode in Figure 1. The code in this function
|
||||
follows what is described there.
|
||||
Introduction to dual decomposition for inference (2011)
|
||||
by David Sontag, Amir Globerson, and Tommi Jaakkola
|
||||
|
||||
In particular, this function implements the star MPLP update equations shown as
|
||||
equation 1.20 from the paper Introduction to dual decomposition for inference
|
||||
(the method was called NMPLP in the first paper). It should also be noted that
|
||||
the original description of the NMPLP in the first paper had an error in the
|
||||
equations and the second paper contains corrected equations, which is what this
|
||||
function uses.
|
||||
*/
|
||||
|
||||
typedef typename map_problem::node_iterator node_iterator;
|
||||
@ -161,14 +168,15 @@ namespace dlib
|
||||
return;
|
||||
|
||||
|
||||
std::vector<double> gamma_elements;
|
||||
gamma_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
|
||||
std::vector<double> delta_elements;
|
||||
delta_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
|
||||
|
||||
impl::simple_hash_map gamma_idx;
|
||||
impl::simple_hash_map delta_idx;
|
||||
|
||||
|
||||
|
||||
// initialize gamma according to the initialization instructions at top of Figure 1
|
||||
// Initialize delta to zero and fill up the hash table with the appropriate values
|
||||
// so we can index into delta later on.
|
||||
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
|
||||
{
|
||||
const unsigned long id_i = prob.node_id(i);
|
||||
@ -176,106 +184,124 @@ namespace dlib
|
||||
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
|
||||
{
|
||||
const unsigned long id_j = prob.node_id(j);
|
||||
|
||||
gamma_idx.insert(id_i, id_j, gamma_elements.size());
|
||||
delta_idx.insert(id_i, id_j, delta_elements.size());
|
||||
|
||||
const unsigned long num_states_xj = prob.num_states(j);
|
||||
|
||||
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
|
||||
{
|
||||
const unsigned long num_states_xi = prob.num_states(i);
|
||||
|
||||
double best_val = -std::numeric_limits<double>::infinity();
|
||||
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
|
||||
{
|
||||
double val = prob.factor_value(i,j,xi,xj);
|
||||
|
||||
double sum_temp = 0;
|
||||
|
||||
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
|
||||
{
|
||||
if (j == k)
|
||||
continue;
|
||||
|
||||
double max_val = -std::numeric_limits<double>::infinity();
|
||||
for (unsigned long xk = 0; xk < prob.num_states(k); ++xk)
|
||||
{
|
||||
double temp = prob.factor_value(k,i,xk,xi);
|
||||
if (temp > max_val)
|
||||
max_val = temp;
|
||||
}
|
||||
|
||||
sum_temp += max_val;
|
||||
}
|
||||
|
||||
|
||||
val += 0.5*sum_temp;
|
||||
|
||||
if (val > best_val)
|
||||
best_val = val;
|
||||
}
|
||||
|
||||
|
||||
gamma_elements.push_back(best_val);
|
||||
}
|
||||
delta_elements.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<double> gamma_i;
|
||||
std::vector<std::vector<double> > gamma_ji;
|
||||
std::vector<std::vector<double> > delta_to_j_no_i;
|
||||
// These arrays will end up with a length equal to the maximum number of neighbors
|
||||
// of any node in the graph. So reserve a bigish number of slots so that we are
|
||||
// very unlikely to need to preform an expensive reallocation during the
|
||||
// optimization.
|
||||
gamma_ji.reserve(10000);
|
||||
delta_to_j_no_i.reserve(10000);
|
||||
|
||||
|
||||
double max_change = eps + 1;
|
||||
// Now do the main body of the optimization.
|
||||
for (unsigned long iter = 0; iter < max_iter && max_change > eps; ++iter)
|
||||
unsigned long iter;
|
||||
for (iter = 0; iter < max_iter && max_change > eps; ++iter)
|
||||
{
|
||||
max_change = -std::numeric_limits<double>::infinity();
|
||||
|
||||
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
|
||||
{
|
||||
const unsigned long id_i = prob.node_id(i);
|
||||
const unsigned long num_states_xi = prob.num_states(i);
|
||||
gamma_i.assign(num_states_xi, 0);
|
||||
|
||||
double num_neighbors = 0;
|
||||
|
||||
unsigned int jcnt = 0;
|
||||
// first we fill in the gamma vectors
|
||||
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
|
||||
{
|
||||
// Make sure these arrays are big enough to hold all the neighbor
|
||||
// information.
|
||||
if (jcnt >= gamma_ji.size())
|
||||
{
|
||||
gamma_ji.resize(gamma_ji.size()+1);
|
||||
delta_to_j_no_i.resize(delta_to_j_no_i.size()+1);
|
||||
}
|
||||
|
||||
++num_neighbors;
|
||||
const unsigned long id_j = prob.node_id(j);
|
||||
const unsigned long num_states_xj = prob.num_states(j);
|
||||
|
||||
gamma_ji[jcnt].assign(num_states_xi, -std::numeric_limits<double>::infinity());
|
||||
delta_to_j_no_i[jcnt].assign(num_states_xj, 0);
|
||||
|
||||
// compute delta_j^{-i} and store it in delta_to_j_no_i[jcnt]
|
||||
for (neighbor_iterator k = prob.begin(j); k != prob.end(j); ++k)
|
||||
{
|
||||
const unsigned long id_k = prob.node_id(k);
|
||||
if (id_k==id_i)
|
||||
continue;
|
||||
const double* const delta_kj = &delta_elements[delta_idx(id_k,id_j)];
|
||||
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
|
||||
{
|
||||
delta_to_j_no_i[jcnt][xj] += delta_kj[xj];
|
||||
}
|
||||
}
|
||||
|
||||
// now compute gamma values
|
||||
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
|
||||
{
|
||||
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
|
||||
{
|
||||
gamma_ji[jcnt][xi] = std::max(gamma_ji[jcnt][xi], prob.factor_value(i,j,xi,xj) + delta_to_j_no_i[jcnt][xj]);
|
||||
}
|
||||
gamma_i[xi] += gamma_ji[jcnt][xi];
|
||||
}
|
||||
++jcnt;
|
||||
}
|
||||
|
||||
// now update the delta values
|
||||
jcnt = 0;
|
||||
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
|
||||
{
|
||||
const unsigned long id_j = prob.node_id(j);
|
||||
double* const gamma_ji = &gamma_elements[gamma_idx(id_j,id_i)];
|
||||
double* const gamma_ij = &gamma_elements[gamma_idx(id_i,id_j)];
|
||||
|
||||
const unsigned long num_states_xj = prob.num_states(j);
|
||||
|
||||
// messages from j to i
|
||||
double* const delta_ji = &delta_elements[delta_idx(id_j,id_i)];
|
||||
|
||||
// messages from i to j
|
||||
double* const delta_ij = &delta_elements[delta_idx(id_i,id_j)];
|
||||
|
||||
for (unsigned long xj = 0; xj < num_states_xj; ++xj)
|
||||
{
|
||||
const unsigned long num_states_xi = prob.num_states(i);
|
||||
|
||||
double best_val = -std::numeric_limits<double>::infinity();
|
||||
|
||||
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
|
||||
{
|
||||
double val = prob.factor_value(i,j,xi,xj) - gamma_ji[xi];
|
||||
|
||||
double sum_temp = 0;
|
||||
|
||||
int num_neighbors = 0;
|
||||
for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
|
||||
{
|
||||
const unsigned long id_k = prob.node_id(k);
|
||||
++num_neighbors;
|
||||
|
||||
const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
|
||||
sum_temp += gamma_ki[xi];
|
||||
}
|
||||
|
||||
|
||||
val += 2.0/(num_neighbors + 1.0)*sum_temp;
|
||||
|
||||
double val = prob.factor_value(i,j,xi,xj) + 2/(num_neighbors+1)*gamma_i[xi] -gamma_ji[jcnt][xi];
|
||||
if (val > best_val)
|
||||
best_val = val;
|
||||
}
|
||||
best_val = -0.5*delta_to_j_no_i[jcnt][xj] + 0.5*best_val;
|
||||
|
||||
if (std::abs(delta_ij[xj] - best_val) > max_change)
|
||||
max_change = std::abs(delta_ij[xj] - best_val);
|
||||
|
||||
if (std::abs(gamma_ij[xj] - best_val) > max_change)
|
||||
max_change = std::abs(gamma_ij[xj] - best_val);
|
||||
|
||||
gamma_ij[xj] = best_val;
|
||||
delta_ij[xj] = best_val;
|
||||
}
|
||||
|
||||
for (unsigned long xi = 0; xi < num_states_xi; ++xi)
|
||||
{
|
||||
double new_val = -1/(num_neighbors+1)*gamma_i[xi] + gamma_ji[jcnt][xi];
|
||||
if (std::abs(delta_ji[xi] - new_val) > max_change)
|
||||
max_change = std::abs(delta_ji[xi] - new_val);
|
||||
delta_ji[xi] = new_val;
|
||||
}
|
||||
++jcnt;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -294,8 +320,8 @@ namespace dlib
|
||||
|
||||
for (unsigned long xi = 0; xi < b.size(); ++xi)
|
||||
{
|
||||
const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
|
||||
b[xi] += gamma_ki[xi];
|
||||
const double* const delta_ki = &delta_elements[delta_idx(id_k,id_i)];
|
||||
b[xi] += delta_ki[xi];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -348,9 +348,12 @@ namespace dlib
|
||||
|
||||
|
||||
- This function is an implementation of the NMPLP algorithm introduced in the
|
||||
following paper:
|
||||
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations
|
||||
following papers:
|
||||
Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008)
|
||||
by Amir Globerson and Tommi Jaakkola
|
||||
|
||||
Introduction to dual decomposition for inference (2011)
|
||||
by David Sontag, Amir Globerson, and Tommi Jaakkola
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
#include <dlib/optimization.h>
|
||||
#include <dlib/unordered_pair.h>
|
||||
#include <dlib/rand.h>
|
||||
|
||||
#include "tester.h"
|
||||
@ -19,14 +20,32 @@ namespace
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
dlib::rand rnd;
|
||||
|
||||
template <bool fully_connected>
|
||||
class map_problem
|
||||
{
|
||||
/*
|
||||
This is a simple 8 node problem with two cycles in it.
|
||||
This is a simple 8 node problem with two cycles in it unless fully_connected is true
|
||||
and then it's a fully connected 8 note graph.
|
||||
*/
|
||||
|
||||
public:
|
||||
|
||||
mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
|
||||
map_problem()
|
||||
{
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
for (int j = i; j < 8; ++j)
|
||||
{
|
||||
weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian();
|
||||
weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian();
|
||||
weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian();
|
||||
weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct node_iterator
|
||||
{
|
||||
@ -58,6 +77,14 @@ namespace
|
||||
|
||||
unsigned long node_id () const
|
||||
{
|
||||
if (fully_connected)
|
||||
{
|
||||
if (count < home_node)
|
||||
return count;
|
||||
else
|
||||
return count+1;
|
||||
}
|
||||
|
||||
if (home_node < 4)
|
||||
{
|
||||
if (count == 0)
|
||||
@ -127,7 +154,7 @@ namespace
|
||||
) const
|
||||
{
|
||||
neighbor_iterator temp;
|
||||
temp.home_node = 8;
|
||||
temp.home_node = 9;
|
||||
temp.count = 8;
|
||||
return temp;
|
||||
}
|
||||
@ -137,7 +164,7 @@ namespace
|
||||
) const
|
||||
{
|
||||
neighbor_iterator temp;
|
||||
temp.home_node = 8;
|
||||
temp.home_node = 9;
|
||||
temp.count = 8;
|
||||
return temp;
|
||||
}
|
||||
@ -195,76 +222,221 @@ namespace
|
||||
swap(n1,n2);
|
||||
swap(s1,s2);
|
||||
}
|
||||
|
||||
if (n1 == 0 && n2 == 1)
|
||||
{
|
||||
if (s1 && s2) return 100;
|
||||
if (!s1 && s2) return 0;
|
||||
if (s1 && !s2) return 0;
|
||||
if (!s1 && !s2) return 0;
|
||||
}
|
||||
if (n1 == 1 && n2 == 2)
|
||||
{
|
||||
if (s1 && s2) return 5;
|
||||
if (!s1 && s2) return 9;
|
||||
if (s1 && !s2) return -100;
|
||||
if (!s1 && !s2) return 9;
|
||||
}
|
||||
if (n1 == 2 && n2 == 3)
|
||||
{
|
||||
if (s1 && s2) return 10;
|
||||
if (!s1 && s2) return 0;
|
||||
if (s1 && !s2) return 10;
|
||||
if (!s1 && !s2) return 10;
|
||||
}
|
||||
if (n1 == 0 && n2 == 3)
|
||||
{
|
||||
if (s1 && s2) return -5;
|
||||
if (!s1 && s2) return 0;
|
||||
if (s1 && !s2) return 0;
|
||||
if (!s1 && !s2) return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
if (n1 == 4 && n2 == 5)
|
||||
{
|
||||
if (s1 && s2) return -100;
|
||||
if (!s1 && s2) return 0;
|
||||
if (s1 && !s2) return 0;
|
||||
if (!s1 && !s2) return 0;
|
||||
}
|
||||
if (n1 == 5 && n2 == 6)
|
||||
{
|
||||
if (s1 && s2) return -5;
|
||||
if (!s1 && s2) return -9;
|
||||
if (s1 && !s2) return 100;
|
||||
if (!s1 && !s2) return -9;
|
||||
}
|
||||
if (n1 == 6 && n2 == 7)
|
||||
{
|
||||
if (s1 && s2) return -10;
|
||||
if (!s1 && s2) return 0;
|
||||
if (s1 && !s2) return -10;
|
||||
if (!s1 && !s2) return -10;
|
||||
}
|
||||
if (n1 == 4 && n2 == 7)
|
||||
{
|
||||
if (s1 && s2) return 5;
|
||||
if (!s1 && s2) return 0;
|
||||
if (s1 && !s2) return 0;
|
||||
if (!s1 && !s2) return 0;
|
||||
}
|
||||
|
||||
DLIB_CASSERT(false, "n1: "<< n1 << " n2: "<< n2 << " s1: "<< s1 << " s2: "<< s2);
|
||||
return 0;
|
||||
return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class map_problem_chain
|
||||
{
|
||||
/*
|
||||
This is a chain structured 8 node graph (so no cycles).
|
||||
*/
|
||||
|
||||
public:
|
||||
|
||||
mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
|
||||
map_problem_chain()
|
||||
{
|
||||
for (int i = 0; i < 7; ++i)
|
||||
{
|
||||
weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian();
|
||||
weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian();
|
||||
weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian();
|
||||
weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian();
|
||||
}
|
||||
}
|
||||
|
||||
struct node_iterator
|
||||
{
|
||||
node_iterator() {}
|
||||
node_iterator(unsigned long nid_): nid(nid_) {}
|
||||
bool operator== (const node_iterator& item) const { return item.nid == nid; }
|
||||
bool operator!= (const node_iterator& item) const { return item.nid != nid; }
|
||||
|
||||
node_iterator& operator++()
|
||||
{
|
||||
++nid;
|
||||
return *this;
|
||||
}
|
||||
|
||||
unsigned long nid;
|
||||
};
|
||||
|
||||
struct neighbor_iterator
|
||||
{
|
||||
neighbor_iterator() : count(0) {}
|
||||
|
||||
bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
|
||||
bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
|
||||
neighbor_iterator& operator++()
|
||||
{
|
||||
++count;
|
||||
return *this;
|
||||
}
|
||||
|
||||
unsigned long node_id () const
|
||||
{
|
||||
if (count >= 2)
|
||||
return 8;
|
||||
return nid[count];
|
||||
}
|
||||
|
||||
unsigned long nid[2];
|
||||
unsigned int count;
|
||||
};
|
||||
|
||||
unsigned long number_of_nodes (
|
||||
) const
|
||||
{
|
||||
return 8;
|
||||
}
|
||||
|
||||
node_iterator begin(
|
||||
) const
|
||||
{
|
||||
node_iterator temp;
|
||||
temp.nid = 0;
|
||||
return temp;
|
||||
}
|
||||
|
||||
node_iterator end(
|
||||
) const
|
||||
{
|
||||
node_iterator temp;
|
||||
temp.nid = 8;
|
||||
return temp;
|
||||
}
|
||||
|
||||
neighbor_iterator begin(
|
||||
const node_iterator& it
|
||||
) const
|
||||
{
|
||||
neighbor_iterator temp;
|
||||
if (it.nid == 0)
|
||||
{
|
||||
temp.nid[0] = it.nid+1;
|
||||
temp.nid[1] = 8;
|
||||
}
|
||||
else if (it.nid == 7)
|
||||
{
|
||||
temp.nid[0] = it.nid-1;
|
||||
temp.nid[1] = 8;
|
||||
}
|
||||
else
|
||||
{
|
||||
temp.nid[0] = it.nid-1;
|
||||
temp.nid[1] = it.nid+1;
|
||||
}
|
||||
return temp;
|
||||
}
|
||||
|
||||
neighbor_iterator begin(
|
||||
const neighbor_iterator& it
|
||||
) const
|
||||
{
|
||||
const unsigned long nid = it.node_id();
|
||||
neighbor_iterator temp;
|
||||
if (nid == 0)
|
||||
{
|
||||
temp.nid[0] = nid+1;
|
||||
temp.nid[1] = 8;
|
||||
}
|
||||
else if (nid == 7)
|
||||
{
|
||||
temp.nid[0] = nid-1;
|
||||
temp.nid[1] = 8;
|
||||
}
|
||||
else
|
||||
{
|
||||
temp.nid[0] = nid-1;
|
||||
temp.nid[1] = nid+1;
|
||||
}
|
||||
return temp;
|
||||
}
|
||||
|
||||
neighbor_iterator end(
|
||||
const node_iterator&
|
||||
) const
|
||||
{
|
||||
neighbor_iterator temp;
|
||||
temp.nid[0] = 8;
|
||||
temp.nid[1] = 8;
|
||||
return temp;
|
||||
}
|
||||
|
||||
neighbor_iterator end(
|
||||
const neighbor_iterator&
|
||||
) const
|
||||
{
|
||||
neighbor_iterator temp;
|
||||
temp.nid[0] = 8;
|
||||
temp.nid[1] = 8;
|
||||
return temp;
|
||||
}
|
||||
|
||||
|
||||
unsigned long node_id (
|
||||
const node_iterator& it
|
||||
) const
|
||||
{
|
||||
return it.nid;
|
||||
}
|
||||
|
||||
unsigned long node_id (
|
||||
const neighbor_iterator& it
|
||||
) const
|
||||
{
|
||||
return it.node_id();
|
||||
}
|
||||
|
||||
|
||||
unsigned long num_states (
|
||||
const node_iterator&
|
||||
) const
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
unsigned long num_states (
|
||||
const neighbor_iterator&
|
||||
) const
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
|
||||
{ return basic_factor_value(it1.nid, it2.nid, s1, s2); }
|
||||
double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
|
||||
{ return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
|
||||
double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
|
||||
{ return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
|
||||
double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
|
||||
{ return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
|
||||
|
||||
private:
|
||||
|
||||
double basic_factor_value (
|
||||
unsigned long n1,
|
||||
unsigned long n2,
|
||||
unsigned long s1,
|
||||
unsigned long s2
|
||||
) const
|
||||
{
|
||||
if (n1 > n2)
|
||||
{
|
||||
swap(n1,n2);
|
||||
swap(s1,s2);
|
||||
}
|
||||
return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
dlib::rand rnd;
|
||||
|
||||
class map_problem2
|
||||
{
|
||||
@ -451,8 +623,8 @@ namespace
|
||||
|
||||
// basically ignore the other node in this factor. The node we
|
||||
// are ignoring is the center node of this star graph. So we basically
|
||||
// let it always have a value of 0.
|
||||
if (s2 == 0)
|
||||
// let it always have a value of 1.
|
||||
if (s2 == 1)
|
||||
return numbers(n1,s1) + 1;
|
||||
else
|
||||
return numbers(n1,s1);
|
||||
@ -563,7 +735,7 @@ namespace
|
||||
map_assignment2[2] = index_of_max(rowm(prob.numbers,2));
|
||||
map_assignment2[3] = index_of_max(rowm(prob.numbers,3));
|
||||
map_assignment2[4] = index_of_max(rowm(prob.numbers,4));
|
||||
map_assignment2[5] = 0;
|
||||
map_assignment2[5] = 1;
|
||||
const double score2 = find_total_score(prob, map_assignment2);
|
||||
|
||||
dlog << LINFO << "score NMPLP: " << score1;
|
||||
@ -588,8 +760,21 @@ namespace
|
||||
void perform_test (
|
||||
)
|
||||
{
|
||||
do_test<map_problem>();
|
||||
rnd.clear();
|
||||
|
||||
dlog << LINFO << "test on a chain structured graph";
|
||||
for (int i = 0; i < 30; ++i)
|
||||
do_test<map_problem_chain>();
|
||||
|
||||
dlog << LINFO << "test on a 2 cycle graph";
|
||||
for (int i = 0; i < 30; ++i)
|
||||
do_test<map_problem<false> >();
|
||||
|
||||
dlog << LINFO << "test on a fully connected graph";
|
||||
for (int i = 0; i < 5; ++i)
|
||||
do_test<map_problem<true> >();
|
||||
|
||||
dlog << LINFO << "test on a tree structured graph";
|
||||
for (int i = 0; i < 10; ++i)
|
||||
do_test2<map_problem2>();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user