From 087c1f2e6136be6d013da74ff3004a6552edee05 Mon Sep 17 00:00:00 2001 From: Davis King Date: Mon, 18 Dec 2017 15:59:23 -0500 Subject: [PATCH] Added a minor optimization. --- .../global_function_search.cpp | 18 +++++++++++++++++- dlib/test/global_optimization.cpp | 18 +++++++++--------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/dlib/global_optimization/global_function_search.cpp b/dlib/global_optimization/global_function_search.cpp index 0da7a92cd..813527b47 100644 --- a/dlib/global_optimization/global_function_search.cpp +++ b/dlib/global_optimization/global_function_search.cpp @@ -702,7 +702,23 @@ namespace dlib for (auto& info : functions) { const long dims = info->spec.lower.size(); - if (info->ub.num_points() < std::max(3,dims)) + if (info->ub.num_points() < 1) + { + outstanding_function_eval_request new_req; + new_req.request_id = next_request_id++; + // Pick the point right in the center of the bounds to evaluate first since + // people will commonly center the bound on a location they think is good. + // So might as well try there first. + new_req.x = (info->spec.lower + info->spec.upper)/2.0; + for (long i = 0; i < new_req.x.size(); ++i) + { + if (info->spec.is_integer_variable[i]) + new_req.x(i) = std::round(new_req.x(i)); + } + info->outstanding_evals.emplace_back(new_req); + return function_evaluation_request(new_req,info); + } + else if (info->ub.num_points() < std::max(3,dims)) { outstanding_function_eval_request new_req; new_req.request_id = next_request_id++; diff --git a/dlib/test/global_optimization.cpp b/dlib/test/global_optimization.cpp index b343cf3d8..504eba136 100644 --- a/dlib/test/global_optimization.cpp +++ b/dlib/test/global_optimization.cpp @@ -160,29 +160,29 @@ namespace print_spinner(); auto rosen = [](const matrix& x) { return -1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); }; - auto result = find_max_global(rosen, {0, 0}, {2, 2}, max_function_calls(100), 0); + auto result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100), 0); matrix true_x = {1,1}; dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); print_spinner(); - result = find_max_global(rosen, {0, 0}, {2, 2}, max_function_calls(100)); + result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100)); dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); print_spinner(); - result = find_max_global(rosen, {0, 0}, {2, 2}, std::chrono::seconds(5)); + result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, std::chrono::seconds(5)); dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); print_spinner(); - result = find_max_global(rosen, {0, 0}, {2, 2}, {false,false}, max_function_calls(100)); + result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, {false,false}, max_function_calls(100)); dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); print_spinner(); - result = find_max_global(rosen, {0, 0}, {0.9, 0.9}, {false,false}, max_function_calls(100)); + result = find_max_global(rosen, {0.1, 0.1}, {0.9, 0.9}, {false,false}, max_function_calls(140)); true_x = {0.9, 0.81}; dlog << LINFO << "rosen, bounded at 0.9: " << trans(result.x); DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); @@ -221,24 +221,24 @@ namespace print_spinner(); auto rosen = [](const matrix& x) { return +1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); }; - auto result = find_min_global(rosen, {0, 0}, {2, 2}, max_function_calls(100), 0); + auto result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100), 0); matrix true_x = {1,1}; dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); print_spinner(); - result = find_min_global(rosen, {0, 0}, {2, 2}, max_function_calls(100)); + result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100)); dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); print_spinner(); - result = find_min_global(rosen, {0, 0}, {2, 2}, std::chrono::seconds(5)); + result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, std::chrono::seconds(5)); dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); print_spinner(); - result = find_min_global(rosen, {0, 0}, {2, 2}, {false,false}, max_function_calls(100)); + result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, {false,false}, max_function_calls(100)); dlog << LINFO << "rosen: " << trans(result.x); DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); print_spinner();