add loss multiclass log weighted (#2022)

* add loss_multiclass_log_weighted

* fix class name in loss_abstract

* add loss_multiclass_log_weighted test

* rename test function to match class name

* fix typo

* reuse the weighted label struct across weighted losses

* do not break compatibility with loss_multiclass_log_per_pixel_weighted

* actually test the loss and fix docs

* fix build with gcc 9
This commit is contained in:
Adrià Arrufat 2020-03-18 21:33:54 +09:00 committed by GitHub
parent 9185a925ce
commit 1380e6b95f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 318 additions and 35 deletions

View File

@ -367,6 +367,140 @@ namespace dlib
template <typename SUBNET>
using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename label_type>
struct weighted_label
{
weighted_label()
{}
weighted_label(label_type label, float weight = 1.f)
: label(label), weight(weight)
{}
label_type label{};
float weight = 1.f;
};
// ----------------------------------------------------------------------------------------
class loss_multiclass_log_weighted_
{
public:
typedef weighted_label<unsigned long> training_label_type;
typedef unsigned long output_label_type;
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter
) const
{
const tensor& output_tensor = sub.get_output();
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1 );
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
// Note that output_tensor.k() should match the number of labels.
for (long i = 0; i < output_tensor.num_samples(); ++i)
{
// The index of the largest output for this sample is the label.
*iter++ = index_of_max(rowm(mat(output_tensor),i));
}
}
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const
{
const tensor& output_tensor = sub.get_output();
tensor& grad = sub.get_gradient_input();
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
DLIB_CASSERT(input_tensor.num_samples() != 0);
DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1);
DLIB_CASSERT(grad.nr() == 1 &&
grad.nc() == 1);
tt::softmax(grad, output_tensor);
// The loss we output is the average loss over the mini-batch.
const double scale = 1.0/output_tensor.num_samples();
double loss = 0;
float* g = grad.host();
for (long i = 0; i < output_tensor.num_samples(); ++i)
{
const auto wl = *truth++;
const long y = wl.label;
const float weight = wl.weight;
// The network must produce a number of outputs that is equal to the number
// of labels when using this type of loss.
DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k());
for (long k = 0; k < output_tensor.k(); ++k)
{
const unsigned long idx = i*output_tensor.k()+k;
if (k == y)
{
loss += weight*scale*-safe_log(g[idx]);
g[idx] =weight*scale*(g[idx]-1);
}
else
{
g[idx] = weight*scale*g[idx];
}
}
}
return loss;
}
friend void serialize(const loss_multiclass_log_weighted_& , std::ostream& out)
{
serialize("loss_multiclass_log_weighted_", out);
}
friend void deserialize(loss_multiclass_log_weighted_& , std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_multiclass_log_weighted_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_weighted_.");
}
friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_weighted_& )
{
out << "loss_multiclass_log_weighted";
return out;
}
friend void to_xml(const loss_multiclass_log_weighted_& /*item*/, std::ostream& out)
{
out << "<loss_multiclass_log_weighted/>";
}
};
template <typename SUBNET>
using loss_multiclass_log_weighted = add_loss_layer<loss_multiclass_log_weighted_, SUBNET>;
// ----------------------------------------------------------------------------------------
class loss_multimulticlass_log_
@ -2832,20 +2966,7 @@ namespace dlib
{
public:
struct weighted_label
{
weighted_label()
{}
weighted_label(uint16_t label, float weight = 1.f)
: label(label), weight(weight)
{}
// In semantic segmentation, 65536 classes ought to be enough for anybody.
uint16_t label = 0;
float weight = 1.f;
};
typedef dlib::weighted_label<uint16_t> weighted_label;
typedef matrix<weighted_label> training_label_type;
typedef matrix<uint16_t> output_label_type;

View File

@ -369,6 +369,107 @@ namespace dlib
template <typename SUBNET>
using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename label_type>
struct weighted_label
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents the truth label of a single sample, together with
an associated weight (the higher the weight, the more emphasis the
corresponding sample is given during the training).
This object is used in the following loss layers:
- loss_multiclass_log_weighted_ with unsigned long as label_type
- loss_multiclass_log_per_pixel_weighted_ with uint16_t as label_type,
since, in semantic segmentation, 65536 classes ought to be enough for
anybody.
!*/
weighted_label()
{}
weighted_label(label_type label, float weight = 1.f)
: label(label), weight(weight)
{}
// The ground truth label
label_type label{};
// The weight of the corresponding sample
float weight = 1.f;
};
// ----------------------------------------------------------------------------------------
class loss_multiclass_log_weighted_
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the loss layer interface defined above by
EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic
regression loss (e.g. negative log-likelihood loss), which is appropriate
for multiclass classification problems. It is basically just like the
loss_multiclass_log except that it lets you define per-sample weights,
which might be useful e.g. if you want to emphasize rare classes while
training. If the classification problem is difficult, a flat weight
structure may lead the network to always predict the most common label,
in particular if the degree of imbalance is high. To emphasize a certain
class or classes, simply increase the weights of the corresponding samples,
relative to the weights of other pixels.
Note that if you set all the weights equals to 1, then you get
loss_multiclass_log_ as a special case.
!*/
public:
typedef weighted_label<unsigned long> training_label_type;
typedef unsigned long output_label_type;
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
it has the additional calling requirements that:
- sub.get_output().nr() == 1
- sub.get_output().nc() == 1
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
and the output label is the predicted class for each classified object. The number
of possible output classes is sub.get_output().k().
!*/
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
except it has the additional calling requirements that:
- sub.get_output().nr() == 1
- sub.get_output().nc() == 1
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
- all values pointed to by truth are < sub.get_output().k()
!*/
};
template <typename SUBNET>
using loss_multiclass_log_weighted = add_loss_layer<loss_multiclass_log_weighted_, SUBNET>;// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class loss_multimulticlass_log_
@ -1434,27 +1535,7 @@ namespace dlib
!*/
public:
struct weighted_label
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents the truth label of a single pixel, together with
an associated weight (the higher the weight, the more emphasis the
corresponding pixel is given during the training).
!*/
weighted_label();
weighted_label(uint16_t label, float weight = 1.f);
// The ground-truth label. In semantic segmentation, 65536 classes ought to be
// enough for anybody.
uint16_t label = 0;
// The weight of the corresponding pixel.
float weight = 1.f;
};
typedef matrix<weighted_label> training_label_type;
typedef matrix<weighted_label<uint16_t> training_label_type;
typedef matrix<uint16_t> output_label_type;
template <

