mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added another overload of find_max_factor_graph_potts() that works on
graphs that are regular grids.
This commit is contained in:
parent
6bf0920d30
commit
1b69ed2e46
@ -9,6 +9,7 @@
|
||||
#include "general_potts_problem.h"
|
||||
#include "../algs.h"
|
||||
#include "../graph_utils.h"
|
||||
#include "../array2d.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
@ -410,6 +411,139 @@ namespace dlib
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename label_image_type,
|
||||
typename image_potts_model
|
||||
>
|
||||
class potts_grid_problem
|
||||
{
|
||||
label_image_type& label_img;
|
||||
long nc;
|
||||
long num_nodes;
|
||||
unsigned char* labels;
|
||||
const image_potts_model& model;
|
||||
|
||||
public:
|
||||
const static unsigned long max_number_of_neighbors = 4;
|
||||
|
||||
potts_grid_problem (
|
||||
label_image_type& label_img_,
|
||||
const image_potts_model& image_potts_model_
|
||||
) :
|
||||
label_img(label_img_),
|
||||
model(image_potts_model_)
|
||||
{
|
||||
num_nodes = model.nr()*model.nc();
|
||||
nc = model.nc();
|
||||
labels = &label_img[0][0];
|
||||
}
|
||||
|
||||
unsigned long number_of_nodes (
|
||||
) const { return num_nodes; }
|
||||
|
||||
unsigned long number_of_neighbors (
|
||||
unsigned long
|
||||
) const
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
|
||||
unsigned long get_neighbor_idx (
|
||||
long node_id1,
|
||||
long node_id2
|
||||
) const
|
||||
{
|
||||
long diff = node_id2-node_id1;
|
||||
if (diff > nc)
|
||||
diff -= (long)number_of_nodes();
|
||||
else if (diff < -nc)
|
||||
diff += (long)number_of_nodes();
|
||||
|
||||
if (diff == 1)
|
||||
return 0;
|
||||
else if (diff == -1)
|
||||
return 1;
|
||||
else if (diff == nc)
|
||||
return 2;
|
||||
else
|
||||
return 3;
|
||||
}
|
||||
|
||||
unsigned long get_neighbor (
|
||||
long node_id,
|
||||
long idx
|
||||
) const
|
||||
{
|
||||
switch(idx)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
long temp = node_id+1;
|
||||
if (temp < (long)number_of_nodes())
|
||||
return temp;
|
||||
else
|
||||
return temp - (long)number_of_nodes();
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
long temp = node_id-1;
|
||||
if (node_id >= 1)
|
||||
return temp;
|
||||
else
|
||||
return temp + (long)number_of_nodes();
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
long temp = node_id+nc;
|
||||
if (temp < (long)number_of_nodes())
|
||||
return temp;
|
||||
else
|
||||
return temp - (long)number_of_nodes();
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
long temp = node_id-nc;
|
||||
if (node_id >= nc)
|
||||
return temp;
|
||||
else
|
||||
return temp + (long)number_of_nodes();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void set_label (
|
||||
const unsigned long& idx,
|
||||
node_label value
|
||||
)
|
||||
{
|
||||
*(labels+idx) = value;
|
||||
}
|
||||
|
||||
node_label get_label (
|
||||
const unsigned long& idx
|
||||
) const
|
||||
{
|
||||
return *(labels+idx);
|
||||
}
|
||||
|
||||
typedef typename image_potts_model::value_type value_type;
|
||||
|
||||
value_type factor_value (unsigned long idx) const
|
||||
{
|
||||
return model.factor_value(idx);
|
||||
}
|
||||
|
||||
value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
|
||||
{
|
||||
return model.factor_value_disagreement(idx1,idx2);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -531,6 +665,29 @@ namespace dlib
|
||||
return score;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename potts_grid_problem,
|
||||
typename mem_manager
|
||||
>
|
||||
typename potts_grid_problem::value_type potts_model_score (
|
||||
const potts_grid_problem& prob,
|
||||
const array2d<node_label,mem_manager>& labels
|
||||
)
|
||||
{
|
||||
DLIB_ASSERT(prob.nr() == labels.nr() && prob.nc() == labels.nc(),
|
||||
"\t value_type potts_model_score(prob,labels)"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t prob.nr(): " << labels.nr()
|
||||
<< "\n\t prob.nc(): " << labels.nc()
|
||||
);
|
||||
typedef array2d<node_label,mem_manager> image_type;
|
||||
// This const_cast is ok because the model object won't actually modify labels
|
||||
dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(const_cast<image_type&>(labels),prob);
|
||||
return potts_model_score(model);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
@ -642,6 +799,23 @@ namespace dlib
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename potts_grid_problem,
|
||||
typename mem_manager
|
||||
>
|
||||
void find_max_factor_graph_potts (
|
||||
const potts_grid_problem& prob,
|
||||
array2d<node_label,mem_manager>& labels
|
||||
)
|
||||
{
|
||||
typedef array2d<node_label,mem_manager> image_type;
|
||||
labels.set_size(prob.nr(), prob.nc());
|
||||
dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(labels,prob);
|
||||
find_max_factor_graph_potts(model);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_H__
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "../matrix.h"
|
||||
#include "min_cut_abstract.h"
|
||||
#include "../graph_utils.h"
|
||||
#include "../array2d/array2d_kernel_abstract.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
@ -159,6 +160,83 @@ namespace dlib
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct potts_grid_problem
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object is a specialization of a potts_problem to the case where
|
||||
the graph is a regular grid where each node is connected to its four
|
||||
neighbors. An example of this is an image where each pixel is a node
|
||||
and is connected to its four immediate neighboring pixels. Therefore,
|
||||
this object defines the interface this special kind of MAP problem
|
||||
must implement if it is to be solved by the find_max_factor_graph_potts(potts_grid_problem,array2d)
|
||||
routine defined at the end of this file.
|
||||
|
||||
|
||||
Note that all nodes always have four neighbors, even nodes on the edge
|
||||
of the graph. This is because these border nodes are connected to
|
||||
the border nodes on the other side of the graph. That is, the graph
|
||||
"wraps" around at the borders.
|
||||
!*/
|
||||
|
||||
// This typedef should be for a type like int or double. It
|
||||
// must also be capable of representing signed values.
|
||||
typedef an_integer_or_real_type value_type;
|
||||
|
||||
long nr(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the number of rows in the grid
|
||||
!*/
|
||||
|
||||
long nc(
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the number of columns in the grid
|
||||
!*/
|
||||
|
||||
value_type factor_value (
|
||||
unsigned long idx
|
||||
) const;
|
||||
/*!
|
||||
requires
|
||||
- idx < nr()*nc()
|
||||
ensures
|
||||
- The grid is represented in row-major-order format. Therefore, idx
|
||||
identifies a node according to its position in the row-major-order
|
||||
representation of the grid graph. Or in other words, idx corresponds
|
||||
to the following row and column location:
|
||||
- row == idx/nc()
|
||||
- col == idx%nc()
|
||||
- returns a value which indicates how "good" it is to assign the idx-th
|
||||
node the label of true. The larger the value, the more desirable it is
|
||||
to give it this label. Similarly, a negative value indicates that it is
|
||||
better to give the node a label of false.
|
||||
!*/
|
||||
|
||||
value_type factor_value_disagreement (
|
||||
unsigned long idx1,
|
||||
unsigned long idx2
|
||||
) const;
|
||||
/*!
|
||||
requires
|
||||
- idx1 < nr()*nc()
|
||||
- idx2 < nr()*nc()
|
||||
- idx1 != idx2
|
||||
- the idx1-th node and idx2-th node are neighbors in the grid graph.
|
||||
ensures
|
||||
- returns a number >= 0. This is the penalty for giving node idx1 and idx2
|
||||
different labels. Larger values indicate a larger penalty.
|
||||
- this function is symmetric. That is, it is true that:
|
||||
factor_value_disagreement(i,j) == factor_value_disagreement(j,i)
|
||||
!*/
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
@ -230,6 +308,44 @@ namespace dlib
|
||||
- Then this function returns F - D
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename potts_grid_problem,
|
||||
typename mem_manager
|
||||
>
|
||||
typename potts_grid_problem::value_type potts_model_score (
|
||||
const potts_grid_problem& prob,
|
||||
const array2d<node_label,mem_manager>& labels
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- prob.nr() == labels.nr()
|
||||
- prob.nc() == labels.nc()
|
||||
- potts_grid_problem == an object with an interface compatible with the
|
||||
potts_grid_problem object defined above.
|
||||
- for all valid i and j:
|
||||
- prob.factor_value_disagreement(i,j) >= 0
|
||||
- prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
|
||||
ensures
|
||||
- computes the model score for the given potts_grid_problem. We define this
|
||||
precisely below:
|
||||
- let L(i) == the boolean label of the i-th variable in prob. Or in other
|
||||
words, L(i) == (labels[i/labels.nc()][i%labels.nc()] != 0).
|
||||
- let F == the sum of values of prob.factor_value(i) for only i values
|
||||
where L(i) == true.
|
||||
- Let D == the sum of values of prob.factor_value_disagreement(i,j)
|
||||
for only i and j values which meet the following conditions:
|
||||
- i and j are neighbors in the graph defined by prob, that is,
|
||||
it is valid to call prob.factor_value_disagreement(i,j).
|
||||
- L(i) != L(j)
|
||||
- i < j
|
||||
(i.e. We want to make sure to only count the edge between i and j once)
|
||||
|
||||
- Then this function returns F - D
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
@ -283,6 +399,33 @@ namespace dlib
|
||||
- the factor_value_disagreement(i,j) is stored in edge(g,i,j).
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename potts_grid_problem,
|
||||
typename mem_manager
|
||||
>
|
||||
void find_max_factor_graph_potts (
|
||||
const potts_grid_problem& prob,
|
||||
array2d<node_label,mem_manager>& labels
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- potts_grid_problem == an object with an interface compatible with the
|
||||
potts_grid_problem object defined above.
|
||||
- for all valid i and j:
|
||||
- prob.factor_value_disagreement(i,j) >= 0
|
||||
- prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
|
||||
ensures
|
||||
- This routine solves a version of a potts problem where the graph is a
|
||||
regular grid where each node is connected to its four immediate neighbors.
|
||||
In particular, this means that this function finds the assignments
|
||||
to all the labels in prob which maximizes potts_model_score(prob,#labels).
|
||||
- The optimal labels are stored in #labels.
|
||||
- #labels.nr() == prob.nr()
|
||||
- #labels.nc() == prob.nc()
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user