mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added visit_layers_backwards(), visit_layers_backwards_range(), and
visit_layers_range().
This commit is contained in:
parent
9726ce1cac
commit
d7c003d190
@ -3337,6 +3337,39 @@ namespace dlib
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t i, size_t num>
|
||||
struct vl_loop_backwards
|
||||
{
|
||||
template <
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
static void visit(
|
||||
net_type& net,
|
||||
visitor&& v
|
||||
)
|
||||
{
|
||||
vl_loop<i+1, num>::visit(net,v);
|
||||
v(i, layer<i>(net));
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t num>
|
||||
struct vl_loop_backwards<num,num>
|
||||
{
|
||||
template <
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
static void visit(
|
||||
net_type&,
|
||||
visitor&&
|
||||
)
|
||||
{
|
||||
// Base case of recursion. Don't do anything.
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
@ -3351,6 +3384,50 @@ namespace dlib
|
||||
impl::vl_loop<0, net_type::num_layers>::visit(net, v);
|
||||
}
|
||||
|
||||
template <
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
void visit_layers_backwards(
|
||||
net_type& net,
|
||||
visitor v
|
||||
)
|
||||
{
|
||||
impl::vl_loop_backwards<0, net_type::num_layers>::visit(net, v);
|
||||
}
|
||||
|
||||
template <
|
||||
size_t begin,
|
||||
size_t end,
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
void visit_layers_range(
|
||||
net_type& net,
|
||||
visitor v
|
||||
)
|
||||
{
|
||||
static_assert(begin <= end, "Invalid range");
|
||||
static_assert(end <= net_type::num_layers, "Invalid range");
|
||||
impl::vl_loop<begin,end>::visit(net, v);
|
||||
}
|
||||
|
||||
template <
|
||||
size_t begin,
|
||||
size_t end,
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
void visit_layers_backwards_range(
|
||||
net_type& net,
|
||||
visitor v
|
||||
)
|
||||
{
|
||||
static_assert(begin <= end, "Invalid range");
|
||||
static_assert(end <= net_type::num_layers, "Invalid range");
|
||||
impl::vl_loop_backwards<begin,end>::visit(net, v);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
@ -1491,6 +1491,89 @@ namespace dlib
|
||||
v(i, layer<i>(net));
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
void visit_layers_backwards(
|
||||
net_type& net,
|
||||
visitor v
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
- v is a function object with a signature equivalent to:
|
||||
v(size_t idx, any_net_type& t)
|
||||
That is, it must take a size_t and then any of the network types such as
|
||||
add_layer, add_loss_layer, etc.
|
||||
ensures
|
||||
- Loops over all the layers in net and calls v() on them. The loop happens in
|
||||
the reverse order of visit_layers(). To be specific, this function
|
||||
essentially performs the following:
|
||||
|
||||
for (size_t i = net_type::num_layers; i != 0; --i)
|
||||
v(i-1, layer<i-1>(net));
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
size_t begin,
|
||||
size_t end,
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
void visit_layers_range(
|
||||
net_type& net,
|
||||
visitor v
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
- v is a function object with a signature equivalent to:
|
||||
v(size_t idx, any_net_type& t)
|
||||
That is, it must take a size_t and then any of the network types such as
|
||||
add_layer, add_loss_layer, etc.
|
||||
- begin <= end <= net_type::num_layers
|
||||
ensures
|
||||
- Loops over the layers in the range [begin,end) in net and calls v() on them.
|
||||
The loop happens in the reverse order of visit_layers(). To be specific,
|
||||
this function essentially performs the following:
|
||||
|
||||
for (size_t i = begin; i < end; ++i)
|
||||
v(i, layer<i>(net));
|
||||
!*/
|
||||
|
||||
template <
|
||||
size_t begin,
|
||||
size_t end,
|
||||
typename net_type,
|
||||
typename visitor
|
||||
>
|
||||
void visit_layers_backwards_range(
|
||||
net_type& net,
|
||||
visitor v
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
|
||||
add_tag_layer.
|
||||
- v is a function object with a signature equivalent to:
|
||||
v(size_t idx, any_net_type& t)
|
||||
That is, it must take a size_t and then any of the network types such as
|
||||
add_layer, add_loss_layer, etc.
|
||||
- begin <= end <= net_type::num_layers
|
||||
ensures
|
||||
- Loops over the layers in the range [begin,end) in net and calls v() on them.
|
||||
The loop happens in the reverse order of visit_layers_range(). To be specific,
|
||||
this function essentially performs the following:
|
||||
|
||||
for (size_t i = end; i != begin; --i)
|
||||
v(i-1, layer<i-1>(net));
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct layer_test_results
|
||||
|
Loading…
Reference in New Issue
Block a user