Made layer_details() part of the SUBNET interface so that user defined layer

details objects can access each other.  Also added the input_layer() global
function for accessing the input layer specifically.
This commit is contained in:
Davis King 2016-08-14 13:48:18 -04:00
parent 285bba7646
commit 390c8e90aa
3 changed files with 56 additions and 2 deletions

View File

@ -503,9 +503,13 @@ namespace dlib
subnet_wrapper(const subnet_wrapper&) = delete;
subnet_wrapper& operator=(const subnet_wrapper&) = delete;
subnet_wrapper(T& /*l_*/) {}
// Nothing here because in this case T is one of the input layer types
subnet_wrapper(T& l_) : l(l_) {}
// Not much here because in this case T is one of the input layer types
// that doesn't have anything in it.
typedef T layer_details_type;
const layer_details_type& layer_details() const { return l; }
private:
T& l;
};
template <typename T>
@ -518,12 +522,16 @@ namespace dlib
typedef T wrapped_type;
const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type;
subnet_wrapper(T& l_) : l(l_),subnetwork(l.subnet()) {}
const tensor& get_output() const { return l.private_get_output(); }
tensor& get_gradient_input() { return l.private_get_gradient_input(); }
const layer_details_type& layer_details() const { return l.layer_details(); }
const subnet_wrapper<typename T::subnet_type,false>& subnet() const { return subnetwork; }
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
@ -542,12 +550,16 @@ namespace dlib
typedef T wrapped_type;
const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type;
subnet_wrapper(T& l_) : l(l_),subnetwork(l.subnet()) {}
const tensor& get_output() const { return l.get_output(); }
tensor& get_gradient_input() { return l.get_gradient_input(); }
const layer_details_type& layer_details() const { return l.layer_details(); }
const subnet_wrapper<typename T::subnet_type,false>& subnet() const { return subnetwork; }
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
@ -1358,6 +1370,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
@ -1554,6 +1567,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t comp_layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers);
const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num;
const static size_t num_computational_layers = comp_layers_in_repeated_group + SUBNET::num_computational_layers;
@ -1825,6 +1839,7 @@ namespace dlib
public:
typedef INPUT_LAYER subnet_type;
typedef typename subnet_type::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_computational_layers = 0;
const static size_t num_layers = 2;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
@ -2544,6 +2559,16 @@ namespace dlib
return impl::layer_helper_match<Match,T,i>::layer(n);
}
// ----------------------------------------------------------------------------------------
template <typename net_type>
auto input_layer (
net_type& net
) -> decltype(layer<net_type::num_layers-1>(net))&
{
return layer<net_type::num_layers-1>(net);
}
// ----------------------------------------------------------------------------------------
template <template<typename> class TAG_TYPE, typename SUBNET>
@ -2552,6 +2577,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;

View File

@ -1332,6 +1332,22 @@ namespace dlib
- returns layer<i>(layer<Match>(n))
!*/
// ----------------------------------------------------------------------------------------
template <typename net_type>
auto& input_layer (
net_type& net
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- returns the input later of the given network object. Specifically, this
function is equivalent to calling:
layer<net_type::num_layers-1>(net);
!*/
// ----------------------------------------------------------------------------------------
template <

View File

@ -82,6 +82,18 @@ namespace dlib
above, if *this was layer1 then subnet() would return the network that
begins with layer2.
!*/
const layer_details_type& layer_details(
) const;
/*!
ensures
- returns the layer_details_type instance that defines the behavior of the
layer at the top of this network. I.e. returns the layer details that
defines the behavior of the layer nearest to the network output rather
than the input layer. For computational layers, this is the object
implementing the EXAMPLE_COMPUTATIONAL_LAYER_ interface that defines the
layer's behavior.
!*/
};
// ----------------------------------------------------------------------------------------