diff --git a/dlib/dnn/core.h b/dlib/dnn/core.h index e231878a5..6cf6cee64 100644 --- a/dlib/dnn/core.h +++ b/dlib/dnn/core.h @@ -3337,6 +3337,39 @@ namespace dlib } }; + template + struct vl_loop_backwards + { + template < + typename net_type, + typename visitor + > + static void visit( + net_type& net, + visitor&& v + ) + { + vl_loop::visit(net,v); + v(i, layer(net)); + } + }; + + template + struct vl_loop_backwards + { + template < + typename net_type, + typename visitor + > + static void visit( + net_type&, + visitor&& + ) + { + // Base case of recursion. Don't do anything. + } + }; + } template < @@ -3351,6 +3384,50 @@ namespace dlib impl::vl_loop<0, net_type::num_layers>::visit(net, v); } + template < + typename net_type, + typename visitor + > + void visit_layers_backwards( + net_type& net, + visitor v + ) + { + impl::vl_loop_backwards<0, net_type::num_layers>::visit(net, v); + } + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_range( + net_type& net, + visitor v + ) + { + static_assert(begin <= end, "Invalid range"); + static_assert(end <= net_type::num_layers, "Invalid range"); + impl::vl_loop::visit(net, v); + } + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_backwards_range( + net_type& net, + visitor v + ) + { + static_assert(begin <= end, "Invalid range"); + static_assert(end <= net_type::num_layers, "Invalid range"); + impl::vl_loop_backwards::visit(net, v); + } + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/dnn/core_abstract.h b/dlib/dnn/core_abstract.h index 759079817..82868a7e4 100644 --- a/dlib/dnn/core_abstract.h +++ b/dlib/dnn/core_abstract.h @@ -1491,6 +1491,89 @@ namespace dlib v(i, layer(net)); !*/ + template < + typename net_type, + typename visitor + > + void visit_layers_backwards( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + ensures + - Loops over all the layers in net and calls v() on them. The loop happens in + the reverse order of visit_layers(). To be specific, this function + essentially performs the following: + + for (size_t i = net_type::num_layers; i != 0; --i) + v(i-1, layer(net)); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_range( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + - begin <= end <= net_type::num_layers + ensures + - Loops over the layers in the range [begin,end) in net and calls v() on them. + The loop happens in the reverse order of visit_layers(). To be specific, + this function essentially performs the following: + + for (size_t i = begin; i < end; ++i) + v(i, layer(net)); + !*/ + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_backwards_range( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + - begin <= end <= net_type::num_layers + ensures + - Loops over the layers in the range [begin,end) in net and calls v() on them. + The loop happens in the reverse order of visit_layers_range(). To be specific, + this function essentially performs the following: + + for (size_t i = end; i != begin; --i) + v(i-1, layer(net)); + !*/ + // ---------------------------------------------------------------------------------------- struct layer_test_results