@ -3471,140 +3471,6 @@ namespace dlib
return impl_test_layer ( l , 0.01 ) ;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template < size_t i , size_t num >
struct vlp_loop
{
template < typename T , typename U >
static typename std : : enable_if < ! is_add_layer < U > : : value > : : type invoke_functor ( T & & , size_t & , U & & )
{
// intentionally left empty
}
template < typename T , typename U >
static typename std : : enable_if < is_add_layer < U > : : value > : : type invoke_functor ( T & & v , size_t & comp_i , U & & l )
{
v ( comp_i , l . layer_details ( ) . get_layer_params ( ) ) ;
+ + comp_i ;
}
template <
typename net_type ,
typename visitor
>
static void visit (
size_t comp_i ,
net_type & net ,
visitor & & v
)
{
invoke_functor ( v , comp_i , layer < i > ( net ) ) ;
vlp_loop < i + 1 , num > : : visit ( comp_i , net , v ) ;
}
} ;
template < size_t num >
struct vlp_loop < num , num >
{
template <
typename net_type ,
typename visitor
>
static void visit (
size_t ,
net_type & ,
visitor & &
)
{
// Base case of recursion. Don't do anything.
}
} ;
}
template <
typename net_type ,
typename visitor
>
void visit_layer_parameters (
net_type & net ,
visitor v
)
{
size_t comp_i = 0 ;
impl : : vlp_loop < 0 , net_type : : num_layers > : : visit ( comp_i , net , v ) ;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template < size_t i , size_t num >
struct vlpg_loop
{
template < typename T , typename U >
static typename std : : enable_if < ! is_add_layer < U > : : value > : : type invoke_functor ( T & & , size_t & , U & & )
{
// intentionally left empty
}
template < typename T , typename U >
static typename std : : enable_if < is_add_layer < U > : : value > : : type invoke_functor ( T & & v , size_t & comp_i , U & & l )
{
v ( comp_i , l . get_parameter_gradient ( ) ) ;
+ + comp_i ;
}
template <
typename net_type ,
typename visitor
>
static void visit (
size_t comp_i ,
net_type & net ,
visitor & & v
)
{
invoke_functor ( v , comp_i , layer < i > ( net ) ) ;
vlpg_loop < i + 1 , num > : : visit ( comp_i , net , v ) ;
}
} ;
template < size_t num >
struct vlpg_loop < num , num >
{
template <
typename net_type ,
typename visitor
>
static void visit (
size_t ,
net_type & ,
visitor & &
)
{
// Base case of recursion. Don't do anything.
}
} ;
}
template <
typename net_type ,
typename visitor
>
void visit_layer_parameter_gradients (
net_type & net ,
visitor v
)
{
size_t comp_i = 0 ;
impl : : vlpg_loop < 0 , net_type : : num_layers > : : visit ( comp_i , net , v ) ;
}
// ----------------------------------------------------------------------------------------
namespace impl
@ -3621,7 +3487,9 @@ namespace dlib
visitor & & v
)
{
v ( i , layer < i > ( net ) ) ;
// Call whatever version of the visitor the user provided.
call_if_valid ( v , i , layer < i > ( net ) ) ;
call_if_valid ( v , layer < i > ( net ) ) ;
vl_loop < i + 1 , num > : : visit ( net , v ) ;
}
} ;
@ -3655,7 +3523,9 @@ namespace dlib
)
{
vl_loop_backwards < i + 1 , num > : : visit ( net , v ) ;
v ( i , layer < i > ( net ) ) ;
// Call whatever version of the visitor the user provided.
call_if_valid ( v , i , layer < i > ( net ) ) ;
call_if_valid ( v , layer < i > ( net ) ) ;
}
} ;
@ -3751,7 +3621,7 @@ namespace dlib
visitor & & v
)
{
v( next_net ) ;
call_if_ valid ( v , next_net ) ;
vl_until_tag < i + 1 , tag_id > : : visit ( net , layer < i + 1 > ( net ) , v ) ;
}
@ -3766,7 +3636,7 @@ namespace dlib
visitor & & v
)
{
v( next_net ) ;
call_if_ valid ( v , next_net ) ;
}
template <
@ -3780,7 +3650,7 @@ namespace dlib
visitor & & v
)
{
v( next_net ) ;
call_if_ valid ( v , next_net ) ;
}
} ;
}
@ -3798,6 +3668,137 @@ namespace dlib
impl : : vl_until_tag < 0 , tag_id > : : visit ( net , net , v ) ;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <
typename visitor
>
class visitor_computational_layer
{
public :
explicit visitor_computational_layer ( visitor & v ) : v_ ( v ) { }
template < typename T , typename U , typename E >
void operator ( ) ( size_t idx , add_layer < T , U , E > & l ) const
{
// Call whatever version of the visitor the user provided.
call_if_valid ( v_ , idx , l . layer_details ( ) ) ;
call_if_valid ( v_ , l . layer_details ( ) ) ;
}
private :
visitor & v_ ;
} ;
}
template <
typename net_type ,
typename visitor
>
void visit_computational_layers (
net_type & net ,
visitor v
)
{
visit_layers ( net , impl : : visitor_computational_layer < visitor > ( v ) ) ;
}
template <
size_t begin ,
size_t end ,
typename net_type ,
typename visitor
>
void visit_computational_layers_range (
net_type & net ,
visitor v
)
{
visit_layers_range < begin , end > ( net , impl : : visitor_computational_layer < visitor > ( v ) ) ;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <
typename visitor
>
class visit_layer_parameters
{
public :
explicit visit_layer_parameters ( visitor & v ) : v_ ( v ) { }
template < typename layer >
void operator ( ) ( layer & l )
{
// Call whatever version of the visitor the user provided.
const bool visitor_called = call_if_valid ( v_ , computational_layer_idx , l . get_layer_params ( ) ) | |
call_if_valid ( v_ , l . get_layer_params ( ) ) ;
DLIB_CASSERT ( visitor_called , " A visitor function with an incorrect signature was given to visit_layer_parameters() " ) ;
+ + computational_layer_idx ;
}
private :
size_t computational_layer_idx = 0 ;
visitor & v_ ;
} ;
}
template <
typename net_type ,
typename visitor
>
void visit_layer_parameters (
net_type & net ,
visitor v
)
{
visit_computational_layers ( net , impl : : visit_layer_parameters < visitor > ( v ) ) ;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <
typename visitor
>
class visit_layer_parameter_gradients
{
public :
explicit visit_layer_parameter_gradients ( visitor & v ) : v_ ( v ) { }
template < typename T , typename U , typename E >
void operator ( ) ( add_layer < T , U , E > & l )
{
// Call whatever version of the visitor the user provided.
const bool visitor_called = call_if_valid ( v_ , computational_layer_idx , l . get_parameter_gradient ( ) ) | |
call_if_valid ( v_ , l . get_parameter_gradient ( ) ) ;
DLIB_CASSERT ( visitor_called , " A visitor function with an incorrect signature was given to visit_layer_parameter_gradients() " ) ;
+ + computational_layer_idx ;
}
private :
size_t computational_layer_idx = 0 ;
visitor & v_ ;
} ;
}
template <
typename net_type ,
typename visitor
>
void visit_layer_parameter_gradients (
net_type & net ,
visitor v
)
{
visit_layers ( net , impl : : visit_layer_parameter_gradients < visitor > ( v ) ) ;
}
// ----------------------------------------------------------------------------------------
}