diff --git a/examples/dnn_instance_segmentation_train_ex.cpp b/examples/dnn_instance_segmentation_train_ex.cpp index 633da585d..1f3ef9d92 100644 --- a/examples/dnn_instance_segmentation_train_ex.cpp +++ b/examples/dnn_instance_segmentation_train_ex.cpp @@ -312,7 +312,8 @@ matrix keep_only_current_instance(const matrix& rgb_label_i seg_bnet_type train_segmentation_network( const std::vector& listing, const std::vector>& truth_instances, - unsigned int seg_minibatch_size + unsigned int seg_minibatch_size, + const std::string& classlabel ) { seg_bnet_type seg_net; @@ -321,10 +322,15 @@ seg_bnet_type train_segmentation_network( const double weight_decay = 0.0001; const double momentum = 0.9; + const std::string synchronization_file_name + = "pascal_voc2012_seg_trainer_state_file" + + (classlabel.empty() ? "" : ("_" + classlabel)) + + ".dat"; + dnn_trainer seg_trainer(seg_net, sgd(weight_decay, momentum)); seg_trainer.be_verbose(); seg_trainer.set_learning_rate(initial_learning_rate); - seg_trainer.set_synchronization_file("pascal_voc2012_seg_trainer_state_file.dat", std::chrono::minutes(10)); + seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10)); seg_trainer.set_iterations_without_progress_threshold(5000); set_all_bn_running_stats_window_sizes(seg_net, 1000); @@ -626,10 +632,10 @@ int main(int argc, char** argv) try filter_listing(listing, truth_instances, desired_classlabels); - cout << "images in dataset filtered by class: " << listing.size() << endl << endl; + cout << "images in dataset filtered by class: " << listing.size() << endl; // First train a detection network (loss_mmod), and then a mask segmentation network (loss_log_per_pixel) - cout << "Training detector network:" << endl; + cout << endl << "Training detector network:" << endl; const auto det_net = train_detection_network (listing, truth_instances, det_minibatch_size); std::map seg_nets_by_class; @@ -644,16 +650,16 @@ int main(int argc, char** argv) try auto truth_instances_for_classlabel = truth_instances; filter_listing(listing_for_classlabel, truth_instances_for_classlabel, { classlabel }); - cout << "Training segmentation network for class " << classlabel << ":" << endl; + cout << endl << "Training segmentation network for class " << classlabel << ":" << endl; seg_nets_by_class[classlabel] = train_segmentation_network( - listing_for_classlabel, truth_instances_for_classlabel, seg_minibatch_size + listing_for_classlabel, truth_instances_for_classlabel, seg_minibatch_size, classlabel ); } } else { cout << "Training a single segmentation network:" << endl; - seg_nets_by_class[""] = train_segmentation_network(listing, truth_instances, seg_minibatch_size); + seg_nets_by_class[""] = train_segmentation_network(listing, truth_instances, seg_minibatch_size, ""); } cout << "Saving networks" << endl;