mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added the ability for the user to set the per class loss.
This commit is contained in:
parent
2c45ab5e91
commit
efb1a12dc4
@ -149,6 +149,8 @@ namespace dlib
|
||||
<< "\n\t labels.size(): " << labels.size()
|
||||
<< "\n\t this: " << this );
|
||||
|
||||
loss_pos = 1.0;
|
||||
loss_neg = 1.0;
|
||||
|
||||
// figure out how many dimensions are in the node and edge vectors.
|
||||
node_dims = 0;
|
||||
@ -172,6 +174,41 @@ namespace dlib
|
||||
return edge_dims;
|
||||
}
|
||||
|
||||
void set_loss_on_positive_class (
|
||||
double loss
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(loss >= 0,
|
||||
"\t structural_svm_graph_labeling_problem::set_loss_on_positive_class()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t loss: " << loss
|
||||
<< "\n\t this: " << this );
|
||||
|
||||
loss_pos = loss;
|
||||
}
|
||||
|
||||
void set_loss_on_negative_class (
|
||||
double loss
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(loss >= 0,
|
||||
"\t structural_svm_graph_labeling_problem::set_loss_on_negative_class()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t loss: " << loss
|
||||
<< "\n\t this: " << this );
|
||||
|
||||
loss_pos = loss;
|
||||
}
|
||||
|
||||
double get_loss_on_negative_class (
|
||||
) const { return loss_neg; }
|
||||
|
||||
double get_loss_on_positive_class (
|
||||
) const { return loss_pos; }
|
||||
|
||||
|
||||
private:
|
||||
virtual long get_num_dimensions (
|
||||
) const
|
||||
@ -303,9 +340,9 @@ namespace dlib
|
||||
// max when we use find_max_factor_graph_potts() below.
|
||||
const bool label_i = (labels[idx][i]!=0);
|
||||
if (label_i)
|
||||
g.node(i).data -= 1.0;
|
||||
g.node(i).data -= loss_pos;
|
||||
else
|
||||
g.node(i).data += 1.0;
|
||||
g.node(i).data += loss_neg;
|
||||
|
||||
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
|
||||
{
|
||||
@ -331,7 +368,12 @@ namespace dlib
|
||||
const bool true_label = (labels[idx][i]!= 0);
|
||||
const bool pred_label = (labeling[i]!= 0);
|
||||
if (true_label != pred_label)
|
||||
++loss;
|
||||
{
|
||||
if (true_label == true)
|
||||
loss += loss_pos;
|
||||
else
|
||||
loss += loss_neg;
|
||||
}
|
||||
}
|
||||
|
||||
// compute psi
|
||||
@ -343,6 +385,8 @@ namespace dlib
|
||||
|
||||
long node_dims;
|
||||
long edge_dims;
|
||||
double loss_pos;
|
||||
double loss_neg;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -120,6 +120,46 @@ namespace dlib
|
||||
part of the total weight vector. You can do this by passing get_num_edge_weights()
|
||||
to the third argument to oca::operator().
|
||||
!*/
|
||||
|
||||
void set_loss_on_positive_class (
|
||||
double loss
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- loss >= 0
|
||||
ensures
|
||||
- #get_loss_on_positive_class() == loss
|
||||
!*/
|
||||
|
||||
void set_loss_on_negative_class (
|
||||
double loss
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- loss >= 0
|
||||
ensures
|
||||
- #get_loss_on_negative_class() == loss
|
||||
!*/
|
||||
|
||||
double get_loss_on_positive_class (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the loss incurred when a graph node which is supposed to have
|
||||
a label of true gets misclassified. This value controls how much we care
|
||||
about correctly classifying nodes which should be labeled as true. Larger
|
||||
loss values indicate that we care more strongly than smaller values.
|
||||
!*/
|
||||
|
||||
double get_loss_on_negative_class (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the loss incurred when a graph node which is supposed to have
|
||||
a label of false gets misclassified. This value controls how much we care
|
||||
about correctly classifying nodes which should be labeled as false. Larger
|
||||
loss values indicate that we care more strongly than smaller values.
|
||||
!*/
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue
Block a user