mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
YOLO improvements: Smoothed L1 loss, Focal Loss, Weighted Loss (#2628)
* Add some improvements to the YOLO loss layer Notably: - stabilized bounding box regression loss during training by smoothing the gradient - added the possibility to compute the iou_anchor_threshod adaptively with ATSS (adaptive training sample selection) - make use of the detection_confidence field in the truth box for class balanced loss - added an optional gamma parameter to be able to use the focal loss in the classifier * Added method to return a non-const reference to yolo_options Also added the gamma_cls to loss_details operator<< * Add focal gamma for objectness, too * Fix yolo_options deserialization logic * Remove non-const reference method to yolo_options * use weight only for positive class, and be more consistent with floating types
This commit is contained in:
parent
6dfba4970d
commit
f81c4b2d00
@ -3518,12 +3518,14 @@ namespace dlib
|
||||
double lambda_obj = 1.0;
|
||||
double lambda_box = 1.0;
|
||||
double lambda_cls = 1.0;
|
||||
double gamma_obj = 0.0;
|
||||
double gamma_cls = 0.0;
|
||||
|
||||
};
|
||||
|
||||
inline void serialize(const yolo_options& item, std::ostream& out)
|
||||
{
|
||||
int version = 1;
|
||||
int version = 2;
|
||||
serialize(version, out);
|
||||
serialize(item.anchors, out);
|
||||
serialize(item.labels, out);
|
||||
@ -3534,13 +3536,15 @@ namespace dlib
|
||||
serialize(item.lambda_obj, out);
|
||||
serialize(item.lambda_box, out);
|
||||
serialize(item.lambda_cls, out);
|
||||
serialize(item.gamma_obj, out);
|
||||
serialize(item.gamma_cls, out);
|
||||
}
|
||||
|
||||
inline void deserialize(yolo_options& item, std::istream& in)
|
||||
{
|
||||
int version = 0;
|
||||
deserialize(version, in);
|
||||
if (version != 1)
|
||||
if (!(1 <= version && version <= 2))
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::yolo_options.");
|
||||
deserialize(item.anchors, in);
|
||||
deserialize(item.labels, in);
|
||||
@ -3551,6 +3555,11 @@ namespace dlib
|
||||
deserialize(item.lambda_obj, in);
|
||||
deserialize(item.lambda_box, in);
|
||||
deserialize(item.lambda_cls, in);
|
||||
if (version >= 2)
|
||||
{
|
||||
deserialize(item.gamma_obj, in);
|
||||
deserialize(item.gamma_cls, in);
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const std::map<int, std::vector<yolo_options::anchor_box_details>>& anchors)
|
||||
@ -3655,19 +3664,21 @@ namespace dlib
|
||||
{
|
||||
for (long c = 0; c < output_tensor.nc(); ++c)
|
||||
{
|
||||
const float obj = out_data[tensor_index(output_tensor, n, k + 4, r, c)];
|
||||
const double obj = out_data[tensor_index(output_tensor, n, k + 4, r, c)];
|
||||
if (obj > adjust_threshold)
|
||||
{
|
||||
const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
|
||||
const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
|
||||
const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
|
||||
const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
|
||||
// The scaling and shifting in the x and y coordinates avoids the grid sensitivity
|
||||
// effect by allowing the network to output centers along the grid boundaries.
|
||||
const double x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
|
||||
const double y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
|
||||
const double w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
|
||||
const double h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
|
||||
yolo_rect det(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
|
||||
w / (1 - w) * anchors[a].width,
|
||||
h / (1 - h) * anchors[a].height));
|
||||
for (long i = 0; i < num_classes; ++i)
|
||||
{
|
||||
const float conf = obj * out_data[tensor_index(output_tensor, n, k + 5 + i, r, c)];
|
||||
const double conf = obj * out_data[tensor_index(output_tensor, n, k + 5 + i, r, c)];
|
||||
if (conf > adjust_threshold)
|
||||
det.labels.emplace_back(conf, options.labels[i]);
|
||||
}
|
||||
@ -3719,10 +3730,10 @@ namespace dlib
|
||||
for (size_t a = 0; a < anchors.size(); ++a)
|
||||
{
|
||||
const long k = a * num_feats;
|
||||
const auto x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
|
||||
const auto y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
|
||||
const auto w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
|
||||
const auto h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
|
||||
const double x = out_data[tensor_index(output_tensor, n, k + 0, r, c)] * 2.0 - 0.5;
|
||||
const double y = out_data[tensor_index(output_tensor, n, k + 1, r, c)] * 2.0 - 0.5;
|
||||
const double w = out_data[tensor_index(output_tensor, n, k + 2, r, c)];
|
||||
const double h = out_data[tensor_index(output_tensor, n, k + 3, r, c)];
|
||||
|
||||
// The prediction at r, c for anchor a
|
||||
const yolo_rect pred(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
|
||||
@ -3739,9 +3750,14 @@ namespace dlib
|
||||
}
|
||||
|
||||
// Incur loss for the boxes that are below a certain IoU threshold with any truth box
|
||||
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
|
||||
if (best_iou < options.iou_ignore_threshold)
|
||||
g[o_idx] = options.lambda_obj * out_data[o_idx];
|
||||
{
|
||||
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
|
||||
const double p = out_data[o_idx];
|
||||
const double focus = std::pow(p, options.gamma_obj);
|
||||
const double g_obj = focus * (options.gamma_obj * (1 - p) * safe_log(1 - p) + p);
|
||||
g[o_idx] = options.lambda_obj * g_obj;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3755,6 +3771,7 @@ namespace dlib
|
||||
double best_iou = 0;
|
||||
size_t best_a = 0;
|
||||
size_t best_tag_id = 0;
|
||||
running_stats<double> ious;
|
||||
for (const auto& item : options.anchors)
|
||||
{
|
||||
const auto tag_id = item.first;
|
||||
@ -3768,22 +3785,27 @@ namespace dlib
|
||||
best_iou = iou;
|
||||
best_a = a;
|
||||
best_tag_id = tag_id;
|
||||
ious.add(iou);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ATSS: Adaptive Training Sample Selection
|
||||
double iou_anchor_threshold = options.iou_anchor_threshold;
|
||||
if (iou_anchor_threshold == 0)
|
||||
iou_anchor_threshold = ious.mean() + ious.stddev();
|
||||
|
||||
for (size_t a = 0; a < anchors.size(); ++a)
|
||||
{
|
||||
// Update best anchor if it's from the current stride, and optionally other anchors
|
||||
if ((best_tag_id == tag_id<TAG_TYPE>::id && best_a == a) || options.iou_anchor_threshold < 1)
|
||||
if ((best_tag_id == tag_id<TAG_TYPE>::id && best_a == a) || iou_anchor_threshold < 1)
|
||||
{
|
||||
|
||||
// do not update other anchors if they have low IoU
|
||||
if (!(best_tag_id == tag_id<TAG_TYPE>::id && best_a == a))
|
||||
{
|
||||
const yolo_rect anchor(centered_drect(t_center, anchors[a].width, anchors[a].height));
|
||||
const double iou = box_intersection_over_union(truth_box.rect, anchor.rect);
|
||||
if (iou < options.iou_anchor_threshold)
|
||||
if (box_intersection_over_union(truth_box.rect, anchor.rect) < iou_anchor_threshold)
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -3800,34 +3822,48 @@ namespace dlib
|
||||
// Scale regression error according to the truth size
|
||||
const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area());
|
||||
|
||||
// Compute the gradient for the box coordinates
|
||||
// Compute the smoothed L1 gradient for the box coordinates
|
||||
const auto x_idx = tensor_index(output_tensor, n, k + 0, r, c);
|
||||
const auto y_idx = tensor_index(output_tensor, n, k + 1, r, c);
|
||||
const auto w_idx = tensor_index(output_tensor, n, k + 2, r, c);
|
||||
const auto h_idx = tensor_index(output_tensor, n, k + 3, r, c);
|
||||
g[x_idx] = scale_box * (out_data[x_idx] * 2.0 - 0.5 - tx);
|
||||
g[y_idx] = scale_box * (out_data[y_idx] * 2.0 - 0.5 - ty);
|
||||
g[w_idx] = scale_box * (out_data[w_idx] - tw);
|
||||
g[h_idx] = scale_box * (out_data[h_idx] - th);
|
||||
g[x_idx] = scale_box * put_in_range(-1, 1, (out_data[x_idx] * 2.0 - 0.5 - tx));
|
||||
g[y_idx] = scale_box * put_in_range(-1, 1, (out_data[y_idx] * 2.0 - 0.5 - ty));
|
||||
g[w_idx] = scale_box * put_in_range(-1, 1, (out_data[w_idx] - tw));
|
||||
g[h_idx] = scale_box * put_in_range(-1, 1, (out_data[h_idx] - th));
|
||||
|
||||
// This grid cell should detect an object
|
||||
const auto o_idx = tensor_index(output_tensor, n, k + 4, r, c);
|
||||
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
|
||||
{
|
||||
const auto p = out_data[o_idx];
|
||||
const double focus = std::pow(1 - p, options.gamma_obj);
|
||||
const double g_obj = focus * (options.gamma_obj * p * safe_log(p) + p - 1);
|
||||
g[o_idx] = options.lambda_obj * g_obj;
|
||||
}
|
||||
|
||||
// Compute the classification error
|
||||
// Compute the classification error using the truth weights and the focal loss
|
||||
for (long i = 0; i < num_classes; ++i)
|
||||
{
|
||||
const auto c_idx = tensor_index(output_tensor, n, k + 5 + i, r, c);
|
||||
const auto p = out_data[c_idx];
|
||||
if (truth_box.label == options.labels[i])
|
||||
g[c_idx] = options.lambda_cls * (out_data[c_idx] - 1);
|
||||
{
|
||||
const double focus = std::pow(1 - p, options.gamma_cls);
|
||||
const double g_cls = focus * (options.gamma_cls * p * safe_log(p) + p - 1);
|
||||
g[c_idx] = truth_box.detection_confidence * options.lambda_cls * g_cls;
|
||||
}
|
||||
else
|
||||
g[c_idx] = options.lambda_cls * out_data[c_idx];
|
||||
{
|
||||
const double focus = std::pow(p, options.gamma_cls);
|
||||
const double g_cls = focus * (options.gamma_cls * (1 - p) * safe_log(1 - p) + p);
|
||||
g[c_idx] = options.lambda_cls * g_cls;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the L2 loss
|
||||
// The loss is the squared norm of the gradient
|
||||
loss += length_squared(rowm(mat(grad), n));
|
||||
}
|
||||
};
|
||||
@ -3956,6 +3992,8 @@ namespace dlib
|
||||
out << ", lambda_obj:" << opts.lambda_obj;
|
||||
out << ", lambda_box:" << opts.lambda_box;
|
||||
out << ", lambda_cls:" << opts.lambda_cls;
|
||||
out << ", gamma_obj:" << opts.gamma_obj;
|
||||
out << ", gamma_cls:" << opts.gamma_cls;
|
||||
out << ", overlaps_nms:(" << opts.overlaps_nms.get_iou_thresh() << "," << opts.overlaps_nms.get_percent_covered_thresh() << ")";
|
||||
out << ", classwise_nms:" << std::boolalpha << opts.classwise_nms;
|
||||
out << ")";
|
||||
|
@ -1928,21 +1928,35 @@ namespace dlib
|
||||
// When computing the YOLO loss (objectness + bounding box regression + classification),
|
||||
// the best match between a truth and an anchor is always used, regardless of the IoU.
|
||||
// However, if other anchors have an IoU with a truth box above iou_anchor_threshold, they
|
||||
// will also experience loss against that truth box as well. Setting iou_anchor_threshold to 1 will
|
||||
// make the model use only the best anchor for each ground truth, so other anchors can be
|
||||
// used for other ground truth boxes in the same cell (useful for detecting objects in crowds).
|
||||
// This setting is meant to be used with "high capacity" models, not small ones.
|
||||
// will also experience loss against that truth box as well. Setting iou_anchor_threshold
|
||||
// to 1 will make the model use only the best anchor for each ground truth, so other
|
||||
// anchors can be used for other ground truth boxes in the same cell (useful for detecting
|
||||
// objects in crowds). This setting is meant to be used with "high capacity" models, not
|
||||
// small ones. Additionaly, when this value is set to 0, it will adaptively compute the
|
||||
// IOU threshold based on the statistics of the IOUS from all anchors with the current
|
||||
// target truth box. In particular, it follows the adaptive training sample selection
|
||||
// (ATSS) from the paper: "Bridging the Gap Between Anchor-based and Anchor-free Detection
|
||||
// via Adaptive Training Sample Selection" by Shifeng Zhang, et al.
|
||||
// (https://arxiv.org/abs/1912.02424)
|
||||
double iou_anchor_threshold = 1.0;
|
||||
// When doing non-max suppression, we use overlaps_nms to decide if a box overlaps
|
||||
// an already output detection and should therefore be thrown out.
|
||||
test_box_overlap overlaps_nms = test_box_overlap(0.45, 1.0);
|
||||
// When set to true, NMS will only be applied between objects with the same class label.
|
||||
bool classwise_nms = true;
|
||||
// These parameters control how we penalize different kinds of mistakes: notably the objectness loss,
|
||||
// the box (bounding box regression) loss, and the classification loss.
|
||||
// These parameters control how we penalize different kinds of mistakes: notably the
|
||||
// objectness loss, the box (bounding box regression) loss, and the classification loss.
|
||||
double lambda_obj = 1.0;
|
||||
double lambda_box = 1.0;
|
||||
double lambda_cls = 1.0;
|
||||
// This parameter makes YOLO behave like the Focal loss, presented in the paper:
|
||||
// "Focal Loss for Dense Object Detection", by Tsung-Yi Lin, et al.
|
||||
// (https://arxiv.org/abs/1708.02002)
|
||||
// The gamma_obj and gamma_cls act as a modulating factor to the cross-entropy layers of
|
||||
// objectness and classification by reducing the relative loss for well-classified
|
||||
// examples, and focusing on the difficult ones.
|
||||
double gamma_obj = 0.0;
|
||||
double gamma_cls = 0.0;
|
||||
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user