mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
cleaned this up a little
This commit is contained in:
parent
29964d2858
commit
90c9d0be6e
@ -73,8 +73,8 @@ void deserialize(feature_extractor&, std::istream&) {}
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void make_dataset (
|
||||
const matrix<double>& emission_probabilities,
|
||||
const matrix<double>& transition_probabilities,
|
||||
const matrix<double>& emission_probabilities,
|
||||
std::vector<std::vector<unsigned long> >& samples,
|
||||
std::vector<std::vector<unsigned long> >& labels,
|
||||
unsigned long dataset_size
|
||||
@ -90,8 +90,10 @@ void make_dataset (
|
||||
- This function randomly samples a bunch of sequences from the HMM defined by
|
||||
transition_probabilities and emission_probabilities.
|
||||
- The HMM is defined by:
|
||||
- P(next_label |previous_label) == transition_probabilities(previous_label, next_label)
|
||||
- P(next_sample|next_label) == emission_probabilities (next_label, next_sample)
|
||||
- The probability of transitioning from hidden state H1 to H2
|
||||
is given by transition_probabilities(H1,H2).
|
||||
- The probability of a hidden state H producing an observed state
|
||||
O is given by emission_probabilities(H,O).
|
||||
- #samples.size() == labels.size() == dataset_size
|
||||
- for all valid i:
|
||||
- #labels[i] is a randomly sampled sequence of hidden states from the
|
||||
@ -103,6 +105,10 @@ void make_dataset (
|
||||
|
||||
int main()
|
||||
{
|
||||
matrix<double> transition_probabilities(num_label_states, num_label_states);
|
||||
transition_probabilities = 0.05, 0.90, 0.05,
|
||||
0.05, 0.05, 0.90,
|
||||
0.90, 0.05, 0.05;
|
||||
|
||||
// set this up so emission_probabilities(L,X) == The probability of a state with label L
|
||||
// emitting an X.
|
||||
@ -111,17 +117,11 @@ int main()
|
||||
0.0, 0.5, 0.5,
|
||||
0.5, 0.0, 0.5;
|
||||
|
||||
matrix<double> transition_probabilities(num_label_states, num_label_states);
|
||||
|
||||
transition_probabilities = 0.05, 0.90, 0.05,
|
||||
0.05, 0.05, 0.90,
|
||||
0.90, 0.05, 0.05;
|
||||
|
||||
|
||||
|
||||
std::vector<std::vector<unsigned long> > samples;
|
||||
std::vector<std::vector<unsigned long> > labels;
|
||||
make_dataset(emission_probabilities, transition_probabilities,
|
||||
make_dataset(transition_probabilities,emission_probabilities,
|
||||
samples, labels, 1000);
|
||||
|
||||
cout << "samples.size(): "<< samples.size() << endl;
|
||||
@ -139,17 +139,19 @@ int main()
|
||||
trainer.set_num_threads(4);
|
||||
|
||||
|
||||
matrix<double> confusion_matrix;
|
||||
|
||||
// Learn to do sequence labeling from the dataset
|
||||
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
|
||||
confusion_matrix = test_sequence_labeler(labeler, samples, labels);
|
||||
cout << "trained sequence labeler: " << endl;
|
||||
cout << confusion_matrix;
|
||||
cout << "label accuracy: "<< sum(diag(confusion_matrix))/sum(confusion_matrix) << endl;
|
||||
|
||||
std::vector<unsigned long> predicted_labels = labeler(samples[0]);
|
||||
cout << "true hidden states: "<< trans(vector_to_matrix(labels[0]));
|
||||
cout << "predicted hidden states: "<< trans(vector_to_matrix(predicted_labels));
|
||||
|
||||
|
||||
|
||||
|
||||
// We can also do cross-validation
|
||||
matrix<double> confusion_matrix;
|
||||
confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4);
|
||||
cout << "\ncross-validation: " << endl;
|
||||
cout << confusion_matrix;
|
||||
@ -236,8 +238,8 @@ void sample_hmm (
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void make_dataset (
|
||||
const matrix<double>& emission_probabilities,
|
||||
const matrix<double>& transition_probabilities,
|
||||
const matrix<double>& emission_probabilities,
|
||||
std::vector<std::vector<unsigned long> >& samples,
|
||||
std::vector<std::vector<unsigned long> >& labels,
|
||||
unsigned long dataset_size
|
||||
|
Loading…
Reference in New Issue
Block a user