Added a function that lets you test and train at the same time

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402251
This commit is contained in:
Davis King 2008-05-23 00:01:44 +00:00
parent 230b8b9509
commit 43946fccac
2 changed files with 83 additions and 48 deletions

View File

@ -105,10 +105,74 @@ namespace dlib
return std::sqrt(kernel(x,x) + bias - 2*temp);
}
scalar_type test_and_train (
const sample_type& x
)
{
return train_and_maybe_test(x,true);
}
void train (
const sample_type& x
)
{
train_and_maybe_test(x,false);
}
void swap (
one_class& item
)
{
exchange(kernel, item.kernel);
dictionary.swap(item.dictionary);
alpha.swap(item.alpha);
K_inv.swap(item.K_inv);
K.swap(item.K);
exchange(tolerance, item.tolerance);
exchange(samples_seen, item.samples_seen);
exchange(bias, item.bias);
exchange(max_dis, item.max_dis);
a.swap(item.a);
k.swap(item.k);
}
unsigned long dictionary_size (
) const { return dictionary.size(); }
friend void serialize(const one_class& item, std::ostream& out)
{
serialize(item.kernel, out);
serialize(item.dictionary, out);
serialize(item.alpha, out);
serialize(item.K_inv, out);
serialize(item.K, out);
serialize(item.tolerance, out);
serialize(item.samples_seen, out);
serialize(item.bias, out);
serialize(item.max_dis, out);
}
friend void deserialize(one_class& item, std::istream& in)
{
deserialize(item.kernel, in);
deserialize(item.dictionary, in);
deserialize(item.alpha, in);
deserialize(item.K_inv, in);
deserialize(item.K, in);
deserialize(item.tolerance, in);
deserialize(item.samples_seen, in);
deserialize(item.bias, in);
deserialize(item.max_dis, in);
}
private:
scalar_type train_and_maybe_test (
const sample_type& x,
bool do_test
)
{
scalar_type test_result = 0;
const scalar_type kx = kernel(x,x);
if (alpha.size() == 0)
{
@ -129,6 +193,11 @@ namespace dlib
for (long r = 0; r < k.nr(); ++r)
k(r) = kernel(x,dictionary[r]);
if (do_test)
{
test_result = std::sqrt(kx + bias - 2*trans(vector_to_matrix(alpha))*k);
}
// compute the error we would have if we approximated the new x sample
// with the dictionary. That is, do the ALD test from the KRLS paper.
a = K_inv*k;
@ -214,56 +283,10 @@ namespace dlib
if (samples_seen > max_dis)
samples_seen = max_dis;
return test_result;
}
void swap (
one_class& item
)
{
exchange(kernel, item.kernel);
dictionary.swap(item.dictionary);
alpha.swap(item.alpha);
K_inv.swap(item.K_inv);
K.swap(item.K);
exchange(tolerance, item.tolerance);
exchange(samples_seen, item.samples_seen);
exchange(bias, item.bias);
exchange(max_dis, item.max_dis);
a.swap(item.a);
k.swap(item.k);
}
unsigned long dictionary_size (
) const { return dictionary.size(); }
friend void serialize(const one_class& item, std::ostream& out)
{
serialize(item.kernel, out);
serialize(item.dictionary, out);
serialize(item.alpha, out);
serialize(item.K_inv, out);
serialize(item.K, out);
serialize(item.tolerance, out);
serialize(item.samples_seen, out);
serialize(item.bias, out);
serialize(item.max_dis, out);
}
friend void deserialize(one_class& item, std::istream& in)
{
deserialize(item.kernel, in);
deserialize(item.dictionary, in);
deserialize(item.alpha, in);
deserialize(item.K_inv, in);
deserialize(item.K, in);
deserialize(item.tolerance, in);
deserialize(item.samples_seen, in);
deserialize(item.bias, in);
deserialize(item.max_dis, in);
}
private:
typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;

View File

@ -125,6 +125,18 @@ namespace dlib
to this object so far.
!*/
scalar_type test_and_train (
const sample_type& x
);
/*!
ensures
- calls train(x)
- returns (*this)(x)
- The reason this function exists is because train() and operator()
both compute some of the same things. So this function is more efficient
than calling both individually.
!*/
void train (
const sample_type& x
);