mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Add dnn_trainer::train_one_step iterator signature (#212)
Add an overload of dnn_trainer::train_one_step that takes a pair of iterators rather than a std::vector.
This commit is contained in:
parent
5e770e848e
commit
9726ce1cac
@ -183,7 +183,22 @@ namespace dlib
|
|||||||
const std::vector<label_type>& labels
|
const std::vector<label_type>& labels
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
|
DLIB_CASSERT(data.size() == labels.size());
|
||||||
|
|
||||||
|
train_one_step(data.begin(), data.end(), labels.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator,
|
||||||
|
typename label_iterator
|
||||||
|
>
|
||||||
|
void train_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend,
|
||||||
|
label_iterator lbegin
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||||
|
|
||||||
if (verbose)
|
if (verbose)
|
||||||
{
|
{
|
||||||
@ -200,7 +215,7 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(data.begin(), data.end(), labels.begin());
|
send_job(dbegin, dend, lbegin);
|
||||||
|
|
||||||
++train_one_step_calls;
|
++train_one_step_calls;
|
||||||
}
|
}
|
||||||
@ -209,7 +224,18 @@ namespace dlib
|
|||||||
const std::vector<input_type>& data
|
const std::vector<input_type>& data
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
DLIB_CASSERT(data.size() > 0);
|
train_one_step(data.begin(), data.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator
|
||||||
|
>
|
||||||
|
void train_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend
|
||||||
|
)
|
||||||
|
{
|
||||||
|
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||||
if (verbose)
|
if (verbose)
|
||||||
{
|
{
|
||||||
using namespace std::chrono;
|
using namespace std::chrono;
|
||||||
@ -225,7 +251,7 @@ namespace dlib
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
sync_to_disk();
|
sync_to_disk();
|
||||||
send_job(data.begin(), data.end());
|
send_job(dbegin, dend);
|
||||||
++train_one_step_calls;
|
++train_one_step_calls;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,6 +414,37 @@ namespace dlib
|
|||||||
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator,
|
||||||
|
typename label_iterator
|
||||||
|
>
|
||||||
|
void train_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend,
|
||||||
|
label_iterator lbegin
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
|
||||||
|
- std::distance(dbegin, dend) > 0
|
||||||
|
- net_type uses a supervised loss.
|
||||||
|
i.e. net_type::label_type != no_label_type.
|
||||||
|
ensures
|
||||||
|
- Performs one stochastic gradient update step based on the mini-batch of
|
||||||
|
data and labels supplied to this function. In particular, calling
|
||||||
|
train_one_step() in a loop is equivalent to calling the train() method
|
||||||
|
defined above. However, train_one_step() allows you to stream data from
|
||||||
|
disk into the training process while train() requires you to first load
|
||||||
|
all the training data into RAM. Otherwise, these training methods are
|
||||||
|
equivalent.
|
||||||
|
- You can observe the current average loss value by calling get_average_loss().
|
||||||
|
- The network training will happen in another thread. Therefore, after
|
||||||
|
calling this function you should call get_net() before you touch the net
|
||||||
|
object from the calling thread to ensure no other threads are still
|
||||||
|
accessing the network.
|
||||||
|
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||||
|
!*/
|
||||||
|
|
||||||
void train_one_step (
|
void train_one_step (
|
||||||
const std::vector<input_type>& data
|
const std::vector<input_type>& data
|
||||||
);
|
);
|
||||||
@ -438,6 +469,34 @@ namespace dlib
|
|||||||
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename data_iterator
|
||||||
|
>
|
||||||
|
void train_one_step (
|
||||||
|
data_iterator dbegin,
|
||||||
|
data_iterator dend
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- std::distance(dbegin, dend) > 0
|
||||||
|
- net_type uses an unsupervised loss.
|
||||||
|
i.e. net_type::label_type == no_label_type.
|
||||||
|
ensures
|
||||||
|
- Performs one stochastic gradient update step based on the mini-batch of
|
||||||
|
data supplied to this function. In particular, calling train_one_step()
|
||||||
|
in a loop is equivalent to calling the train() method defined above.
|
||||||
|
However, train_one_step() allows you to stream data from disk into the
|
||||||
|
training process while train() requires you to first load all the
|
||||||
|
training data into RAM. Otherwise, these training methods are
|
||||||
|
equivalent.
|
||||||
|
- You can observe the current average loss value by calling get_average_loss().
|
||||||
|
- The network training will happen in another thread. Therefore, after
|
||||||
|
calling this function you should call get_net() before you touch the net
|
||||||
|
object from the calling thread to ensure no other threads are still
|
||||||
|
accessing the network.
|
||||||
|
- #get_train_one_step_calls() == get_train_one_step_calls() + 1.
|
||||||
|
!*/
|
||||||
|
|
||||||
double get_average_loss (
|
double get_average_loss (
|
||||||
) const;
|
) const;
|
||||||
/*!
|
/*!
|
||||||
|
Loading…
Reference in New Issue
Block a user