2009-10-17 02:05:58 +08:00
|
|
|
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
|
|
|
/*
|
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
This is an example that shows how you can perform model selection with the
|
|
|
|
dlib C++ Library.
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
It will create a simple dataset and show you how to use cross validation and
|
|
|
|
global optimization to determine good parameters for the purpose of training
|
|
|
|
an svm to classify the data.
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
The data used in this example will be 2 dimensional data and will come from a
|
|
|
|
distribution where points with a distance less than 10 from the origin are
|
|
|
|
labeled +1 and all other points are labeled as -1.
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2009-10-19 10:43:41 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
As an side, you should probably read the svm_ex.cpp and matrix_ex.cpp example
|
2009-10-19 10:43:41 +08:00
|
|
|
programs before you read this one.
|
2009-10-17 02:05:58 +08:00
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
#include <iostream>
|
2012-12-08 22:32:13 +08:00
|
|
|
#include <dlib/svm.h>
|
2017-11-25 11:04:25 +08:00
|
|
|
#include <dlib/global_optimization.h>
|
2009-10-17 02:05:58 +08:00
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
using namespace dlib;
|
|
|
|
|
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
int main() try
|
|
|
|
{
|
|
|
|
// The svm functions use column vectors to contain a lot of the data on which they
|
|
|
|
// operate. So the first thing we do here is declare a convenient typedef.
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
// This typedef declares a matrix with 2 rows and 1 column. It will be the
|
|
|
|
// object that contains each of our 2 dimensional samples.
|
|
|
|
typedef matrix<double, 2, 1> sample_type;
|
2009-10-17 02:05:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
// Now we make objects to contain our samples and their respective labels.
|
|
|
|
std::vector<sample_type> samples;
|
|
|
|
std::vector<double> labels;
|
2009-10-19 10:43:41 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
// Now let's put some data into our samples and labels objects. We do this
|
|
|
|
// by looping over a bunch of points and labeling them according to their
|
|
|
|
// distance from the origin.
|
|
|
|
for (double r = -20; r <= 20; r += 0.8)
|
2009-10-17 02:05:58 +08:00
|
|
|
{
|
2017-11-25 11:04:25 +08:00
|
|
|
for (double c = -20; c <= 20; c += 0.8)
|
|
|
|
{
|
|
|
|
sample_type samp;
|
|
|
|
samp(0) = r;
|
|
|
|
samp(1) = c;
|
|
|
|
samples.push_back(samp);
|
|
|
|
|
|
|
|
// if this point is less than 10 from the origin
|
|
|
|
if (sqrt(r*r + c*c) <= 10)
|
|
|
|
labels.push_back(+1);
|
|
|
|
else
|
|
|
|
labels.push_back(-1);
|
|
|
|
}
|
|
|
|
}
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
cout << "Generated " << samples.size() << " points" << endl;
|
|
|
|
|
|
|
|
|
|
|
|
// Here we normalize all the samples by subtracting their mean and dividing by their
|
|
|
|
// standard deviation. This is generally a good idea since it often heads off
|
|
|
|
// numerical stability problems and also prevents one large feature from smothering
|
|
|
|
// others. Doing this doesn't matter much in this example so I'm just doing this here
|
|
|
|
// so you can see an easy way to accomplish this with the library.
|
|
|
|
vector_normalizer<sample_type> normalizer;
|
|
|
|
// let the normalizer learn the mean and standard deviation of the samples
|
|
|
|
normalizer.train(samples);
|
|
|
|
// now normalize each sample
|
|
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
|
|
samples[i] = normalizer(samples[i]);
|
|
|
|
|
|
|
|
|
|
|
|
// Now that we have some data we want to train on it. We are going to train a
|
2017-11-25 21:26:16 +08:00
|
|
|
// binary SVM with the RBF kernel to classify the data. However, there are
|
|
|
|
// three parameters to the training. These are the SVM C parameters for each
|
|
|
|
// class and the RBF kernel's gamma parameter. Our choice for these
|
|
|
|
// parameters will influence how good the resulting decision function is. To
|
|
|
|
// test how good a particular choice of these parameters is we can use the
|
2017-11-25 11:04:25 +08:00
|
|
|
// cross_validate_trainer() function to perform n-fold cross validation on our
|
2017-11-25 21:26:16 +08:00
|
|
|
// training data. However, there is a problem with the way we have sampled
|
|
|
|
// our distribution above. The problem is that there is a definite ordering
|
|
|
|
// to the samples. That is, the first half of the samples look like they are
|
|
|
|
// from a different distribution than the second half. This would screw up
|
|
|
|
// the cross validation process, but we can fix it by randomizing the order of
|
|
|
|
// the samples with the following function call.
|
2017-11-25 11:04:25 +08:00
|
|
|
randomize_samples(samples, labels);
|
|
|
|
|
|
|
|
|
|
|
|
// And now we get to the important bit. Here we define a function,
|
|
|
|
// cross_validation_score(), that will do the cross-validation we
|
|
|
|
// mentioned and return a number indicating how good a particular setting
|
2017-11-25 21:26:16 +08:00
|
|
|
// of gamma, c1, and c2 is.
|
|
|
|
auto cross_validation_score = [&](const double gamma, const double c1, const double c2)
|
2017-11-25 11:04:25 +08:00
|
|
|
{
|
|
|
|
// Make a RBF SVM trainer and tell it what the parameters are supposed to be.
|
|
|
|
typedef radial_basis_kernel<sample_type> kernel_type;
|
2017-11-25 21:26:16 +08:00
|
|
|
svm_c_trainer<kernel_type> trainer;
|
2009-10-17 02:05:58 +08:00
|
|
|
trainer.set_kernel(kernel_type(gamma));
|
2017-11-25 21:26:16 +08:00
|
|
|
trainer.set_c_class1(c1);
|
|
|
|
trainer.set_c_class2(c2);
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2009-10-19 10:43:41 +08:00
|
|
|
// Finally, perform 10-fold cross validation and then print and return the results.
|
2009-10-17 02:05:58 +08:00
|
|
|
matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
|
2017-11-25 21:26:16 +08:00
|
|
|
cout << "gamma: " << setw(11) << gamma << " c1: " << setw(11) << c1 << " c2: " << setw(11) << c2 << " cross validation accuracy: " << result;
|
2009-10-19 10:43:41 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
// Now return a number indicating how good the parameters are. Bigger is
|
|
|
|
// better in this example. Here I'm returning the harmonic mean between the
|
|
|
|
// accuracies of each class. However, you could do something else. For
|
|
|
|
// example, you might care a lot more about correctly predicting the +1 class,
|
|
|
|
// so you could penalize results that didn't obtain a high accuracy on that
|
|
|
|
// class. You might do this by using something like a weighted version of the
|
|
|
|
// F1-score (see http://en.wikipedia.org/wiki/F1_score).
|
2014-06-24 04:06:23 +08:00
|
|
|
return 2*prod(result)/sum(result);
|
2017-11-25 11:04:25 +08:00
|
|
|
};
|
2009-10-17 02:05:58 +08:00
|
|
|
|
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
// And finally, we call this global optimizer that will search for the best parameters.
|
2017-12-19 05:20:21 +08:00
|
|
|
// It will call cross_validation_score() 50 times with different settings and return
|
2017-11-25 11:04:25 +08:00
|
|
|
// the best parameter setting it finds. find_max_global() uses a global optimization
|
|
|
|
// method based on a combination of non-parametric global function modeling and
|
|
|
|
// quadratic trust region modeling to efficiently find a global maximizer. It usually
|
|
|
|
// does a good job with a relatively small number of calls to cross_validation_score().
|
|
|
|
// In this example, you should observe that it finds settings that give perfect binary
|
2017-11-25 21:26:16 +08:00
|
|
|
// classification of the data.
|
2017-11-25 11:04:25 +08:00
|
|
|
auto result = find_max_global(cross_validation_score,
|
2017-11-25 21:26:16 +08:00
|
|
|
{1e-5, 1e-5, 1e-5}, // lower bound constraints on gamma, c1, and c2, respectively
|
|
|
|
{100, 1e6, 1e6}, // upper bound constraints on gamma, c1, and c2, respectively
|
2017-12-19 05:20:21 +08:00
|
|
|
max_function_calls(50));
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
double best_gamma = result.x(0);
|
2017-11-25 21:26:16 +08:00
|
|
|
double best_c1 = result.x(1);
|
|
|
|
double best_c2 = result.x(2);
|
2009-10-17 02:05:58 +08:00
|
|
|
|
2017-11-25 11:04:25 +08:00
|
|
|
cout << " best cross-validation score: " << result.y << endl;
|
2017-11-25 21:26:16 +08:00
|
|
|
cout << " best gamma: " << best_gamma << " best c1: " << best_c1 << " best c2: "<< best_c2 << endl;
|
2017-11-25 11:04:25 +08:00
|
|
|
}
|
|
|
|
catch (exception& e)
|
|
|
|
{
|
|
|
|
cout << e.what() << endl;
|
2009-10-17 02:05:58 +08:00
|
|
|
}
|
|
|
|
|