Small YOLO loss improvements (#2653)

* Small YOLO loss improvements

* Refactor iou_anchor_threshold logic

* Simplify iou_anchor_threshold logic a bit more

* Be more robust in the iou_anchor_threshold check
This commit is contained in:
Adrià Arrufat 2022-08-26 10:35:24 +09:00 committed by GitHub
parent bf273a8c2e
commit 9b8f5d88f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3719,7 +3719,8 @@ namespace dlib
tensor& grad = layer<TAG_TYPE>(sub).get_gradient_input();
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
const rectangle input_rect(input_tensor.nr(), input_tensor.nc());
const drectangle input_rect = rectangle(input_tensor.nr(), input_tensor.nc());
const auto input_area = input_rect.area();
float* g = grad.host();
// Compute the objectness loss for all grid cells
@ -3767,7 +3768,8 @@ namespace dlib
{
if (truth_box.ignore || !input_rect.contains(center(truth_box.rect)))
continue;
const dpoint t_center = dcenter(truth_box);
const auto truth_box_area = truth_box.rect.area();
const auto t_center = center(truth_box.rect);
double best_iou = 0;
size_t best_a = 0;
size_t best_tag_id = 0;
@ -3778,8 +3780,8 @@ namespace dlib
const auto details = item.second;
for (size_t a = 0; a < details.size(); ++a)
{
const yolo_rect anchor(centered_drect(t_center, details[a].width, details[a].height));
const double iou = box_intersection_over_union(truth_box.rect, anchor.rect);
const auto anchor(centered_drect(t_center, details[a].width, details[a].height));
const double iou = box_intersection_over_union(truth_box.rect, anchor);
if (iou > best_iou)
{
best_iou = iou;
@ -3797,67 +3799,65 @@ namespace dlib
for (size_t a = 0; a < anchors.size(); ++a)
{
// Update best anchor if it's from the current stride, and optionally other anchors
if ((best_tag_id == tag_id<TAG_TYPE>::id && best_a == a) || iou_anchor_threshold < 1)
// We will always backpropagate on the best anchor, regardless of its IOU.
// For other anchors, only if they have an IOU >= iou_anchor_threshold.
if (!(best_tag_id == tag_id<TAG_TYPE>::id && best_a == a))
{
if (iou_anchor_threshold >= 1)
continue;
const auto anchor(centered_drect(t_center, anchors[a].width, anchors[a].height));
if (box_intersection_over_union(truth_box.rect, anchor) < iou_anchor_threshold)
continue;
}
// do not update other anchors if they have low IoU
if (!(best_tag_id == tag_id<TAG_TYPE>::id && best_a == a))
const long c = t_center.x() / stride_x;
const long r = t_center.y() / stride_y;
const long k = a * num_feats;
// Get the truth box target values
const double tx = t_center.x() / stride_x - c;
const double ty = t_center.y() / stride_y - r;
const double tw = truth_box.rect.width() / (anchors[a].width + truth_box.rect.width());
const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height());
// Scale regression error according to the truth size
const double scale_box = options.lambda_box * (2.0 - truth_box_area / input_area);
// Compute the smoothed L1 gradient for the box coordinates
const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c);
g[x_idx] = scale_box * put_in_range(-1, 1, (out_data[x_idx] * 2.0 - 0.5 - tx));
g[y_idx] = scale_box * put_in_range(-1, 1, (out_data[y_idx] * 2.0 - 0.5 - ty));
g[w_idx] = scale_box * put_in_range(-1, 1, (out_data[w_idx] - tw));
g[h_idx] = scale_box * put_in_range(-1, 1, (out_data[h_idx] - th));
// This grid cell should detect an object
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
{
const auto p = out_data[o_idx];
const double focus = std::pow(1 - p, options.gamma_obj);
const double g_obj = focus * (options.gamma_obj * p * safe_log(p) + p - 1);
g[o_idx] = options.lambda_obj * g_obj;
}
// Compute the classification error using the truth weights and the focal loss
for (long i = 0; i < num_classes; ++i)
{
const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c);
const auto p = out_data[c_idx];
if (truth_box.label == options.labels[i])
{
const yolo_rect anchor(centered_drect(t_center, anchors[a].width, anchors[a].height));
if (box_intersection_over_union(truth_box.rect, anchor.rect) < iou_anchor_threshold)
continue;
const double focus = std::pow(1 - p, options.gamma_cls);
const double g_cls = focus * (options.gamma_cls * p * safe_log(p) + p - 1);
g[c_idx] = truth_box.detection_confidence * options.lambda_cls * g_cls;
}
const long c = t_center.x() / stride_x;
const long r = t_center.y() / stride_y;
const long k = a * num_feats;
// Get the truth box target values
const double tx = t_center.x() / stride_x - c;
const double ty = t_center.y() / stride_y - r;
const double tw = truth_box.rect.width() / (anchors[a].width + truth_box.rect.width());
const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height());
// Scale regression error according to the truth size
const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area());
// Compute the smoothed L1 gradient for the box coordinates
const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c);
g[x_idx] = scale_box * put_in_range(-1, 1, (out_data[x_idx] * 2.0 - 0.5 - tx));
g[y_idx] = scale_box * put_in_range(-1, 1, (out_data[y_idx] * 2.0 - 0.5 - ty));
g[w_idx] = scale_box * put_in_range(-1, 1, (out_data[w_idx] - tw));
g[h_idx] = scale_box * put_in_range(-1, 1, (out_data[h_idx] - th));
// This grid cell should detect an object
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
else
{
const auto p = out_data[o_idx];
const double focus = std::pow(1 - p, options.gamma_obj);
const double g_obj = focus * (options.gamma_obj * p * safe_log(p) + p - 1);
g[o_idx] = options.lambda_obj * g_obj;
}
// Compute the classification error using the truth weights and the focal loss
for (long i = 0; i < num_classes; ++i)
{
const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c);
const auto p = out_data[c_idx];
if (truth_box.label == options.labels[i])
{
const double focus = std::pow(1 - p, options.gamma_cls);
const double g_cls = focus * (options.gamma_cls * p * safe_log(p) + p - 1);
g[c_idx] = truth_box.detection_confidence * options.lambda_cls * g_cls;
}
else
{
const double focus = std::pow(p, options.gamma_cls);
const double g_cls = focus * (options.gamma_cls * (1 - p) * safe_log(1 - p) + p);
g[c_idx] = options.lambda_cls * g_cls;
}
const double focus = std::pow(p, options.gamma_cls);
const double g_cls = focus * (options.gamma_cls * (1 - p) * safe_log(1 - p) + p);
g[c_idx] = options.lambda_cls * g_cls;
}
}
}