Replaced sizeof... with variadic templates

This commit is contained in:
Fm 2016-05-29 17:21:42 +03:00
parent 1974e68d31
commit 01b3b08be6

View File

@ -1844,6 +1844,7 @@ namespace dlib
};
template <template<typename> class TAG_TYPE>
struct concat_helper_impl<TAG_TYPE>{
constexpr static size_t tag_count() {return 1;}
template<typename SUBNET>
static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
{
@ -1865,6 +1866,9 @@ namespace dlib
};
template <template<typename> class TAG_TYPE, template<typename> class... TAG_TYPES>
struct concat_helper_impl<TAG_TYPE, TAG_TYPES...>{
constexpr static size_t tag_count() {return 1 + concat_helper_impl<TAG_TYPES...>::tag_count();}
template<typename SUBNET>
static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
{
@ -1896,6 +1900,8 @@ namespace dlib
class concat_
{
public:
constexpr static size_t tag_count() {return impl::concat_helper_impl<TAG_TYPES...>::tag_count();};
template <typename SUBNET>
void setup (const SUBNET&)
{
@ -1924,7 +1930,8 @@ namespace dlib
friend void serialize(const concat_& item, std::ostream& out)
{
serialize("concat_", out);
serialize(sizeof...(TAG_TYPES), out);
size_t count = tag_count();
serialize(count, out);
}
friend void deserialize(concat_& item, std::istream& in)
@ -1935,15 +1942,16 @@ namespace dlib
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_.");
size_t count_tags;
deserialize(count_tags, in);
if (count_tags != sizeof...(TAG_TYPES))
if (count_tags != tag_count())
throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " +
std::to_string(sizeof...(TAG_TYPES)) + " found while deserializing dlib::concat_.");
std::to_string(tag_count()) +
" found while deserializing dlib::concat_.");
}
friend std::ostream& operator<<(std::ostream& out, const concat_& item)
{
out << "concat\t ("
<< sizeof...(TAG_TYPES)
<< tag_count()
<< ")";
return out;
}