mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add input_tensor input type (#2951)
This commit is contained in:
parent
fa0e3ff954
commit
51c7a35979
@ -680,6 +680,14 @@ namespace dlib
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
inline void memcpy (
|
||||||
|
alias_tensor_instance&& dest,
|
||||||
|
const tensor& src
|
||||||
|
)
|
||||||
|
{
|
||||||
|
memcpy(static_cast<tensor&>(dest), src);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // DLIB_DNn_TENSOR_H_
|
#endif // DLIB_DNn_TENSOR_H_
|
||||||
|
@ -607,6 +607,14 @@ namespace dlib
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline void memcpy (
|
||||||
|
alias_tensor_instance&& dest,
|
||||||
|
const tensor& src
|
||||||
|
) { memcpy(static_cast<tensor&>(dest), src); }
|
||||||
|
/*!
|
||||||
|
A convenient overload for copying from src to dest when you have a temporary alias tensor.
|
||||||
|
!*/
|
||||||
|
|
||||||
class alias_tensor_const_instance
|
class alias_tensor_const_instance
|
||||||
{
|
{
|
||||||
/*!
|
/*!
|
||||||
|
@ -1082,6 +1082,93 @@ namespace dlib
|
|||||||
float avg_blue;
|
float avg_blue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class input_tensor
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef tensor input_type;
|
||||||
|
|
||||||
|
input_tensor() {}
|
||||||
|
input_tensor(const input_tensor&) {}
|
||||||
|
|
||||||
|
template<typename forward_iterator>
|
||||||
|
void to_tensor(
|
||||||
|
forward_iterator ibegin,
|
||||||
|
forward_iterator iend,
|
||||||
|
resizable_tensor& data
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(std::distance(ibegin, iend) > 0);
|
||||||
|
const auto k = ibegin->k();
|
||||||
|
const auto nr = ibegin->nr();
|
||||||
|
const auto nc = ibegin->nc();
|
||||||
|
// make sure all the input tensors have the same dimensions
|
||||||
|
for (auto i = ibegin; i != iend; ++i)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(i->k() == k && i->nr() == nr && i->nc() == nc,
|
||||||
|
"\t input_tensor::to_tensor()"
|
||||||
|
<< "\n\t All tensor objects given to to_tensor() must have the same dimensions."
|
||||||
|
<< "\n\t k: " << k
|
||||||
|
<< "\n\t nr: " << nr
|
||||||
|
<< "\n\t nc: " << nc
|
||||||
|
<< "\n\t i->k(): " << i->k()
|
||||||
|
<< "\n\t i->nr(): " << i->nr()
|
||||||
|
<< "\n\t i->nc(): " << i->nc()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto num_samples = count_samples(ibegin, iend);
|
||||||
|
// initialize data to the right size to contain the stuff in the iterator range.
|
||||||
|
data.set_size(num_samples, k, nr, nc);
|
||||||
|
|
||||||
|
const size_t stride = k * nr * nc;
|
||||||
|
size_t offset = 0;
|
||||||
|
for (auto i = ibegin; i != iend; ++i)
|
||||||
|
{
|
||||||
|
alias_tensor slice(i->num_samples(), k, nr, nc);
|
||||||
|
memcpy(slice(data, offset), *i);
|
||||||
|
offset += slice.num_samples() * stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
friend void serialize(const input_tensor&, std::ostream& out)
|
||||||
|
{
|
||||||
|
serialize("input_tensor", out);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend void deserialize(input_tensor&, std::istream& in)
|
||||||
|
{
|
||||||
|
std::string version;
|
||||||
|
deserialize(version, in);
|
||||||
|
if (version != "input_tensor")
|
||||||
|
throw serialization_error("Unexpected version found while deserializing dlib::input_tensor.");
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& out, const input_tensor&)
|
||||||
|
{
|
||||||
|
out << "input_tensor";
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend void to_xml(const input_tensor&, std::ostream& out)
|
||||||
|
{
|
||||||
|
out << "<input_tensor/>\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
template<typename forward_iterator>
|
||||||
|
long long count_samples(
|
||||||
|
forward_iterator ibegin,
|
||||||
|
forward_iterator iend
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
return std::accumulate(ibegin, iend, 0,
|
||||||
|
[](long long a, const auto& b) { return a + b.num_samples(); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -719,6 +719,57 @@ namespace dlib
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class input_tensor
|
||||||
|
{
|
||||||
|
/*!
|
||||||
|
WHAT THIS OBJECT REPRESENTS
|
||||||
|
This input layer works with dlib::tensor objects. It is very similar to
|
||||||
|
the dlib::input layer except that it allows for concatenating data that
|
||||||
|
already resides in GPU memory.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef tensor input_type;
|
||||||
|
|
||||||
|
input_tensor(
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- input_tensor objects are default constructable
|
||||||
|
!*/
|
||||||
|
|
||||||
|
input_tensor(
|
||||||
|
const input_tensor& item
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- input_tensor objects are copy constructable
|
||||||
|
!*/
|
||||||
|
|
||||||
|
template <typename forward_iterator>
|
||||||
|
void to_tensor(
|
||||||
|
forward_iterator ibegin,
|
||||||
|
forward_iterator iend,
|
||||||
|
resizable_tensor& data
|
||||||
|
) const;
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- [ibegin, iend) is an iterator range over input_type objects.
|
||||||
|
- std::distance(ibegin,iend) > 0
|
||||||
|
- The input range should contain tensor objects that all have the same
|
||||||
|
dimensions.
|
||||||
|
ensures
|
||||||
|
- Copies the iterator range into #data. In particular, if the input tensors
|
||||||
|
have R rows, C columns, and K channels then we will have:
|
||||||
|
- #data.num_samples() == count_samples(ibegin,iend)
|
||||||
|
- #data.nr() == R
|
||||||
|
- #data.nc() == C
|
||||||
|
- #data.k() == K
|
||||||
|
This results in a tensor concatenation along the sample dimension.
|
||||||
|
!*/
|
||||||
|
};
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // DLIB_DNn_INPUT_ABSTRACT_H_
|
#endif // DLIB_DNn_INPUT_ABSTRACT_H_
|
||||||
|
@ -4276,6 +4276,38 @@ namespace
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_input_tensor()
|
||||||
|
{
|
||||||
|
using namespace dlib::tt;
|
||||||
|
print_spinner();
|
||||||
|
tt::tensor_rand rnd;
|
||||||
|
std::vector<resizable_tensor> tensors(3);
|
||||||
|
|
||||||
|
for (auto& t : tensors) {
|
||||||
|
t.set_size(1, 3, 224, 224);
|
||||||
|
rnd.fill_gaussian(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
resizable_tensor out;
|
||||||
|
input_tensor input_layer;
|
||||||
|
|
||||||
|
input_layer.to_tensor(tensors.begin(), tensors.end(), out);
|
||||||
|
|
||||||
|
DLIB_TEST(out.num_samples() == 3);
|
||||||
|
DLIB_TEST(out.k() == 3);
|
||||||
|
DLIB_TEST(out.nr() == 224);
|
||||||
|
DLIB_TEST(out.nc() == 224);
|
||||||
|
size_t stride = out.k() * out.nr() * out.nc();
|
||||||
|
size_t offset = 0;
|
||||||
|
int error = 0;
|
||||||
|
|
||||||
|
for (auto& t : tensors) {
|
||||||
|
error = memcmp(out.host() + offset, t.host(), sizeof(float) * t.size());
|
||||||
|
DLIB_TEST(error == 0);
|
||||||
|
offset += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
class dnn_tester : public tester
|
class dnn_tester : public tester
|
||||||
@ -4386,6 +4418,7 @@ namespace
|
|||||||
test_input_ouput_mappers();
|
test_input_ouput_mappers();
|
||||||
test_fuse_layers();
|
test_fuse_layers();
|
||||||
test_reorg();
|
test_reorg();
|
||||||
|
test_input_tensor();
|
||||||
}
|
}
|
||||||
|
|
||||||
void perform_test()
|
void perform_test()
|
||||||
|
Loading…
Reference in New Issue
Block a user