View File

@ -3224,6 +3224,86 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
void test_loss_multiclass_log_weighted()
{
print_spinner();
constexpr int input_height = 5;
constexpr int input_width = 7;
const size_t num_samples = 1000;
const size_t num_classes = 4;
::std::vector<matrix<double>> x(num_samples);
::std::vector<unsigned long> y(num_samples);
matrix<double> xtmp(input_height, input_width);
dlib::rand rnd;
// Generate input data
for (size_t ii = 0; ii < num_samples; ++ii)
{
for (int jj = 0; jj < input_height; ++jj)
{
for (int kk = 0; kk < input_width; ++kk)
{
xtmp(jj, kk) = rnd.get_random_float();
}
}
x[ii] = xtmp;
y[ii] = rnd.get_integer_in_range(0, num_classes);
}
using net_type = loss_multiclass_log_weighted<fc<num_classes, input<matrix<double>>>>;
::std::vector<weighted_label<unsigned long>> y_weighted(num_samples);
for (size_t weighted_class = 0; weighted_class < num_classes; ++weighted_class)
{
print_spinner();
// Assign weights
for (size_t ii = 0; ii < num_samples; ++ii)
{
const unsigned long label = y[ii];
const float weight
= label == weighted_class
? 1.4f
: 0.6f;
y_weighted[ii] = weighted_label<unsigned long>(label, weight);
}
net_type net;
sgd defsolver(0, 0.9);
dnn_trainer<net_type> trainer(net, defsolver);
trainer.set_learning_rate(0.1);
trainer.set_min_learning_rate(0.01);
trainer.set_mini_batch_size(10);
trainer.set_max_num_epochs(10);
trainer.train(x, y_weighted);
const ::std::vector<unsigned long> predictions = net(x);
int num_weighted_class = 0;
int num_not_weighted_class = 0;
for (size_t ii = 0; ii < num_samples; ++ii)
{
if (predictions[ii] == weighted_class)
++num_weighted_class;
else
++num_not_weighted_class;
}
DLIB_TEST_MSG(num_weighted_class > num_not_weighted_class,
"The weighted class (" << weighted_class << ") does not dominate: "
<< num_weighted_class << " <= " << num_not_weighted_class);
}
}
// ----------------------------------------------------------------------------------------
void test_tensor_resize_bilinear(long samps, long k, long nr, long nc, long onr, long onc)
@ -3645,6 +3725,7 @@ namespace
test_loss_multiclass_per_pixel_outputs_on_trivial_task();
test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore();
test_loss_multiclass_per_pixel_weighted();
test_loss_multiclass_log_weighted();
test_serialization();
test_loss_dot();
test_loss_multimulticlass_log();