mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Fully setup the functional python interface to the sequence segmenter tool.
Need to add documentation next.
This commit is contained in:
parent
bb0f764ca8
commit
66d5a906bb
@ -12,7 +12,7 @@ using namespace dlib;
|
||||
using namespace std;
|
||||
using namespace boost::python;
|
||||
|
||||
typedef matrix<double,0,1> sample_type;
|
||||
typedef matrix<double,0,1> dense_vect;
|
||||
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;
|
||||
typedef std::vector<std::pair<unsigned long, unsigned long> > ranges;
|
||||
|
||||
@ -33,7 +33,7 @@ public:
|
||||
unsigned long _window_size;
|
||||
|
||||
segmenter_feature_extractor(
|
||||
) : _num_features(0), _window_size(0) {}
|
||||
) : _num_features(1), _window_size(1) {}
|
||||
|
||||
segmenter_feature_extractor(
|
||||
unsigned long _num_features_,
|
||||
@ -49,7 +49,7 @@ public:
|
||||
template <typename feature_setter>
|
||||
void get_features (
|
||||
feature_setter& set_feature,
|
||||
const std::vector<sample_type>& x,
|
||||
const std::vector<dense_vect>& x,
|
||||
unsigned long position
|
||||
) const
|
||||
{
|
||||
@ -88,17 +88,54 @@ public:
|
||||
|
||||
struct segmenter_type
|
||||
{
|
||||
segmenter_type() : mode(0)
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This the object that python will use directly to represent a
|
||||
sequence_segmenter. All it does is contain all the possible template
|
||||
instantiations of a sequence_segmenter and invoke the right one depending on
|
||||
the mode variable.
|
||||
!*/
|
||||
|
||||
segmenter_type() : mode(-1)
|
||||
{ }
|
||||
|
||||
ranges segment_sequence (
|
||||
const std::vector<sample_type>& x
|
||||
ranges segment_sequence_dense (
|
||||
const std::vector<dense_vect>& x
|
||||
) const
|
||||
{
|
||||
return ranges();
|
||||
switch (mode)
|
||||
{
|
||||
case 0: return segmenter0(x);
|
||||
case 1: return segmenter1(x);
|
||||
case 2: return segmenter2(x);
|
||||
case 3: return segmenter3(x);
|
||||
case 4: return segmenter4(x);
|
||||
case 5: return segmenter5(x);
|
||||
case 6: return segmenter6(x);
|
||||
case 7: return segmenter7(x);
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
}
|
||||
|
||||
const matrix<double,0,1>& get_weights()
|
||||
ranges segment_sequence_sparse (
|
||||
const std::vector<sparse_vect>& x
|
||||
) const
|
||||
{
|
||||
switch (mode)
|
||||
{
|
||||
case 8: return segmenter8(x);
|
||||
case 9: return segmenter9(x);
|
||||
case 10: return segmenter10(x);
|
||||
case 11: return segmenter11(x);
|
||||
case 12: return segmenter12(x);
|
||||
case 13: return segmenter13(x);
|
||||
case 14: return segmenter14(x);
|
||||
case 15: return segmenter15(x);
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
}
|
||||
|
||||
const matrix<double,0,1> get_weights()
|
||||
{
|
||||
switch(mode)
|
||||
{
|
||||
@ -110,6 +147,17 @@ struct segmenter_type
|
||||
case 5: return segmenter5.get_weights();
|
||||
case 6: return segmenter6.get_weights();
|
||||
case 7: return segmenter7.get_weights();
|
||||
|
||||
case 8: return segmenter8.get_weights();
|
||||
case 9: return segmenter9.get_weights();
|
||||
case 10: return segmenter10.get_weights();
|
||||
case 11: return segmenter11.get_weights();
|
||||
case 12: return segmenter12.get_weights();
|
||||
case 13: return segmenter13.get_weights();
|
||||
case 14: return segmenter14.get_weights();
|
||||
case 15: return segmenter15.get_weights();
|
||||
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,6 +174,16 @@ struct segmenter_type
|
||||
case 5: serialize(item.segmenter5, out); break;
|
||||
case 6: serialize(item.segmenter6, out); break;
|
||||
case 7: serialize(item.segmenter7, out); break;
|
||||
|
||||
case 8: serialize(item.segmenter8, out); break;
|
||||
case 9: serialize(item.segmenter9, out); break;
|
||||
case 10: serialize(item.segmenter10, out); break;
|
||||
case 11: serialize(item.segmenter11, out); break;
|
||||
case 12: serialize(item.segmenter12, out); break;
|
||||
case 13: serialize(item.segmenter13, out); break;
|
||||
case 14: serialize(item.segmenter14, out); break;
|
||||
case 15: serialize(item.segmenter15, out); break;
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
}
|
||||
friend void deserialize (segmenter_type& item, std::istream& in)
|
||||
@ -141,19 +199,29 @@ struct segmenter_type
|
||||
case 5: deserialize(item.segmenter5, in); break;
|
||||
case 6: deserialize(item.segmenter6, in); break;
|
||||
case 7: deserialize(item.segmenter7, in); break;
|
||||
|
||||
case 8: deserialize(item.segmenter8, in); break;
|
||||
case 9: deserialize(item.segmenter9, in); break;
|
||||
case 10: deserialize(item.segmenter10, in); break;
|
||||
case 11: deserialize(item.segmenter11, in); break;
|
||||
case 12: deserialize(item.segmenter12, in); break;
|
||||
case 13: deserialize(item.segmenter13, in); break;
|
||||
case 14: deserialize(item.segmenter14, in); break;
|
||||
case 15: deserialize(item.segmenter15, in); break;
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
}
|
||||
|
||||
int mode;
|
||||
|
||||
typedef segmenter_feature_extractor<sample_type, true, true, true> fe0;
|
||||
typedef segmenter_feature_extractor<sample_type, true, true, false> fe1;
|
||||
typedef segmenter_feature_extractor<sample_type, true, false,true> fe2;
|
||||
typedef segmenter_feature_extractor<sample_type, true, false,false> fe3;
|
||||
typedef segmenter_feature_extractor<sample_type, false,true, true> fe4;
|
||||
typedef segmenter_feature_extractor<sample_type, false,true, false> fe5;
|
||||
typedef segmenter_feature_extractor<sample_type, false,false,true> fe6;
|
||||
typedef segmenter_feature_extractor<sample_type, false,false,false> fe7;
|
||||
typedef segmenter_feature_extractor<dense_vect, false,false,false> fe0;
|
||||
typedef segmenter_feature_extractor<dense_vect, false,false,true> fe1;
|
||||
typedef segmenter_feature_extractor<dense_vect, false,true, false> fe2;
|
||||
typedef segmenter_feature_extractor<dense_vect, false,true, true> fe3;
|
||||
typedef segmenter_feature_extractor<dense_vect, true, false,false> fe4;
|
||||
typedef segmenter_feature_extractor<dense_vect, true, false,true> fe5;
|
||||
typedef segmenter_feature_extractor<dense_vect, true, true, false> fe6;
|
||||
typedef segmenter_feature_extractor<dense_vect, true, true, true> fe7;
|
||||
sequence_segmenter<fe0> segmenter0;
|
||||
sequence_segmenter<fe1> segmenter1;
|
||||
sequence_segmenter<fe2> segmenter2;
|
||||
@ -163,14 +231,14 @@ struct segmenter_type
|
||||
sequence_segmenter<fe6> segmenter6;
|
||||
sequence_segmenter<fe7> segmenter7;
|
||||
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, true, true> fe8;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, true, false> fe9;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, false,true> fe10;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, false,false> fe11;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,true, true> fe12;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,true, false> fe13;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,false,true> fe14;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,false,false> fe15;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,false,false> fe8;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,false,true> fe9;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,true, false> fe10;
|
||||
typedef segmenter_feature_extractor<sparse_vect, false,true, true> fe11;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, false,false> fe12;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, false,true> fe13;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, true, false> fe14;
|
||||
typedef segmenter_feature_extractor<sparse_vect, true, true, true> fe15;
|
||||
sequence_segmenter<fe8> segmenter8;
|
||||
sequence_segmenter<fe9> segmenter9;
|
||||
sequence_segmenter<fe10> segmenter10;
|
||||
@ -195,6 +263,7 @@ struct segmenter_params
|
||||
num_threads = 4;
|
||||
epsilon = 0.1;
|
||||
max_cache_size = 40;
|
||||
be_verbose = false;
|
||||
C = 100;
|
||||
}
|
||||
|
||||
@ -209,11 +278,77 @@ struct segmenter_params
|
||||
double C;
|
||||
};
|
||||
|
||||
|
||||
string segmenter_params__str__(const segmenter_params& p)
|
||||
{
|
||||
ostringstream sout;
|
||||
if (p.use_BIO_model)
|
||||
sout << "BIO,";
|
||||
else
|
||||
sout << "BILOU,";
|
||||
|
||||
if (p.use_high_order_features)
|
||||
sout << "highFeats,";
|
||||
else
|
||||
sout << "lowFeats,";
|
||||
|
||||
if (p.allow_negative_weights)
|
||||
sout << "signed,";
|
||||
else
|
||||
sout << "non-negative,";
|
||||
|
||||
sout << "win="<<p.window_size << ",";
|
||||
sout << "threads="<<p.num_threads << ",";
|
||||
sout << "eps="<<p.epsilon << ",";
|
||||
sout << "cache="<<p.max_cache_size << ",";
|
||||
if (p.be_verbose)
|
||||
sout << "verbose,";
|
||||
else
|
||||
sout << "non-verbose,";
|
||||
sout << "C="<<p.C;
|
||||
return trim(sout.str());
|
||||
}
|
||||
|
||||
string segmenter_params__repr__(const segmenter_params& p)
|
||||
{
|
||||
ostringstream sout;
|
||||
sout << "<";
|
||||
sout << segmenter_params__str__(p);
|
||||
sout << ">";
|
||||
return sout.str();
|
||||
}
|
||||
|
||||
void serialize ( const segmenter_params& item, std::ostream& out)
|
||||
{
|
||||
serialize(item.use_BIO_model, out);
|
||||
serialize(item.use_high_order_features, out);
|
||||
serialize(item.allow_negative_weights, out);
|
||||
serialize(item.window_size, out);
|
||||
serialize(item.num_threads, out);
|
||||
serialize(item.epsilon, out);
|
||||
serialize(item.max_cache_size, out);
|
||||
serialize(item.be_verbose, out);
|
||||
serialize(item.C, out);
|
||||
}
|
||||
|
||||
void deserialize (segmenter_params& item, std::istream& in)
|
||||
{
|
||||
deserialize(item.use_BIO_model, in);
|
||||
deserialize(item.use_high_order_features, in);
|
||||
deserialize(item.allow_negative_weights, in);
|
||||
deserialize(item.window_size, in);
|
||||
deserialize(item.num_threads, in);
|
||||
deserialize(item.epsilon, in);
|
||||
deserialize(item.max_cache_size, in);
|
||||
deserialize(item.be_verbose, in);
|
||||
deserialize(item.C, in);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
void configure_trainer (
|
||||
const std::vector<std::vector<sample_type> >& samples,
|
||||
const std::vector<std::vector<dense_vect> >& samples,
|
||||
structural_sequence_segmentation_trainer<T>& trainer,
|
||||
const segmenter_params& params
|
||||
)
|
||||
@ -233,8 +368,35 @@ void configure_trainer (
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
void configure_trainer (
|
||||
const std::vector<std::vector<sparse_vect> >& samples,
|
||||
structural_sequence_segmentation_trainer<T>& trainer,
|
||||
const segmenter_params& params
|
||||
)
|
||||
{
|
||||
pyassert(samples.size() != 0, "Invalid arguments. You must give some training sequences.");
|
||||
pyassert(samples[0].size() != 0, "Invalid arguments. You can't have zero length training sequences.");
|
||||
|
||||
unsigned long dims = 0;
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
dims = std::max(dims, max_index_plus_one(samples[i]));
|
||||
}
|
||||
|
||||
trainer = structural_sequence_segmentation_trainer<T>(T(dims, params.window_size));
|
||||
trainer.set_num_threads(params.num_threads);
|
||||
trainer.set_epsilon(params.epsilon);
|
||||
trainer.set_max_cache_size(params.max_cache_size);
|
||||
trainer.set_c(params.C);
|
||||
if (params.be_verbose)
|
||||
trainer.be_verbose();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
segmenter_type train_dense (
|
||||
const std::vector<std::vector<sample_type> >& samples,
|
||||
const std::vector<std::vector<dense_vect> >& samples,
|
||||
const std::vector<ranges>& segments,
|
||||
segmenter_params params
|
||||
)
|
||||
@ -255,6 +417,7 @@ segmenter_type train_dense (
|
||||
else
|
||||
mode = mode*2;
|
||||
|
||||
|
||||
segmenter_type res;
|
||||
res.mode = mode;
|
||||
switch(mode)
|
||||
@ -291,6 +454,76 @@ segmenter_type train_dense (
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter7 = trainer.train(samples, segments);
|
||||
} break;
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
segmenter_type train_sparse (
|
||||
const std::vector<std::vector<sparse_vect> >& samples,
|
||||
const std::vector<ranges>& segments,
|
||||
segmenter_params params
|
||||
)
|
||||
{
|
||||
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
|
||||
|
||||
int mode = 0;
|
||||
if (params.use_BIO_model)
|
||||
mode = mode*2 + 1;
|
||||
else
|
||||
mode = mode*2;
|
||||
if (params.use_high_order_features)
|
||||
mode = mode*2 + 1;
|
||||
else
|
||||
mode = mode*2;
|
||||
if (params.allow_negative_weights)
|
||||
mode = mode*2 + 1;
|
||||
else
|
||||
mode = mode*2;
|
||||
|
||||
mode += 8;
|
||||
|
||||
segmenter_type res;
|
||||
res.mode = mode;
|
||||
switch(mode)
|
||||
{
|
||||
case 8: { structural_sequence_segmentation_trainer<segmenter_type::fe8> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter8 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 9: { structural_sequence_segmentation_trainer<segmenter_type::fe9> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter9 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 10: { structural_sequence_segmentation_trainer<segmenter_type::fe10> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter10 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 11: { structural_sequence_segmentation_trainer<segmenter_type::fe11> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter11 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 12: { structural_sequence_segmentation_trainer<segmenter_type::fe12> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter12 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 13: { structural_sequence_segmentation_trainer<segmenter_type::fe13> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter13 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 14: { structural_sequence_segmentation_trainer<segmenter_type::fe14> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter14 = trainer.train(samples, segments);
|
||||
} break;
|
||||
case 15: { structural_sequence_segmentation_trainer<segmenter_type::fe15> trainer;
|
||||
configure_trainer(samples, trainer, params);
|
||||
res.segmenter15 = trainer.train(samples, segments);
|
||||
} break;
|
||||
default: throw dlib::error("Invalid mode");
|
||||
}
|
||||
|
||||
|
||||
@ -304,21 +537,27 @@ void bind_sequence_segmenter()
|
||||
class_<segmenter_params>("segmenter_params",
|
||||
"This class is used to define all the optional parameters to the \n\
|
||||
train_sequence_segmenter() routine. ")
|
||||
.add_property("use_BIO_model", &segmenter_params::use_BIO_model)
|
||||
.add_property("use_high_order_features", &segmenter_params::use_high_order_features)
|
||||
.add_property("allow_negative_weights", &segmenter_params::allow_negative_weights)
|
||||
.add_property("window_size", &segmenter_params::window_size)
|
||||
.add_property("num_threads", &segmenter_params::num_threads)
|
||||
.add_property("epsilon", &segmenter_params::epsilon)
|
||||
.add_property("max_cache_size", &segmenter_params::max_cache_size)
|
||||
.add_property("C", &segmenter_params::C);
|
||||
.def_readwrite("use_BIO_model", &segmenter_params::use_BIO_model)
|
||||
.def_readwrite("use_high_order_features", &segmenter_params::use_high_order_features)
|
||||
.def_readwrite("allow_negative_weights", &segmenter_params::allow_negative_weights)
|
||||
.def_readwrite("window_size", &segmenter_params::window_size)
|
||||
.def_readwrite("num_threads", &segmenter_params::num_threads)
|
||||
.def_readwrite("epsilon", &segmenter_params::epsilon)
|
||||
.def_readwrite("max_cache_size", &segmenter_params::max_cache_size)
|
||||
.def_readwrite("C", &segmenter_params::C, "SVM C parameter")
|
||||
.def("__repr__",&segmenter_params__repr__)
|
||||
.def("__str__",&segmenter_params__str__)
|
||||
.def_pickle(serialize_pickle<segmenter_params>());
|
||||
|
||||
class_<segmenter_type> ("segmenter_type")
|
||||
.def("segment_sequence", &segmenter_type::segment_sequence)
|
||||
.def("segment_sequence", &segmenter_type::segment_sequence_dense)
|
||||
.def("segment_sequence", &segmenter_type::segment_sequence_sparse)
|
||||
.def_readonly("weights", &segmenter_type::get_weights)
|
||||
.def_pickle(serialize_pickle<segmenter_type>());
|
||||
|
||||
using boost::python::arg;
|
||||
def("train_sequence_segmenter", train_dense, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
|
||||
def("train_sequence_segmenter", train_sparse, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user