diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index cfb2402ea..bb017d1dc 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -850,14 +850,14 @@ namespace dlib { const float temp = log1pexp(-out_data[idx]); const float focus = std::pow(1 - g[idx], gamma); - loss += y * scale * temp; + loss += y * scale * temp * focus; g[idx] = y * scale * focus * (g[idx] * (gamma * temp + 1) - 1); } else { const float temp = -(-out_data[idx] - log1pexp(-out_data[idx])); const float focus = std::pow(g[idx], gamma); - loss += -y * scale * temp; + loss += -y * scale * temp * focus; g[idx] = -y * scale * focus * g[idx] * (gamma * temp + 1); } }