From efb1a12dc4cf5e747b38689b406388d2cb1f9d07 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 6 May 2012 11:51:18 -0400 Subject: [PATCH] Added the ability for the user to set the per class loss. --- .../structural_svm_graph_labeling_problem.h | 50 +++++++++++++++++-- ...ural_svm_graph_labeling_problem_abstract.h | 40 +++++++++++++++ 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/dlib/svm/structural_svm_graph_labeling_problem.h b/dlib/svm/structural_svm_graph_labeling_problem.h index 36c290a09..79b490e97 100644 --- a/dlib/svm/structural_svm_graph_labeling_problem.h +++ b/dlib/svm/structural_svm_graph_labeling_problem.h @@ -149,6 +149,8 @@ namespace dlib << "\n\t labels.size(): " << labels.size() << "\n\t this: " << this ); + loss_pos = 1.0; + loss_neg = 1.0; // figure out how many dimensions are in the node and edge vectors. node_dims = 0; @@ -172,6 +174,41 @@ namespace dlib return edge_dims; } + void set_loss_on_positive_class ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0, + "\t structural_svm_graph_labeling_problem::set_loss_on_positive_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t loss: " << loss + << "\n\t this: " << this ); + + loss_pos = loss; + } + + void set_loss_on_negative_class ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0, + "\t structural_svm_graph_labeling_problem::set_loss_on_negative_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t loss: " << loss + << "\n\t this: " << this ); + + loss_pos = loss; + } + + double get_loss_on_negative_class ( + ) const { return loss_neg; } + + double get_loss_on_positive_class ( + ) const { return loss_pos; } + + private: virtual long get_num_dimensions ( ) const @@ -303,9 +340,9 @@ namespace dlib // max when we use find_max_factor_graph_potts() below. const bool label_i = (labels[idx][i]!=0); if (label_i) - g.node(i).data -= 1.0; + g.node(i).data -= loss_pos; else - g.node(i).data += 1.0; + g.node(i).data += loss_neg; for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) { @@ -331,7 +368,12 @@ namespace dlib const bool true_label = (labels[idx][i]!= 0); const bool pred_label = (labeling[i]!= 0); if (true_label != pred_label) - ++loss; + { + if (true_label == true) + loss += loss_pos; + else + loss += loss_neg; + } } // compute psi @@ -343,6 +385,8 @@ namespace dlib long node_dims; long edge_dims; + double loss_pos; + double loss_neg; }; // ---------------------------------------------------------------------------------------- diff --git a/dlib/svm/structural_svm_graph_labeling_problem_abstract.h b/dlib/svm/structural_svm_graph_labeling_problem_abstract.h index a18abbc2c..8241b9cc3 100644 --- a/dlib/svm/structural_svm_graph_labeling_problem_abstract.h +++ b/dlib/svm/structural_svm_graph_labeling_problem_abstract.h @@ -120,6 +120,46 @@ namespace dlib part of the total weight vector. You can do this by passing get_num_edge_weights() to the third argument to oca::operator(). !*/ + + void set_loss_on_positive_class ( + double loss + ); + /*! + requires + - loss >= 0 + ensures + - #get_loss_on_positive_class() == loss + !*/ + + void set_loss_on_negative_class ( + double loss + ); + /*! + requires + - loss >= 0 + ensures + - #get_loss_on_negative_class() == loss + !*/ + + double get_loss_on_positive_class ( + ) const; + /*! + ensures + - returns the loss incurred when a graph node which is supposed to have + a label of true gets misclassified. This value controls how much we care + about correctly classifying nodes which should be labeled as true. Larger + loss values indicate that we care more strongly than smaller values. + !*/ + + double get_loss_on_negative_class ( + ) const; + /*! + ensures + - returns the loss incurred when a graph node which is supposed to have + a label of false gets misclassified. This value controls how much we care + about correctly classifying nodes which should be labeled as false. Larger + loss values indicate that we care more strongly than smaller values. + !*/ }; // ----------------------------------------------------------------------------------------