mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Cleaned up python svm struct code a little.
This commit is contained in:
parent
d0a054f15a
commit
cc9ff97a29
@ -1,7 +1,10 @@
|
||||
#!/usr/bin/python
|
||||
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
#
|
||||
#
|
||||
# This is an example illustrating the use of the structural SVM solver from the dlib C++
|
||||
# Library. This example will briefly introduce it and then walk through an example showing
|
||||
# how to use it to create a simple multi-class classifier.
|
||||
#
|
||||
#
|
||||
# COMPILING THE DLIB PYTHON INTERFACE
|
||||
# Dlib comes with a compiled python interface for python 2.7 on MS Windows. If
|
||||
@ -15,6 +18,7 @@
|
||||
import dlib
|
||||
|
||||
def dot(a, b):
|
||||
"Compute the dot product between the two vectors a and b."
|
||||
return sum(i*j for i,j in zip(a,b))
|
||||
|
||||
|
||||
@ -23,30 +27,35 @@ class three_class_classifier_problem:
|
||||
be_verbose = True
|
||||
epsilon = 0.0001
|
||||
|
||||
|
||||
def __init__(self, samples, labels):
|
||||
self.num_samples = len(samples)
|
||||
self.num_dimensions = len(samples[0])*3
|
||||
self.samples = samples
|
||||
self.labels = labels
|
||||
|
||||
def make_psi(self, psi, vector, label):
|
||||
|
||||
def make_psi(self, vector, label):
|
||||
psi = dlib.vector()
|
||||
psi.resize(self.num_dimensions)
|
||||
dims = len(vector)
|
||||
if (label == 1):
|
||||
if (label == 0):
|
||||
for i in range(0,dims):
|
||||
psi[i] = vector[i]
|
||||
elif (label == 2):
|
||||
elif (label == 1):
|
||||
for i in range(dims,2*dims):
|
||||
psi[i] = vector[i-dims]
|
||||
else:
|
||||
else: # the label must be 2
|
||||
for i in range(2*dims,3*dims):
|
||||
psi[i] = vector[i-2*dims]
|
||||
return psi
|
||||
|
||||
|
||||
def get_truth_joint_feature_vector(self, idx, psi):
|
||||
self.make_psi(psi, self.samples[idx], self.labels[idx])
|
||||
def get_truth_joint_feature_vector(self, idx):
|
||||
return self.make_psi(self.samples[idx], self.labels[idx])
|
||||
|
||||
def separation_oracle(self, idx, current_solution, psi):
|
||||
|
||||
def separation_oracle(self, idx, current_solution):
|
||||
samp = samples[idx]
|
||||
dims = len(samp)
|
||||
scores = [0,0,0]
|
||||
@ -56,29 +65,28 @@ class three_class_classifier_problem:
|
||||
scores[2] = dot(current_solution[2*dims:3*dims], samp)
|
||||
|
||||
# Add in the loss-augmentation
|
||||
if (labels[idx] != 1):
|
||||
if (labels[idx] != 0):
|
||||
scores[0] += 1
|
||||
if (labels[idx] != 2):
|
||||
if (labels[idx] != 1):
|
||||
scores[1] += 1
|
||||
if (labels[idx] != 3):
|
||||
if (labels[idx] != 2):
|
||||
scores[2] += 1
|
||||
|
||||
|
||||
# Now figure out which classifier has the largest loss-augmented score.
|
||||
max_scoring_label = scores.index(max(scores))+1
|
||||
max_scoring_label = scores.index(max(scores))
|
||||
if (max_scoring_label == labels[idx]):
|
||||
loss = 0
|
||||
else:
|
||||
loss = 1
|
||||
|
||||
self.make_psi(psi, samp, max_scoring_label)
|
||||
psi = self.make_psi(samp, max_scoring_label)
|
||||
|
||||
return loss
|
||||
return loss,psi
|
||||
|
||||
|
||||
|
||||
samples = [ [0,0,1], [0,1,0], [1,0,0]];
|
||||
labels = [1, 2, 3]
|
||||
samples = [[0,0,1], [0,1,0], [1,0,0]];
|
||||
labels = [0,1,2]
|
||||
|
||||
problem = three_class_classifier_problem(samples, labels)
|
||||
weights = dlib.solve_structural_svm_problem(problem)
|
||||
|
@ -37,7 +37,7 @@ public:
|
||||
feature_vector_type& psi
|
||||
) const
|
||||
{
|
||||
problem.attr("get_truth_joint_feature_vector")(idx,boost::ref(psi));
|
||||
psi = extract<feature_vector_type&>(problem.attr("get_truth_joint_feature_vector")(idx));
|
||||
}
|
||||
|
||||
virtual void separation_oracle (
|
||||
@ -47,7 +47,19 @@ public:
|
||||
feature_vector_type& psi
|
||||
) const
|
||||
{
|
||||
loss = extract<double>(problem.attr("separation_oracle")(idx,boost::ref(current_solution),boost::ref(psi)));
|
||||
object res = problem.attr("separation_oracle")(idx,boost::ref(current_solution));
|
||||
pyassert(len(res) == 2, "separation_oracle() must return two objects, the loss and the psi vector");
|
||||
// let the user supply the output arguments in any order.
|
||||
if (extract<double>(res[0]).check())
|
||||
{
|
||||
loss = extract<double>(res[0]);
|
||||
psi = extract<feature_vector_type&>(res[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
psi = extract<feature_vector_type&>(res[0]);
|
||||
loss = extract<double>(res[1]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user