Grouping layer added

pull/114/head
Fm 8 years ago
parent 617ffba652
commit 28c4a48281

@ -568,7 +568,7 @@ namespace dlib
struct is_nonloss_layer_type<add_layer<T,U>> : std::true_type {};
template <typename LAYER_DETAILS, typename SUBNET>
class add_layer<LAYER_DETAILS,SUBNET,
class add_layer<LAYER_DETAILS,SUBNET,
typename std::enable_if<is_nonloss_layer_type<SUBNET>::value>::type>
{
public:
@ -3159,6 +3159,499 @@ namespace dlib
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T>
struct group_helper;
template<typename... R>
struct group_count_helper;
}
// --------------------------------------------------------------------------------------
// this class is used to reference group layer input
class group_input
{
public:
typedef tensor input_type;
const static unsigned int sample_expansion_factor = 1;
friend void serialize(const group_input& item, std::ostream& out)
{
serialize("group_input", out);
}
friend void deserialize(group_input& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "group_input")
throw serialization_error("Unexpected version found while deserializing dlib::group_input.");
}
friend std::ostream& operator<<(std::ostream& out, const group_input& item)
{
out << "group_input";
return out;
}
};
// --------------------------------------------------------------------------------------
template <typename GRP, typename SUBNET>
class depth_group;
template <typename T, typename U>
struct is_nonloss_layer_type<depth_group<T,U>> : std::true_type {};
template <typename GRP, typename SUBNET>
class depth_group
{
public:
typedef GRP grp_type;
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
const static size_t group_size = std::tuple_size<grp_type>::value;
const static size_t num_layers_in_group = impl::group_count_helper<GRP>::num_layers;
const static size_t num_layers = subnet_type::num_layers + num_layers_in_group;
const static size_t num_computational_layers_in_group = impl::group_count_helper<GRP>::num_computational_layers;
const static size_t num_computational_layers = subnet_type::num_computational_layers + num_computational_layers_in_group;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
using group_helper = impl::group_helper<grp_type>;
depth_group(
):
subnetwork(new subnet_type()),
grp(new grp_type()),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false)
{
}
depth_group(const depth_group& item)
{
grp.reset(new grp_type(*item.grp));
subnetwork.reset(new subnet_type(*item.subnetwork));
gradient_input_is_stale = item.gradient_input_is_stale;
get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled;
x_grad = item.x_grad;
cached_output = item.cached_output;
temp_tensor = item.temp_tensor;
}
depth_group& operator=(const depth_group& item) { depth_group(item).swap(*this); return *this;}
depth_group(depth_group&& item) : depth_group() { swap(item); }
depth_group& operator=(depth_group&& item) { swap(item); return *this; }
template <typename T, typename U, typename E>
friend class add_layer;
template <typename T, bool is_first, typename E>
friend class dimpl::subnet_wrapper;
template <unsigned long T, typename U, typename E>
friend class add_tag_layer;
template <template<typename> class T, typename U>
friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
// Allow copying networks from one to another as long as their corresponding
// layers can be constructed from each other.
template <typename T, typename U>
depth_group(
const depth_group<T,U>& item
) :
grp(new grp_type(item.detail())),
subnetwork(new subnet_type(item.subnet())),
gradient_input_is_stale(item.gradient_input_is_stale),
get_output_and_gradient_input_disabled(item.get_output_and_gradient_input_disabled),
x_grad(item.x_grad),
cached_output(item.cached_output)
{
}
template <typename input_iterator>
void to_tensor (
input_iterator ibegin,
input_iterator iend,
resizable_tensor& data
) const
{
subnetwork->to_tensor(ibegin,iend,data);
}
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
)
{
to_tensor(ibegin,iend,temp_tensor);
return forward(temp_tensor);
}
const tensor& operator() (const input_type& x)
{
return (*this)(&x, &x+1);
}
// forward for group: subnet->for_each_in_group->concat->cached_output
const tensor& forward(const tensor& x)
{
subnetwork->forward(x);
long group_depth = 0;
group_helper::forward(subnetwork->get_output(), detail(), group_depth);
auto& out_0 = std::get<0>(detail()).get_output();
cached_output.set_size(out_0.num_samples(), group_depth, out_0.nr(), out_0.nc());
group_helper::concat(cached_output, detail());
gradient_input_is_stale = true;
return private_get_output();
}
private:
bool this_layer_requires_forward_output(
)
{
return true;
}
tensor& private_get_output() const
{
return const_cast<resizable_tensor&>(cached_output);
}
tensor& private_get_gradient_input()
{
if (gradient_input_is_stale)
{
gradient_input_is_stale = false;
x_grad.copy_size(private_get_output());
x_grad = 0;
}
return x_grad;
}
void disable_output_and_gradient_getters (
) { get_output_and_gradient_input_disabled = true; }
public:
const tensor& get_output() const
{
if (get_output_and_gradient_input_disabled)
throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it.");
return private_get_output();
}
tensor& get_gradient_input()
{
if (get_output_and_gradient_input_disabled)
throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it.");
return private_get_gradient_input();
}
const tensor& get_final_data_gradient(
) const { return subnetwork->get_final_data_gradient(); }
void back_propagate_error(const tensor& x)
{
back_propagate_error(x, private_get_gradient_input());
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
{
group_helper::backward(detail(), get_gradient_input(), subnetwork->get_output(), subnetwork->get_gradient_input());
subnetwork->back_propagate_error(x);
// zero out get_gradient_input()
gradient_input_is_stale = true;
}
template <typename solver_type>
void update_parameters(sstack<solver_type> solvers, double step_size)
{
DLIB_CASSERT(solvers.size()>=num_computational_layers,"");
group_helper::update_parameters(solvers, step_size, detail());
solvers = solvers.pop(num_computational_layers_in_group);
subnetwork->update_parameters(solvers, step_size);
}
const subnet_type& subnet() const { return *subnetwork; }
subnet_type& subnet() { return *subnetwork; }
const grp_type& detail() const { return *grp; }
grp_type& detail() { return *grp; }
void clean()
{
x_grad.clear();
cached_output.clear();
temp_tensor.clear();
gradient_input_is_stale = true;
subnetwork->clean();
}
friend void serialize(const depth_group& item, std::ostream& out)
{
int version = 2;
serialize(version, out);
serialize(*item.subnetwork, out);
group_helper::serialize(*item.grp, out);
serialize(item.gradient_input_is_stale, out);
serialize(item.get_output_and_gradient_input_disabled, out);
serialize(item.x_grad, out);
serialize(item.cached_output, out);
}
friend void deserialize(depth_group& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (!(1 <= version && version <= 2))
throw serialization_error("Unexpected version found while deserializing dlib::depth_group.");
deserialize(*item.subnetwork, in);
group_helper::deserialize(*item.grp, in);
deserialize(item.gradient_input_is_stale, in);
deserialize(item.get_output_and_gradient_input_disabled, in);
deserialize(item.x_grad, in);
deserialize(item.cached_output, in);
}
friend std::ostream& operator<< (std::ostream& out, const depth_group& item)
{
item.print(out, 0);
return out;
}
void print (std::ostream& out, unsigned long idx=0) const
{
out << "layer<" << idx << ">\t";
detail().print(out, idx);
subnet().print(out, idx+1);
}
private:
void swap(depth_group& item)
{
std::swap(subnetwork,item.subnetwork);
std::swap(grp, item.grp);
std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled);
std::swap(x_grad, item.x_grad);
std::swap(cached_output, item.cached_output);
}
std::unique_ptr<subnet_type> subnetwork;
std::unique_ptr<grp_type> grp;
bool gradient_input_is_stale;
bool get_output_and_gradient_input_disabled;
resizable_tensor x_grad;
resizable_tensor cached_output;
// temp_tensor doesn't logically contribute to the state of this object.
// It is here only to prevent it from being reallocated over and over.
resizable_tensor temp_tensor;
};
// define "grp" layer shorter name for usage when creating networks
template <typename GRP, typename SUBNET>
using grp = depth_group<GRP, SUBNET>;
namespace impl {
template<
unsigned int i,
typename T, typename U
>
struct layer_helper<i, depth_group<T, U>,
typename std::enable_if<(i != 0 && i >= depth_group<T, U>::num_layers_in_group)>::type> {
const static size_t num_layers_in_group = depth_group<T, U>::num_layers_in_group;
using next_type = typename depth_group<T, U>::subnet_type;
using type = typename layer_helper<i - num_layers_in_group, next_type>::type;
static type &layer(depth_group<T, U> &n) {
return layer_helper<i - num_layers_in_group, next_type>::layer(n.subnet());
}
};
template<
unsigned int i,
typename T, typename U
>
struct layer_helper<i, depth_group<T, U>,
typename std::enable_if<(i != 0 && i < depth_group<T, U>::num_layers_in_group)>::type> {
const static size_t num_layers_in_group = depth_group<T, U>::num_layers_in_group;
typedef typename depth_group<T, U>::grp_type grp_type;
using type = typename layer_helper<i, grp_type>::type;
static type &layer(depth_group<T, U> &n) {
return layer_helper<i, grp_type>::layer(n.detail());
}
};
template <unsigned int pos, unsigned int i, typename... T>
struct group_pos_search{
const static unsigned int count = sizeof...(T);
const static unsigned int pos_from_begin = count - pos - 1;
using tuple_elem_type = typename std::tuple_element<pos_from_begin, std::tuple<T...>>::type;
static const unsigned int num_layers = tuple_elem_type::num_layers;
static const unsigned int layer_index = i >= num_layers ? group_pos_search<pos - 1, i - num_layers, T...>::layer_index : i;
static const unsigned int tuple_index = i >= num_layers ? group_pos_search<pos - 1, i - num_layers, T...>::tuple_index + 1 : pos;
};
template <unsigned int i, typename... T>
struct group_pos_search<0, i, T...>{
static const unsigned int layer_index = i;
static const unsigned int tuple_index = 0;
};
template<
unsigned int i,
typename... R
>
struct layer_helper<i, std::tuple<R...>, typename std::enable_if<true>::type>{
const static unsigned tuple_size = sizeof...(R);
static const unsigned int layer_index = group_pos_search<tuple_size - 1, i, R...>::layer_index;
static const unsigned int tuple_index = group_pos_search<tuple_size - 1, i, R...>::tuple_index;
using next_type = typename std::tuple_element<tuple_index, std::tuple<R...>>::type;//typename std::remove_reference<decltype(makeT().subnet())>::type;
using type = typename layer_helper<layer_index,next_type>::type;
static type &layer(std::tuple<R...> &n) {
return layer_helper<layer_index, next_type>::layer(std::get<tuple_index>(n));
}
};
// helper classes for layer group processing
template <size_t idx, typename... T>
struct group_helper_impl{
static void serialize_impl(const std::tuple<T...>& data, std::ostream& out){
group_helper_impl<idx - 1, T...>::serialize_impl(data, out);
serialize(std::get<idx>(data), out);
}
static void deserialize_impl(std::tuple<T...>& data, std::istream& in){
group_helper_impl<idx - 1, T...>::deserialize_impl(data, in);
deserialize(std::get<idx>(data), in);
}
static void forward(const tensor& x, std::tuple<T...>& grp, long& group_depth){
group_helper_impl<idx - 1, T...>::forward(x, grp, group_depth);
auto& r = std::get<idx>(grp).forward(x);
group_depth += r.k();
}
static size_t concat(resizable_tensor& cached_output, std::tuple<T...>& grp, size_t offset){
offset += group_helper_impl<idx - 1, T...>::concat(cached_output, grp, offset);
auto& output = std::get<idx>(grp).get_output();
tt::concat_depth(cached_output, offset, output);
return offset + output.nc() * output.nr() * output.k();
}
template<typename solver_type>
static sstack<solver_type> update_parameters(sstack<solver_type> solvers, double step_size, std::tuple<T...>& grp){
sstack<solver_type> sub_solvers = group_helper_impl<idx - 1, T...>::update_parameters(solvers, step_size, grp);
std::get<idx>(grp).update_parameters(sub_solvers, step_size);
using tuple_elem_type = typename std::tuple_element<idx, std::tuple<T...>>::type;
return sub_solvers.pop(tuple_elem_type::num_computational_layers);
}
static size_t backward(std::tuple<T...>& grp, const tensor& group_gradient_in,
const tensor& subnet_out, tensor& group_gradient_out, size_t offset)
{
offset += group_helper_impl<idx - 1, T...>::backward(grp, group_gradient_in, subnet_out, group_gradient_out, offset);
auto& subnet = std::get<idx>(grp);
auto& gr_input = subnet.get_gradient_input();
tt::split_depth(gr_input, offset, group_gradient_in);
subnet.back_propagate_error(subnet_out);
tt::add(group_gradient_out, group_gradient_out, subnet.get_final_data_gradient());
return offset + gr_input.nc() * gr_input.nr() * gr_input.k();
}
};
template <typename... T>
struct group_helper_impl<0, T...>{
static void serialize_impl(const std::tuple<T...>& data, std::ostream& out){
serialize(std::get<0>(data), out);
}
static void deserialize_impl(std::tuple<T...>& data, std::istream& in){
deserialize(std::get<0>(data), in);
}
static void forward(const tensor& x, std::tuple<T...>& grp, long& group_depth){
auto& r = std::get<0>(grp).forward(x);
group_depth += r.k();
}
static size_t concat(resizable_tensor& cached_output, std::tuple<T...>& grp, size_t offset){
auto& output = std::get<0>(grp).get_output();
tt::concat_depth(cached_output, offset, output);
return offset + output.nc() * output.nr() * output.k();
}
template<typename solver_type>
static sstack<solver_type> update_parameters(sstack<solver_type> solvers, double step_size, std::tuple<T...>& grp){
std::get<0>(grp).update_parameters(solvers, step_size);
using tuple_elem_type = typename std::tuple_element<0, std::tuple<T...>>::type;
return solvers.pop(tuple_elem_type::num_computational_layers);
}
static size_t backward(std::tuple<T...>& grp, const tensor& group_gradient_in,
const tensor& subnet_out, tensor& group_gradient_out, size_t offset)
{
auto& item = std::get<0>(grp);
auto& gr_input = item.get_gradient_input();
tt::split_depth(gr_input, offset, group_gradient_in);
item.back_propagate_error(subnet_out);
tt::add(group_gradient_out, group_gradient_out, item.get_final_data_gradient());
return offset + gr_input.nc() * gr_input.nr() * gr_input.k();
}
};
template <typename... T>
struct group_helper<std::tuple<T...>>{
static void serialize(const std::tuple<T...> & data, std::ostream& out){
group_helper_impl<std::tuple_size<std::tuple<T...>>::value - 1, T...>::serialize_impl(data, out);
}
static void deserialize(std::tuple<T...>& data, std::istream& in){
group_helper_impl<std::tuple_size<std::tuple<T...>>::value - 1, T...>::deserialize_impl(data, in);
}
static void forward(const tensor& x, std::tuple<T...>& grp, long& group_depth){
group_helper_impl<std::tuple_size<std::tuple<T...>>::value - 1, T...>::forward(x, grp, group_depth);
}
static void concat(resizable_tensor& out, std::tuple<T...>& grp){
group_helper_impl<std::tuple_size<std::tuple<T...>>::value - 1, T...>::concat(out, grp, 0);
}
template<typename solver_type>
static void update_parameters(sstack<solver_type> solvers, double step_size, std::tuple<T...>& grp){
group_helper_impl<std::tuple_size<std::tuple<T...>>::value - 1, T...>::update_parameters(solvers, step_size, grp);
}
static void backward(std::tuple<T...>& grp, const tensor& group_gradient_in, const tensor& subnet_out, tensor& group_gradient_out)
{
group_helper_impl<std::tuple_size<std::tuple<T...>>::value - 1, T...>::backward(grp, group_gradient_in, subnet_out, group_gradient_out, 0);
}
};
// helper classes to understand the count of group items layers
template<typename T>
struct group_count_helper<T>{
const static size_t num_layers = T::num_layers;
const static size_t num_computational_layers = T::num_computational_layers;
};
template<typename T, typename... R>
struct group_count_helper<T, R...>{
const static size_t num_layers = group_count_helper<T>::num_layers + group_count_helper<R...>::num_layers;
const static size_t num_computational_layers = group_count_helper<T>::num_computational_layers + group_count_helper<R...>::num_computational_layers;
};
template<typename... R>
struct group_count_helper<std::tuple<R...>>{
const static size_t num_layers = group_count_helper<R...>::num_layers;
const static size_t num_computational_layers = group_count_helper<R...>::num_computational_layers;
};
}
}
#endif // DLIB_DNn_CORE_H_

