From 7d3fac5502b34b85b58d1ce92b5e4784d34258de Mon Sep 17 00:00:00 2001 From: Davis King Date: Mon, 18 Jun 2018 21:36:36 -0400 Subject: [PATCH] Added a --split-train-test option to imglab. --- tools/imglab/src/main.cpp | 52 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/tools/imglab/src/main.cpp b/tools/imglab/src/main.cpp index 6054b5220..90224600c 100644 --- a/tools/imglab/src/main.cpp +++ b/tools/imglab/src/main.cpp @@ -21,7 +21,7 @@ #include -const char* VERSION = "1.14"; +const char* VERSION = "1.15"; @@ -127,6 +127,44 @@ int split_dataset ( // ---------------------------------------------------------------------------------------- +int make_train_test_splits ( + const command_line_parser& parser +) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --split-train-test option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const double train_frac = get_option(parser, "split-train-test", 0.5); + + dlib::image_dataset_metadata::dataset data, data_train, data_test; + load_image_dataset_metadata(data, parser[0]); + + data_train.name = data.name; + data_train.comment = data.comment; + data_test.name = data.name; + data_test.comment = data.comment; + + const unsigned long num_train_images = static_cast(std::round(train_frac*data.images.size())); + + for (unsigned long i = 0; i < data.images.size(); ++i) + { + if (i < num_train_images) + data_train.images.push_back(data.images[i]); + else + data_test.images.push_back(data.images[i]); + } + + save_image_dataset_metadata(data_train, left_substr(parser[0],".") + "_train.xml"); + save_image_dataset_metadata(data_test, left_substr(parser[0],".") + "_test.xml"); + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + void print_all_labels ( const dlib::image_dataset_metadata::dataset& data ) @@ -545,6 +583,10 @@ int main(int argc, char** argv) parser.add_option("seed", "When using --shuffle, set the random seed to the string .",1); parser.add_option("split", "Split the contents of an XML file into two separate files. One containing the " "images with objects labeled and another file with all the other images. ",1); + parser.add_option("split-train-test", "Split the contents of an XML file into two separate files. A training " + "file containing fraction of the images and a testing file containing the remaining (1-) images. " + "The partitioning is done deterministically by putting the first images in the input xml file into the training split " + "and the later images into the test split.",1); parser.add_option("add", "Add the image metadata from into . If any of the image " "tags are in both files then the ones in are deleted and replaced with the " "image tags from . The results are saved into merged.xml and neither or " @@ -581,7 +623,7 @@ int main(int argc, char** argv) const char* singles[] = {"h","c","r","l","files","convert","parts","rmdiff", "rmtrunc", "rmdupes", "seed", "shuffle", "split", "add", "flip-basic", "flip", "rotate", "tile", "size", "cluster", "resample", "min-object-size", "rmempty", "crop-size", "cropped-object-size", "rmlabel", "rm-other-labels", "rm-if-overlaps", "sort-num-objects", - "one-object-per-image", "jpg", "rmignore", "sort"}; + "one-object-per-image", "jpg", "rmignore", "sort", "split-train-test"}; parser.check_one_time_options(singles); const char* c_sub_ops[] = {"r", "convert"}; parser.check_sub_options("c", c_sub_ops); @@ -676,6 +718,7 @@ int main(int argc, char** argv) parser.check_option_arg_range("min-object-size", 1, 10000*10000); parser.check_option_arg_range("cropped-object-size", 4, 10000*10000); parser.check_option_arg_range("crop-size", 1.0, 100.0); + parser.check_option_arg_range("split-train-test", 0.0, 1.0); if (parser.option("h")) { @@ -1016,6 +1059,11 @@ int main(int argc, char** argv) return split_dataset(parser); } + if (parser.option("split-train-test")) + { + return make_train_test_splits(parser); + } + if (parser.option("shuffle")) { if (parser.number_of_arguments() != 1)