Added the ability for the user to set the per class loss.

This commit is contained in:
Davis King 2012-05-06 11:51:18 -04:00
parent 2c45ab5e91
commit efb1a12dc4
2 changed files with 87 additions and 3 deletions

View File

@ -149,6 +149,8 @@ namespace dlib
<< "\n\t labels.size(): " << labels.size() << "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this ); << "\n\t this: " << this );
loss_pos = 1.0;
loss_neg = 1.0;
// figure out how many dimensions are in the node and edge vectors. // figure out how many dimensions are in the node and edge vectors.
node_dims = 0; node_dims = 0;
@ -172,6 +174,41 @@ namespace dlib
return edge_dims; return edge_dims;
} }
void set_loss_on_positive_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_svm_graph_labeling_problem::set_loss_on_positive_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_pos = loss;
}
void set_loss_on_negative_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_svm_graph_labeling_problem::set_loss_on_negative_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_pos = loss;
}
double get_loss_on_negative_class (
) const { return loss_neg; }
double get_loss_on_positive_class (
) const { return loss_pos; }
private: private:
virtual long get_num_dimensions ( virtual long get_num_dimensions (
) const ) const
@ -303,9 +340,9 @@ namespace dlib
// max when we use find_max_factor_graph_potts() below. // max when we use find_max_factor_graph_potts() below.
const bool label_i = (labels[idx][i]!=0); const bool label_i = (labels[idx][i]!=0);
if (label_i) if (label_i)
g.node(i).data -= 1.0; g.node(i).data -= loss_pos;
else else
g.node(i).data += 1.0; g.node(i).data += loss_neg;
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
{ {
@ -331,7 +368,12 @@ namespace dlib
const bool true_label = (labels[idx][i]!= 0); const bool true_label = (labels[idx][i]!= 0);
const bool pred_label = (labeling[i]!= 0); const bool pred_label = (labeling[i]!= 0);
if (true_label != pred_label) if (true_label != pred_label)
++loss; {
if (true_label == true)
loss += loss_pos;
else
loss += loss_neg;
}
} }
// compute psi // compute psi
@ -343,6 +385,8 @@ namespace dlib
long node_dims; long node_dims;
long edge_dims; long edge_dims;
double loss_pos;
double loss_neg;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------

View File

@ -120,6 +120,46 @@ namespace dlib
part of the total weight vector. You can do this by passing get_num_edge_weights() part of the total weight vector. You can do this by passing get_num_edge_weights()
to the third argument to oca::operator(). to the third argument to oca::operator().
!*/ !*/
void set_loss_on_positive_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_positive_class() == loss
!*/
void set_loss_on_negative_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_negative_class() == loss
!*/
double get_loss_on_positive_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of true gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as true. Larger
loss values indicate that we care more strongly than smaller values.
!*/
double get_loss_on_negative_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of false gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as false. Larger
loss values indicate that we care more strongly than smaller values.
!*/
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------