mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added a spec for the assignment problem validation functions and added
missing asserts.
This commit is contained in:
parent
25e976feea
commit
9de4e129a6
@ -23,6 +23,28 @@ namespace dlib
|
||||
const std::vector<typename assignment_function::label_type>& labels
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
#ifdef ENABLE_ASSERTS
|
||||
if (assigner.forces_assignment())
|
||||
{
|
||||
DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
|
||||
"\t double test_assignment_function()"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
|
||||
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
||||
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
DLIB_ASSERT(is_assignment_problem(samples, labels),
|
||||
"\t double test_assignment_function()"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
||||
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
double total_right = 0;
|
||||
double total = 0;
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
@ -55,6 +77,37 @@ namespace dlib
|
||||
const long folds
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
#ifdef ENABLE_ASSERTS
|
||||
if (trainer.forces_assignment())
|
||||
{
|
||||
DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
|
||||
1 < folds && folds <= static_cast<long>(samples.size()),
|
||||
"\t double cross_validate_assignment_trainer()"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t samples.size(): " << samples.size()
|
||||
<< "\n\t folds: " << folds
|
||||
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
|
||||
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
||||
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
DLIB_ASSERT(is_assignment_problem(samples, labels) &&
|
||||
1 < folds && folds <= static_cast<long>(samples.size()),
|
||||
"\t double cross_validate_assignment_trainer()"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t samples.size(): " << samples.size()
|
||||
<< "\n\t folds: " << folds
|
||||
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
||||
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
typedef typename trainer_type::sample_type sample_type;
|
||||
typedef typename trainer_type::label_type label_type;
|
||||
|
||||
|
@ -1,2 +1,69 @@
|
||||
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
|
||||
#ifdef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
|
||||
|
||||
#include <vector>
|
||||
#include "../matrix.h"
|
||||
#include "svm.h"
|
||||
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename assignment_function
|
||||
>
|
||||
double test_assignment_function (
|
||||
const assignment_function& assigner,
|
||||
const std::vector<typename assignment_function::sample_type>& samples,
|
||||
const std::vector<typename assignment_function::label_type>& labels
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_assignment_problem(samples, labels)
|
||||
- if (assigner.forces_assignment()) then
|
||||
- is_forced_assignment_problem(samples, labels)
|
||||
- assignment_function == an instantiation of the dlib::assignment_function
|
||||
template or an object with a compatible interface.
|
||||
ensures
|
||||
- Tests assigner against the given samples and labels and returns the fraction
|
||||
of assignments predicted correctly.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename trainer_type
|
||||
>
|
||||
double cross_validate_assignment_trainer (
|
||||
const trainer_type& trainer,
|
||||
const std::vector<typename trainer_type::sample_type>& samples,
|
||||
const std::vector<typename trainer_type::label_type>& labels,
|
||||
const long folds
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_assignment_problem(samples, labels)
|
||||
- if (trainer.forces_assignment()) then
|
||||
- is_forced_assignment_problem(samples, labels)
|
||||
- 1 < folds <= samples.size()
|
||||
- trainer_type == dlib::structural_assignment_trainer or an object
|
||||
with a compatible interface.
|
||||
ensures
|
||||
- performs k-fold cross validation by using the given trainer to solve the
|
||||
given assignment learning problem for the given number of folds. Each fold
|
||||
is tested using the output of the trainer and the fraction of assignments
|
||||
predicted correctly is returned.
|
||||
- The number of folds used is given by the folds argument.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
|
||||
|
||||
|
||||
|
@ -29,6 +29,8 @@ namespace dlib
|
||||
|
||||
typedef assignment_function<feature_extractor> trained_function_type;
|
||||
|
||||
bool forces_assignment(
|
||||
) const { return false; } // TODO
|
||||
|
||||
const assignment_function<feature_extractor> train (
|
||||
const std::vector<sample_type>& x,
|
||||
|
Loading…
Reference in New Issue
Block a user