@ -634,6 +634,66 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void concat_depth(tensor& dest, size_t sample_offset, const tensor& src)
{
const size_t dest_sample_size = static_cast<size_t>(dest.nc() * dest.nr() * dest.k());
const size_t src_sample_size = static_cast<size_t>(src.nc() * src.nr() * src.k());
DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size");
DLIB_CASSERT(dest_sample_size >= src_sample_size + sample_offset, "Not enough space in dest tensor");
#ifdef DLIB_USE_CUDA
float* dest_p = dest.device_write_only() + sample_offset;
const float* src_p = src.device();
#else
float* dest_p = dest.host_write_only() + sample_offset;
const float* src_p = src.host();
#endif
for (unsigned long i = 0; i < src.num_samples(); ++i)
{
#ifdef DLIB_USE_CUDA
CHECK_CUDA(cudaMemcpy(dest_p, src_p, src_sample_size * sizeof(float), cudaMemcpyDeviceToDevice));
#else
::memcpy(dest_p, src_p, src_sample_size * sizeof(float));
#endif
dest_p += dest_sample_size;
src_p += src_sample_size;
}
}
void split_depth(tensor& dest, size_t sample_offset, const tensor& src)
{
const size_t dest_sample_size = static_cast<size_t>(dest.nc() * dest.nr() * dest.k());
const size_t src_sample_size = static_cast<size_t>(src.nc() * src.nr() * src.k());
DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size");
DLIB_CASSERT(dest_sample_size <= src_sample_size - sample_offset, "Not enough space in dest tensor");
#ifdef DLIB_USE_CUDA
float* dest_p = dest.device_write_only();
const float* src_p = src.device() + sample_offset;
#else
float* dest_p = dest.host_write_only();
const float* src_p = src.host() + sample_offset;
#endif
for (unsigned long i = 0; i < src.num_samples(); ++i)
{
#ifdef DLIB_USE_CUDA
CHECK_CUDA(cudaMemcpy(dest_p, src_p, dest_sample_size * sizeof(float), cudaMemcpyDeviceToDevice));
#else
::memcpy(dest_p, src_p, dest_sample_size * sizeof(float));
#endif
dest_p += dest_sample_size;
src_p += src_sample_size;
}
}
// ----------------------------------------------------------------------------------------
}}

