mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Updated the interface to the structural_graph_labeling_trainer so the user
can set the per class loss to whatever they want.
This commit is contained in:
parent
42c123f298
commit
3a5f99f497
@ -32,6 +32,8 @@ namespace dlib
|
||||
eps = 0.1;
|
||||
num_threads = 2;
|
||||
max_cache_size = 40;
|
||||
loss_pos = 1.0;
|
||||
loss_neg = 1.0;
|
||||
}
|
||||
|
||||
void set_num_threads (
|
||||
@ -124,6 +126,42 @@ namespace dlib
|
||||
return C;
|
||||
}
|
||||
|
||||
|
||||
void set_loss_on_positive_class (
|
||||
double loss
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(loss >= 0,
|
||||
"\t structural_graph_labeling_trainer::set_loss_on_positive_class()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t loss: " << loss
|
||||
<< "\n\t this: " << this );
|
||||
|
||||
loss_pos = loss;
|
||||
}
|
||||
|
||||
void set_loss_on_negative_class (
|
||||
double loss
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(loss >= 0,
|
||||
"\t structural_graph_labeling_trainer::set_loss_on_negative_class()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t loss: " << loss
|
||||
<< "\n\t this: " << this );
|
||||
|
||||
loss_neg = loss;
|
||||
}
|
||||
|
||||
double get_loss_on_negative_class (
|
||||
) const { return loss_neg; }
|
||||
|
||||
double get_loss_on_positive_class (
|
||||
) const { return loss_pos; }
|
||||
|
||||
|
||||
template <
|
||||
typename graph_type
|
||||
>
|
||||
@ -150,6 +188,8 @@ namespace dlib
|
||||
prob.set_c(C);
|
||||
prob.set_epsilon(eps);
|
||||
prob.set_max_cache_size(max_cache_size);
|
||||
prob.set_loss_on_positive_class(loss_pos);
|
||||
prob.set_loss_on_negative_class(loss_neg);
|
||||
|
||||
matrix<double,0,1> w;
|
||||
solver(prob, w, prob.get_num_edge_weights());
|
||||
@ -203,6 +243,8 @@ namespace dlib
|
||||
bool verbose;
|
||||
unsigned long num_threads;
|
||||
unsigned long max_cache_size;
|
||||
double loss_pos;
|
||||
double loss_neg;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -48,6 +48,8 @@ namespace dlib
|
||||
- #get_epsilon() == 0.1
|
||||
- #get_num_threads() == 2
|
||||
- #get_max_cache_size() == 40
|
||||
- #get_loss_on_positive_class() == 1.0
|
||||
- #get_loss_on_negative_class() == 1.0
|
||||
!*/
|
||||
|
||||
void set_num_threads (
|
||||
@ -159,6 +161,46 @@ namespace dlib
|
||||
better generalization.
|
||||
!*/
|
||||
|
||||
void set_loss_on_positive_class (
|
||||
double loss
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- loss >= 0
|
||||
ensures
|
||||
- #get_loss_on_positive_class() == loss
|
||||
!*/
|
||||
|
||||
void set_loss_on_negative_class (
|
||||
double loss
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- loss >= 0
|
||||
ensures
|
||||
- #get_loss_on_negative_class() == loss
|
||||
!*/
|
||||
|
||||
double get_loss_on_positive_class (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the loss incurred when a graph node which is supposed to have
|
||||
a label of true gets misclassified. This value controls how much we care
|
||||
about correctly classifying nodes which should be labeled as true. Larger
|
||||
loss values indicate that we care more strongly than smaller values.
|
||||
!*/
|
||||
|
||||
double get_loss_on_negative_class (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the loss incurred when a graph node which is supposed to have
|
||||
a label of false gets misclassified. This value controls how much we care
|
||||
about correctly classifying nodes which should be labeled as false. Larger
|
||||
loss values indicate that we care more strongly than smaller values.
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename graph_type
|
||||
>
|
||||
|
Loading…
Reference in New Issue
Block a user