mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added a multiclass example program
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404052
This commit is contained in:
parent
b4b7ed28bf
commit
d20baf6bba
@ -64,6 +64,7 @@ add_example(matrix_expressions_ex)
|
||||
add_example(member_function_pointer_ex)
|
||||
add_example(mlp_ex)
|
||||
add_example(model_selection_ex)
|
||||
add_example(multiclass_classification_ex)
|
||||
add_example(multithreaded_object_ex)
|
||||
add_example(optimization_ex)
|
||||
add_example(pipe_ex)
|
||||
|
222
examples/multiclass_classification_ex.cpp
Normal file
222
examples/multiclass_classification_ex.cpp
Normal file
@ -0,0 +1,222 @@
|
||||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
This is an example illustrating the use of the multiclass classification tools
|
||||
from the dlib C++ Library. Specifically, this example will make points from
|
||||
three classes and show you how to train a multiclass classifier to recognize
|
||||
these three classes.
|
||||
|
||||
The classes are as follows:
|
||||
- class 1: points very close to the origin
|
||||
- class 2: points on the circle of radius 10 around the origin
|
||||
- class 3: points that are on a circle of radius 4 but not around the origin at all
|
||||
*/
|
||||
|
||||
#include "dlib/svm.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "dlib/rand.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
// Our data will be 2-dimensional data. So declare an appropriate type to contain these points.
|
||||
typedef matrix<double,2,1> sample_type;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void generate_data (
|
||||
std::vector<sample_type>& samples,
|
||||
std::vector<double>& labels
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- make some 3 class data as described above.
|
||||
- Create 60 points from class 1
|
||||
- Create 70 points from class 2
|
||||
- Create 80 points from class 3
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main()
|
||||
{
|
||||
std::vector<sample_type> samples;
|
||||
std::vector<double> labels;
|
||||
|
||||
// First, get our labeled set of training data
|
||||
generate_data(samples, labels);
|
||||
|
||||
cout << "samples.size(): "<< samples.size() << endl;
|
||||
|
||||
// The main object in this example program is the one_vs_one_trainer. It is essentially
|
||||
// a container class for regular binary classifier trainer objects. In particular, it
|
||||
// uses the any_trainer object to store any kind of trainer object that implements a
|
||||
// .train(samples,labels) function which returns some kind of learned decision function.
|
||||
// It uses these binary classifiers to construct a voting multiclass classifier. If
|
||||
// there are N classes then it trains N*(N-1)/2 binary classifiers, one for each pair of
|
||||
// labels, which then vote on the label of a sample.
|
||||
//
|
||||
// In this example program we will work with a one_vs_one_trainer object which stores any
|
||||
// kind of trainer that uses our sample_type samples.
|
||||
typedef one_vs_one_trainer<any_trainer<sample_type> > ovo_trainer;
|
||||
|
||||
|
||||
// Finally, make the one_vs_one_trainer.
|
||||
ovo_trainer trainer;
|
||||
|
||||
|
||||
// Next, we will make two different binary classification trainer objects. One
|
||||
// which uses kernel ridge regression and RBF kernels and another which uses a
|
||||
// support vector machine and polynomial kernels. The particular details don't matter.
|
||||
// The point of this part of the example is that you can use any kind of trainer object
|
||||
// with the one_vs_one_trainer.
|
||||
typedef polynomial_kernel<sample_type> poly_kernel;
|
||||
typedef radial_basis_kernel<sample_type> rbf_kernel;
|
||||
|
||||
// make the binary trainers and set some parameters
|
||||
krr_trainer<rbf_kernel> rbf_trainer;
|
||||
svm_nu_trainer<poly_kernel> poly_trainer;
|
||||
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
|
||||
rbf_trainer.set_kernel(rbf_kernel(0.1));
|
||||
|
||||
|
||||
// Now tell the one_vs_one_trainer that, by default, it should use the rbf_trainer
|
||||
// to solve the individual binary classification subproblems.
|
||||
trainer.set_trainer(rbf_trainer);
|
||||
// We can also get more specific. Here we tell the one_vs_one_trainer to use the
|
||||
// poly_trainer to solve the class 1 vs class 2 subproblem. All the others will
|
||||
// still be solved with the rbf_trainer.
|
||||
trainer.set_trainer(poly_trainer, 1, 2);
|
||||
|
||||
// Now lets do 5-fold cross-validation using the one_vs_one_trainer we just setup.
|
||||
// As an aside, always shuffle the order of the samples before doing cross validation.
|
||||
// For a discussion of why this is a good idea see the svm_ex.cpp example.
|
||||
randomize_samples(samples, labels);
|
||||
cout << "cross validation: \n" << cross_validate_multiclass_trainer(trainer, samples, labels, 5) << endl;
|
||||
// The output is shown below. It is the confusion matrix which describes the results. Each row
|
||||
// corresponds to a class of data and each column to a prediction. Reading from top to bottom,
|
||||
// the rows correspond to the class labels if the labels have been listed in sorted order. So the
|
||||
// top row corresponds to class 1, the middle row to class 2, and the bottom row to class 3. The
|
||||
// columns are organized similarly, with the left most column showing how many samples were predicted
|
||||
// as members of class 1.
|
||||
//
|
||||
// So in the results below we can see that, for the class 1 samples, 60 of them were correctly predicted
|
||||
// to be members of class 1 and 0 were incorrectly classified. Similarly, the other two classes of data
|
||||
// are perfectly classified.
|
||||
/*
|
||||
cross validation:
|
||||
60 0 0
|
||||
0 70 0
|
||||
0 0 80
|
||||
*/
|
||||
|
||||
// Next, if you wanted to obtain the decision rule learned by a one_vs_one_trainer you
|
||||
// would store it into a one_vs_one_decision_function.
|
||||
one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
|
||||
|
||||
cout << "predicted label: "<< df(samples[0]) << ", true label: "<< labels[0] << endl;
|
||||
cout << "predicted label: "<< df(samples[90]) << ", true label: "<< labels[90] << endl;
|
||||
// The output is:
|
||||
/*
|
||||
predicted label: 2, true label: 2
|
||||
predicted label: 1, true label: 1
|
||||
*/
|
||||
|
||||
|
||||
// Finally, if you want to save a one_vs_one_decision_function to disk, you can do
|
||||
// so. However, you must declare what kind of decision functions it contains.
|
||||
one_vs_one_decision_function<ovo_trainer,
|
||||
decision_function<poly_kernel>, // This is the output of the poly_trainer
|
||||
decision_function<rbf_kernel> // This is the output of the rbf_trainer
|
||||
> df2, df3;
|
||||
|
||||
|
||||
// Put df into df2 and then save df2 to disk. Note that we could have also said
|
||||
// df2 = trainer.train(samples, labels); But doing it this way avoids retraining.
|
||||
df2 = df;
|
||||
ofstream fout("df.dat", ios::binary);
|
||||
serialize(df2, fout);
|
||||
fout.close();
|
||||
|
||||
// load the function back in from disk and store it in df3.
|
||||
ifstream fin("df.dat", ios::binary);
|
||||
deserialize(df3, fin);
|
||||
|
||||
|
||||
// Test df3 to see that this worked.
|
||||
cout << endl;
|
||||
cout << "predicted label: "<< df3(samples[0]) << ", true label: "<< labels[0] << endl;
|
||||
cout << "predicted label: "<< df3(samples[90]) << ", true label: "<< labels[90] << endl;
|
||||
// Test df3 on the samples and labels and print the confusion matrix.
|
||||
cout << "test deserialized function: \n" << test_multiclass_decision_function(df3, samples, labels) << endl;
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void generate_data (
|
||||
std::vector<sample_type>& samples,
|
||||
std::vector<double>& labels
|
||||
)
|
||||
{
|
||||
const long num = 50;
|
||||
|
||||
sample_type m;
|
||||
|
||||
dlib::rand::float_1a rnd;
|
||||
|
||||
|
||||
// make some samples near the origin
|
||||
double radius = 0.5;
|
||||
for (long i = 0; i < num+10; ++i)
|
||||
{
|
||||
double sign = 1;
|
||||
if (rnd.get_random_double() < 0.5)
|
||||
sign = -1;
|
||||
m(0) = 2*radius*rnd.get_random_double()-radius;
|
||||
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
|
||||
|
||||
// add this sample to our set of samples we will run k-means
|
||||
samples.push_back(m);
|
||||
labels.push_back(1);
|
||||
}
|
||||
|
||||
// make some samples in a circle around the origin but far away
|
||||
radius = 10.0;
|
||||
for (long i = 0; i < num+20; ++i)
|
||||
{
|
||||
double sign = 1;
|
||||
if (rnd.get_random_double() < 0.5)
|
||||
sign = -1;
|
||||
m(0) = 2*radius*rnd.get_random_double()-radius;
|
||||
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
|
||||
|
||||
// add this sample to our set of samples we will run k-means
|
||||
samples.push_back(m);
|
||||
labels.push_back(2);
|
||||
}
|
||||
|
||||
// make some samples in a circle around the point (25,25)
|
||||
radius = 4.0;
|
||||
for (long i = 0; i < num+30; ++i)
|
||||
{
|
||||
double sign = 1;
|
||||
if (rnd.get_random_double() < 0.5)
|
||||
sign = -1;
|
||||
m(0) = 2*radius*rnd.get_random_double()-radius;
|
||||
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
|
||||
|
||||
// translate this point away from the origin
|
||||
m(0) += 25;
|
||||
m(1) += 25;
|
||||
|
||||
// add this sample to our set of samples we will run k-means
|
||||
samples.push_back(m);
|
||||
labels.push_back(3);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue
Block a user