@ -1171,6 +1171,43 @@ namespace dlib { namespace tt
resizable_tensor accum_buffer;
};
// ----------------------------------------------------------------------------------------
void concat_depth(
tensor& dest,
size_t sample_offset,
const tensor& src
);
/*!
requires
- dest.nc() == src.nc()
- dest.nr() == src.nr()
- dest.num_samples() == src.num_samples()
- dest.k() >= src.k() + sample_offset
- is_same_object(dest,src) == false
- sample_offset a count of elements, not bytes
ensures
- performs: dest[i, k + sample_offset, r, c] = src[i, k, r, c], where k in [0..src.k()]
Copies content of each sample from src in to corresponding place of sample at dst
!*/
void split_depth(
tensor& dest,
size_t sample_offset,
const tensor& src
);
/*!
requires
- dest.nc() == src.nc()
- dest.nr() == src.nr()
- dest.num_samples() == src.num_samples()
- dest.k() <= src.k() - sample_offset
- is_same_object(dest,src) == false
- sample_offset a count of elements, not bytes
ensures
- performs: dest[i, k, r, c] = src[i, k + sample_offset, r, c], where k in [0..dest.k()]
Fills each sample of dst from the corresponding part of each sample at src
!*/
// ----------------------------------------------------------------------------------------

