mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added the ability for the user to set the per class loss.
This commit is contained in:
parent
2c45ab5e91
commit
efb1a12dc4
@ -149,6 +149,8 @@ namespace dlib
|
|||||||
<< "\n\t labels.size(): " << labels.size()
|
<< "\n\t labels.size(): " << labels.size()
|
||||||
<< "\n\t this: " << this );
|
<< "\n\t this: " << this );
|
||||||
|
|
||||||
|
loss_pos = 1.0;
|
||||||
|
loss_neg = 1.0;
|
||||||
|
|
||||||
// figure out how many dimensions are in the node and edge vectors.
|
// figure out how many dimensions are in the node and edge vectors.
|
||||||
node_dims = 0;
|
node_dims = 0;
|
||||||
@ -172,6 +174,41 @@ namespace dlib
|
|||||||
return edge_dims;
|
return edge_dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_loss_on_positive_class (
|
||||||
|
double loss
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(loss >= 0,
|
||||||
|
"\t structural_svm_graph_labeling_problem::set_loss_on_positive_class()"
|
||||||
|
<< "\n\t Invalid inputs were given to this function."
|
||||||
|
<< "\n\t loss: " << loss
|
||||||
|
<< "\n\t this: " << this );
|
||||||
|
|
||||||
|
loss_pos = loss;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_loss_on_negative_class (
|
||||||
|
double loss
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(loss >= 0,
|
||||||
|
"\t structural_svm_graph_labeling_problem::set_loss_on_negative_class()"
|
||||||
|
<< "\n\t Invalid inputs were given to this function."
|
||||||
|
<< "\n\t loss: " << loss
|
||||||
|
<< "\n\t this: " << this );
|
||||||
|
|
||||||
|
loss_pos = loss;
|
||||||
|
}
|
||||||
|
|
||||||
|
double get_loss_on_negative_class (
|
||||||
|
) const { return loss_neg; }
|
||||||
|
|
||||||
|
double get_loss_on_positive_class (
|
||||||
|
) const { return loss_pos; }
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
virtual long get_num_dimensions (
|
virtual long get_num_dimensions (
|
||||||
) const
|
) const
|
||||||
@ -303,9 +340,9 @@ namespace dlib
|
|||||||
// max when we use find_max_factor_graph_potts() below.
|
// max when we use find_max_factor_graph_potts() below.
|
||||||
const bool label_i = (labels[idx][i]!=0);
|
const bool label_i = (labels[idx][i]!=0);
|
||||||
if (label_i)
|
if (label_i)
|
||||||
g.node(i).data -= 1.0;
|
g.node(i).data -= loss_pos;
|
||||||
else
|
else
|
||||||
g.node(i).data += 1.0;
|
g.node(i).data += loss_neg;
|
||||||
|
|
||||||
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
|
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
|
||||||
{
|
{
|
||||||
@ -331,7 +368,12 @@ namespace dlib
|
|||||||
const bool true_label = (labels[idx][i]!= 0);
|
const bool true_label = (labels[idx][i]!= 0);
|
||||||
const bool pred_label = (labeling[i]!= 0);
|
const bool pred_label = (labeling[i]!= 0);
|
||||||
if (true_label != pred_label)
|
if (true_label != pred_label)
|
||||||
++loss;
|
{
|
||||||
|
if (true_label == true)
|
||||||
|
loss += loss_pos;
|
||||||
|
else
|
||||||
|
loss += loss_neg;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute psi
|
// compute psi
|
||||||
@ -343,6 +385,8 @@ namespace dlib
|
|||||||
|
|
||||||
long node_dims;
|
long node_dims;
|
||||||
long edge_dims;
|
long edge_dims;
|
||||||
|
double loss_pos;
|
||||||
|
double loss_neg;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
@ -120,6 +120,46 @@ namespace dlib
|
|||||||
part of the total weight vector. You can do this by passing get_num_edge_weights()
|
part of the total weight vector. You can do this by passing get_num_edge_weights()
|
||||||
to the third argument to oca::operator().
|
to the third argument to oca::operator().
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
void set_loss_on_positive_class (
|
||||||
|
double loss
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- loss >= 0
|
||||||
|
ensures
|
||||||
|
- #get_loss_on_positive_class() == loss
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void set_loss_on_negative_class (
|
||||||
|
double loss
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- loss >= 0
|
||||||
|
ensures
|
||||||
|
- #get_loss_on_negative_class() == loss
|
||||||
|
!*/
|
||||||
|
|
||||||
|
double get_loss_on_positive_class (
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns the loss incurred when a graph node which is supposed to have
|
||||||
|
a label of true gets misclassified. This value controls how much we care
|
||||||
|
about correctly classifying nodes which should be labeled as true. Larger
|
||||||
|
loss values indicate that we care more strongly than smaller values.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
double get_loss_on_negative_class (
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns the loss incurred when a graph node which is supposed to have
|
||||||
|
a label of false gets misclassified. This value controls how much we care
|
||||||
|
about correctly classifying nodes which should be labeled as false. Larger
|
||||||
|
loss values indicate that we care more strongly than smaller values.
|
||||||
|
!*/
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
Loading…
Reference in New Issue
Block a user