mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
add loss multiclass log weighted (#2022)
* add loss_multiclass_log_weighted * fix class name in loss_abstract * add loss_multiclass_log_weighted test * rename test function to match class name * fix typo * reuse the weighted label struct across weighted losses * do not break compatibility with loss_multiclass_log_per_pixel_weighted * actually test the loss and fix docs * fix build with gcc 9
This commit is contained in:
parent
9185a925ce
commit
1380e6b95f
149
dlib/dnn/loss.h
149
dlib/dnn/loss.h
@ -367,6 +367,140 @@ namespace dlib
|
||||
template <typename SUBNET>
|
||||
using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename label_type>
|
||||
struct weighted_label
|
||||
{
|
||||
weighted_label()
|
||||
{}
|
||||
|
||||
weighted_label(label_type label, float weight = 1.f)
|
||||
: label(label), weight(weight)
|
||||
{}
|
||||
|
||||
label_type label{};
|
||||
float weight = 1.f;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class loss_multiclass_log_weighted_
|
||||
{
|
||||
public:
|
||||
|
||||
typedef weighted_label<unsigned long> training_label_type;
|
||||
typedef unsigned long output_label_type;
|
||||
|
||||
template <
|
||||
typename SUB_TYPE,
|
||||
typename label_iterator
|
||||
>
|
||||
void to_label (
|
||||
const tensor& input_tensor,
|
||||
const SUB_TYPE& sub,
|
||||
label_iterator iter
|
||||
) const
|
||||
{
|
||||
const tensor& output_tensor = sub.get_output();
|
||||
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
|
||||
DLIB_CASSERT(output_tensor.nr() == 1 &&
|
||||
output_tensor.nc() == 1 );
|
||||
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
|
||||
|
||||
|
||||
// Note that output_tensor.k() should match the number of labels.
|
||||
|
||||
for (long i = 0; i < output_tensor.num_samples(); ++i)
|
||||
{
|
||||
// The index of the largest output for this sample is the label.
|
||||
*iter++ = index_of_max(rowm(mat(output_tensor),i));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename const_label_iterator,
|
||||
typename SUBNET
|
||||
>
|
||||
double compute_loss_value_and_gradient (
|
||||
const tensor& input_tensor,
|
||||
const_label_iterator truth,
|
||||
SUBNET& sub
|
||||
) const
|
||||
{
|
||||
const tensor& output_tensor = sub.get_output();
|
||||
tensor& grad = sub.get_gradient_input();
|
||||
|
||||
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
|
||||
DLIB_CASSERT(input_tensor.num_samples() != 0);
|
||||
DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
|
||||
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
|
||||
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
|
||||
DLIB_CASSERT(output_tensor.nr() == 1 &&
|
||||
output_tensor.nc() == 1);
|
||||
DLIB_CASSERT(grad.nr() == 1 &&
|
||||
grad.nc() == 1);
|
||||
|
||||
tt::softmax(grad, output_tensor);
|
||||
|
||||
// The loss we output is the average loss over the mini-batch.
|
||||
const double scale = 1.0/output_tensor.num_samples();
|
||||
double loss = 0;
|
||||
float* g = grad.host();
|
||||
for (long i = 0; i < output_tensor.num_samples(); ++i)
|
||||
{
|
||||
const auto wl = *truth++;
|
||||
const long y = wl.label;
|
||||
const float weight = wl.weight;
|
||||
// The network must produce a number of outputs that is equal to the number
|
||||
// of labels when using this type of loss.
|
||||
DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k());
|
||||
for (long k = 0; k < output_tensor.k(); ++k)
|
||||
{
|
||||
const unsigned long idx = i*output_tensor.k()+k;
|
||||
if (k == y)
|
||||
{
|
||||
loss += weight*scale*-safe_log(g[idx]);
|
||||
g[idx] =weight*scale*(g[idx]-1);
|
||||
}
|
||||
else
|
||||
{
|
||||
g[idx] = weight*scale*g[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
|
||||
friend void serialize(const loss_multiclass_log_weighted_& , std::ostream& out)
|
||||
{
|
||||
serialize("loss_multiclass_log_weighted_", out);
|
||||
}
|
||||
|
||||
friend void deserialize(loss_multiclass_log_weighted_& , std::istream& in)
|
||||
{
|
||||
std::string version;
|
||||
deserialize(version, in);
|
||||
if (version != "loss_multiclass_log_weighted_")
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_weighted_.");
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_weighted_& )
|
||||
{
|
||||
out << "loss_multiclass_log_weighted";
|
||||
return out;
|
||||
}
|
||||
|
||||
friend void to_xml(const loss_multiclass_log_weighted_& /*item*/, std::ostream& out)
|
||||
{
|
||||
out << "<loss_multiclass_log_weighted/>";
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using loss_multiclass_log_weighted = add_loss_layer<loss_multiclass_log_weighted_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class loss_multimulticlass_log_
|
||||
@ -2832,20 +2966,7 @@ namespace dlib
|
||||
{
|
||||
public:
|
||||
|
||||
struct weighted_label
|
||||
{
|
||||
weighted_label()
|
||||
{}
|
||||
|
||||
weighted_label(uint16_t label, float weight = 1.f)
|
||||
: label(label), weight(weight)
|
||||
{}
|
||||
|
||||
// In semantic segmentation, 65536 classes ought to be enough for anybody.
|
||||
uint16_t label = 0;
|
||||
float weight = 1.f;
|
||||
};
|
||||
|
||||
typedef dlib::weighted_label<uint16_t> weighted_label;
|
||||
typedef matrix<weighted_label> training_label_type;
|
||||
typedef matrix<uint16_t> output_label_type;
|
||||
|
||||
|
@ -369,6 +369,107 @@ namespace dlib
|
||||
template <typename SUBNET>
|
||||
using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename label_type>
|
||||
struct weighted_label
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object represents the truth label of a single sample, together with
|
||||
an associated weight (the higher the weight, the more emphasis the
|
||||
corresponding sample is given during the training).
|
||||
This object is used in the following loss layers:
|
||||
- loss_multiclass_log_weighted_ with unsigned long as label_type
|
||||
- loss_multiclass_log_per_pixel_weighted_ with uint16_t as label_type,
|
||||
since, in semantic segmentation, 65536 classes ought to be enough for
|
||||
anybody.
|
||||
!*/
|
||||
weighted_label()
|
||||
{}
|
||||
|
||||
weighted_label(label_type label, float weight = 1.f)
|
||||
: label(label), weight(weight)
|
||||
{}
|
||||
|
||||
// The ground truth label
|
||||
label_type label{};
|
||||
|
||||
// The weight of the corresponding sample
|
||||
float weight = 1.f;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class loss_multiclass_log_weighted_
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object implements the loss layer interface defined above by
|
||||
EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic
|
||||
regression loss (e.g. negative log-likelihood loss), which is appropriate
|
||||
for multiclass classification problems. It is basically just like the
|
||||
loss_multiclass_log except that it lets you define per-sample weights,
|
||||
which might be useful e.g. if you want to emphasize rare classes while
|
||||
training. If the classification problem is difficult, a flat weight
|
||||
structure may lead the network to always predict the most common label,
|
||||
in particular if the degree of imbalance is high. To emphasize a certain
|
||||
class or classes, simply increase the weights of the corresponding samples,
|
||||
relative to the weights of other pixels.
|
||||
|
||||
Note that if you set all the weights equals to 1, then you get
|
||||
loss_multiclass_log_ as a special case.
|
||||
!*/
|
||||
|
||||
public:
|
||||
|
||||
typedef weighted_label<unsigned long> training_label_type;
|
||||
typedef unsigned long output_label_type;
|
||||
|
||||
template <
|
||||
typename SUB_TYPE,
|
||||
typename label_iterator
|
||||
>
|
||||
void to_label (
|
||||
const tensor& input_tensor,
|
||||
const SUB_TYPE& sub,
|
||||
label_iterator iter
|
||||
) const;
|
||||
/*!
|
||||
This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
|
||||
it has the additional calling requirements that:
|
||||
- sub.get_output().nr() == 1
|
||||
- sub.get_output().nc() == 1
|
||||
- sub.get_output().num_samples() == input_tensor.num_samples()
|
||||
- sub.sample_expansion_factor() == 1
|
||||
and the output label is the predicted class for each classified object. The number
|
||||
of possible output classes is sub.get_output().k().
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename const_label_iterator,
|
||||
typename SUBNET
|
||||
>
|
||||
double compute_loss_value_and_gradient (
|
||||
const tensor& input_tensor,
|
||||
const_label_iterator truth,
|
||||
SUBNET& sub
|
||||
) const;
|
||||
/*!
|
||||
This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
|
||||
except it has the additional calling requirements that:
|
||||
- sub.get_output().nr() == 1
|
||||
- sub.get_output().nc() == 1
|
||||
- sub.get_output().num_samples() == input_tensor.num_samples()
|
||||
- sub.sample_expansion_factor() == 1
|
||||
- all values pointed to by truth are < sub.get_output().k()
|
||||
!*/
|
||||
|
||||
};
|
||||
|
||||
template <typename SUBNET>
|
||||
using loss_multiclass_log_weighted = add_loss_layer<loss_multiclass_log_weighted_, SUBNET>;// ----------------------------------------------------------------------------------------
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class loss_multimulticlass_log_
|
||||
@ -1434,27 +1535,7 @@ namespace dlib
|
||||
!*/
|
||||
public:
|
||||
|
||||
struct weighted_label
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object represents the truth label of a single pixel, together with
|
||||
an associated weight (the higher the weight, the more emphasis the
|
||||
corresponding pixel is given during the training).
|
||||
!*/
|
||||
|
||||
weighted_label();
|
||||
weighted_label(uint16_t label, float weight = 1.f);
|
||||
|
||||
// The ground-truth label. In semantic segmentation, 65536 classes ought to be
|
||||
// enough for anybody.
|
||||
uint16_t label = 0;
|
||||
|
||||
// The weight of the corresponding pixel.
|
||||
float weight = 1.f;
|
||||
};
|
||||
|
||||
typedef matrix<weighted_label> training_label_type;
|
||||
typedef matrix<weighted_label<uint16_t> training_label_type;
|
||||
typedef matrix<uint16_t> output_label_type;
|
||||
|
||||
template <
|
||||
|
@ -3224,6 +3224,86 @@ namespace
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_loss_multiclass_log_weighted()
|
||||
{
|
||||
|
||||
print_spinner();
|
||||
|
||||
constexpr int input_height = 5;
|
||||
constexpr int input_width = 7;
|
||||
const size_t num_samples = 1000;
|
||||
const size_t num_classes = 4;
|
||||
|
||||
::std::vector<matrix<double>> x(num_samples);
|
||||
::std::vector<unsigned long> y(num_samples);
|
||||
|
||||
matrix<double> xtmp(input_height, input_width);
|
||||
|
||||
dlib::rand rnd;
|
||||
// Generate input data
|
||||
for (size_t ii = 0; ii < num_samples; ++ii)
|
||||
{
|
||||
for (int jj = 0; jj < input_height; ++jj)
|
||||
{
|
||||
for (int kk = 0; kk < input_width; ++kk)
|
||||
{
|
||||
xtmp(jj, kk) = rnd.get_random_float();
|
||||
}
|
||||
}
|
||||
x[ii] = xtmp;
|
||||
y[ii] = rnd.get_integer_in_range(0, num_classes);
|
||||
}
|
||||
|
||||
using net_type = loss_multiclass_log_weighted<fc<num_classes, input<matrix<double>>>>;
|
||||
|
||||
::std::vector<weighted_label<unsigned long>> y_weighted(num_samples);
|
||||
|
||||
for (size_t weighted_class = 0; weighted_class < num_classes; ++weighted_class)
|
||||
{
|
||||
|
||||
print_spinner();
|
||||
|
||||
// Assign weights
|
||||
for (size_t ii = 0; ii < num_samples; ++ii)
|
||||
{
|
||||
const unsigned long label = y[ii];
|
||||
const float weight
|
||||
= label == weighted_class
|
||||
? 1.4f
|
||||
: 0.6f;
|
||||
y_weighted[ii] = weighted_label<unsigned long>(label, weight);
|
||||
}
|
||||
|
||||
net_type net;
|
||||
sgd defsolver(0, 0.9);
|
||||
dnn_trainer<net_type> trainer(net, defsolver);
|
||||
trainer.set_learning_rate(0.1);
|
||||
trainer.set_min_learning_rate(0.01);
|
||||
trainer.set_mini_batch_size(10);
|
||||
trainer.set_max_num_epochs(10);
|
||||
trainer.train(x, y_weighted);
|
||||
|
||||
const ::std::vector<unsigned long> predictions = net(x);
|
||||
|
||||
int num_weighted_class = 0;
|
||||
int num_not_weighted_class = 0;
|
||||
|
||||
for (size_t ii = 0; ii < num_samples; ++ii)
|
||||
{
|
||||
if (predictions[ii] == weighted_class)
|
||||
++num_weighted_class;
|
||||
else
|
||||
++num_not_weighted_class;
|
||||
}
|
||||
|
||||
DLIB_TEST_MSG(num_weighted_class > num_not_weighted_class,
|
||||
"The weighted class (" << weighted_class << ") does not dominate: "
|
||||
<< num_weighted_class << " <= " << num_not_weighted_class);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_tensor_resize_bilinear(long samps, long k, long nr, long nc, long onr, long onc)
|
||||
@ -3645,6 +3725,7 @@ namespace
|
||||
test_loss_multiclass_per_pixel_outputs_on_trivial_task();
|
||||
test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore();
|
||||
test_loss_multiclass_per_pixel_weighted();
|
||||
test_loss_multiclass_log_weighted();
|
||||
test_serialization();
|
||||
test_loss_dot();
|
||||
test_loss_multimulticlass_log();
|
||||
|
Loading…
Reference in New Issue
Block a user