mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
ab8bfefaa0
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404251
281 lines
9.2 KiB
C++
281 lines
9.2 KiB
C++
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
|
/*
|
|
This example program shows you how to create your own custom binary classification
|
|
trainer object and use it with the multiclass classification tools in the dlib C++
|
|
library. This example assumes you have already become familiar with the concepts
|
|
introduced in the multiclass_classification_ex.cpp example program.
|
|
|
|
|
|
In this example we will create a very simple trainer object that takes a binary
|
|
classification problem and produces a decision rule which says a test point has the
|
|
same class as whichever centroid it is closest to.
|
|
|
|
The multiclass training dataset will consist of four classes. Each class will be a blob
|
|
of points in one of the quadrants of the cartesian plane. For fun, we will use
|
|
std::string labels and therefore the labels of these classes will be the following:
|
|
"upper_left",
|
|
"upper_right",
|
|
"lower_left",
|
|
"lower_right"
|
|
*/
|
|
|
|
#include "dlib/svm.h"
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
#include "dlib/rand.h"
|
|
|
|
using namespace std;
|
|
using namespace dlib;
|
|
|
|
// Our data will be 2-dimensional data. So declare an appropriate type to contain these points.
|
|
typedef matrix<double,2,1> sample_type;
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
struct custom_decision_function
|
|
{
|
|
/*!
|
|
WHAT THIS OBJECT REPRESENTS
|
|
This object is the representation of our binary decision rule.
|
|
!*/
|
|
|
|
// centers of the two classes
|
|
sample_type positive_center, negative_center;
|
|
|
|
double operator() (
|
|
const sample_type& x
|
|
) const
|
|
{
|
|
// if x is closer to the positive class then return +1
|
|
if (length(positive_center - x) < length(negative_center - x))
|
|
return +1;
|
|
else
|
|
return -1;
|
|
}
|
|
};
|
|
|
|
// Later on in this example we will save our decision functions to disk. This
|
|
// pair of routines is needed for this functionality.
|
|
void serialize (const custom_decision_function& item, std::ostream& out)
|
|
{
|
|
// write the state of item to the output stream
|
|
serialize(item.positive_center, out);
|
|
serialize(item.negative_center, out);
|
|
}
|
|
|
|
void deserialize (custom_decision_function& item, std::istream& in)
|
|
{
|
|
// read the data from the input stream and store it in item
|
|
deserialize(item.positive_center, in);
|
|
deserialize(item.negative_center, in);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class simple_custom_trainer
|
|
{
|
|
/*!
|
|
WHAT THIS OBJECT REPRESENTS
|
|
This is our example custom binary classifier trainer object. It simply
|
|
computes the means of the +1 and -1 classes, puts them into our
|
|
custom_decision_function, and returns the results.
|
|
|
|
Below we define the train() function. I have also included the
|
|
requires/ensures definition for a generic binary classifier's train()
|
|
!*/
|
|
public:
|
|
|
|
|
|
custom_decision_function train (
|
|
const std::vector<sample_type>& samples,
|
|
const std::vector<double>& labels
|
|
) const
|
|
/*!
|
|
requires
|
|
- is_binary_classification_problem(samples, labels) == true
|
|
(e.g. labels consists of only +1 and -1 values, samples.size() == labels.size())
|
|
ensures
|
|
- returns a decision function F with the following properties:
|
|
- if (new_x is a sample predicted have +1 label) then
|
|
- F(new_x) >= 0
|
|
- else
|
|
- F(new_x) < 0
|
|
!*/
|
|
{
|
|
sample_type positive_center, negative_center;
|
|
|
|
// compute sums of each class
|
|
positive_center = 0;
|
|
negative_center = 0;
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
{
|
|
if (labels[i] == +1)
|
|
positive_center += samples[i];
|
|
else // this is a -1 sample
|
|
negative_center += samples[i];
|
|
}
|
|
|
|
// divide by number of +1 samples
|
|
positive_center /= sum(vector_to_matrix(labels) == +1);
|
|
// divide by number of -1 samples
|
|
negative_center /= sum(vector_to_matrix(labels) == -1);
|
|
|
|
custom_decision_function df;
|
|
df.positive_center = positive_center;
|
|
df.negative_center = negative_center;
|
|
|
|
return df;
|
|
}
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void generate_data (
|
|
std::vector<sample_type>& samples,
|
|
std::vector<string>& labels
|
|
);
|
|
/*!
|
|
ensures
|
|
- make some four class data as described above.
|
|
- each class will have 50 samples in it
|
|
!*/
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
int main()
|
|
{
|
|
std::vector<sample_type> samples;
|
|
std::vector<string> labels;
|
|
|
|
// First, get our labeled set of training data
|
|
generate_data(samples, labels);
|
|
|
|
cout << "samples.size(): "<< samples.size() << endl;
|
|
|
|
// Define the trainer we will use. The second template argument specifies the type
|
|
// of label used, which is string in this case.
|
|
typedef one_vs_one_trainer<any_trainer<sample_type>, string> ovo_trainer;
|
|
|
|
|
|
ovo_trainer trainer;
|
|
|
|
// Now tell the one_vs_one_trainer that, by default, it should use the simple_custom_trainer
|
|
// to solve the individual binary classification subproblems.
|
|
trainer.set_trainer(simple_custom_trainer());
|
|
|
|
// Next, to make things a little more interesting, we will setup the one_vs_one_trainer
|
|
// to use kernel ridge regression to solve the upper_left vs lower_right binary classification
|
|
// subproblem.
|
|
typedef radial_basis_kernel<sample_type> rbf_kernel;
|
|
krr_trainer<rbf_kernel> rbf_trainer;
|
|
rbf_trainer.set_kernel(rbf_kernel(0.1));
|
|
trainer.set_trainer(rbf_trainer, "upper_left", "lower_right");
|
|
|
|
|
|
// Now lets do 5-fold cross-validation using the one_vs_one_trainer we just setup.
|
|
// As an aside, always shuffle the order of the samples before doing cross validation.
|
|
// For a discussion of why this is a good idea see the svm_ex.cpp example.
|
|
randomize_samples(samples, labels);
|
|
cout << "cross validation: \n" << cross_validate_multiclass_trainer(trainer, samples, labels, 5) << endl;
|
|
// This dataset is very easy and everything is correctly classified. Therefore, the output of
|
|
// cross validation is the following confusion matrix.
|
|
/*
|
|
50 0 0 0
|
|
0 50 0 0
|
|
0 0 50 0
|
|
0 0 0 50
|
|
*/
|
|
|
|
|
|
// We can also obtain the decision rule as always.
|
|
one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
|
|
|
|
cout << "predicted label: "<< df(samples[0]) << ", true label: "<< labels[0] << endl;
|
|
cout << "predicted label: "<< df(samples[90]) << ", true label: "<< labels[90] << endl;
|
|
// The output is:
|
|
/*
|
|
predicted label: upper_right, true label: upper_right
|
|
predicted label: lower_left, true label: lower_left
|
|
*/
|
|
|
|
|
|
// Finally, lets save our multiclass decision rule to disk. Remember that we have
|
|
// to specify the types of binary decision function used inside the one_vs_one_decision_function.
|
|
one_vs_one_decision_function<ovo_trainer,
|
|
custom_decision_function, // This is the output of the simple_custom_trainer
|
|
decision_function<radial_basis_kernel<sample_type> > // This is the output of the rbf_trainer
|
|
> df2, df3;
|
|
|
|
|
|
df2 = df;
|
|
ofstream fout("df.dat", ios::binary);
|
|
serialize(df2, fout);
|
|
fout.close();
|
|
|
|
// load the function back in from disk and store it in df3.
|
|
ifstream fin("df.dat", ios::binary);
|
|
deserialize(df3, fin);
|
|
|
|
|
|
// Test df3 to see that this worked.
|
|
cout << endl;
|
|
cout << "predicted label: "<< df3(samples[0]) << ", true label: "<< labels[0] << endl;
|
|
cout << "predicted label: "<< df3(samples[90]) << ", true label: "<< labels[90] << endl;
|
|
// Test df3 on the samples and labels and print the confusion matrix.
|
|
cout << "test deserialized function: \n" << test_multiclass_decision_function(df3, samples, labels) << endl;
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void generate_data (
|
|
std::vector<sample_type>& samples,
|
|
std::vector<string>& labels
|
|
)
|
|
{
|
|
const long num = 50;
|
|
|
|
sample_type m;
|
|
|
|
dlib::rand rnd;
|
|
|
|
|
|
// add some points in the upper right quadrant
|
|
m = 10, 10;
|
|
for (long i = 0; i < num; ++i)
|
|
{
|
|
samples.push_back(m + randm(2,1,rnd));
|
|
labels.push_back("upper_right");
|
|
}
|
|
|
|
// add some points in the upper left quadrant
|
|
m = -10, 10;
|
|
for (long i = 0; i < num; ++i)
|
|
{
|
|
samples.push_back(m + randm(2,1,rnd));
|
|
labels.push_back("upper_left");
|
|
}
|
|
|
|
// add some points in the lower right quadrant
|
|
m = 10, -10;
|
|
for (long i = 0; i < num; ++i)
|
|
{
|
|
samples.push_back(m + randm(2,1,rnd));
|
|
labels.push_back("lower_right");
|
|
}
|
|
|
|
// add some points in the lower left quadrant
|
|
m = -10, -10;
|
|
for (long i = 0; i < num; ++i)
|
|
{
|
|
samples.push_back(m + randm(2,1,rnd));
|
|
labels.push_back("lower_left");
|
|
}
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|