Made input_layer() work in a more reasonable and general way.

This commit is contained in:
Davis King 2016-09-05 09:16:44 -04:00
parent a105c616d0
commit 70619d2fd6
3 changed files with 38 additions and 6 deletions

View File

@ -2617,12 +2617,43 @@ namespace dlib
// ----------------------------------------------------------------------------------------
namespace dimpl
{
template <typename T>
T& get_input_details (
T& net
)
{
return net;
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
const dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
}
template <typename net_type>
auto input_layer (
net_type& net
) -> decltype(layer<net_type::num_layers-1>(net))&
) -> decltype(dimpl::get_input_details(layer<net_type::num_layers-1>(net)))&
{
return layer<net_type::num_layers-1>(net);
// Calling input_layer() on a subnet_wrapper is a little funny since the behavior of
// .subnet() returns another subnet_wrapper rather than an input details object as it
// does in add_layer.
return dimpl::get_input_details(layer<net_type::num_layers-1>(net));
}
// ----------------------------------------------------------------------------------------

View File

@ -1396,6 +1396,7 @@ namespace dlib
- returns the input later of the given network object. Specifically, this
function is equivalent to calling:
layer<net_type::num_layers-1>(net);
That is, you get the input layer details object for the network.
!*/
// ----------------------------------------------------------------------------------------

View File

@ -725,7 +725,7 @@ namespace dlib
{
dpoint p = output_tensor_to_input_tensor(net, point(c,r));
drectangle rect = centered_drect(p, options.detector_width, options.detector_height);
rect = input_layer(net).layer_details().tensor_space_to_image_space(input_tensor,rect);
rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect);
dets_accum.push_back(intermediate_detection(rect, score, r*output_tensor.nc() + c));
}
@ -743,7 +743,7 @@ namespace dlib
) const
{
using namespace std;
if (!input_layer(net).layer_details().image_contained_point(input_tensor,center(rect)))
if (!input_layer(net).image_contained_point(input_tensor,center(rect)))
{
std::ostringstream sout;
sout << "Encountered a truth rectangle located at " << rect << " that is outside the image." << endl;
@ -757,12 +757,12 @@ namespace dlib
// it means the box can't be matched by the sliding window. But picking the
// max causes the right error message to be selected in the logic below.
const double scale = std::max(options.detector_width/(double)rect.width(), options.detector_height/(double)rect.height());
const rectangle mapped_rect = input_layer(net).layer_details().image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect);
const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect);
// compute the detection window that we would use at this position.
point tensor_p = center(mapped_rect);
rectangle det_window = centered_rect(tensor_p, options.detector_width,options.detector_height);
det_window = input_layer(net).layer_details().tensor_space_to_image_space(input_tensor, det_window);
det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window);
// make sure the rect can actually be represented by the image pyramid we are
// using.