|
|
|
@ -137,49 +137,6 @@ namespace dlib
|
|
|
|
|
mean = X*mean;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
inline std::pair<double,double> equal_error_rate (
|
|
|
|
|
const std::vector<double>& low_vals,
|
|
|
|
|
const std::vector<double>& high_vals
|
|
|
|
|
)
|
|
|
|
|
{
|
|
|
|
|
std::vector<std::pair<double,int> > temp;
|
|
|
|
|
temp.reserve(low_vals.size()+high_vals.size());
|
|
|
|
|
for (unsigned long i = 0; i < low_vals.size(); ++i)
|
|
|
|
|
temp.push_back(std::make_pair(low_vals[i], -1));
|
|
|
|
|
for (unsigned long i = 0; i < high_vals.size(); ++i)
|
|
|
|
|
temp.push_back(std::make_pair(high_vals[i], +1));
|
|
|
|
|
|
|
|
|
|
std::sort(temp.begin(), temp.end());
|
|
|
|
|
|
|
|
|
|
if (temp.size() == 0)
|
|
|
|
|
return std::make_pair(0,0);
|
|
|
|
|
|
|
|
|
|
double thresh = temp[0].first;
|
|
|
|
|
|
|
|
|
|
unsigned long num_low_wrong = low_vals.size();
|
|
|
|
|
unsigned long num_high_wrong = 0;
|
|
|
|
|
double low_error = num_low_wrong/(double)low_vals.size();
|
|
|
|
|
double high_error = num_high_wrong/(double)high_vals.size();
|
|
|
|
|
for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i)
|
|
|
|
|
{
|
|
|
|
|
thresh = temp[i].first;
|
|
|
|
|
if (temp[i].second > 0)
|
|
|
|
|
{
|
|
|
|
|
num_high_wrong++;
|
|
|
|
|
high_error = num_high_wrong/(double)high_vals.size();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
num_low_wrong--;
|
|
|
|
|
low_error = num_low_wrong/(double)low_vals.size();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::make_pair((low_error+high_error)/2, thresh);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
struct roc_point
|
|
|
|
@ -199,10 +156,15 @@ namespace dlib
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double,int> > temp;
|
|
|
|
|
temp.reserve(true_detections.size()+false_detections.size());
|
|
|
|
|
// We use -1 for true labels and +1 for false so when we call std::sort() below it will sort
|
|
|
|
|
// runs with equal detection scores so true come first. This will avoid it seeming like we
|
|
|
|
|
// can separate true from false when scores are equal in the loop below.
|
|
|
|
|
const int true_label = -1;
|
|
|
|
|
const int false_label = +1;
|
|
|
|
|
for (unsigned long i = 0; i < true_detections.size(); ++i)
|
|
|
|
|
temp.push_back(std::make_pair(true_detections[i], +1));
|
|
|
|
|
temp.push_back(std::make_pair(true_detections[i], true_label));
|
|
|
|
|
for (unsigned long i = 0; i < false_detections.size(); ++i)
|
|
|
|
|
temp.push_back(std::make_pair(false_detections[i], -1));
|
|
|
|
|
temp.push_back(std::make_pair(false_detections[i], false_label));
|
|
|
|
|
|
|
|
|
|
std::sort(temp.rbegin(), temp.rend());
|
|
|
|
|
|
|
|
|
@ -214,7 +176,7 @@ namespace dlib
|
|
|
|
|
double num_true_included = 0;
|
|
|
|
|
for (unsigned long i = 0; i < temp.size(); ++i)
|
|
|
|
|
{
|
|
|
|
|
if (temp[i].second > 0)
|
|
|
|
|
if (temp[i].second == true_label)
|
|
|
|
|
num_true_included++;
|
|
|
|
|
else
|
|
|
|
|
num_false_included++;
|
|
|
|
@ -229,6 +191,39 @@ namespace dlib
|
|
|
|
|
return roc_curve;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
inline std::pair<double,double> equal_error_rate (
|
|
|
|
|
const std::vector<double>& low_vals,
|
|
|
|
|
const std::vector<double>& high_vals
|
|
|
|
|
)
|
|
|
|
|
{
|
|
|
|
|
if (low_vals.size() == 0 && high_vals.size() == 0)
|
|
|
|
|
return std::make_pair(0,0);
|
|
|
|
|
else if (low_vals.size() == 0)
|
|
|
|
|
return std::make_pair(0, min(mat(high_vals)));
|
|
|
|
|
else if (high_vals.size() == 0)
|
|
|
|
|
return std::make_pair(0, max(mat(low_vals))+1);
|
|
|
|
|
|
|
|
|
|
// Find the point of equal error rates
|
|
|
|
|
double best_thresh = 0;
|
|
|
|
|
double best_error = 0;
|
|
|
|
|
double best_delta = std::numeric_limits<double>::infinity();
|
|
|
|
|
for (const auto& pt : compute_roc_curve(high_vals, low_vals))
|
|
|
|
|
{
|
|
|
|
|
const double false_negative_rate = 1-pt.true_positive_rate;
|
|
|
|
|
const double delta = std::abs(false_negative_rate - pt.false_positive_rate);
|
|
|
|
|
if (delta < best_delta)
|
|
|
|
|
{
|
|
|
|
|
best_delta = delta;
|
|
|
|
|
best_error = std::max(false_negative_rate, pt.false_positive_rate);
|
|
|
|
|
best_thresh = pt.detection_threshold;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::make_pair(best_error, best_thresh);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|