diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 4e97f9609..355ca6e3d 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -3719,7 +3719,8 @@ namespace dlib tensor& grad = layer(sub).get_gradient_input(); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); - const rectangle input_rect(input_tensor.nr(), input_tensor.nc()); + const drectangle input_rect = rectangle(input_tensor.nr(), input_tensor.nc()); + const auto input_area = input_rect.area(); float* g = grad.host(); // Compute the objectness loss for all grid cells @@ -3767,7 +3768,8 @@ namespace dlib { if (truth_box.ignore || !input_rect.contains(center(truth_box.rect))) continue; - const dpoint t_center = dcenter(truth_box); + const auto truth_box_area = truth_box.rect.area(); + const auto t_center = center(truth_box.rect); double best_iou = 0; size_t best_a = 0; size_t best_tag_id = 0; @@ -3778,8 +3780,8 @@ namespace dlib const auto details = item.second; for (size_t a = 0; a < details.size(); ++a) { - const yolo_rect anchor(centered_drect(t_center, details[a].width, details[a].height)); - const double iou = box_intersection_over_union(truth_box.rect, anchor.rect); + const auto anchor(centered_drect(t_center, details[a].width, details[a].height)); + const double iou = box_intersection_over_union(truth_box.rect, anchor); if (iou > best_iou) { best_iou = iou; @@ -3797,67 +3799,65 @@ namespace dlib for (size_t a = 0; a < anchors.size(); ++a) { - // Update best anchor if it's from the current stride, and optionally other anchors - if ((best_tag_id == tag_id::id && best_a == a) || iou_anchor_threshold < 1) + // We will always backpropagate on the best anchor, regardless of its IOU. + // For other anchors, only if they have an IOU >= iou_anchor_threshold. + if (!(best_tag_id == tag_id::id && best_a == a)) { + if (iou_anchor_threshold >= 1) + continue; + const auto anchor(centered_drect(t_center, anchors[a].width, anchors[a].height)); + if (box_intersection_over_union(truth_box.rect, anchor) < iou_anchor_threshold) + continue; + } - // do not update other anchors if they have low IoU - if (!(best_tag_id == tag_id::id && best_a == a)) + const long c = t_center.x() / stride_x; + const long r = t_center.y() / stride_y; + const long k = a * num_feats; + + // Get the truth box target values + const double tx = t_center.x() / stride_x - c; + const double ty = t_center.y() / stride_y - r; + const double tw = truth_box.rect.width() / (anchors[a].width + truth_box.rect.width()); + const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height()); + + // Scale regression error according to the truth size + const double scale_box = options.lambda_box * (2.0 - truth_box_area / input_area); + + // Compute the smoothed L1 gradient for the box coordinates + const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c); + const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c); + const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c); + const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c); + g[x_idx] = scale_box * put_in_range(-1, 1, (out_data[x_idx] * 2.0 - 0.5 - tx)); + g[y_idx] = scale_box * put_in_range(-1, 1, (out_data[y_idx] * 2.0 - 0.5 - ty)); + g[w_idx] = scale_box * put_in_range(-1, 1, (out_data[w_idx] - tw)); + g[h_idx] = scale_box * put_in_range(-1, 1, (out_data[h_idx] - th)); + + // This grid cell should detect an object + const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c); + { + const auto p = out_data[o_idx]; + const double focus = std::pow(1 - p, options.gamma_obj); + const double g_obj = focus * (options.gamma_obj * p * safe_log(p) + p - 1); + g[o_idx] = options.lambda_obj * g_obj; + } + + // Compute the classification error using the truth weights and the focal loss + for (long i = 0; i < num_classes; ++i) + { + const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c); + const auto p = out_data[c_idx]; + if (truth_box.label == options.labels[i]) { - const yolo_rect anchor(centered_drect(t_center, anchors[a].width, anchors[a].height)); - if (box_intersection_over_union(truth_box.rect, anchor.rect) < iou_anchor_threshold) - continue; + const double focus = std::pow(1 - p, options.gamma_cls); + const double g_cls = focus * (options.gamma_cls * p * safe_log(p) + p - 1); + g[c_idx] = truth_box.detection_confidence * options.lambda_cls * g_cls; } - - const long c = t_center.x() / stride_x; - const long r = t_center.y() / stride_y; - const long k = a * num_feats; - - // Get the truth box target values - const double tx = t_center.x() / stride_x - c; - const double ty = t_center.y() / stride_y - r; - const double tw = truth_box.rect.width() / (anchors[a].width + truth_box.rect.width()); - const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height()); - - // Scale regression error according to the truth size - const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area()); - - // Compute the smoothed L1 gradient for the box coordinates - const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c); - const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c); - const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c); - const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c); - g[x_idx] = scale_box * put_in_range(-1, 1, (out_data[x_idx] * 2.0 - 0.5 - tx)); - g[y_idx] = scale_box * put_in_range(-1, 1, (out_data[y_idx] * 2.0 - 0.5 - ty)); - g[w_idx] = scale_box * put_in_range(-1, 1, (out_data[w_idx] - tw)); - g[h_idx] = scale_box * put_in_range(-1, 1, (out_data[h_idx] - th)); - - // This grid cell should detect an object - const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c); + else { - const auto p = out_data[o_idx]; - const double focus = std::pow(1 - p, options.gamma_obj); - const double g_obj = focus * (options.gamma_obj * p * safe_log(p) + p - 1); - g[o_idx] = options.lambda_obj * g_obj; - } - - // Compute the classification error using the truth weights and the focal loss - for (long i = 0; i < num_classes; ++i) - { - const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c); - const auto p = out_data[c_idx]; - if (truth_box.label == options.labels[i]) - { - const double focus = std::pow(1 - p, options.gamma_cls); - const double g_cls = focus * (options.gamma_cls * p * safe_log(p) + p - 1); - g[c_idx] = truth_box.detection_confidence * options.lambda_cls * g_cls; - } - else - { - const double focus = std::pow(p, options.gamma_cls); - const double g_cls = focus * (options.gamma_cls * (1 - p) * safe_log(1 - p) + p); - g[c_idx] = options.lambda_cls * g_cls; - } + const double focus = std::pow(p, options.gamma_cls); + const double g_cls = focus * (options.gamma_cls * (1 - p) * safe_log(1 - p) + p); + g[c_idx] = options.lambda_cls * g_cls; } } }