mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add dnn_trainer::train_one_step iterator signature (#212)
Add an overload of dnn_trainer::train_one_step that takes a pair of iterators rather than a std::vector.
This commit is contained in:
parent
5e770e848e
commit
9726ce1cac
@ -183,7 +183,22 @@ namespace dlib
|
||||
const std::vector<label_type>& labels
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
|
||||
DLIB_CASSERT(data.size() == labels.size());
|
||||
|
||||
train_one_step(data.begin(), data.end(), labels.begin());
|
||||
}
|
||||
|
||||
template <
|
||||
typename data_iterator,
|
||||
typename label_iterator
|
||||
>
|
||||
void train_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend,
|
||||
label_iterator lbegin
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
@ -200,7 +215,7 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
sync_to_disk();
|
||||
send_job(data.begin(), data.end(), labels.begin());
|
||||
send_job(dbegin, dend, lbegin);
|
||||
|
||||
++train_one_step_calls;
|
||||
}
|
||||
@ -209,7 +224,18 @@ namespace dlib
|
||||
const std::vector<input_type>& data
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(data.size() > 0);
|
||||
train_one_step(data.begin(), data.end());
|
||||
}
|
||||
|
||||
template <
|
||||
typename data_iterator
|
||||
>
|
||||
void train_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||
if (verbose)
|
||||
{
|
||||
using namespace std::chrono;
|
||||
@ -225,7 +251,7 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
sync_to_disk();
|
||||
send_job(data.begin(), data.end());
|
||||
send_job(dbegin, dend);
|
||||
++train_one_step_calls;
|
||||
}
|
||||
|
||||
|
@ -414,6 +414,37 @@ namespace dlib
|
||||
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename data_iterator,
|
||||
typename label_iterator
|
||||
>
|
||||
void train_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend,
|
||||
label_iterator lbegin
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
|
||||
- std::distance(dbegin, dend) > 0
|
||||
- net_type uses a supervised loss.
|
||||
i.e. net_type::label_type != no_label_type.
|
||||
ensures
|
||||
- Performs one stochastic gradient update step based on the mini-batch of
|
||||
data and labels supplied to this function. In particular, calling
|
||||
train_one_step() in a loop is equivalent to calling the train() method
|
||||
defined above. However, train_one_step() allows you to stream data from
|
||||
disk into the training process while train() requires you to first load
|
||||
all the training data into RAM. Otherwise, these training methods are
|
||||
equivalent.
|
||||
- You can observe the current average loss value by calling get_average_loss().
|
||||
- The network training will happen in another thread. Therefore, after
|
||||
calling this function you should call get_net() before you touch the net
|
||||
object from the calling thread to ensure no other threads are still
|
||||
accessing the network.
|
||||
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||
!*/
|
||||
|
||||
void train_one_step (
|
||||
const std::vector<input_type>& data
|
||||
);
|
||||
@ -438,6 +469,34 @@ namespace dlib
|
||||
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename data_iterator
|
||||
>
|
||||
void train_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- std::distance(dbegin, dend) > 0
|
||||
- net_type uses an unsupervised loss.
|
||||
i.e. net_type::label_type == no_label_type.
|
||||
ensures
|
||||
- Performs one stochastic gradient update step based on the mini-batch of
|
||||
data supplied to this function. In particular, calling train_one_step()
|
||||
in a loop is equivalent to calling the train() method defined above.
|
||||
However, train_one_step() allows you to stream data from disk into the
|
||||
training process while train() requires you to first load all the
|
||||
training data into RAM. Otherwise, these training methods are
|
||||
equivalent.
|
||||
- You can observe the current average loss value by calling get_average_loss().
|
||||
- The network training will happen in another thread. Therefore, after
|
||||
calling this function you should call get_net() before you touch the net
|
||||
object from the calling thread to ensure no other threads are still
|
||||
accessing the network.
|
||||
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||
!*/
|
||||
|
||||
double get_average_loss (
|
||||
) const;
|
||||
/*!
|
||||
|
Loading…
Reference in New Issue
Block a user