From 28c4a48281f9fddcd7be4721e71b7acd968e4e32 Mon Sep 17 00:00:00 2001 From: Fm Date: Tue, 17 May 2016 13:07:04 +0300 Subject: [PATCH] Grouping layer added --- dlib/dnn/core.h | 495 +++++++++++++++++++++++++++++++++- dlib/dnn/tensor_tools.cpp | 60 +++++ dlib/dnn/tensor_tools.h | 37 +++ examples/CMakeLists.txt | 1 + examples/dnn_inception_ex.cpp | 145 ++++++++++ 5 files changed, 737 insertions(+), 1 deletion(-) create mode 100644 examples/dnn_inception_ex.cpp diff --git a/dlib/dnn/core.h b/dlib/dnn/core.h index b7492673a..47daea746 100644 --- a/dlib/dnn/core.h +++ b/dlib/dnn/core.h @@ -568,7 +568,7 @@ namespace dlib struct is_nonloss_layer_type> : std::true_type {}; template - class add_layer::value>::type> { public: @@ -3159,6 +3159,499 @@ namespace dlib // ---------------------------------------------------------------------------------------- + namespace impl + { + template + struct group_helper; + template + struct group_count_helper; + } + + + // -------------------------------------------------------------------------------------- + // this class is used to reference group layer input + class group_input + { + public: + typedef tensor input_type; + const static unsigned int sample_expansion_factor = 1; + friend void serialize(const group_input& item, std::ostream& out) + { + serialize("group_input", out); + } + + friend void deserialize(group_input& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "group_input") + throw serialization_error("Unexpected version found while deserializing dlib::group_input."); + } + + friend std::ostream& operator<<(std::ostream& out, const group_input& item) + { + out << "group_input"; + return out; + } + }; + // -------------------------------------------------------------------------------------- + + template + class depth_group; + + + template + struct is_nonloss_layer_type> : std::true_type {}; + + template + class depth_group + { + public: + typedef GRP grp_type; + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + const static size_t group_size = std::tuple_size::value; + const static size_t num_layers_in_group = impl::group_count_helper::num_layers; + const static size_t num_layers = subnet_type::num_layers + num_layers_in_group; + const static size_t num_computational_layers_in_group = impl::group_count_helper::num_computational_layers; + const static size_t num_computational_layers = subnet_type::num_computational_layers + num_computational_layers_in_group; + const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor; + + using group_helper = impl::group_helper; + + depth_group( + ): + subnetwork(new subnet_type()), + grp(new grp_type()), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false) + { + } + + depth_group(const depth_group& item) + { + grp.reset(new grp_type(*item.grp)); + subnetwork.reset(new subnet_type(*item.subnetwork)); + gradient_input_is_stale = item.gradient_input_is_stale; + get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled; + x_grad = item.x_grad; + cached_output = item.cached_output; + temp_tensor = item.temp_tensor; + } + depth_group& operator=(const depth_group& item) { depth_group(item).swap(*this); return *this;} + depth_group(depth_group&& item) : depth_group() { swap(item); } + depth_group& operator=(depth_group&& item) { swap(item); return *this; } + + template + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + // Allow copying networks from one to another as long as their corresponding + // layers can be constructed from each other. + template + depth_group( + const depth_group& item + ) : + grp(new grp_type(item.detail())), + subnetwork(new subnet_type(item.subnet())), + gradient_input_is_stale(item.gradient_input_is_stale), + get_output_and_gradient_input_disabled(item.get_output_and_gradient_input_disabled), + x_grad(item.x_grad), + cached_output(item.cached_output) + { + } + + template + void to_tensor ( + input_iterator ibegin, + input_iterator iend, + resizable_tensor& data + ) const + { + subnetwork->to_tensor(ibegin,iend,data); + } + + template + const tensor& operator() ( + input_iterator ibegin, + input_iterator iend + ) + { + to_tensor(ibegin,iend,temp_tensor); + return forward(temp_tensor); + } + + + const tensor& operator() (const input_type& x) + { + return (*this)(&x, &x+1); + } + + + // forward for group: subnet->for_each_in_group->concat->cached_output + const tensor& forward(const tensor& x) + { + + subnetwork->forward(x); + long group_depth = 0; + + group_helper::forward(subnetwork->get_output(), detail(), group_depth); + + auto& out_0 = std::get<0>(detail()).get_output(); + cached_output.set_size(out_0.num_samples(), group_depth, out_0.nr(), out_0.nc()); + + group_helper::concat(cached_output, detail()); + + + gradient_input_is_stale = true; + return private_get_output(); + } + + private: + bool this_layer_requires_forward_output( + ) + { + return true; + } + + tensor& private_get_output() const + { + return const_cast(cached_output); + } + tensor& private_get_gradient_input() + { + if (gradient_input_is_stale) + { + gradient_input_is_stale = false; + x_grad.copy_size(private_get_output()); + x_grad = 0; + } + return x_grad; + } + void disable_output_and_gradient_getters ( + ) { get_output_and_gradient_input_disabled = true; } + public: + const tensor& get_output() const + { + if (get_output_and_gradient_input_disabled) + throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it."); + return private_get_output(); + } + tensor& get_gradient_input() + { + if (get_output_and_gradient_input_disabled) + throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it."); + return private_get_gradient_input(); + } + + const tensor& get_final_data_gradient( + ) const { return subnetwork->get_final_data_gradient(); } + + void back_propagate_error(const tensor& x) + { + back_propagate_error(x, private_get_gradient_input()); + } + void back_propagate_error(const tensor& x, const tensor& gradient_input) + { + group_helper::backward(detail(), get_gradient_input(), subnetwork->get_output(), subnetwork->get_gradient_input()); + + subnetwork->back_propagate_error(x); + + // zero out get_gradient_input() + gradient_input_is_stale = true; + } + + template + void update_parameters(sstack solvers, double step_size) + { + DLIB_CASSERT(solvers.size()>=num_computational_layers,""); + group_helper::update_parameters(solvers, step_size, detail()); + solvers = solvers.pop(num_computational_layers_in_group); + subnetwork->update_parameters(solvers, step_size); + } + + const subnet_type& subnet() const { return *subnetwork; } + subnet_type& subnet() { return *subnetwork; } + + const grp_type& detail() const { return *grp; } + grp_type& detail() { return *grp; } + + void clean() + { + x_grad.clear(); + cached_output.clear(); + temp_tensor.clear(); + gradient_input_is_stale = true; + subnetwork->clean(); + } + + friend void serialize(const depth_group& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(*item.subnetwork, out); + group_helper::serialize(*item.grp, out); + serialize(item.gradient_input_is_stale, out); + serialize(item.get_output_and_gradient_input_disabled, out); + serialize(item.x_grad, out); + serialize(item.cached_output, out); + } + + friend void deserialize(depth_group& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (!(1 <= version && version <= 2)) + throw serialization_error("Unexpected version found while deserializing dlib::depth_group."); + deserialize(*item.subnetwork, in); + group_helper::deserialize(*item.grp, in); + deserialize(item.gradient_input_is_stale, in); + deserialize(item.get_output_and_gradient_input_disabled, in); + deserialize(item.x_grad, in); + deserialize(item.cached_output, in); + } + + friend std::ostream& operator<< (std::ostream& out, const depth_group& item) + { + item.print(out, 0); + return out; + } + + void print (std::ostream& out, unsigned long idx=0) const + { + out << "layer<" << idx << ">\t"; + detail().print(out, idx); + subnet().print(out, idx+1); + } + + private: + + + void swap(depth_group& item) + { + std::swap(subnetwork,item.subnetwork); + std::swap(grp, item.grp); + std::swap(gradient_input_is_stale, item.gradient_input_is_stale); + std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled); + std::swap(x_grad, item.x_grad); + std::swap(cached_output, item.cached_output); + } + + + std::unique_ptr subnetwork; + std::unique_ptr grp; + + bool gradient_input_is_stale; + bool get_output_and_gradient_input_disabled; + + resizable_tensor x_grad; + resizable_tensor cached_output; + + // temp_tensor doesn't logically contribute to the state of this object. + // It is here only to prevent it from being reallocated over and over. + resizable_tensor temp_tensor; + }; + + // define "grp" layer shorter name for usage when creating networks + template + using grp = depth_group; + + namespace impl { + template< + unsigned int i, + typename T, typename U + > + struct layer_helper, + typename std::enable_if<(i != 0 && i >= depth_group::num_layers_in_group)>::type> { + const static size_t num_layers_in_group = depth_group::num_layers_in_group; + + using next_type = typename depth_group::subnet_type; + using type = typename layer_helper::type; + + static type &layer(depth_group &n) { + return layer_helper::layer(n.subnet()); + } + }; + + template< + unsigned int i, + typename T, typename U + > + struct layer_helper, + typename std::enable_if<(i != 0 && i < depth_group::num_layers_in_group)>::type> { + const static size_t num_layers_in_group = depth_group::num_layers_in_group; + typedef typename depth_group::grp_type grp_type; + using type = typename layer_helper::type; + + static type &layer(depth_group &n) { + return layer_helper::layer(n.detail()); + } + }; + + template + struct group_pos_search{ + const static unsigned int count = sizeof...(T); + const static unsigned int pos_from_begin = count - pos - 1; + using tuple_elem_type = typename std::tuple_element>::type; + static const unsigned int num_layers = tuple_elem_type::num_layers; + + static const unsigned int layer_index = i >= num_layers ? group_pos_search::layer_index : i; + static const unsigned int tuple_index = i >= num_layers ? group_pos_search::tuple_index + 1 : pos; + }; + template + struct group_pos_search<0, i, T...>{ + static const unsigned int layer_index = i; + static const unsigned int tuple_index = 0; + }; + + + template< + unsigned int i, + typename... R + > + struct layer_helper, typename std::enable_if::type>{ + const static unsigned tuple_size = sizeof...(R); + + static const unsigned int layer_index = group_pos_search::layer_index; + static const unsigned int tuple_index = group_pos_search::tuple_index; + + using next_type = typename std::tuple_element>::type;//typename std::remove_reference::type; + using type = typename layer_helper::type; + + static type &layer(std::tuple &n) { + return layer_helper::layer(std::get(n)); + } + }; + + // helper classes for layer group processing + template + struct group_helper_impl{ + static void serialize_impl(const std::tuple& data, std::ostream& out){ + group_helper_impl::serialize_impl(data, out); + serialize(std::get(data), out); + } + static void deserialize_impl(std::tuple& data, std::istream& in){ + group_helper_impl::deserialize_impl(data, in); + deserialize(std::get(data), in); + } + static void forward(const tensor& x, std::tuple& grp, long& group_depth){ + group_helper_impl::forward(x, grp, group_depth); + auto& r = std::get(grp).forward(x); + group_depth += r.k(); + } + static size_t concat(resizable_tensor& cached_output, std::tuple& grp, size_t offset){ + offset += group_helper_impl::concat(cached_output, grp, offset); + auto& output = std::get(grp).get_output(); + tt::concat_depth(cached_output, offset, output); + return offset + output.nc() * output.nr() * output.k(); + } + template + static sstack update_parameters(sstack solvers, double step_size, std::tuple& grp){ + sstack sub_solvers = group_helper_impl::update_parameters(solvers, step_size, grp); + std::get(grp).update_parameters(sub_solvers, step_size); + using tuple_elem_type = typename std::tuple_element>::type; + return sub_solvers.pop(tuple_elem_type::num_computational_layers); + } + static size_t backward(std::tuple& grp, const tensor& group_gradient_in, + const tensor& subnet_out, tensor& group_gradient_out, size_t offset) + { + offset += group_helper_impl::backward(grp, group_gradient_in, subnet_out, group_gradient_out, offset); + + auto& subnet = std::get(grp); + auto& gr_input = subnet.get_gradient_input(); + tt::split_depth(gr_input, offset, group_gradient_in); + + subnet.back_propagate_error(subnet_out); + + tt::add(group_gradient_out, group_gradient_out, subnet.get_final_data_gradient()); + return offset + gr_input.nc() * gr_input.nr() * gr_input.k(); + } + }; + template + struct group_helper_impl<0, T...>{ + static void serialize_impl(const std::tuple& data, std::ostream& out){ + serialize(std::get<0>(data), out); + } + static void deserialize_impl(std::tuple& data, std::istream& in){ + deserialize(std::get<0>(data), in); + } + static void forward(const tensor& x, std::tuple& grp, long& group_depth){ + auto& r = std::get<0>(grp).forward(x); + group_depth += r.k(); + } + static size_t concat(resizable_tensor& cached_output, std::tuple& grp, size_t offset){ + auto& output = std::get<0>(grp).get_output(); + tt::concat_depth(cached_output, offset, output); + return offset + output.nc() * output.nr() * output.k(); + } + template + static sstack update_parameters(sstack solvers, double step_size, std::tuple& grp){ + std::get<0>(grp).update_parameters(solvers, step_size); + using tuple_elem_type = typename std::tuple_element<0, std::tuple>::type; + return solvers.pop(tuple_elem_type::num_computational_layers); + } + static size_t backward(std::tuple& grp, const tensor& group_gradient_in, + const tensor& subnet_out, tensor& group_gradient_out, size_t offset) + { + auto& item = std::get<0>(grp); + auto& gr_input = item.get_gradient_input(); + tt::split_depth(gr_input, offset, group_gradient_in); + item.back_propagate_error(subnet_out); + + tt::add(group_gradient_out, group_gradient_out, item.get_final_data_gradient()); + return offset + gr_input.nc() * gr_input.nr() * gr_input.k(); + } + }; + template + struct group_helper>{ + static void serialize(const std::tuple & data, std::ostream& out){ + group_helper_impl>::value - 1, T...>::serialize_impl(data, out); + } + static void deserialize(std::tuple& data, std::istream& in){ + group_helper_impl>::value - 1, T...>::deserialize_impl(data, in); + } + static void forward(const tensor& x, std::tuple& grp, long& group_depth){ + group_helper_impl>::value - 1, T...>::forward(x, grp, group_depth); + } + static void concat(resizable_tensor& out, std::tuple& grp){ + group_helper_impl>::value - 1, T...>::concat(out, grp, 0); + } + template + static void update_parameters(sstack solvers, double step_size, std::tuple& grp){ + group_helper_impl>::value - 1, T...>::update_parameters(solvers, step_size, grp); + } + static void backward(std::tuple& grp, const tensor& group_gradient_in, const tensor& subnet_out, tensor& group_gradient_out) + { + group_helper_impl>::value - 1, T...>::backward(grp, group_gradient_in, subnet_out, group_gradient_out, 0); + } + }; + + // helper classes to understand the count of group items layers + template + struct group_count_helper{ + const static size_t num_layers = T::num_layers; + const static size_t num_computational_layers = T::num_computational_layers; + }; + + template + struct group_count_helper{ + const static size_t num_layers = group_count_helper::num_layers + group_count_helper::num_layers; + const static size_t num_computational_layers = group_count_helper::num_computational_layers + group_count_helper::num_computational_layers; + }; + template + struct group_count_helper>{ + const static size_t num_layers = group_count_helper::num_layers; + const static size_t num_computational_layers = group_count_helper::num_computational_layers; + }; + + } } #endif // DLIB_DNn_CORE_H_ diff --git a/dlib/dnn/tensor_tools.cpp b/dlib/dnn/tensor_tools.cpp index 521918eba..62e9c9c62 100644 --- a/dlib/dnn/tensor_tools.cpp +++ b/dlib/dnn/tensor_tools.cpp @@ -634,6 +634,66 @@ namespace dlib { namespace tt #endif } + // ---------------------------------------------------------------------------------------- + // ------------------------------------------------------------------------------------ + + void concat_depth(tensor& dest, size_t sample_offset, const tensor& src) + { + const size_t dest_sample_size = static_cast(dest.nc() * dest.nr() * dest.k()); + const size_t src_sample_size = static_cast(src.nc() * src.nr() * src.k()); + + DLIB_CASSERT(dest.num_samples() == src.num_samples() && + dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size"); + DLIB_CASSERT(dest_sample_size >= src_sample_size + sample_offset, "Not enough space in dest tensor"); + +#ifdef DLIB_USE_CUDA + float* dest_p = dest.device_write_only() + sample_offset; + const float* src_p = src.device(); +#else + float* dest_p = dest.host_write_only() + sample_offset; + const float* src_p = src.host(); +#endif + + for (unsigned long i = 0; i < src.num_samples(); ++i) + { +#ifdef DLIB_USE_CUDA + CHECK_CUDA(cudaMemcpy(dest_p, src_p, src_sample_size * sizeof(float), cudaMemcpyDeviceToDevice)); +#else + ::memcpy(dest_p, src_p, src_sample_size * sizeof(float)); +#endif + dest_p += dest_sample_size; + src_p += src_sample_size; + } + } + + void split_depth(tensor& dest, size_t sample_offset, const tensor& src) + { + const size_t dest_sample_size = static_cast(dest.nc() * dest.nr() * dest.k()); + const size_t src_sample_size = static_cast(src.nc() * src.nr() * src.k()); + + DLIB_CASSERT(dest.num_samples() == src.num_samples() && + dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size"); + DLIB_CASSERT(dest_sample_size <= src_sample_size - sample_offset, "Not enough space in dest tensor"); + +#ifdef DLIB_USE_CUDA + float* dest_p = dest.device_write_only(); + const float* src_p = src.device() + sample_offset; +#else + float* dest_p = dest.host_write_only(); + const float* src_p = src.host() + sample_offset; +#endif + + for (unsigned long i = 0; i < src.num_samples(); ++i) + { +#ifdef DLIB_USE_CUDA + CHECK_CUDA(cudaMemcpy(dest_p, src_p, dest_sample_size * sizeof(float), cudaMemcpyDeviceToDevice)); +#else + ::memcpy(dest_p, src_p, dest_sample_size * sizeof(float)); +#endif + dest_p += dest_sample_size; + src_p += src_sample_size; + } + } // ---------------------------------------------------------------------------------------- }} diff --git a/dlib/dnn/tensor_tools.h b/dlib/dnn/tensor_tools.h index 638f0c338..2f10db58a 100644 --- a/dlib/dnn/tensor_tools.h +++ b/dlib/dnn/tensor_tools.h @@ -1171,6 +1171,43 @@ namespace dlib { namespace tt resizable_tensor accum_buffer; }; + // ---------------------------------------------------------------------------------------- + + void concat_depth( + tensor& dest, + size_t sample_offset, + const tensor& src + ); + /*! + requires + - dest.nc() == src.nc() + - dest.nr() == src.nr() + - dest.num_samples() == src.num_samples() + - dest.k() >= src.k() + sample_offset + - is_same_object(dest,src) == false + - sample_offset a count of elements, not bytes + ensures + - performs: dest[i, k + sample_offset, r, c] = src[i, k, r, c], where k in [0..src.k()] + Copies content of each sample from src in to corresponding place of sample at dst + !*/ + + void split_depth( + tensor& dest, + size_t sample_offset, + const tensor& src + ); + /*! + requires + - dest.nc() == src.nc() + - dest.nr() == src.nr() + - dest.num_samples() == src.num_samples() + - dest.k() <= src.k() - sample_offset + - is_same_object(dest,src) == false + - sample_offset a count of elements, not bytes + ensures + - performs: dest[i, k, r, c] = src[i, k + sample_offset, r, c], where k in [0..dest.k()] + Fills each sample of dst from the corresponding part of each sample at src + !*/ // ---------------------------------------------------------------------------------------- diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 61bc606bf..153fe6b0e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -33,6 +33,7 @@ ENDMACRO() if (COMPILER_CAN_DO_CPP_11) add_example(dnn_mnist_ex) add_example(dnn_mnist_advanced_ex) + add_example(dnn_inception_ex) endif() #here we apply our macros diff --git a/examples/dnn_inception_ex.cpp b/examples/dnn_inception_ex.cpp new file mode 100644 index 000000000..6490aebd0 --- /dev/null +++ b/examples/dnn_inception_ex.cpp @@ -0,0 +1,145 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This is an example illustrating the use of the deep learning tools from the + dlib C++ Library. I'm assuming you have already read the dnn_mnist_ex.cpp + example. So in this example program I'm going to go over a number of more + advanced parts of the API, including: + - Using grp layer for constructing inception layer + + Inception layer is a kind of NN architecture for running sevelar convolution types + on the same input area and joining all convolution results into one output. + For further reading refer http://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf +*/ + + +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + +// Here we define inception module as described in GoogLeNet specification. The depth of each sublayer can be changed +template +using inception = grp, + con<8,3,3,1,1, con<8,1,1,1,1, group_input>>, + con<8,5,5,1,1, con<8,1,1,1,1, group_input>>, + con<8,1,1,1,1, max_pool<3,3,1,1, group_input>>>, + SUBNET>; + +int main(int argc, char** argv) try +{ + // This example is going to run on the MNIST dataset. + if (argc != 2) + { + cout << "This example needs the MNIST dataset to run!" << endl; + cout << "You can get MNIST from http://yann.lecun.com/exdb/mnist/" << endl; + cout << "Download the 4 files that comprise the dataset, decompress them, and" << endl; + cout << "put them in a folder. Then give that folder as input to this program." << endl; + return 1; + } + + + std::vector> training_images; + std::vector training_labels; + std::vector> testing_images; + std::vector testing_labels; + load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels); + + + // Create a the same network as in dnn_mnist_ex, but use inception layer insteam of convolution + // in the middle + using net_type = loss_multiclass_log< + fc<10, + relu> + >>>>>>>>>>>>; + + + // Create a network as defined above. This network will produce 10 outputs + // because that's how we defined net_type. However, fc layers can have the + // number of outputs they produce changed at runtime. + net_type net; + + // the following training process is the same as in dnn_mnist_ex sample + + // And then train it using the MNIST data. The code below uses mini-batch stochastic + // gradient descent with an initial learning rate of 0.01 to accomplish this. + dnn_trainer trainer(net); + trainer.set_learning_rate(0.01); + trainer.set_min_learning_rate(0.00001); + trainer.set_mini_batch_size(128); + trainer.be_verbose(); + // Since DNN training can take a long time, we can ask the trainer to save its state to + // a file named "mnist_sync" every 20 seconds. This way, if we kill this program and + // start it again it will begin where it left off rather than restarting the training + // from scratch. This is because, when the program restarts, this call to + // set_synchronization_file() will automatically reload the settings from mnist_sync if + // the file exists. + trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20)); + // Finally, this line begins training. By default, it runs SGD with our specified + // learning rate until the loss stops decreasing. Then it reduces the learning rate by + // a factor of 10 and continues running until the loss stops decreasing again. It will + // keep doing this until the learning rate has dropped below the min learning rate + // defined above or the maximum number of epochs as been executed (defaulted to 10000). + trainer.train(training_images, training_labels); + + // At this point our net object should have learned how to classify MNIST images. But + // before we try it out let's save it to disk. Note that, since the trainer has been + // running images through the network, net will have a bunch of state in it related to + // the last batch of images it processed (e.g. outputs from each layer). Since we + // don't care about saving that kind of stuff to disk we can tell the network to forget + // about that kind of transient data so that our file will be smaller. We do this by + // "cleaning" the network before saving it. + net.clean(); + serialize("mnist_network.dat") << net; + // Now if we later wanted to recall the network from disk we can simply say: + // deserialize("mnist_network.dat") >> net; + + + // Now let's run the training images through the network. This statement runs all the + // images through it and asks the loss layer to convert the network's raw output into + // labels. In our case, these labels are the numbers between 0 and 9. + std::vector predicted_labels = net(training_images); + int num_right = 0; + int num_wrong = 0; + // And then let's see if it classified them correctly. + for (size_t i = 0; i < training_images.size(); ++i) + { + if (predicted_labels[i] == training_labels[i]) + ++num_right; + else + ++num_wrong; + + } + cout << "training num_right: " << num_right << endl; + cout << "training num_wrong: " << num_wrong << endl; + cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl; + + // Let's also see if the network can correctly classify the testing images. Since + // MNIST is an easy dataset, we should see at least 99% accuracy. + predicted_labels = net(testing_images); + num_right = 0; + num_wrong = 0; + for (size_t i = 0; i < testing_images.size(); ++i) + { + if (predicted_labels[i] == testing_labels[i]) + ++num_right; + else + ++num_wrong; + + } + cout << "testing num_right: " << num_right << endl; + cout << "testing num_wrong: " << num_wrong << endl; + cout << "testing accuracy: " << num_right/(double)(num_right+num_wrong) << endl; + +} +catch(std::exception& e) +{ + cout << e.what() << endl; +} +