YOLO improvements: Smoothed L1 loss, Focal Loss, Weighted Loss (#2628)

* Add some improvements to the YOLO loss layer

Notably:
- stabilized bounding box regression loss during training by smoothing the gradient
- added the possibility to compute the iou_anchor_threshod adaptively with ATSS
  (adaptive training sample selection)
- make use of the detection_confidence field in the truth box for class balanced loss
- added an optional gamma parameter to be able to use the focal loss in the classifier

* Added method to return a non-const reference to yolo_options

Also added the gamma_cls to loss_details operator<<

* Add focal gamma for objectness, too

* Fix yolo_options deserialization logic

* Remove non-const reference method to yolo_options

* use weight only for positive class, and be more consistent with floating types
This commit is contained in:
Adrià Arrufat 2022-07-23 23:54:28 +09:00 committed by GitHub
parent 6dfba4970d
commit f81c4b2d00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 33 deletions

View File

@ -3518,12 +3518,14 @@ namespace dlib
double lambda_obj = 1.0;
double lambda_box = 1.0;
double lambda_cls = 1.0;
double gamma_obj = 0.0;
double gamma_cls = 0.0;
};
inline void serialize(const yolo_options& item, std::ostream& out)
{
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.anchors, out);
serialize(item.labels, out);
@ -3534,13 +3536,15 @@ namespace dlib
serialize(item.lambda_obj, out);
serialize(item.lambda_box, out);
serialize(item.lambda_cls, out);
serialize(item.gamma_obj, out);
serialize(item.gamma_cls, out);
}
inline void deserialize(yolo_options& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
if (!(1 <= version && version <= 2))
throw serialization_error("Unexpected version found while deserializing dlib::yolo_options.");
deserialize(item.anchors, in);
deserialize(item.labels, in);
@ -3551,6 +3555,11 @@ namespace dlib
deserialize(item.lambda_obj, in);
deserialize(item.lambda_box, in);
deserialize(item.lambda_cls, in);
if (version >= 2)
{
deserialize(item.gamma_obj, in);
deserialize(item.gamma_cls, in);
}
}
inline std::ostream& operator<<(std::ostream& out, const std::map<int, std::vector<yolo_options::anchor_box_details>>& anchors)
@ -3655,19 +3664,21 @@ namespace dlib
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
const float obj = out_data[tensor_index(output_tensor, n, k + 4, r, c)];
const double obj = out_data[tensor_index(output_tensor, n, k + 4, r, c)];
if (obj > adjust_threshold)
{
const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
// The scaling and shifting in the x and y coordinates avoids the grid sensitivity
// effect by allowing the network to output centers along the grid boundaries.
const double x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const double y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const double w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const double h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
yolo_rect det(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
w / (1 - w) * anchors[a].width,
h / (1 - h) * anchors[a].height));
for (long i = 0; i < num_classes; ++i)
{
const float conf = obj * out_data[tensor_index(output_tensor, n, k + 5 + i, r, c)];
const double conf = obj * out_data[tensor_index(output_tensor, n, k + 5 + i, r, c)];
if (conf > adjust_threshold)
det.labels.emplace_back(conf, options.labels[i]);
}
@ -3719,10 +3730,10 @@ namespace dlib
for (size_t a = 0; a < anchors.size(); ++a)
{
const long k = a * num_feats;
const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
const double x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
const double y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
const double w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
const double h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
// The prediction at r, c for anchor a
const yolo_rect pred(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
@ -3739,9 +3750,14 @@ namespace dlib
}
// Incur loss for the boxes that are below a certain IoU threshold with any truth box
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
if (best_iou < options.iou_ignore_threshold)
g[o_idx] = options.lambda_obj * out_data[o_idx];
{
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
const double p = out_data[o_idx];
const double focus = std::pow(p, options.gamma_obj);
const double g_obj = focus * (options.gamma_obj * (1 - p) * safe_log(1 - p) + p);
g[o_idx] = options.lambda_obj * g_obj;
}
}
}
}
@ -3755,6 +3771,7 @@ namespace dlib
double best_iou = 0;
size_t best_a = 0;
size_t best_tag_id = 0;
running_stats<double> ious;
for (const auto& item : options.anchors)
{
const auto tag_id = item.first;
@ -3768,22 +3785,27 @@ namespace dlib
best_iou = iou;
best_a = a;
best_tag_id = tag_id;
ious.add(iou);
}
}
}
// ATSS: Adaptive Training Sample Selection
double iou_anchor_threshold = options.iou_anchor_threshold;
if (iou_anchor_threshold == 0)
iou_anchor_threshold = ious.mean() + ious.stddev();
for (size_t a = 0; a < anchors.size(); ++a)
{
// Update best anchor if it's from the current stride, and optionally other anchors
if ((best_tag_id == tag_id<TAG_TYPE>::id && best_a == a) || options.iou_anchor_threshold < 1)
if ((best_tag_id == tag_id<TAG_TYPE>::id && best_a == a) || iou_anchor_threshold < 1)
{
// do not update other anchors if they have low IoU
if (!(best_tag_id == tag_id<TAG_TYPE>::id && best_a == a))
{
const yolo_rect anchor(centered_drect(t_center, anchors[a].width, anchors[a].height));
const double iou = box_intersection_over_union(truth_box.rect, anchor.rect);
if (iou < options.iou_anchor_threshold)
if (box_intersection_over_union(truth_box.rect, anchor.rect) < iou_anchor_threshold)
continue;
}
@ -3800,34 +3822,48 @@ namespace dlib
// Scale regression error according to the truth size
const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area());
// Compute the gradient for the box coordinates
// Compute the smoothed L1 gradient for the box coordinates
const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c);
g[x_idx] = scale_box * (out_data[x_idx] * 2.0 - 0.5 - tx);
g[y_idx] = scale_box * (out_data[y_idx] * 2.0 - 0.5 - ty);
g[w_idx] = scale_box * (out_data[w_idx] - tw);
g[h_idx] = scale_box * (out_data[h_idx] - th);
g[x_idx] = scale_box * put_in_range(-1, 1, (out_data[x_idx] * 2.0 - 0.5 - tx));
g[y_idx] = scale_box * put_in_range(-1, 1, (out_data[y_idx] * 2.0 - 0.5 - ty));
g[w_idx] = scale_box * put_in_range(-1, 1, (out_data[w_idx] - tw));
g[h_idx] = scale_box * put_in_range(-1, 1, (out_data[h_idx] - th));
// This grid cell should detect an object
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
{
const auto p = out_data[o_idx];
const double focus = std::pow(1 - p, options.gamma_obj);
const double g_obj = focus * (options.gamma_obj * p * safe_log(p) + p - 1);
g[o_idx] = options.lambda_obj * g_obj;
}
// Compute the classification error
// Compute the classification error using the truth weights and the focal loss
for (long i = 0; i < num_classes; ++i)
{
const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c);
const auto p = out_data[c_idx];
if (truth_box.label == options.labels[i])
g[c_idx] = options.lambda_cls * (out_data[c_idx] - 1);
{
const double focus = std::pow(1 - p, options.gamma_cls);
const double g_cls = focus * (options.gamma_cls * p * safe_log(p) + p - 1);
g[c_idx] = truth_box.detection_confidence * options.lambda_cls * g_cls;
}
else
g[c_idx] = options.lambda_cls * out_data[c_idx];
{
const double focus = std::pow(p, options.gamma_cls);
const double g_cls = focus * (options.gamma_cls * (1 - p) * safe_log(1 - p) + p);
g[c_idx] = options.lambda_cls * g_cls;
}
}
}
}
}
// Compute the L2 loss
// The loss is the squared norm of the gradient
loss += length_squared(rowm(mat(grad), n));
}
};
@ -3956,6 +3992,8 @@ namespace dlib
out << ", lambda_obj:" << opts.lambda_obj;
out << ", lambda_box:" << opts.lambda_box;
out << ", lambda_cls:" << opts.lambda_cls;
out << ", gamma_obj:" << opts.gamma_obj;
out << ", gamma_cls:" << opts.gamma_cls;
out << ", overlaps_nms:(" << opts.overlaps_nms.get_iou_thresh() << "," << opts.overlaps_nms.get_percent_covered_thresh() << ")";
out << ", classwise_nms:" << std::boolalpha << opts.classwise_nms;
out << ")";

