Added a spec for the assignment problem validation functions and added

missing asserts.
This commit is contained in:
Davis King 2011-12-03 21:57:10 -05:00
parent 25e976feea
commit 9de4e129a6
3 changed files with 122 additions and 0 deletions

View File

@ -23,6 +23,28 @@ namespace dlib
const std::vector<typename assignment_function::label_type>& labels
)
{
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
if (assigner.forces_assignment())
{
DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
"\t double test_assignment_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
else
{
DLIB_ASSERT(is_assignment_problem(samples, labels),
"\t double test_assignment_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
#endif
double total_right = 0;
double total = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
@ -55,6 +77,37 @@ namespace dlib
const long folds
)
{
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
if (trainer.forces_assignment())
{
DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
1 < folds && folds <= static_cast<long>(samples.size()),
"\t double cross_validate_assignment_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t folds: " << folds
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
else
{
DLIB_ASSERT(is_assignment_problem(samples, labels) &&
1 < folds && folds <= static_cast<long>(samples.size()),
"\t double cross_validate_assignment_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t folds: " << folds
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
#endif
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::label_type label_type;

View File

@ -1,2 +1,69 @@
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
#include <vector>
#include "../matrix.h"
#include "svm.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename assignment_function
>
double test_assignment_function (
const assignment_function& assigner,
const std::vector<typename assignment_function::sample_type>& samples,
const std::vector<typename assignment_function::label_type>& labels
);
/*!
requires
- is_assignment_problem(samples, labels)
- if (assigner.forces_assignment()) then
- is_forced_assignment_problem(samples, labels)
- assignment_function == an instantiation of the dlib::assignment_function
template or an object with a compatible interface.
ensures
- Tests assigner against the given samples and labels and returns the fraction
of assignments predicted correctly.
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
double cross_validate_assignment_trainer (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& samples,
const std::vector<typename trainer_type::label_type>& labels,
const long folds
);
/*!
requires
- is_assignment_problem(samples, labels)
- if (trainer.forces_assignment()) then
- is_forced_assignment_problem(samples, labels)
- 1 < folds <= samples.size()
- trainer_type == dlib::structural_assignment_trainer or an object
with a compatible interface.
ensures
- performs k-fold cross validation by using the given trainer to solve the
given assignment learning problem for the given number of folds. Each fold
is tested using the output of the trainer and the fraction of assignments
predicted correctly is returned.
- The number of folds used is given by the folds argument.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__

View File

@ -29,6 +29,8 @@ namespace dlib
typedef assignment_function<feature_extractor> trained_function_type;
bool forces_assignment(
) const { return false; } // TODO
const assignment_function<feature_extractor> train (
const std::vector<sample_type>& x,