diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4166e75a9..2d86bcde5 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -35,6 +35,7 @@ ENDMACRO() #here we apply our macros +add_example(assignment_learning_ex) add_example(bayes_net_ex) add_example(bayes_net_from_disk_ex) add_example(bayes_net_gui_ex) diff --git a/examples/assignment_learning_ex.cpp b/examples/assignment_learning_ex.cpp new file mode 100644 index 000000000..c646f39fd --- /dev/null +++ b/examples/assignment_learning_ex.cpp @@ -0,0 +1,325 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the dlib machine learning tools for + learning to solve the assignment problem. + + Many tasks in computer vision or natural language processing can be thought of + as assignment problems. For example, in a computer vision application where + you are trying to track objects moving around in video, you likely need to solve + an association problem every time you get a new video frame. That is, each new + frame will contain objects (e.g. people, cars, etc.) and you will want to + determine which of these objects are actually things you have seen in previous + frames. + + The assignment problem can be optimally solved using the well known Hungarian + algorithm. However, this algorithm requires the user to supply some function + which measure the "goodness" of an individual association. In many cases the + best way to measure this goodness isn't obvious and therefore machine learning + methods are used. + + The remainder of this example program will show you how to learn a goodness + function which is optimal, in a certain sense, for use with the Hungarian + algorithm. To do this, we will make a simple dataset of example associations + and use them to train a supervised machine learning method. +*/ + + +#include +#include "dlib/svm_threaded.h" + +using namespace std; +using namespace dlib; + + +// ---------------------------------------------------------------------------------------- + +/* + In an association problem, we will talk about the "Left Hand Set" (LHS) and the + "Right Hand Set" (RHS). The task will be to learn to map all elements of LHS to + unique elements of RHS. If an element of LHS can't be mapped to a unique element of + RHS for any reason (e.g. LHS is bigger than RHS) then it can also be mapped to the + special -1 output indicating no mapping. + + So the first step is to define the type of elements in each of these sets. In the + code below we will use column vectors in both LHS and RHS. However, in general, + they can each contain any type you like. LHS can even contain a different type + than RHS. +*/ + +typedef dlib::matrix column_vector; + +// This type represents a pair of LHS and RHS. That is, sample_type::first +// contains a left hand set and sample_type::second contains a right hand set. +typedef std::pair, std::vector > sample_type; + +// This type will contain the association information between LHS and RHS. That is, +// it will determine which elements of LHS map to which elements of RHS. +typedef std::vector label_type; + +// In this example, all our LHS and RHS elements will be 3-dimensional vectors. +const unsigned long num_dims = 3; + +void make_data ( + std::vector& samples, + std::vector& labels +); +/*! + ensures + - This function creates a training dataset of 5 example associations. + - #samples.size() == 5 + - #labels.size() == 5 + - for all valid i: + - #samples[i].first == a left hand set + - #samples[i].second == a right hand set + - #labels[i] == a set of integers indicating how to map LHS to RHS. To be + precise: + - #samples[i].first.size() == #labels[i].size() + - for all valid j: + -1 <= #labels[i][j] < #samples[i].second.size() + (A value of -1 indicates that #samples[i].first[j] isn't associated with anything. + All other values indicate the associating element of #samples[i].second) + - All elements of #labels[i] which are not equal to -1 are unique. That is, + multiple elements of #samples[i].first can't associate to the same element + in #samples[i].second. +!*/ + +// ---------------------------------------------------------------------------------------- + +struct feature_extractor +{ + /*! + Recall that our task is to learn the "goodness of assignment" function for + use with the Hungarian algorithm. The dlib tools assume this function + can be written as: + match_score(l,r) == dot(w, PSI(l,r)) + where l is an element of LHS, r is an element of RHS, w is a parameter vector, + and PSI() is a user supplied feature extractor. + + This feature_extractor is where we implement PSI(). How you implement this + is highly problem dependent. + !*/ + + // The type of feature vector returned from get_features(). This must be either + // a dlib::matrix or a sparse vector. + typedef column_vector feature_vector_type; + + // The types of elements in the LHS and RHS sets + typedef column_vector lhs_element; + typedef column_vector rhs_element; + + + unsigned long num_features() const + { + // Return the dimensionality of feature vectors produced by get_features() + return num_dims + 1; + } + + void get_features ( + const lhs_element& left, + const rhs_element& right, + feature_vector_type& feats + ) const + /*! + ensures + - #feats == PSI(left,right) + (i.e. This function computes a feature vector which, in some sense, + captures information useful for deciding if matching left to right + is "good"). + !*/ + { + // We will have: + // - feats(i) == std::pow(left(i) - right(i), 2.0) + // Except for the last element of feats which will be equal to 1 and + // therefore function as a bias term. Again, how you define this feature + // extractor is highly problem dependent. + feats = join_cols(squared(left - right), ones_matrix(1,1)); + } + +}; + +// We need to define serialize() and deserialize() for our feature extractor if we want +// to be able to serialize and deserialize our learned models. In this case the +// implementation is empty since our feature_extractor doesn't have any state. But you +// might define more complex feature extractors which have state that needs to be saved. +void serialize (const feature_extractor& , std::ostream& ) {} +void deserialize (feature_extractor& , std::istream& ) {} + +// ---------------------------------------------------------------------------------------- + +int main() +{ + try + { + // Get a small bit of training data. + std::vector samples; + std::vector labels; + make_data(samples, labels); + + + structural_assignment_trainer trainer; + // This is the common SVM C parameter. Larger values encourage the + // trainer to attempt to fit the data exactly but might overfit. + // In general, you determine this parameter by cross-validation. + trainer.set_c(10); + // This trainer can use multiple CPU cores to speed up the training. + // So set this to the number of available CPU cores. + trainer.set_num_threads(4); + + // Do the training and save the results in assigner. + assignment_function assigner = trainer.train(samples, labels); + + + // Test the assigner on our data. The output will indicate that it makes the + // correct associations on all samples. + cout << "Test the learned assignment function: " << endl; + for (unsigned long i = 0; i < samples.size(); ++i) + { + // Predict the assignments for the LHS and RHS in samples[i]. + std::vector predicted_assignments = assigner(samples[i]); + cout << "true labels: " << trans(vector_to_matrix(labels[i])); + cout << "predicted labels: " << trans(vector_to_matrix(predicted_assignments)) << endl; + } + + // We can also call this tool to compute the percentage of assignments predicted correctly. + cout << "training accuracy: " << test_assignment_function(assigner, samples, labels) << endl; + + + // Since testing on your training data is a really bad idea, we can also do 5-fold cross validation. + // Happily, this also indicates that all associations were made correctly. + randomize_samples(samples, labels); + cout << "cv accuracy: " << cross_validate_assignment_trainer(trainer, samples, labels, 5) << endl; + + + + // Finally, the assigner can be serialized to disk just like most dlib objects. + ofstream fout("assigner.dat", ios::binary); + serialize(assigner, fout); + fout.close(); + + // recall from disk + ifstream fin("assigner.dat", ios::binary); + deserialize(assigner, fin); + } + catch (std::exception& e) + { + cout << "EXCEPTION THROWN" << endl; + cout << e.what() << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +void make_data ( + std::vector& samples, + std::vector& labels +) +{ + // Make four different vectors. We will use them to make example assignments. + column_vector A(num_dims), B(num_dims), C(num_dims), D(num_dims); + A = 1,0,0; + B = 0,1,0; + C = 0,0,1; + D = 0,1,1; + + std::vector lhs; + std::vector rhs; + label_type mapping; + + // In all the assignments to follow, we will only say an element of LHS + // matches an element of RHS if the two are equal. So A matches with A, + // B with B, etc. But never A with C, for example. + // ------------------------ + + lhs.resize(3); + lhs[0] = A; + lhs[1] = B; + lhs[2] = C; + + rhs.resize(3); + rhs[0] = B; + rhs[1] = A; + rhs[2] = C; + + mapping.resize(3); + mapping[0] = 1; // lhs[0] matches rhs[1] + mapping[1] = 0; // lhs[1] matches rhs[0] + mapping[2] = 2; // lhs[2] matches rhs[2] + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs[0] = C; + lhs[1] = A; + lhs[2] = B; + + rhs[0] = A; + rhs[1] = B; + rhs[2] = D; + + mapping[0] = -1; // The -1 indicates that lhs[0] doesn't match anything in rhs. + mapping[1] = 0; // lhs[1] matches rhs[0] + mapping[2] = 1; // lhs[2] matches rhs[1] + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs[0] = A; + lhs[1] = B; + lhs[2] = C; + + rhs.resize(4); + rhs[0] = C; + rhs[1] = B; + rhs[2] = A; + rhs[3] = D; + + mapping[0] = 2; + mapping[1] = 1; + mapping[2] = 0; + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs.resize(2); + lhs[0] = B; + lhs[1] = C; + + rhs.resize(3); + rhs[0] = C; + rhs[1] = A; + rhs[2] = D; + + mapping.resize(2); + mapping[0] = -1; + mapping[1] = 0; + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs.resize(3); + lhs[0] = D; + lhs[1] = B; + lhs[2] = C; + + // rhs will be empty. So none of the items in lhs can match anything. + rhs.resize(0); + + mapping.resize(3); + mapping[0] = -1; + mapping[1] = -1; + mapping[2] = -1; + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + +} +