Updated the interface to the structural_graph_labeling_trainer so the user

can set the per class loss to whatever they want.
This commit is contained in:
Davis King 2012-05-06 12:00:22 -04:00
parent 42c123f298
commit 3a5f99f497
2 changed files with 84 additions and 0 deletions

View File

@ -32,6 +32,8 @@ namespace dlib
eps = 0.1;
num_threads = 2;
max_cache_size = 40;
loss_pos = 1.0;
loss_neg = 1.0;
}
void set_num_threads (
@ -124,6 +126,42 @@ namespace dlib
return C;
}
void set_loss_on_positive_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_graph_labeling_trainer::set_loss_on_positive_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_pos = loss;
}
void set_loss_on_negative_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_graph_labeling_trainer::set_loss_on_negative_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_neg = loss;
}
double get_loss_on_negative_class (
) const { return loss_neg; }
double get_loss_on_positive_class (
) const { return loss_pos; }
template <
typename graph_type
>
@ -150,6 +188,8 @@ namespace dlib
prob.set_c(C);
prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size);
prob.set_loss_on_positive_class(loss_pos);
prob.set_loss_on_negative_class(loss_neg);
matrix<double,0,1> w;
solver(prob, w, prob.get_num_edge_weights());
@ -203,6 +243,8 @@ namespace dlib
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
double loss_pos;
double loss_neg;
};
// ----------------------------------------------------------------------------------------

View File

@ -48,6 +48,8 @@ namespace dlib
- #get_epsilon() == 0.1
- #get_num_threads() == 2
- #get_max_cache_size() == 40
- #get_loss_on_positive_class() == 1.0
- #get_loss_on_negative_class() == 1.0
!*/
void set_num_threads (
@ -159,6 +161,46 @@ namespace dlib
better generalization.
!*/
void set_loss_on_positive_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_positive_class() == loss
!*/
void set_loss_on_negative_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_negative_class() == loss
!*/
double get_loss_on_positive_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of true gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as true. Larger
loss values indicate that we care more strongly than smaller values.
!*/
double get_loss_on_negative_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of false gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as false. Larger
loss values indicate that we care more strongly than smaller values.
!*/
template <
typename graph_type
>