mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
improved example a little
This commit is contained in:
parent
02b6328704
commit
29964d2858
@ -67,8 +67,127 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
void serialize(const feature_extractor&, std::ostream&) {}
|
||||
void deserialize(feature_extractor&, std::istream&) {}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void make_dataset (
|
||||
const matrix<double>& emission_probabilities,
|
||||
const matrix<double>& transition_probabilities,
|
||||
std::vector<std::vector<unsigned long> >& samples,
|
||||
std::vector<std::vector<unsigned long> >& labels,
|
||||
unsigned long dataset_size
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- transition_probabilities.nr() == transition_probabilities.nc()
|
||||
- transition_probabilities.nr() == emission_probabilities.nr()
|
||||
- The rows of transition_probabilities and emission_probabilities must sum to 1.
|
||||
(i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities)
|
||||
must evaluate to vectors of all 1s.)
|
||||
ensures
|
||||
- This function randomly samples a bunch of sequences from the HMM defined by
|
||||
transition_probabilities and emission_probabilities.
|
||||
- The HMM is defined by:
|
||||
- P(next_label |previous_label) == transition_probabilities(previous_label, next_label)
|
||||
- P(next_sample|next_label) == emission_probabilities (next_label, next_sample)
|
||||
- #samples.size() == labels.size() == dataset_size
|
||||
- for all valid i:
|
||||
- #labels[i] is a randomly sampled sequence of hidden states from the
|
||||
given HMM. #samples[i] is its corresponding randomly sampled sequence
|
||||
of observed states.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main()
|
||||
{
|
||||
|
||||
// set this up so emission_probabilities(L,X) == The probability of a state with label L
|
||||
// emitting an X.
|
||||
matrix<double> emission_probabilities(num_label_states,num_sample_states);
|
||||
emission_probabilities = 0.5, 0.5, 0.0,
|
||||
0.0, 0.5, 0.5,
|
||||
0.5, 0.0, 0.5;
|
||||
|
||||
matrix<double> transition_probabilities(num_label_states, num_label_states);
|
||||
|
||||
transition_probabilities = 0.05, 0.90, 0.05,
|
||||
0.05, 0.05, 0.90,
|
||||
0.90, 0.05, 0.05;
|
||||
|
||||
|
||||
|
||||
std::vector<std::vector<unsigned long> > samples;
|
||||
std::vector<std::vector<unsigned long> > labels;
|
||||
make_dataset(emission_probabilities, transition_probabilities,
|
||||
samples, labels, 1000);
|
||||
|
||||
cout << "samples.size(): "<< samples.size() << endl;
|
||||
|
||||
// print out some of the randomly sampled sequences
|
||||
for (int i = 0; i < 10; ++i)
|
||||
{
|
||||
cout << "hidden states: " << trans(vector_to_matrix(labels[i]));
|
||||
cout << "observed states: " << trans(vector_to_matrix(samples[i]));
|
||||
cout << "******************************" << endl;
|
||||
}
|
||||
|
||||
structural_sequence_labeling_trainer<feature_extractor> trainer;
|
||||
trainer.set_c(4);
|
||||
trainer.set_num_threads(4);
|
||||
|
||||
|
||||
matrix<double> confusion_matrix;
|
||||
|
||||
// Learn to do sequence labeling from the dataset
|
||||
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
|
||||
confusion_matrix = test_sequence_labeler(labeler, samples, labels);
|
||||
cout << "trained sequence labeler: " << endl;
|
||||
cout << confusion_matrix;
|
||||
cout << "label accuracy: "<< sum(diag(confusion_matrix))/sum(confusion_matrix) << endl;
|
||||
|
||||
|
||||
// We can also do cross-validation
|
||||
confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4);
|
||||
cout << "\ncross-validation: " << endl;
|
||||
cout << confusion_matrix;
|
||||
cout << "label accuracy: "<< sum(diag(confusion_matrix))/sum(confusion_matrix) << endl;
|
||||
|
||||
|
||||
|
||||
matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities),
|
||||
reshape_to_column_vector(emission_probabilities)));
|
||||
|
||||
sequence_labeler<feature_extractor> labeler_true(feature_extractor(), true_hmm_model_weights);
|
||||
|
||||
confusion_matrix = test_sequence_labeler(labeler_true, samples, labels);
|
||||
cout << "\nTrue HMM model: " << endl;
|
||||
cout << confusion_matrix;
|
||||
cout << "label accuracy: "<< sum(diag(confusion_matrix))/sum(confusion_matrix) << endl;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// Finally, the labeler can be serialized to disk just like most dlib objects.
|
||||
ofstream fout("labeler.dat", ios::binary);
|
||||
serialize(labeler, fout);
|
||||
fout.close();
|
||||
|
||||
// recall from disk
|
||||
ifstream fin("labeler.dat", ios::binary);
|
||||
deserialize(labeler, fin);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// Code for creating a bunch of random samples from our HMM.
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void sample_hmm (
|
||||
dlib::rand& rnd,
|
||||
@ -78,7 +197,26 @@ void sample_hmm (
|
||||
unsigned long& next_label,
|
||||
unsigned long& next_sample
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
- previous_label < transition_probabilities.nr()
|
||||
- transition_probabilities.nr() == transition_probabilities.nc()
|
||||
- transition_probabilities.nr() == emission_probabilities.nr()
|
||||
- The rows of transition_probabilities and emission_probabilities must sum to 1.
|
||||
(i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities)
|
||||
must evaluate to vectors of all 1s.)
|
||||
ensures
|
||||
- This function randomly samples the HMM defined by transition_probabilities
|
||||
and emission_probabilities assuming that the previous hidden state
|
||||
was previous_label.
|
||||
- The HMM is defined by:
|
||||
- P(next_label |previous_label) == transition_probabilities(previous_label, next_label)
|
||||
- P(next_sample|next_label) == emission_probabilities (next_label, next_sample)
|
||||
- #next_label == the sampled value of the hidden state
|
||||
- #next_sample == the sampled value of the observed state
|
||||
!*/
|
||||
{
|
||||
// sample next_label
|
||||
double p = rnd.get_random_double();
|
||||
for (long c = 0; p >= 0 && c < transition_probabilities.nc(); ++c)
|
||||
{
|
||||
@ -86,7 +224,7 @@ void sample_hmm (
|
||||
p -= transition_probabilities(previous_label, c);
|
||||
}
|
||||
|
||||
|
||||
// now sample next_sample
|
||||
p = rnd.get_random_double();
|
||||
for (long c = 0; p >= 0 && c < emission_probabilities.nc(); ++c)
|
||||
{
|
||||
@ -104,10 +242,6 @@ void make_dataset (
|
||||
std::vector<std::vector<unsigned long> >& labels,
|
||||
unsigned long dataset_size
|
||||
)
|
||||
/*!
|
||||
2 kinds of label
|
||||
3 kinds of input state
|
||||
!*/
|
||||
{
|
||||
samples.clear();
|
||||
labels.clear();
|
||||
@ -117,9 +251,9 @@ void make_dataset (
|
||||
// now randomly sample some labeled sequences from our Hidden Markov Model
|
||||
for (unsigned long iter = 0; iter < dataset_size; ++iter)
|
||||
{
|
||||
const unsigned long size = rnd.get_random_32bit_number()%20+3;
|
||||
std::vector<unsigned long> sample(size);
|
||||
std::vector<unsigned long> label(size);
|
||||
const unsigned long sequence_size = rnd.get_random_32bit_number()%20+3;
|
||||
std::vector<unsigned long> sample(sequence_size);
|
||||
std::vector<unsigned long> label(sequence_size);
|
||||
|
||||
unsigned long previous_label = rnd.get_random_32bit_number()%num_label_states;
|
||||
for (unsigned long i = 0; i < sample.size(); ++i)
|
||||
@ -141,64 +275,3 @@ void make_dataset (
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main()
|
||||
{
|
||||
std::vector<std::vector<unsigned long> > samples;
|
||||
std::vector<std::vector<unsigned long> > labels;
|
||||
|
||||
// set this up so emission_probabilities(L,X) == The probability of a state with label L
|
||||
// emitting an X.
|
||||
matrix<double> emission_probabilities(num_label_states,num_sample_states);
|
||||
emission_probabilities = 0.5, 0.5, 0.0,
|
||||
0.0, 0.5, 0.5,
|
||||
0.5, 0.0, 0.5;
|
||||
|
||||
matrix<double> transition_probabilities(num_label_states, num_label_states);
|
||||
|
||||
transition_probabilities = 0.05, 0.90, 0.05,
|
||||
0.05, 0.05, 0.90,
|
||||
0.90, 0.05, 0.05;
|
||||
|
||||
|
||||
make_dataset(emission_probabilities, transition_probabilities,
|
||||
samples, labels, 1000);
|
||||
|
||||
cout << "samples.size(): "<< samples.size() << endl;
|
||||
|
||||
for (int i = 0; i < 10; ++i)
|
||||
{
|
||||
cout << trans(vector_to_matrix(labels[i]));
|
||||
cout << trans(vector_to_matrix(samples[i]));
|
||||
cout << "******************************" << endl;
|
||||
}
|
||||
|
||||
structural_sequence_labeling_trainer<feature_extractor> trainer;
|
||||
trainer.set_c(1000);
|
||||
trainer.set_num_threads(4);
|
||||
//trainer.be_verbose();
|
||||
|
||||
//sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
|
||||
//cout << labeler.get_weights() << endl;
|
||||
|
||||
matrix<double> cm;
|
||||
|
||||
cm = cross_validate_sequence_labeler(trainer, samples, labels, 4);
|
||||
//cm = test_sequence_labeler(labeler, samples, labels);
|
||||
cout << cm << endl;
|
||||
cout << "label accuracy: "<< sum(diag(cm))/sum(cm) << endl;
|
||||
|
||||
|
||||
|
||||
matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities),
|
||||
reshape_to_column_vector(emission_probabilities)));
|
||||
|
||||
sequence_labeler<feature_extractor> labeler_true(feature_extractor(), true_hmm_model_weights);
|
||||
|
||||
cout << endl;
|
||||
cm = test_sequence_labeler(labeler_true, samples, labels);
|
||||
cout << cm << endl;
|
||||
cout << "label accuracy: "<< sum(diag(cm))/sum(cm) << endl;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user