mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Use separate synchronization file for each seg net of each class
This commit is contained in:
parent
2f17289803
commit
030c468003
@ -312,7 +312,8 @@ matrix<uint16_t> keep_only_current_instance(const matrix<rgb_pixel>& rgb_label_i
|
||||
seg_bnet_type train_segmentation_network(
|
||||
const std::vector<image_info>& listing,
|
||||
const std::vector<std::vector<truth_instance>>& truth_instances,
|
||||
unsigned int seg_minibatch_size
|
||||
unsigned int seg_minibatch_size,
|
||||
const std::string& classlabel
|
||||
)
|
||||
{
|
||||
seg_bnet_type seg_net;
|
||||
@ -321,10 +322,15 @@ seg_bnet_type train_segmentation_network(
|
||||
const double weight_decay = 0.0001;
|
||||
const double momentum = 0.9;
|
||||
|
||||
const std::string synchronization_file_name
|
||||
= "pascal_voc2012_seg_trainer_state_file"
|
||||
+ (classlabel.empty() ? "" : ("_" + classlabel))
|
||||
+ ".dat";
|
||||
|
||||
dnn_trainer<seg_bnet_type> seg_trainer(seg_net, sgd(weight_decay, momentum));
|
||||
seg_trainer.be_verbose();
|
||||
seg_trainer.set_learning_rate(initial_learning_rate);
|
||||
seg_trainer.set_synchronization_file("pascal_voc2012_seg_trainer_state_file.dat", std::chrono::minutes(10));
|
||||
seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10));
|
||||
seg_trainer.set_iterations_without_progress_threshold(5000);
|
||||
set_all_bn_running_stats_window_sizes(seg_net, 1000);
|
||||
|
||||
@ -626,10 +632,10 @@ int main(int argc, char** argv) try
|
||||
|
||||
filter_listing(listing, truth_instances, desired_classlabels);
|
||||
|
||||
cout << "images in dataset filtered by class: " << listing.size() << endl << endl;
|
||||
cout << "images in dataset filtered by class: " << listing.size() << endl;
|
||||
|
||||
// First train a detection network (loss_mmod), and then a mask segmentation network (loss_log_per_pixel)
|
||||
cout << "Training detector network:" << endl;
|
||||
cout << endl << "Training detector network:" << endl;
|
||||
const auto det_net = train_detection_network (listing, truth_instances, det_minibatch_size);
|
||||
|
||||
std::map<std::string, seg_bnet_type> seg_nets_by_class;
|
||||
@ -644,16 +650,16 @@ int main(int argc, char** argv) try
|
||||
auto truth_instances_for_classlabel = truth_instances;
|
||||
filter_listing(listing_for_classlabel, truth_instances_for_classlabel, { classlabel });
|
||||
|
||||
cout << "Training segmentation network for class " << classlabel << ":" << endl;
|
||||
cout << endl << "Training segmentation network for class " << classlabel << ":" << endl;
|
||||
seg_nets_by_class[classlabel] = train_segmentation_network(
|
||||
listing_for_classlabel, truth_instances_for_classlabel, seg_minibatch_size
|
||||
listing_for_classlabel, truth_instances_for_classlabel, seg_minibatch_size, classlabel
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cout << "Training a single segmentation network:" << endl;
|
||||
seg_nets_by_class[""] = train_segmentation_network(listing, truth_instances, seg_minibatch_size);
|
||||
seg_nets_by_class[""] = train_segmentation_network(listing, truth_instances, seg_minibatch_size, "");
|
||||
}
|
||||
|
||||
cout << "Saving networks" << endl;
|
||||
|
Loading…
Reference in New Issue
Block a user