View File

@ -1928,21 +1928,35 @@ namespace dlib
// When computing the YOLO loss (objectness + bounding box regression + classification),
// the best match between a truth and an anchor is always used, regardless of the IoU.
// However, if other anchors have an IoU with a truth box above iou_anchor_threshold, they
// will also experience loss against that truth box as well. Setting iou_anchor_threshold to 1 will
// make the model use only the best anchor for each ground truth, so other anchors can be
// used for other ground truth boxes in the same cell (useful for detecting objects in crowds).
// This setting is meant to be used with "high capacity" models, not small ones.
// will also experience loss against that truth box as well. Setting iou_anchor_threshold
// to 1 will make the model use only the best anchor for each ground truth, so other
// anchors can be used for other ground truth boxes in the same cell (useful for detecting
// objects in crowds). This setting is meant to be used with "high capacity" models, not
// small ones. Additionaly, when this value is set to 0, it will adaptively compute the
// IOU threshold based on the statistics of the IOUS from all anchors with the current
// target truth box. In particular, it follows the adaptive training sample selection
// (ATSS) from the paper: "Bridging the Gap Between Anchor-based and Anchor-free Detection
// via Adaptive Training Sample Selection" by Shifeng Zhang, et al.
// (https://arxiv.org/abs/1912.02424)
double iou_anchor_threshold = 1.0;
// When doing non-max suppression, we use overlaps_nms to decide if a box overlaps
// an already output detection and should therefore be thrown out.
test_box_overlap overlaps_nms = test_box_overlap(0.45, 1.0);
// When set to true, NMS will only be applied between objects with the same class label.
bool classwise_nms = true;
// These parameters control how we penalize different kinds of mistakes: notably the objectness loss,
// the box (bounding box regression) loss, and the classification loss.
// These parameters control how we penalize different kinds of mistakes: notably the
// objectness loss, the box (bounding box regression) loss, and the classification loss.
double lambda_obj = 1.0;
double lambda_box = 1.0;
double lambda_cls = 1.0;
// This parameter makes YOLO behave like the Focal loss, presented in the paper:
// "Focal Loss for Dense Object Detection", by Tsung-Yi Lin, et al.
// (https://arxiv.org/abs/1708.02002)
// The gamma_obj and gamma_cls act as a modulating factor to the cross-entropy layers of
// objectness and classification by reducing the relative loss for well-classified
// examples, and focusing on the difficult ones.
double gamma_obj = 0.0;
double gamma_cls = 0.0;
};