mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Cleaned up example
This commit is contained in:
parent
539b416c48
commit
f28d2f7329
@ -4,7 +4,26 @@
|
|||||||
dlib C++ Library. In it, we will show how to use the loss_metric layer to do
|
dlib C++ Library. In it, we will show how to use the loss_metric layer to do
|
||||||
metric learning.
|
metric learning.
|
||||||
|
|
||||||
|
The main reason you might want to use this kind of algorithm is because you
|
||||||
|
would like to use a k-nearest neighbor classifier or similar algorithm, but
|
||||||
|
you don't know a good way to calculate the distance between two things. A
|
||||||
|
popular example would be face recognition. There are a whole lot of papers
|
||||||
|
that train some kind of deep metric learning algorithm that embeds face
|
||||||
|
images in some vector space where images of the same person are close to each
|
||||||
|
other and images of different people are far apart. Then in that vector
|
||||||
|
space it's very easy to do face recognition with some kind of k-nearest
|
||||||
|
neighbor classifier.
|
||||||
|
|
||||||
|
|
||||||
|
To keep this example as simple as possible we won't do face recognition.
|
||||||
|
Instead, we will create a very simple network and use it to learn a mapping
|
||||||
|
from 8D vectors to 2D vectors such that vectors with the same class labels
|
||||||
|
are near each other. If you want to see a more complex example that learns
|
||||||
|
the kind of network you would use for something like face recognition read
|
||||||
|
the dnn_metric_learning_on_images_ex.cpp example.
|
||||||
|
|
||||||
|
You should also have read the examples that introduce the dlib DNN API before
|
||||||
|
continuing. These are dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
@ -17,39 +36,52 @@ using namespace dlib;
|
|||||||
|
|
||||||
int main() try
|
int main() try
|
||||||
{
|
{
|
||||||
using net_type = loss_metric<fc<2,input<matrix<double,0,1>>>>;
|
// The API for doing metric learning is very similar to the API for
|
||||||
|
// multi-class classification. In fact, the inputs are the same, a bunch of
|
||||||
net_type net;
|
// labeled objects. So here we create our dataset. We make up some simple
|
||||||
dnn_trainer<net_type> trainer(net);
|
// vectors and label them with the integers 1,2,3,4. The specific values of
|
||||||
trainer.set_learning_rate(0.1);
|
// the integer labels don't matter.
|
||||||
trainer.set_min_learning_rate(0.00001);
|
|
||||||
trainer.set_mini_batch_size(128);
|
|
||||||
trainer.be_verbose();
|
|
||||||
trainer.set_iterations_without_progress_threshold(100);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<matrix<double,0,1>> samples;
|
std::vector<matrix<double,0,1>> samples;
|
||||||
std::vector<unsigned long> labels;
|
std::vector<unsigned long> labels;
|
||||||
|
|
||||||
|
// class 1 training vectors
|
||||||
samples.push_back({1,0,0,0,0,0,0,0}); labels.push_back(1);
|
samples.push_back({1,0,0,0,0,0,0,0}); labels.push_back(1);
|
||||||
samples.push_back({0,1,0,0,0,0,0,0}); labels.push_back(1);
|
samples.push_back({0,1,0,0,0,0,0,0}); labels.push_back(1);
|
||||||
|
|
||||||
|
// class 2 training vectors
|
||||||
samples.push_back({0,0,1,0,0,0,0,0}); labels.push_back(2);
|
samples.push_back({0,0,1,0,0,0,0,0}); labels.push_back(2);
|
||||||
samples.push_back({0,0,0,1,0,0,0,0}); labels.push_back(2);
|
samples.push_back({0,0,0,1,0,0,0,0}); labels.push_back(2);
|
||||||
|
|
||||||
|
// class 3 training vectors
|
||||||
samples.push_back({0,0,0,0,1,0,0,0}); labels.push_back(3);
|
samples.push_back({0,0,0,0,1,0,0,0}); labels.push_back(3);
|
||||||
samples.push_back({0,0,0,0,0,1,0,0}); labels.push_back(3);
|
samples.push_back({0,0,0,0,0,1,0,0}); labels.push_back(3);
|
||||||
|
|
||||||
|
// class 4 training vectors
|
||||||
samples.push_back({0,0,0,0,0,0,1,0}); labels.push_back(4);
|
samples.push_back({0,0,0,0,0,0,1,0}); labels.push_back(4);
|
||||||
samples.push_back({0,0,0,0,0,0,0,1}); labels.push_back(4);
|
samples.push_back({0,0,0,0,0,0,0,1}); labels.push_back(4);
|
||||||
|
|
||||||
|
|
||||||
|
// Make a network that simply learns a linear mapping from 8D vectors to 2D
|
||||||
|
// vectors.
|
||||||
|
using net_type = loss_metric<fc<2,input<matrix<double,0,1>>>>;
|
||||||
|
net_type net;
|
||||||
|
// Now setup the trainer and train the network using our data.
|
||||||
|
dnn_trainer<net_type> trainer(net);
|
||||||
|
trainer.set_learning_rate(0.1);
|
||||||
|
trainer.set_min_learning_rate(0.001);
|
||||||
|
trainer.set_mini_batch_size(128);
|
||||||
|
trainer.be_verbose();
|
||||||
|
trainer.set_iterations_without_progress_threshold(100);
|
||||||
trainer.train(samples, labels);
|
trainer.train(samples, labels);
|
||||||
|
|
||||||
|
|
||||||
// Run all the images through the network to get their vector embeddings.
|
|
||||||
std::vector<matrix<float,0,1>> embedded = net(images);
|
|
||||||
|
|
||||||
|
// Run all the samples through the network to get their 2D vector embeddings.
|
||||||
|
std::vector<matrix<float,0,1>> embedded = net(samples);
|
||||||
|
|
||||||
|
// Print the embedding for each sample to the screen. If you look at the
|
||||||
|
// outputs carefully you should notice that they are grouped together in 2D
|
||||||
|
// space according to their label.
|
||||||
for (size_t i = 0; i < embedded.size(); ++i)
|
for (size_t i = 0; i < embedded.size(); ++i)
|
||||||
cout << "label: " << labels[i] << "\t" << trans(embedded[i]);
|
cout << "label: " << labels[i] << "\t" << trans(embedded[i]);
|
||||||
|
|
||||||
@ -65,7 +97,8 @@ int main() try
|
|||||||
{
|
{
|
||||||
// The loss_metric layer will cause things with the same label to be less
|
// The loss_metric layer will cause things with the same label to be less
|
||||||
// than net.loss_details().get_distance_threshold() distance from each
|
// than net.loss_details().get_distance_threshold() distance from each
|
||||||
// other. So we can use that distance value as our testing threshold.
|
// other. So we can use that distance value as our testing threshold for
|
||||||
|
// "being near to each other".
|
||||||
if (length(embedded[i]-embedded[j]) < net.loss_details().get_distance_threshold())
|
if (length(embedded[i]-embedded[j]) < net.loss_details().get_distance_threshold())
|
||||||
++num_right;
|
++num_right;
|
||||||
else
|
else
|
||||||
@ -83,8 +116,6 @@ int main() try
|
|||||||
|
|
||||||
cout << "num_right: "<< num_right << endl;
|
cout << "num_right: "<< num_right << endl;
|
||||||
cout << "num_wrong: "<< num_wrong << endl;
|
cout << "num_wrong: "<< num_wrong << endl;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
catch(std::exception& e)
|
catch(std::exception& e)
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user