mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added overloads of test_graph_labeling_function() and
cross_validate_graph_labeling_trainer() that can incorporate per node loss values.
This commit is contained in:
parent
6eee12f291
commit
8c8c5bf3ce
@ -20,7 +20,8 @@ namespace dlib
|
||||
matrix<double,1,2> test_graph_labeling_function (
|
||||
const graph_labeler& labeler,
|
||||
const dlib::array<graph_type>& samples,
|
||||
const std::vector<std::vector<bool> >& labels
|
||||
const std::vector<std::vector<bool> >& labels,
|
||||
const std::vector<std::vector<double> >& losses
|
||||
)
|
||||
{
|
||||
#ifdef ENABLE_ASSERTS
|
||||
@ -31,6 +32,15 @@ namespace dlib
|
||||
<< "\n\t samples.size(): " << samples.size()
|
||||
<< "\n\t reason_for_failure: " << reason_for_failure
|
||||
);
|
||||
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
|
||||
all_values_are_nonnegative(losses) == true,
|
||||
"\t matrix test_graph_labeling_function()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t labels.size(): " << labels.size()
|
||||
<< "\n\t losses.size(): " << losses.size()
|
||||
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
|
||||
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
|
||||
);
|
||||
#endif
|
||||
|
||||
std::vector<bool> temp;
|
||||
@ -45,17 +55,21 @@ namespace dlib
|
||||
|
||||
for (unsigned long j = 0; j < labels[i].size(); ++j)
|
||||
{
|
||||
// What is the loss for this example? It's just 1 unless we have a
|
||||
// per example loss vector.
|
||||
const double loss = (losses.size() == 0) ? 1.0 : losses[i][j];
|
||||
|
||||
if (labels[i][j])
|
||||
{
|
||||
++num_pos;
|
||||
num_pos += loss;
|
||||
if (temp[j])
|
||||
++num_pos_correct;
|
||||
num_pos_correct += loss;
|
||||
}
|
||||
else
|
||||
{
|
||||
++num_neg;
|
||||
num_neg += loss;
|
||||
if (!temp[j])
|
||||
++num_neg_correct;
|
||||
num_neg_correct += loss;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -72,6 +86,20 @@ namespace dlib
|
||||
return res;
|
||||
}
|
||||
|
||||
template <
|
||||
typename graph_labeler,
|
||||
typename graph_type
|
||||
>
|
||||
matrix<double,1,2> test_graph_labeling_function (
|
||||
const graph_labeler& labeler,
|
||||
const dlib::array<graph_type>& samples,
|
||||
const std::vector<std::vector<bool> >& labels
|
||||
)
|
||||
{
|
||||
std::vector<std::vector<double> > losses;
|
||||
return test_graph_labeling_function(labeler, samples, labels, losses);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
@ -82,6 +110,7 @@ namespace dlib
|
||||
const trainer_type& trainer,
|
||||
const dlib::array<graph_type>& samples,
|
||||
const std::vector<std::vector<bool> >& labels,
|
||||
const std::vector<std::vector<double> >& losses,
|
||||
const long folds
|
||||
)
|
||||
{
|
||||
@ -98,6 +127,15 @@ namespace dlib
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t folds: " << folds
|
||||
);
|
||||
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
|
||||
all_values_are_nonnegative(losses) == true,
|
||||
"\t matrix cross_validate_graph_labeling_trainer()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t labels.size(): " << labels.size()
|
||||
<< "\n\t losses.size(): " << losses.size()
|
||||
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
|
||||
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
|
||||
);
|
||||
#endif
|
||||
|
||||
typedef std::vector<bool> label_type;
|
||||
@ -108,6 +146,7 @@ namespace dlib
|
||||
|
||||
dlib::array<graph_type> samples_test, samples_train;
|
||||
std::vector<label_type> labels_test, labels_train;
|
||||
std::vector<std::vector<double> > losses_test, losses_train;
|
||||
|
||||
|
||||
long next_test_idx = 0;
|
||||
@ -124,8 +163,10 @@ namespace dlib
|
||||
{
|
||||
samples_test.clear();
|
||||
labels_test.clear();
|
||||
losses_test.clear();
|
||||
samples_train.clear();
|
||||
labels_train.clear();
|
||||
losses_train.clear();
|
||||
|
||||
// load up the test samples
|
||||
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
||||
@ -133,6 +174,8 @@ namespace dlib
|
||||
copy_graph(samples[next_test_idx], gtemp);
|
||||
samples_test.push_back(gtemp);
|
||||
labels_test.push_back(labels[next_test_idx]);
|
||||
if (losses.size() != 0)
|
||||
losses_test.push_back(losses[next_test_idx]);
|
||||
next_test_idx = (next_test_idx + 1)%samples.size();
|
||||
}
|
||||
|
||||
@ -143,11 +186,13 @@ namespace dlib
|
||||
copy_graph(samples[next], gtemp);
|
||||
samples_train.push_back(gtemp);
|
||||
labels_train.push_back(labels[next]);
|
||||
if (losses.size() != 0)
|
||||
losses_train.push_back(losses[next]);
|
||||
next = (next + 1)%samples.size();
|
||||
}
|
||||
|
||||
|
||||
const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train);
|
||||
const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train,losses_train);
|
||||
|
||||
// check how good labeler is on the test data
|
||||
for (unsigned long i = 0; i < samples_test.size(); ++i)
|
||||
@ -155,17 +200,21 @@ namespace dlib
|
||||
labeler(samples_test[i], temp);
|
||||
for (unsigned long j = 0; j < labels_test[i].size(); ++j)
|
||||
{
|
||||
// What is the loss for this example? It's just 1 unless we have a
|
||||
// per example loss vector.
|
||||
const double loss = (losses_test.size() == 0) ? 1.0 : losses_test[i][j];
|
||||
|
||||
if (labels_test[i][j])
|
||||
{
|
||||
++num_pos;
|
||||
num_pos += loss;
|
||||
if (temp[j])
|
||||
++num_pos_correct;
|
||||
num_pos_correct += loss;
|
||||
}
|
||||
else
|
||||
{
|
||||
++num_neg;
|
||||
num_neg += loss;
|
||||
if (!temp[j])
|
||||
++num_neg_correct;
|
||||
num_neg_correct += loss;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -185,6 +234,21 @@ namespace dlib
|
||||
return res;
|
||||
}
|
||||
|
||||
template <
|
||||
typename trainer_type,
|
||||
typename graph_type
|
||||
>
|
||||
matrix<double,1,2> cross_validate_graph_labeling_trainer (
|
||||
const trainer_type& trainer,
|
||||
const dlib::array<graph_type>& samples,
|
||||
const std::vector<std::vector<bool> >& labels,
|
||||
const long folds
|
||||
)
|
||||
{
|
||||
std::vector<std::vector<double> > losses;
|
||||
return cross_validate_graph_labeling_trainer(trainer, samples, labels, losses, folds);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
@ -39,6 +39,38 @@ namespace dlib
|
||||
an R of [0,0] indicates that it gets everything wrong.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename graph_labeler,
|
||||
typename graph_type
|
||||
>
|
||||
matrix<double,1,2> test_graph_labeling_function (
|
||||
const graph_labeler& labeler,
|
||||
const dlib::array<graph_type>& samples,
|
||||
const std::vector<std::vector<bool> >& labels,
|
||||
const std::vector<std::vector<double> >& losses
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_graph_labeling_problem(samples,labels) == true
|
||||
- graph_labeler == an object with an interface compatible with the
|
||||
dlib::graph_labeler object.
|
||||
- the following must be a valid expression: labeler(samples[0]);
|
||||
- if (losses.size() != 0) then
|
||||
- sizes_match(labels, losses) == true
|
||||
- all_values_are_nonnegative(losses) == true
|
||||
ensures
|
||||
- This overload of test_graph_labeling_function() does the same thing as the
|
||||
one defined above, except that instead of counting 1 for each labeling
|
||||
mistake, it weights each mistake according to the corresponding value in
|
||||
losses. That is, instead of counting a value of 1 for making a mistake on
|
||||
samples[i].node(j), this routine counts a value of losses[i][j]. Under this
|
||||
interpretation, the loss values represent how useful it is to correctly label
|
||||
each node. Therefore, the values returned represent fractions of overall
|
||||
labeling utility rather than raw labeling accuracy.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
@ -72,6 +104,39 @@ namespace dlib
|
||||
- The number of folds used is given by the folds argument.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename trainer_type,
|
||||
typename graph_type
|
||||
>
|
||||
matrix<double,1,2> cross_validate_graph_labeling_trainer (
|
||||
const trainer_type& trainer,
|
||||
const dlib::array<graph_type>& samples,
|
||||
const std::vector<std::vector<bool> >& labels,
|
||||
const std::vector<std::vector<double> >& losses,
|
||||
const long folds
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_graph_labeling_problem(samples,labels) == true
|
||||
- 1 < folds <= samples.size()
|
||||
- trainer_type == an object which trains some kind of graph labeler object
|
||||
(e.g. structural_graph_labeling_trainer)
|
||||
- if (losses.size() != 0) then
|
||||
- sizes_match(labels, losses) == true
|
||||
- all_values_are_nonnegative(losses) == true
|
||||
ensures
|
||||
- This overload of cross_validate_graph_labeling_trainer() does the same thing
|
||||
as the one defined above, except that instead of counting 1 for each labeling
|
||||
mistake, it weights each mistake according to the corresponding value in
|
||||
losses. That is, instead of counting a value of 1 for making a mistake on
|
||||
samples[i].node(j), this routine counts a value of losses[i][j]. Under this
|
||||
interpretation, the loss values represent how useful it is to correctly label
|
||||
each node. Therefore, the values returned represent fractions of overall
|
||||
labeling utility rather than raw labeling accuracy.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user