mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Made the optimizer_state serializable.
This commit is contained in:
parent
52e35c31fb
commit
a121906fd4
@ -327,6 +327,45 @@ namespace dlib
|
||||
std::vector<long> index;
|
||||
long dims;
|
||||
dlib::rand rnd;
|
||||
|
||||
public:
|
||||
friend void serialize(const optimizer_state& item, std::ostream& out)
|
||||
{
|
||||
const int version = 1;
|
||||
dlib::serialize(version, out);
|
||||
dlib::serialize(item.did_init, out);
|
||||
dlib::serialize(item.have_bias, out);
|
||||
dlib::serialize(item.last_weight_1, out);
|
||||
dlib::serialize(item.alpha, out);
|
||||
dlib::serialize(item.w, out);
|
||||
dlib::serialize(item.Q, out);
|
||||
dlib::serialize(item.index, out);
|
||||
dlib::serialize(item.dims, out);
|
||||
dlib::serialize(item.rnd, out);
|
||||
}
|
||||
|
||||
friend void deserialize(optimizer_state& item, std::istream& in)
|
||||
{
|
||||
int version = 0;
|
||||
dlib::deserialize(version, in);
|
||||
if (version != 1)
|
||||
{
|
||||
throw dlib::serialization_error(
|
||||
"Error while deserializing dlib::svm_c_linear_dcd_trainer::optimizer_state, unexpected version."
|
||||
);
|
||||
}
|
||||
|
||||
dlib::deserialize(item.did_init, in);
|
||||
dlib::deserialize(item.have_bias, in);
|
||||
dlib::deserialize(item.last_weight_1, in);
|
||||
dlib::deserialize(item.alpha, in);
|
||||
dlib::deserialize(item.w, in);
|
||||
dlib::deserialize(item.Q, in);
|
||||
dlib::deserialize(item.index, in);
|
||||
dlib::deserialize(item.dims, in);
|
||||
dlib::deserialize(item.rnd, in);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
|
Loading…
Reference in New Issue
Block a user