mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added overloads of test_graph_labeling_function() and
cross_validate_graph_labeling_trainer() that can incorporate per node loss values.
This commit is contained in:
parent
6eee12f291
commit
8c8c5bf3ce
@ -20,7 +20,8 @@ namespace dlib
|
|||||||
matrix<double,1,2> test_graph_labeling_function (
|
matrix<double,1,2> test_graph_labeling_function (
|
||||||
const graph_labeler& labeler,
|
const graph_labeler& labeler,
|
||||||
const dlib::array<graph_type>& samples,
|
const dlib::array<graph_type>& samples,
|
||||||
const std::vector<std::vector<bool> >& labels
|
const std::vector<std::vector<bool> >& labels,
|
||||||
|
const std::vector<std::vector<double> >& losses
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
#ifdef ENABLE_ASSERTS
|
#ifdef ENABLE_ASSERTS
|
||||||
@ -31,6 +32,15 @@ namespace dlib
|
|||||||
<< "\n\t samples.size(): " << samples.size()
|
<< "\n\t samples.size(): " << samples.size()
|
||||||
<< "\n\t reason_for_failure: " << reason_for_failure
|
<< "\n\t reason_for_failure: " << reason_for_failure
|
||||||
);
|
);
|
||||||
|
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
|
||||||
|
all_values_are_nonnegative(losses) == true,
|
||||||
|
"\t matrix test_graph_labeling_function()"
|
||||||
|
<< "\n\t Invalid inputs were given to this function."
|
||||||
|
<< "\n\t labels.size(): " << labels.size()
|
||||||
|
<< "\n\t losses.size(): " << losses.size()
|
||||||
|
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
|
||||||
|
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
|
||||||
|
);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::vector<bool> temp;
|
std::vector<bool> temp;
|
||||||
@ -45,17 +55,21 @@ namespace dlib
|
|||||||
|
|
||||||
for (unsigned long j = 0; j < labels[i].size(); ++j)
|
for (unsigned long j = 0; j < labels[i].size(); ++j)
|
||||||
{
|
{
|
||||||
|
// What is the loss for this example? It's just 1 unless we have a
|
||||||
|
// per example loss vector.
|
||||||
|
const double loss = (losses.size() == 0) ? 1.0 : losses[i][j];
|
||||||
|
|
||||||
if (labels[i][j])
|
if (labels[i][j])
|
||||||
{
|
{
|
||||||
++num_pos;
|
num_pos += loss;
|
||||||
if (temp[j])
|
if (temp[j])
|
||||||
++num_pos_correct;
|
num_pos_correct += loss;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
++num_neg;
|
num_neg += loss;
|
||||||
if (!temp[j])
|
if (!temp[j])
|
||||||
++num_neg_correct;
|
num_neg_correct += loss;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -72,6 +86,20 @@ namespace dlib
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename graph_labeler,
|
||||||
|
typename graph_type
|
||||||
|
>
|
||||||
|
matrix<double,1,2> test_graph_labeling_function (
|
||||||
|
const graph_labeler& labeler,
|
||||||
|
const dlib::array<graph_type>& samples,
|
||||||
|
const std::vector<std::vector<bool> >& labels
|
||||||
|
)
|
||||||
|
{
|
||||||
|
std::vector<std::vector<double> > losses;
|
||||||
|
return test_graph_labeling_function(labeler, samples, labels, losses);
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -82,6 +110,7 @@ namespace dlib
|
|||||||
const trainer_type& trainer,
|
const trainer_type& trainer,
|
||||||
const dlib::array<graph_type>& samples,
|
const dlib::array<graph_type>& samples,
|
||||||
const std::vector<std::vector<bool> >& labels,
|
const std::vector<std::vector<bool> >& labels,
|
||||||
|
const std::vector<std::vector<double> >& losses,
|
||||||
const long folds
|
const long folds
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
@ -98,6 +127,15 @@ namespace dlib
|
|||||||
<< "\n\t invalid inputs were given to this function"
|
<< "\n\t invalid inputs were given to this function"
|
||||||
<< "\n\t folds: " << folds
|
<< "\n\t folds: " << folds
|
||||||
);
|
);
|
||||||
|
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
|
||||||
|
all_values_are_nonnegative(losses) == true,
|
||||||
|
"\t matrix cross_validate_graph_labeling_trainer()"
|
||||||
|
<< "\n\t Invalid inputs were given to this function."
|
||||||
|
<< "\n\t labels.size(): " << labels.size()
|
||||||
|
<< "\n\t losses.size(): " << losses.size()
|
||||||
|
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
|
||||||
|
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
|
||||||
|
);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
typedef std::vector<bool> label_type;
|
typedef std::vector<bool> label_type;
|
||||||
@ -108,6 +146,7 @@ namespace dlib
|
|||||||
|
|
||||||
dlib::array<graph_type> samples_test, samples_train;
|
dlib::array<graph_type> samples_test, samples_train;
|
||||||
std::vector<label_type> labels_test, labels_train;
|
std::vector<label_type> labels_test, labels_train;
|
||||||
|
std::vector<std::vector<double> > losses_test, losses_train;
|
||||||
|
|
||||||
|
|
||||||
long next_test_idx = 0;
|
long next_test_idx = 0;
|
||||||
@ -124,8 +163,10 @@ namespace dlib
|
|||||||
{
|
{
|
||||||
samples_test.clear();
|
samples_test.clear();
|
||||||
labels_test.clear();
|
labels_test.clear();
|
||||||
|
losses_test.clear();
|
||||||
samples_train.clear();
|
samples_train.clear();
|
||||||
labels_train.clear();
|
labels_train.clear();
|
||||||
|
losses_train.clear();
|
||||||
|
|
||||||
// load up the test samples
|
// load up the test samples
|
||||||
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
||||||
@ -133,6 +174,8 @@ namespace dlib
|
|||||||
copy_graph(samples[next_test_idx], gtemp);
|
copy_graph(samples[next_test_idx], gtemp);
|
||||||
samples_test.push_back(gtemp);
|
samples_test.push_back(gtemp);
|
||||||
labels_test.push_back(labels[next_test_idx]);
|
labels_test.push_back(labels[next_test_idx]);
|
||||||
|
if (losses.size() != 0)
|
||||||
|
losses_test.push_back(losses[next_test_idx]);
|
||||||
next_test_idx = (next_test_idx + 1)%samples.size();
|
next_test_idx = (next_test_idx + 1)%samples.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,11 +186,13 @@ namespace dlib
|
|||||||
copy_graph(samples[next], gtemp);
|
copy_graph(samples[next], gtemp);
|
||||||
samples_train.push_back(gtemp);
|
samples_train.push_back(gtemp);
|
||||||
labels_train.push_back(labels[next]);
|
labels_train.push_back(labels[next]);
|
||||||
|
if (losses.size() != 0)
|
||||||
|
losses_train.push_back(losses[next]);
|
||||||
next = (next + 1)%samples.size();
|
next = (next + 1)%samples.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train);
|
const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train,losses_train);
|
||||||
|
|
||||||
// check how good labeler is on the test data
|
// check how good labeler is on the test data
|
||||||
for (unsigned long i = 0; i < samples_test.size(); ++i)
|
for (unsigned long i = 0; i < samples_test.size(); ++i)
|
||||||
@ -155,17 +200,21 @@ namespace dlib
|
|||||||
labeler(samples_test[i], temp);
|
labeler(samples_test[i], temp);
|
||||||
for (unsigned long j = 0; j < labels_test[i].size(); ++j)
|
for (unsigned long j = 0; j < labels_test[i].size(); ++j)
|
||||||
{
|
{
|
||||||
|
// What is the loss for this example? It's just 1 unless we have a
|
||||||
|
// per example loss vector.
|
||||||
|
const double loss = (losses_test.size() == 0) ? 1.0 : losses_test[i][j];
|
||||||
|
|
||||||
if (labels_test[i][j])
|
if (labels_test[i][j])
|
||||||
{
|
{
|
||||||
++num_pos;
|
num_pos += loss;
|
||||||
if (temp[j])
|
if (temp[j])
|
||||||
++num_pos_correct;
|
num_pos_correct += loss;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
++num_neg;
|
num_neg += loss;
|
||||||
if (!temp[j])
|
if (!temp[j])
|
||||||
++num_neg_correct;
|
num_neg_correct += loss;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -185,6 +234,21 @@ namespace dlib
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename trainer_type,
|
||||||
|
typename graph_type
|
||||||
|
>
|
||||||
|
matrix<double,1,2> cross_validate_graph_labeling_trainer (
|
||||||
|
const trainer_type& trainer,
|
||||||
|
const dlib::array<graph_type>& samples,
|
||||||
|
const std::vector<std::vector<bool> >& labels,
|
||||||
|
const long folds
|
||||||
|
)
|
||||||
|
{
|
||||||
|
std::vector<std::vector<double> > losses;
|
||||||
|
return cross_validate_graph_labeling_trainer(trainer, samples, labels, losses, folds);
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,38 @@ namespace dlib
|
|||||||
an R of [0,0] indicates that it gets everything wrong.
|
an R of [0,0] indicates that it gets everything wrong.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename graph_labeler,
|
||||||
|
typename graph_type
|
||||||
|
>
|
||||||
|
matrix<double,1,2> test_graph_labeling_function (
|
||||||
|
const graph_labeler& labeler,
|
||||||
|
const dlib::array<graph_type>& samples,
|
||||||
|
const std::vector<std::vector<bool> >& labels,
|
||||||
|
const std::vector<std::vector<double> >& losses
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- is_graph_labeling_problem(samples,labels) == true
|
||||||
|
- graph_labeler == an object with an interface compatible with the
|
||||||
|
dlib::graph_labeler object.
|
||||||
|
- the following must be a valid expression: labeler(samples[0]);
|
||||||
|
- if (losses.size() != 0) then
|
||||||
|
- sizes_match(labels, losses) == true
|
||||||
|
- all_values_are_nonnegative(losses) == true
|
||||||
|
ensures
|
||||||
|
- This overload of test_graph_labeling_function() does the same thing as the
|
||||||
|
one defined above, except that instead of counting 1 for each labeling
|
||||||
|
mistake, it weights each mistake according to the corresponding value in
|
||||||
|
losses. That is, instead of counting a value of 1 for making a mistake on
|
||||||
|
samples[i].node(j), this routine counts a value of losses[i][j]. Under this
|
||||||
|
interpretation, the loss values represent how useful it is to correctly label
|
||||||
|
each node. Therefore, the values returned represent fractions of overall
|
||||||
|
labeling utility rather than raw labeling accuracy.
|
||||||
|
!*/
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -72,6 +104,39 @@ namespace dlib
|
|||||||
- The number of folds used is given by the folds argument.
|
- The number of folds used is given by the folds argument.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename trainer_type,
|
||||||
|
typename graph_type
|
||||||
|
>
|
||||||
|
matrix<double,1,2> cross_validate_graph_labeling_trainer (
|
||||||
|
const trainer_type& trainer,
|
||||||
|
const dlib::array<graph_type>& samples,
|
||||||
|
const std::vector<std::vector<bool> >& labels,
|
||||||
|
const std::vector<std::vector<double> >& losses,
|
||||||
|
const long folds
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- is_graph_labeling_problem(samples,labels) == true
|
||||||
|
- 1 < folds <= samples.size()
|
||||||
|
- trainer_type == an object which trains some kind of graph labeler object
|
||||||
|
(e.g. structural_graph_labeling_trainer)
|
||||||
|
- if (losses.size() != 0) then
|
||||||
|
- sizes_match(labels, losses) == true
|
||||||
|
- all_values_are_nonnegative(losses) == true
|
||||||
|
ensures
|
||||||
|
- This overload of cross_validate_graph_labeling_trainer() does the same thing
|
||||||
|
as the one defined above, except that instead of counting 1 for each labeling
|
||||||
|
mistake, it weights each mistake according to the corresponding value in
|
||||||
|
losses. That is, instead of counting a value of 1 for making a mistake on
|
||||||
|
samples[i].node(j), this routine counts a value of losses[i][j]. Under this
|
||||||
|
interpretation, the loss values represent how useful it is to correctly label
|
||||||
|
each node. Therefore, the values returned represent fractions of overall
|
||||||
|
labeling utility rather than raw labeling accuracy.
|
||||||
|
!*/
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user