Fixed the ranking test functions so they correctly compute the MAP values

for ranking functions which output constant values.
This commit is contained in:
Davis King 2013-04-09 17:44:24 -04:00
parent 0c0e744d9c
commit bbefbc17b1
2 changed files with 30 additions and 5 deletions

View File

@ -216,6 +216,24 @@ namespace dlib
// ----------------------------------------------------------------------------------------
namespace impl
{
inline bool compare_first_reverse_second (
const std::pair<double,bool>& a,
const std::pair<double,bool>& b
)
{
if (a.first < b.first)
return true;
else if (a.first > b.first)
return false;
else if (a.second == true)
return true;
else
return false;
}
}
template <
typename ranking_function,
typename T
@ -264,8 +282,11 @@ namespace dlib
}
// Now compute the average precision for this sample. We need to sort the
// results and the back them into total_ranking.
std::sort(total_scores.rbegin(), total_scores.rend());
// results and the back them into total_ranking. Note that we sort them so
// that, if you get a block of ranking values that are all equal, the elements
// marked as true will come last. This prevents a ranking from outputting a
// constant value for everything and still getting a good MAP score.
std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
total_ranking.clear();
for (unsigned long i = 0; i < total_scores.size(); ++i)
total_ranking.push_back(total_scores[i].second);
@ -390,8 +411,11 @@ namespace dlib
}
// Now compute the average precision for this sample. We need to sort the
// results and the back them into total_ranking.
std::sort(total_scores.rbegin(), total_scores.rend());
// results and the back them into total_ranking. Note that we sort them so
// that, if you get a block of ranking values that are all equal, the elements
// marked as true will come last. This prevents a ranking from outputting a
// constant value for everything and still getting a good MAP score.
std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
total_ranking.clear();
for (unsigned long i = 0; i < total_scores.size(); ++i)
total_ranking.push_back(total_scores[i].second);

View File

@ -181,7 +181,8 @@ namespace dlib
- M(1) == the mean average precision of the rankings induced by funct.
(Mean average precision is a number in the range 0 to 1. Moreover, a
mean average precision of 1 means everything was correctly predicted
while smaller values indicate worse rankings.)
while smaller values indicate worse rankings. See the documentation
for average_precision() for details of its computation.)
!*/
// ----------------------------------------------------------------------------------------