Add dnn_trainer::train_one_step iterator signature (#212)

Add an overload of dnn_trainer::train_one_step that takes a pair of
iterators rather than a std::vector.
This commit is contained in:
jpblackburn 2016-08-31 22:13:39 -04:00 committed by Davis E. King
parent 5e770e848e
commit 9726ce1cac
2 changed files with 89 additions and 4 deletions

View File

@ -183,7 +183,22 @@ namespace dlib
const std::vector<label_type>& labels
)
{
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
DLIB_CASSERT(data.size() == labels.size());
train_one_step(data.begin(), data.end(), labels.begin());
}
template <
typename data_iterator,
typename label_iterator
>
void train_one_step (
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
)
{
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
if (verbose)
{
@ -200,7 +215,7 @@ namespace dlib
}
}
sync_to_disk();
send_job(data.begin(), data.end(), labels.begin());
send_job(dbegin, dend, lbegin);
++train_one_step_calls;
}
@ -209,7 +224,18 @@ namespace dlib
const std::vector<input_type>& data
)
{
DLIB_CASSERT(data.size() > 0);
train_one_step(data.begin(), data.end());
}
template <
typename data_iterator
>
void train_one_step (
data_iterator dbegin,
data_iterator dend
)
{
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
if (verbose)
{
using namespace std::chrono;
@ -225,7 +251,7 @@ namespace dlib
}
}
sync_to_disk();
send_job(data.begin(), data.end());
send_job(dbegin, dend);
++train_one_step_calls;
}

View File

@ -414,6 +414,37 @@ namespace dlib
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
!*/
template <
typename data_iterator,
typename label_iterator
>
void train_one_step (
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
);
/*!
requires
- std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
- std::distance(dbegin, dend) > 0
- net_type uses a supervised loss.
i.e. net_type::label_type != no_label_type.
ensures
- Performs one stochastic gradient update step based on the mini-batch of
data and labels supplied to this function. In particular, calling
train_one_step() in a loop is equivalent to calling the train() method
defined above. However, train_one_step() allows you to stream data from
disk into the training process while train() requires you to first load
all the training data into RAM. Otherwise, these training methods are
equivalent.
- You can observe the current average loss value by calling get_average_loss().
- The network training will happen in another thread. Therefore, after
calling this function you should call get_net() before you touch the net
object from the calling thread to ensure no other threads are still
accessing the network.
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
!*/
void train_one_step (
const std::vector<input_type>& data
);
@ -438,6 +469,34 @@ namespace dlib
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
!*/
template <
typename data_iterator
>
void train_one_step (
data_iterator dbegin,
data_iterator dend
);
/*!
requires
- std::distance(dbegin, dend) > 0
- net_type uses an unsupervised loss.
i.e. net_type::label_type == no_label_type.
ensures
- Performs one stochastic gradient update step based on the mini-batch of
data supplied to this function. In particular, calling train_one_step()
in a loop is equivalent to calling the train() method defined above.
However, train_one_step() allows you to stream data from disk into the
training process while train() requires you to first load all the
training data into RAM. Otherwise, these training methods are
equivalent.
- You can observe the current average loss value by calling get_average_loss().
- The network training will happen in another thread. Therefore, after
calling this function you should call get_net() before you touch the net
object from the calling thread to ensure no other threads are still
accessing the network.
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
!*/
double get_average_loss (
) const;
/*!