@ -33,6 +33,7 @@ ENDMACRO()
if (COMPILER_CAN_DO_CPP_11)
add_example(dnn_mnist_ex)
add_example(dnn_mnist_advanced_ex)
add_example(dnn_inception_ex)
endif()
#here we apply our macros

@ -0,0 +1,145 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This is an example illustrating the use of the deep learning tools from the
dlib C++ Library. I'm assuming you have already read the dnn_mnist_ex.cpp
example. So in this example program I'm going to go over a number of more
advanced parts of the API, including:
- Using grp layer for constructing inception layer
Inception layer is a kind of NN architecture for running sevelar convolution types
on the same input area and joining all convolution results into one output.
For further reading refer http://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf
*/
#include <dlib/dnn.h>
#include <iostream>
#include <dlib/data_io.h>
#include <tuple>
using namespace std;
using namespace dlib;
// Here we define inception module as described in GoogLeNet specification. The depth of each sublayer can be changed
template<typename SUBNET>
using inception = grp<std::tuple<con<8,1,1,1,1, group_input>,
con<8,3,3,1,1, con<8,1,1,1,1, group_input>>,
con<8,5,5,1,1, con<8,1,1,1,1, group_input>>,
con<8,1,1,1,1, max_pool<3,3,1,1, group_input>>>,
SUBNET>;
int main(int argc, char** argv) try
{
// This example is going to run on the MNIST dataset.
if (argc != 2)
{
cout << "This example needs the MNIST dataset to run!" << endl;
cout << "You can get MNIST from http://yann.lecun.com/exdb/mnist/" << endl;
cout << "Download the 4 files that comprise the dataset, decompress them, and" << endl;
cout << "put them in a folder. Then give that folder as input to this program." << endl;
return 1;
}
std::vector<matrix<unsigned char>> training_images;
std::vector<unsigned long> training_labels;
std::vector<matrix<unsigned char>> testing_images;
std::vector<unsigned long> testing_labels;
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels);
// Create a the same network as in dnn_mnist_ex, but use inception layer insteam of convolution
// in the middle
using net_type = loss_multiclass_log<
fc<10,
relu<fc<84,
relu<fc<120,
max_pool<2,2,2,2,relu<inception<
max_pool<2,2,2,2,relu<con<6,5,5,1,1,
input<matrix<unsigned char>>
>>>>>>>>>>>>;
// Create a network as defined above. This network will produce 10 outputs
// because that's how we defined net_type. However, fc layers can have the
// number of outputs they produce changed at runtime.
net_type net;
// the following training process is the same as in dnn_mnist_ex sample
// And then train it using the MNIST data. The code below uses mini-batch stochastic
// gradient descent with an initial learning rate of 0.01 to accomplish this.
dnn_trainer<net_type> trainer(net);
trainer.set_learning_rate(0.01);
trainer.set_min_learning_rate(0.00001);
trainer.set_mini_batch_size(128);
trainer.be_verbose();
// Since DNN training can take a long time, we can ask the trainer to save its state to
// a file named "mnist_sync" every 20 seconds. This way, if we kill this program and
// start it again it will begin where it left off rather than restarting the training
// from scratch. This is because, when the program restarts, this call to
// set_synchronization_file() will automatically reload the settings from mnist_sync if
// the file exists.
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
// Finally, this line begins training. By default, it runs SGD with our specified
// learning rate until the loss stops decreasing. Then it reduces the learning rate by
// a factor of 10 and continues running until the loss stops decreasing again. It will
// keep doing this until the learning rate has dropped below the min learning rate
// defined above or the maximum number of epochs as been executed (defaulted to 10000).
trainer.train(training_images, training_labels);
// At this point our net object should have learned how to classify MNIST images. But
// before we try it out let's save it to disk. Note that, since the trainer has been
// running images through the network, net will have a bunch of state in it related to
// the last batch of images it processed (e.g. outputs from each layer). Since we
// don't care about saving that kind of stuff to disk we can tell the network to forget
// about that kind of transient data so that our file will be smaller. We do this by
// "cleaning" the network before saving it.
net.clean();
serialize("mnist_network.dat") << net;
// Now if we later wanted to recall the network from disk we can simply say:
// deserialize("mnist_network.dat") >> net;
// Now let's run the training images through the network. This statement runs all the
// images through it and asks the loss layer to convert the network's raw output into
// labels. In our case, these labels are the numbers between 0 and 9.
std::vector<unsigned long> predicted_labels = net(training_images);
int num_right = 0;
int num_wrong = 0;
// And then let's see if it classified them correctly.
for (size_t i = 0; i < training_images.size(); ++i)
{
if (predicted_labels[i] == training_labels[i])
++num_right;
else
++num_wrong;
}
cout << "training num_right: " << num_right << endl;
cout << "training num_wrong: " << num_wrong << endl;
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
// Let's also see if the network can correctly classify the testing images. Since
// MNIST is an easy dataset, we should see at least 99% accuracy.
predicted_labels = net(testing_images);
num_right = 0;
num_wrong = 0;
for (size_t i = 0; i < testing_images.size(); ++i)
{
if (predicted_labels[i] == testing_labels[i])
++num_right;
else
++num_wrong;
}
cout << "testing num_right: " << num_right << endl;
cout << "testing num_wrong: " << num_wrong << endl;
cout << "testing accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
}
catch(std::exception& e)
{
cout << e.what() << endl;
}
Loading…
Cancel
Save