mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added a function that lets you test and train at the same time
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402251
This commit is contained in:
parent
230b8b9509
commit
43946fccac
@ -105,10 +105,74 @@ namespace dlib
|
||||
return std::sqrt(kernel(x,x) + bias - 2*temp);
|
||||
}
|
||||
|
||||
scalar_type test_and_train (
|
||||
const sample_type& x
|
||||
)
|
||||
{
|
||||
return train_and_maybe_test(x,true);
|
||||
}
|
||||
|
||||
void train (
|
||||
const sample_type& x
|
||||
)
|
||||
{
|
||||
train_and_maybe_test(x,false);
|
||||
}
|
||||
|
||||
void swap (
|
||||
one_class& item
|
||||
)
|
||||
{
|
||||
exchange(kernel, item.kernel);
|
||||
dictionary.swap(item.dictionary);
|
||||
alpha.swap(item.alpha);
|
||||
K_inv.swap(item.K_inv);
|
||||
K.swap(item.K);
|
||||
exchange(tolerance, item.tolerance);
|
||||
exchange(samples_seen, item.samples_seen);
|
||||
exchange(bias, item.bias);
|
||||
exchange(max_dis, item.max_dis);
|
||||
a.swap(item.a);
|
||||
k.swap(item.k);
|
||||
}
|
||||
|
||||
unsigned long dictionary_size (
|
||||
) const { return dictionary.size(); }
|
||||
|
||||
friend void serialize(const one_class& item, std::ostream& out)
|
||||
{
|
||||
serialize(item.kernel, out);
|
||||
serialize(item.dictionary, out);
|
||||
serialize(item.alpha, out);
|
||||
serialize(item.K_inv, out);
|
||||
serialize(item.K, out);
|
||||
serialize(item.tolerance, out);
|
||||
serialize(item.samples_seen, out);
|
||||
serialize(item.bias, out);
|
||||
serialize(item.max_dis, out);
|
||||
}
|
||||
|
||||
friend void deserialize(one_class& item, std::istream& in)
|
||||
{
|
||||
deserialize(item.kernel, in);
|
||||
deserialize(item.dictionary, in);
|
||||
deserialize(item.alpha, in);
|
||||
deserialize(item.K_inv, in);
|
||||
deserialize(item.K, in);
|
||||
deserialize(item.tolerance, in);
|
||||
deserialize(item.samples_seen, in);
|
||||
deserialize(item.bias, in);
|
||||
deserialize(item.max_dis, in);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
scalar_type train_and_maybe_test (
|
||||
const sample_type& x,
|
||||
bool do_test
|
||||
)
|
||||
{
|
||||
scalar_type test_result = 0;
|
||||
const scalar_type kx = kernel(x,x);
|
||||
if (alpha.size() == 0)
|
||||
{
|
||||
@ -129,6 +193,11 @@ namespace dlib
|
||||
for (long r = 0; r < k.nr(); ++r)
|
||||
k(r) = kernel(x,dictionary[r]);
|
||||
|
||||
if (do_test)
|
||||
{
|
||||
test_result = std::sqrt(kx + bias - 2*trans(vector_to_matrix(alpha))*k);
|
||||
}
|
||||
|
||||
// compute the error we would have if we approximated the new x sample
|
||||
// with the dictionary. That is, do the ALD test from the KRLS paper.
|
||||
a = K_inv*k;
|
||||
@ -214,56 +283,10 @@ namespace dlib
|
||||
|
||||
if (samples_seen > max_dis)
|
||||
samples_seen = max_dis;
|
||||
|
||||
return test_result;
|
||||
}
|
||||
|
||||
void swap (
|
||||
one_class& item
|
||||
)
|
||||
{
|
||||
exchange(kernel, item.kernel);
|
||||
dictionary.swap(item.dictionary);
|
||||
alpha.swap(item.alpha);
|
||||
K_inv.swap(item.K_inv);
|
||||
K.swap(item.K);
|
||||
exchange(tolerance, item.tolerance);
|
||||
exchange(samples_seen, item.samples_seen);
|
||||
exchange(bias, item.bias);
|
||||
exchange(max_dis, item.max_dis);
|
||||
a.swap(item.a);
|
||||
k.swap(item.k);
|
||||
}
|
||||
|
||||
unsigned long dictionary_size (
|
||||
) const { return dictionary.size(); }
|
||||
|
||||
friend void serialize(const one_class& item, std::ostream& out)
|
||||
{
|
||||
serialize(item.kernel, out);
|
||||
serialize(item.dictionary, out);
|
||||
serialize(item.alpha, out);
|
||||
serialize(item.K_inv, out);
|
||||
serialize(item.K, out);
|
||||
serialize(item.tolerance, out);
|
||||
serialize(item.samples_seen, out);
|
||||
serialize(item.bias, out);
|
||||
serialize(item.max_dis, out);
|
||||
}
|
||||
|
||||
friend void deserialize(one_class& item, std::istream& in)
|
||||
{
|
||||
deserialize(item.kernel, in);
|
||||
deserialize(item.dictionary, in);
|
||||
deserialize(item.alpha, in);
|
||||
deserialize(item.K_inv, in);
|
||||
deserialize(item.K, in);
|
||||
deserialize(item.tolerance, in);
|
||||
deserialize(item.samples_seen, in);
|
||||
deserialize(item.bias, in);
|
||||
deserialize(item.max_dis, in);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
|
||||
typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
|
||||
typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;
|
||||
|
@ -125,6 +125,18 @@ namespace dlib
|
||||
to this object so far.
|
||||
!*/
|
||||
|
||||
scalar_type test_and_train (
|
||||
const sample_type& x
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- calls train(x)
|
||||
- returns (*this)(x)
|
||||
- The reason this function exists is because train() and operator()
|
||||
both compute some of the same things. So this function is more efficient
|
||||
than calling both individually.
|
||||
!*/
|
||||
|
||||
void train (
|
||||
const sample_type& x
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user