Added visit_layers_backwards(), visit_layers_backwards_range(), and

visit_layers_range().
This commit is contained in:
Davis King 2016-09-03 07:14:07 -04:00
parent 9726ce1cac
commit d7c003d190
2 changed files with 160 additions and 0 deletions

View File

@ -3337,6 +3337,39 @@ namespace dlib
}
};
template <size_t i, size_t num>
struct vl_loop_backwards
{
template <
typename net_type,
typename visitor
>
static void visit(
net_type& net,
visitor&& v
)
{
vl_loop<i+1, num>::visit(net,v);
v(i, layer<i>(net));
}
};
template <size_t num>
struct vl_loop_backwards<num,num>
{
template <
typename net_type,
typename visitor
>
static void visit(
net_type&,
visitor&&
)
{
// Base case of recursion. Don't do anything.
}
};
}
template <
@ -3351,6 +3384,50 @@ namespace dlib
impl::vl_loop<0, net_type::num_layers>::visit(net, v);
}
template <
typename net_type,
typename visitor
>
void visit_layers_backwards(
net_type& net,
visitor v
)
{
impl::vl_loop_backwards<0, net_type::num_layers>::visit(net, v);
}
template <
size_t begin,
size_t end,
typename net_type,
typename visitor
>
void visit_layers_range(
net_type& net,
visitor v
)
{
static_assert(begin <= end, "Invalid range");
static_assert(end <= net_type::num_layers, "Invalid range");
impl::vl_loop<begin,end>::visit(net, v);
}
template <
size_t begin,
size_t end,
typename net_type,
typename visitor
>
void visit_layers_backwards_range(
net_type& net,
visitor v
)
{
static_assert(begin <= end, "Invalid range");
static_assert(end <= net_type::num_layers, "Invalid range");
impl::vl_loop_backwards<begin,end>::visit(net, v);
}
// ----------------------------------------------------------------------------------------
}

View File

@ -1491,6 +1491,89 @@ namespace dlib
v(i, layer<i>(net));
!*/
template <
typename net_type,
typename visitor
>
void visit_layers_backwards(
net_type& net,
visitor v
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
- v is a function object with a signature equivalent to:
v(size_t idx, any_net_type& t)
That is, it must take a size_t and then any of the network types such as
add_layer, add_loss_layer, etc.
ensures
- Loops over all the layers in net and calls v() on them. The loop happens in
the reverse order of visit_layers(). To be specific, this function
essentially performs the following:
for (size_t i = net_type::num_layers; i != 0; --i)
v(i-1, layer<i-1>(net));
!*/
// ----------------------------------------------------------------------------------------
template <
size_t begin,
size_t end,
typename net_type,
typename visitor
>
void visit_layers_range(
net_type& net,
visitor v
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
- v is a function object with a signature equivalent to:
v(size_t idx, any_net_type& t)
That is, it must take a size_t and then any of the network types such as
add_layer, add_loss_layer, etc.
- begin <= end <= net_type::num_layers
ensures
- Loops over the layers in the range [begin,end) in net and calls v() on them.
The loop happens in the reverse order of visit_layers(). To be specific,
this function essentially performs the following:
for (size_t i = begin; i < end; ++i)
v(i, layer<i>(net));
!*/
template <
size_t begin,
size_t end,
typename net_type,
typename visitor
>
void visit_layers_backwards_range(
net_type& net,
visitor v
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
- v is a function object with a signature equivalent to:
v(size_t idx, any_net_type& t)
That is, it must take a size_t and then any of the network types such as
add_layer, add_loss_layer, etc.
- begin <= end <= net_type::num_layers
ensures
- Loops over the layers in the range [begin,end) in net and calls v() on them.
The loop happens in the reverse order of visit_layers_range(). To be specific,
this function essentially performs the following:
for (size_t i = end; i != begin; --i)
v(i-1, layer<i-1>(net));
!*/
// ----------------------------------------------------------------------------------------
struct layer_test_results