mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added an example program for the new assignment learning tools.
This commit is contained in:
parent
b4ba2fce4b
commit
8ce84862a9
@ -35,6 +35,7 @@ ENDMACRO()
|
||||
|
||||
|
||||
#here we apply our macros
|
||||
add_example(assignment_learning_ex)
|
||||
add_example(bayes_net_ex)
|
||||
add_example(bayes_net_from_disk_ex)
|
||||
add_example(bayes_net_gui_ex)
|
||||
|
325
examples/assignment_learning_ex.cpp
Normal file
325
examples/assignment_learning_ex.cpp
Normal file
@ -0,0 +1,325 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
|
||||
This is an example illustrating the use of the dlib machine learning tools for
|
||||
learning to solve the assignment problem.
|
||||
|
||||
Many tasks in computer vision or natural language processing can be thought of
|
||||
as assignment problems. For example, in a computer vision application where
|
||||
you are trying to track objects moving around in video, you likely need to solve
|
||||
an association problem every time you get a new video frame. That is, each new
|
||||
frame will contain objects (e.g. people, cars, etc.) and you will want to
|
||||
determine which of these objects are actually things you have seen in previous
|
||||
frames.
|
||||
|
||||
The assignment problem can be optimally solved using the well known Hungarian
|
||||
algorithm. However, this algorithm requires the user to supply some function
|
||||
which measure the "goodness" of an individual association. In many cases the
|
||||
best way to measure this goodness isn't obvious and therefore machine learning
|
||||
methods are used.
|
||||
|
||||
The remainder of this example program will show you how to learn a goodness
|
||||
function which is optimal, in a certain sense, for use with the Hungarian
|
||||
algorithm. To do this, we will make a simple dataset of example associations
|
||||
and use them to train a supervised machine learning method.
|
||||
*/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include "dlib/svm_threaded.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
/*
|
||||
In an association problem, we will talk about the "Left Hand Set" (LHS) and the
|
||||
"Right Hand Set" (RHS). The task will be to learn to map all elements of LHS to
|
||||
unique elements of RHS. If an element of LHS can't be mapped to a unique element of
|
||||
RHS for any reason (e.g. LHS is bigger than RHS) then it can also be mapped to the
|
||||
special -1 output indicating no mapping.
|
||||
|
||||
So the first step is to define the type of elements in each of these sets. In the
|
||||
code below we will use column vectors in both LHS and RHS. However, in general,
|
||||
they can each contain any type you like. LHS can even contain a different type
|
||||
than RHS.
|
||||
*/
|
||||
|
||||
typedef dlib::matrix<double,0,1> column_vector;
|
||||
|
||||
// This type represents a pair of LHS and RHS. That is, sample_type::first
|
||||
// contains a left hand set and sample_type::second contains a right hand set.
|
||||
typedef std::pair<std::vector<column_vector>, std::vector<column_vector> > sample_type;
|
||||
|
||||
// This type will contain the association information between LHS and RHS. That is,
|
||||
// it will determine which elements of LHS map to which elements of RHS.
|
||||
typedef std::vector<long> label_type;
|
||||
|
||||
// In this example, all our LHS and RHS elements will be 3-dimensional vectors.
|
||||
const unsigned long num_dims = 3;
|
||||
|
||||
void make_data (
|
||||
std::vector<sample_type>& samples,
|
||||
std::vector<label_type>& labels
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- This function creates a training dataset of 5 example associations.
|
||||
- #samples.size() == 5
|
||||
- #labels.size() == 5
|
||||
- for all valid i:
|
||||
- #samples[i].first == a left hand set
|
||||
- #samples[i].second == a right hand set
|
||||
- #labels[i] == a set of integers indicating how to map LHS to RHS. To be
|
||||
precise:
|
||||
- #samples[i].first.size() == #labels[i].size()
|
||||
- for all valid j:
|
||||
-1 <= #labels[i][j] < #samples[i].second.size()
|
||||
(A value of -1 indicates that #samples[i].first[j] isn't associated with anything.
|
||||
All other values indicate the associating element of #samples[i].second)
|
||||
- All elements of #labels[i] which are not equal to -1 are unique. That is,
|
||||
multiple elements of #samples[i].first can't associate to the same element
|
||||
in #samples[i].second.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct feature_extractor
|
||||
{
|
||||
/*!
|
||||
Recall that our task is to learn the "goodness of assignment" function for
|
||||
use with the Hungarian algorithm. The dlib tools assume this function
|
||||
can be written as:
|
||||
match_score(l,r) == dot(w, PSI(l,r))
|
||||
where l is an element of LHS, r is an element of RHS, w is a parameter vector,
|
||||
and PSI() is a user supplied feature extractor.
|
||||
|
||||
This feature_extractor is where we implement PSI(). How you implement this
|
||||
is highly problem dependent.
|
||||
!*/
|
||||
|
||||
// The type of feature vector returned from get_features(). This must be either
|
||||
// a dlib::matrix or a sparse vector.
|
||||
typedef column_vector feature_vector_type;
|
||||
|
||||
// The types of elements in the LHS and RHS sets
|
||||
typedef column_vector lhs_element;
|
||||
typedef column_vector rhs_element;
|
||||
|
||||
|
||||
unsigned long num_features() const
|
||||
{
|
||||
// Return the dimensionality of feature vectors produced by get_features()
|
||||
return num_dims + 1;
|
||||
}
|
||||
|
||||
void get_features (
|
||||
const lhs_element& left,
|
||||
const rhs_element& right,
|
||||
feature_vector_type& feats
|
||||
) const
|
||||
/*!
|
||||
ensures
|
||||
- #feats == PSI(left,right)
|
||||
(i.e. This function computes a feature vector which, in some sense,
|
||||
captures information useful for deciding if matching left to right
|
||||
is "good").
|
||||
!*/
|
||||
{
|
||||
// We will have:
|
||||
// - feats(i) == std::pow(left(i) - right(i), 2.0)
|
||||
// Except for the last element of feats which will be equal to 1 and
|
||||
// therefore function as a bias term. Again, how you define this feature
|
||||
// extractor is highly problem dependent.
|
||||
feats = join_cols(squared(left - right), ones_matrix<double>(1,1));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// We need to define serialize() and deserialize() for our feature extractor if we want
|
||||
// to be able to serialize and deserialize our learned models. In this case the
|
||||
// implementation is empty since our feature_extractor doesn't have any state. But you
|
||||
// might define more complex feature extractors which have state that needs to be saved.
|
||||
void serialize (const feature_extractor& , std::ostream& ) {}
|
||||
void deserialize (feature_extractor& , std::istream& ) {}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main()
|
||||
{
|
||||
try
|
||||
{
|
||||
// Get a small bit of training data.
|
||||
std::vector<sample_type> samples;
|
||||
std::vector<label_type> labels;
|
||||
make_data(samples, labels);
|
||||
|
||||
|
||||
structural_assignment_trainer<feature_extractor> trainer;
|
||||
// This is the common SVM C parameter. Larger values encourage the
|
||||
// trainer to attempt to fit the data exactly but might overfit.
|
||||
// In general, you determine this parameter by cross-validation.
|
||||
trainer.set_c(10);
|
||||
// This trainer can use multiple CPU cores to speed up the training.
|
||||
// So set this to the number of available CPU cores.
|
||||
trainer.set_num_threads(4);
|
||||
|
||||
// Do the training and save the results in assigner.
|
||||
assignment_function<feature_extractor> assigner = trainer.train(samples, labels);
|
||||
|
||||
|
||||
// Test the assigner on our data. The output will indicate that it makes the
|
||||
// correct associations on all samples.
|
||||
cout << "Test the learned assignment function: " << endl;
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
// Predict the assignments for the LHS and RHS in samples[i].
|
||||
std::vector<long> predicted_assignments = assigner(samples[i]);
|
||||
cout << "true labels: " << trans(vector_to_matrix(labels[i]));
|
||||
cout << "predicted labels: " << trans(vector_to_matrix(predicted_assignments)) << endl;
|
||||
}
|
||||
|
||||
// We can also call this tool to compute the percentage of assignments predicted correctly.
|
||||
cout << "training accuracy: " << test_assignment_function(assigner, samples, labels) << endl;
|
||||
|
||||
|
||||
// Since testing on your training data is a really bad idea, we can also do 5-fold cross validation.
|
||||
// Happily, this also indicates that all associations were made correctly.
|
||||
randomize_samples(samples, labels);
|
||||
cout << "cv accuracy: " << cross_validate_assignment_trainer(trainer, samples, labels, 5) << endl;
|
||||
|
||||
|
||||
|
||||
// Finally, the assigner can be serialized to disk just like most dlib objects.
|
||||
ofstream fout("assigner.dat", ios::binary);
|
||||
serialize(assigner, fout);
|
||||
fout.close();
|
||||
|
||||
// recall from disk
|
||||
ifstream fin("assigner.dat", ios::binary);
|
||||
deserialize(assigner, fin);
|
||||
}
|
||||
catch (std::exception& e)
|
||||
{
|
||||
cout << "EXCEPTION THROWN" << endl;
|
||||
cout << e.what() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void make_data (
|
||||
std::vector<sample_type>& samples,
|
||||
std::vector<label_type>& labels
|
||||
)
|
||||
{
|
||||
// Make four different vectors. We will use them to make example assignments.
|
||||
column_vector A(num_dims), B(num_dims), C(num_dims), D(num_dims);
|
||||
A = 1,0,0;
|
||||
B = 0,1,0;
|
||||
C = 0,0,1;
|
||||
D = 0,1,1;
|
||||
|
||||
std::vector<column_vector> lhs;
|
||||
std::vector<column_vector> rhs;
|
||||
label_type mapping;
|
||||
|
||||
// In all the assignments to follow, we will only say an element of LHS
|
||||
// matches an element of RHS if the two are equal. So A matches with A,
|
||||
// B with B, etc. But never A with C, for example.
|
||||
// ------------------------
|
||||
|
||||
lhs.resize(3);
|
||||
lhs[0] = A;
|
||||
lhs[1] = B;
|
||||
lhs[2] = C;
|
||||
|
||||
rhs.resize(3);
|
||||
rhs[0] = B;
|
||||
rhs[1] = A;
|
||||
rhs[2] = C;
|
||||
|
||||
mapping.resize(3);
|
||||
mapping[0] = 1; // lhs[0] matches rhs[1]
|
||||
mapping[1] = 0; // lhs[1] matches rhs[0]
|
||||
mapping[2] = 2; // lhs[2] matches rhs[2]
|
||||
|
||||
samples.push_back(make_pair(lhs,rhs));
|
||||
labels.push_back(mapping);
|
||||
|
||||
// ------------------------
|
||||
|
||||
lhs[0] = C;
|
||||
lhs[1] = A;
|
||||
lhs[2] = B;
|
||||
|
||||
rhs[0] = A;
|
||||
rhs[1] = B;
|
||||
rhs[2] = D;
|
||||
|
||||
mapping[0] = -1; // The -1 indicates that lhs[0] doesn't match anything in rhs.
|
||||
mapping[1] = 0; // lhs[1] matches rhs[0]
|
||||
mapping[2] = 1; // lhs[2] matches rhs[1]
|
||||
|
||||
samples.push_back(make_pair(lhs,rhs));
|
||||
labels.push_back(mapping);
|
||||
|
||||
// ------------------------
|
||||
|
||||
lhs[0] = A;
|
||||
lhs[1] = B;
|
||||
lhs[2] = C;
|
||||
|
||||
rhs.resize(4);
|
||||
rhs[0] = C;
|
||||
rhs[1] = B;
|
||||
rhs[2] = A;
|
||||
rhs[3] = D;
|
||||
|
||||
mapping[0] = 2;
|
||||
mapping[1] = 1;
|
||||
mapping[2] = 0;
|
||||
|
||||
samples.push_back(make_pair(lhs,rhs));
|
||||
labels.push_back(mapping);
|
||||
|
||||
// ------------------------
|
||||
|
||||
lhs.resize(2);
|
||||
lhs[0] = B;
|
||||
lhs[1] = C;
|
||||
|
||||
rhs.resize(3);
|
||||
rhs[0] = C;
|
||||
rhs[1] = A;
|
||||
rhs[2] = D;
|
||||
|
||||
mapping.resize(2);
|
||||
mapping[0] = -1;
|
||||
mapping[1] = 0;
|
||||
|
||||
samples.push_back(make_pair(lhs,rhs));
|
||||
labels.push_back(mapping);
|
||||
|
||||
// ------------------------
|
||||
|
||||
lhs.resize(3);
|
||||
lhs[0] = D;
|
||||
lhs[1] = B;
|
||||
lhs[2] = C;
|
||||
|
||||
// rhs will be empty. So none of the items in lhs can match anything.
|
||||
rhs.resize(0);
|
||||
|
||||
mapping.resize(3);
|
||||
mapping[0] = -1;
|
||||
mapping[1] = -1;
|
||||
mapping[2] = -1;
|
||||
|
||||
samples.push_back(make_pair(lhs,rhs));
|
||||
labels.push_back(mapping);
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user