mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added test_one_step() to the dnn_trainer. This allows you to do automatic
early stopping based on observing the loss on held out data.
This commit is contained in:
parent
5f5684a8fb
commit
c20c11af90
@ -40,6 +40,7 @@ namespace dlib
|
||||
std::vector<std::vector<training_label_type>> labels;
|
||||
std::vector<resizable_tensor> t;
|
||||
std::vector<int> have_data; // have_data[i] is true if there is data in labels[i] and t[i].
|
||||
bool test_only = false;
|
||||
};
|
||||
|
||||
template <typename training_label_type>
|
||||
@ -48,6 +49,7 @@ namespace dlib
|
||||
a.labels.swap(b.labels);
|
||||
a.t.swap(b.t);
|
||||
a.have_data.swap(b.have_data);
|
||||
std::swap(a.test_only,b.test_only);
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,22 +207,9 @@ namespace dlib
|
||||
{
|
||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
using namespace std::chrono;
|
||||
auto now_time = system_clock::now();
|
||||
if (now_time-last_time > seconds(40))
|
||||
{
|
||||
last_time = now_time;
|
||||
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
|
||||
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
|
||||
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
||||
print_progress();
|
||||
clear_average_loss();
|
||||
}
|
||||
}
|
||||
print_periodic_verbose_status();
|
||||
sync_to_disk();
|
||||
send_job(dbegin, dend, lbegin);
|
||||
send_job(false, dbegin, dend, lbegin);
|
||||
|
||||
++train_one_step_calls;
|
||||
}
|
||||
@ -241,22 +230,60 @@ namespace dlib
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||
if (verbose)
|
||||
{
|
||||
using namespace std::chrono;
|
||||
auto now_time = system_clock::now();
|
||||
if (now_time-last_time > seconds(40))
|
||||
{
|
||||
last_time = now_time;
|
||||
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
|
||||
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
|
||||
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
||||
print_progress();
|
||||
clear_average_loss();
|
||||
}
|
||||
}
|
||||
print_periodic_verbose_status();
|
||||
sync_to_disk();
|
||||
send_job(dbegin, dend);
|
||||
send_job(false, dbegin, dend);
|
||||
++train_one_step_calls;
|
||||
}
|
||||
|
||||
void test_one_step (
|
||||
const std::vector<input_type>& data,
|
||||
const std::vector<training_label_type>& labels
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(data.size() == labels.size());
|
||||
|
||||
test_one_step(data.begin(), data.end(), labels.begin());
|
||||
}
|
||||
|
||||
template <
|
||||
typename data_iterator,
|
||||
typename label_iterator
|
||||
>
|
||||
void test_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend,
|
||||
label_iterator lbegin
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||
|
||||
print_periodic_verbose_status();
|
||||
sync_to_disk();
|
||||
send_job(true, dbegin, dend, lbegin);
|
||||
|
||||
++train_one_step_calls;
|
||||
}
|
||||
|
||||
void test_one_step (
|
||||
const std::vector<input_type>& data
|
||||
)
|
||||
{
|
||||
test_one_step(data.begin(), data.end());
|
||||
}
|
||||
|
||||
template <
|
||||
typename data_iterator
|
||||
>
|
||||
void test_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend
|
||||
)
|
||||
{
|
||||
DLIB_CASSERT(std::distance(dbegin, dend) > 0);
|
||||
print_periodic_verbose_status();
|
||||
sync_to_disk();
|
||||
send_job(true, dbegin, dend);
|
||||
++train_one_step_calls;
|
||||
}
|
||||
|
||||
@ -294,7 +321,7 @@ namespace dlib
|
||||
}
|
||||
|
||||
sync_to_disk();
|
||||
send_job(data.begin()+epoch_pos,
|
||||
send_job(false, data.begin()+epoch_pos,
|
||||
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
|
||||
labels.begin()+epoch_pos);
|
||||
}
|
||||
@ -352,7 +379,7 @@ namespace dlib
|
||||
}
|
||||
|
||||
sync_to_disk();
|
||||
send_job(data.begin()+epoch_pos,
|
||||
send_job(false, data.begin()+epoch_pos,
|
||||
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
|
||||
}
|
||||
epoch_pos = 0;
|
||||
@ -394,6 +421,16 @@ namespace dlib
|
||||
return rs.mean();
|
||||
}
|
||||
|
||||
double get_average_test_loss (
|
||||
) const
|
||||
{
|
||||
wait_for_thread_to_pause();
|
||||
running_stats<double> tmp;
|
||||
for (auto& x : test_previous_loss_values)
|
||||
tmp.add(x);
|
||||
return tmp.mean();
|
||||
}
|
||||
|
||||
void clear_average_loss (
|
||||
)
|
||||
{
|
||||
@ -410,7 +447,9 @@ namespace dlib
|
||||
if (learning_rate != lr)
|
||||
{
|
||||
steps_without_progress = 0;
|
||||
test_steps_without_progress = 0;
|
||||
previous_loss_values.clear();
|
||||
test_previous_loss_values.clear();
|
||||
}
|
||||
learning_rate = lr;
|
||||
lr_schedule.set_size(0);
|
||||
@ -479,6 +518,27 @@ namespace dlib
|
||||
return steps_without_progress;
|
||||
}
|
||||
|
||||
void set_test_iterations_without_progress_threshold (
|
||||
unsigned long thresh
|
||||
)
|
||||
{
|
||||
wait_for_thread_to_pause();
|
||||
lr_schedule.set_size(0);
|
||||
test_iter_without_progress_thresh = thresh;
|
||||
}
|
||||
|
||||
unsigned long get_test_iterations_without_progress_threshold (
|
||||
) const
|
||||
{
|
||||
return test_iter_without_progress_thresh;
|
||||
}
|
||||
|
||||
unsigned long get_test_steps_without_progress (
|
||||
) const
|
||||
{
|
||||
return test_steps_without_progress;
|
||||
}
|
||||
|
||||
void set_learning_rate_shrink_factor (
|
||||
double shrink
|
||||
)
|
||||
@ -488,6 +548,7 @@ namespace dlib
|
||||
lr_schedule.set_size(0);
|
||||
learning_rate_shrink = shrink;
|
||||
steps_without_progress = 0;
|
||||
test_steps_without_progress = 0;
|
||||
}
|
||||
|
||||
double get_learning_rate_shrink_factor (
|
||||
@ -504,6 +565,14 @@ namespace dlib
|
||||
|
||||
private:
|
||||
|
||||
void record_test_loss(double loss)
|
||||
{
|
||||
test_previous_loss_values.push_back(loss);
|
||||
// discard really old loss values.
|
||||
while (test_previous_loss_values.size() > test_iter_without_progress_thresh)
|
||||
test_previous_loss_values.pop_front();
|
||||
}
|
||||
|
||||
void record_loss(double loss)
|
||||
{
|
||||
// Say that we will check if the gradient is bad 200 times during each
|
||||
@ -527,7 +596,10 @@ namespace dlib
|
||||
{
|
||||
auto&& dev = *devices[device];
|
||||
dlib::cuda::set_device(dev.device_id);
|
||||
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
|
||||
if (next_job.test_only)
|
||||
return dev.net.compute_loss(next_job.t[device], next_job.labels[device].begin());
|
||||
else
|
||||
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -542,7 +614,10 @@ namespace dlib
|
||||
auto&& dev = *devices[device];
|
||||
dlib::cuda::set_device(dev.device_id);
|
||||
no_label_type pick_which_run_update;
|
||||
return dev.net.compute_parameter_gradients(next_job.t[device]);
|
||||
if (next_job.test_only)
|
||||
return dev.net.compute_loss(next_job.t[device]);
|
||||
else
|
||||
return dev.net.compute_parameter_gradients(next_job.t[device]);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -585,6 +660,37 @@ namespace dlib
|
||||
main_iteration_counter = 0;
|
||||
while(job_pipe.dequeue(next_job))
|
||||
{
|
||||
if (next_job.test_only)
|
||||
{
|
||||
// compute the testing loss
|
||||
for (size_t i = 0; i < devices.size(); ++i)
|
||||
tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
|
||||
// aggregate loss values from all the network computations.
|
||||
double theloss = 0;
|
||||
for (auto&& loss : losses)
|
||||
theloss += loss.get();
|
||||
record_test_loss(theloss/losses.size());
|
||||
|
||||
// Check if we should shrink the learning rate based on how the test
|
||||
// error has been doing lately.
|
||||
if (learning_rate_shrink != 1)
|
||||
{
|
||||
test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
|
||||
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
||||
{
|
||||
test_steps_without_progress = count_steps_without_decrease_robust(test_previous_loss_values);
|
||||
if (test_steps_without_progress >= test_iter_without_progress_thresh)
|
||||
{
|
||||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
test_steps_without_progress = 0;
|
||||
test_previous_loss_values.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
updated_net_since_last_sync = true;
|
||||
++main_iteration_counter;
|
||||
// Call compute_parameter_gradients() and update_parameters() but pick the
|
||||
@ -719,7 +825,7 @@ namespace dlib
|
||||
job_pipe.wait_for_num_blocked_dequeues(1);
|
||||
}
|
||||
|
||||
const static long string_pad = 10;
|
||||
const static long string_pad = 11;
|
||||
const static long epoch_string_pad = 4;
|
||||
const static long lr_string_pad = 4;
|
||||
|
||||
@ -732,6 +838,9 @@ namespace dlib
|
||||
min_learning_rate = 1e-5;
|
||||
iter_without_progress_thresh = 2000;
|
||||
steps_without_progress = 0;
|
||||
test_iter_without_progress_thresh = 200;
|
||||
test_steps_without_progress = 0;
|
||||
|
||||
learning_rate_shrink = 0.1;
|
||||
epoch_iteration = 0;
|
||||
epoch_pos = 0;
|
||||
@ -755,7 +864,7 @@ namespace dlib
|
||||
friend void serialize(const dnn_trainer& item, std::ostream& out)
|
||||
{
|
||||
item.wait_for_thread_to_pause();
|
||||
int version = 7;
|
||||
int version = 8;
|
||||
serialize(version, out);
|
||||
|
||||
size_t nl = dnn_trainer::num_layers;
|
||||
@ -777,13 +886,17 @@ namespace dlib
|
||||
serialize(item.train_one_step_calls, out);
|
||||
serialize(item.lr_schedule, out);
|
||||
serialize(item.lr_schedule_pos, out);
|
||||
serialize(item.test_iter_without_progress_thresh.load(), out);
|
||||
serialize(item.test_steps_without_progress.load(), out);
|
||||
serialize(item.test_previous_loss_values, out);
|
||||
|
||||
}
|
||||
friend void deserialize(dnn_trainer& item, std::istream& in)
|
||||
{
|
||||
item.wait_for_thread_to_pause();
|
||||
int version = 0;
|
||||
deserialize(version, in);
|
||||
if (version != 7)
|
||||
if (version != 8)
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
|
||||
|
||||
size_t num_layers = 0;
|
||||
@ -815,6 +928,9 @@ namespace dlib
|
||||
deserialize(item.train_one_step_calls, in);
|
||||
deserialize(item.lr_schedule, in);
|
||||
deserialize(item.lr_schedule_pos, in);
|
||||
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
|
||||
deserialize(ltemp, in); item.test_steps_without_progress = ltemp;
|
||||
deserialize(item.test_previous_loss_values, in);
|
||||
|
||||
if (item.devices.size() > 1)
|
||||
{
|
||||
@ -966,6 +1082,7 @@ namespace dlib
|
||||
typename label_iterator
|
||||
>
|
||||
void send_job (
|
||||
bool test_only,
|
||||
data_iterator dbegin,
|
||||
data_iterator dend,
|
||||
label_iterator lbegin
|
||||
@ -977,6 +1094,7 @@ namespace dlib
|
||||
job.t.resize(devs);
|
||||
job.labels.resize(devs);
|
||||
job.have_data.resize(devs);
|
||||
job.test_only = test_only;
|
||||
|
||||
// chop the data into devs blocks, each of about block_size elements.
|
||||
size_t block_size = (num+devs-1)/devs;
|
||||
@ -1009,19 +1127,23 @@ namespace dlib
|
||||
typename data_iterator
|
||||
>
|
||||
void send_job (
|
||||
bool test_only,
|
||||
data_iterator dbegin,
|
||||
data_iterator dend
|
||||
)
|
||||
{
|
||||
typename std::vector<training_label_type>::iterator nothing;
|
||||
send_job(dbegin, dend, nothing);
|
||||
send_job(test_only, dbegin, dend, nothing);
|
||||
}
|
||||
|
||||
void print_progress()
|
||||
{
|
||||
if (lr_schedule.size() == 0)
|
||||
{
|
||||
std::cout << "steps without apparent progress: " << steps_without_progress;
|
||||
if (test_previous_loss_values.size() == 0)
|
||||
std::cout << "steps without apparent progress: " << steps_without_progress;
|
||||
else
|
||||
std::cout << "steps without apparent progress: train=" << steps_without_progress << ", test=" << test_steps_without_progress;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1032,6 +1154,32 @@ namespace dlib
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void print_periodic_verbose_status()
|
||||
{
|
||||
if (verbose)
|
||||
{
|
||||
using namespace std::chrono;
|
||||
auto now_time = system_clock::now();
|
||||
if (now_time-last_time > seconds(40))
|
||||
{
|
||||
last_time = now_time;
|
||||
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
|
||||
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " ";
|
||||
if (test_previous_loss_values.size() == 0)
|
||||
{
|
||||
std::cout << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "train loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
|
||||
std::cout << "test loss: " << rpad(cast_to_string(get_average_test_loss()),string_pad) << " ";
|
||||
}
|
||||
print_progress();
|
||||
clear_average_loss();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<device_data>> devices;
|
||||
dlib::pipe<job_t> job_pipe;
|
||||
job_t job;
|
||||
@ -1047,6 +1195,11 @@ namespace dlib
|
||||
double min_learning_rate;
|
||||
std::atomic<unsigned long> iter_without_progress_thresh;
|
||||
std::atomic<unsigned long> steps_without_progress;
|
||||
|
||||
std::atomic<unsigned long> test_iter_without_progress_thresh;
|
||||
std::atomic<unsigned long> test_steps_without_progress;
|
||||
std::deque<double> test_previous_loss_values;
|
||||
|
||||
std::atomic<double> learning_rate_shrink;
|
||||
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
|
||||
std::string sync_filename;
|
||||
|
@ -76,6 +76,7 @@ namespace dlib
|
||||
- #get_learning_rate() == 1e-2
|
||||
- #get_min_learning_rate() == 1e-5
|
||||
- #get_iterations_without_progress_threshold() == 2000
|
||||
- #get_test_iterations_without_progress_threshold() == 200
|
||||
- #get_learning_rate_shrink_factor() == 0.1
|
||||
- #get_learning_rate_schedule().size() == 0
|
||||
- #get_train_one_step_calls() == 0
|
||||
@ -536,6 +537,170 @@ namespace dlib
|
||||
stopped touching the net.
|
||||
!*/
|
||||
|
||||
// ----------------------
|
||||
|
||||
double get_average_test_loss (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the average loss value observed during previous calls to
|
||||
test_one_step().
|
||||
- This function blocks until all threads inside the dnn_trainer have
|
||||
stopped touching the net.
|
||||
!*/
|
||||
|
||||
void test_one_step (
|
||||
const std::vector<input_type>& data,
|
||||
const std::vector<training_label_type>& labels
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- data.size() == labels.size()
|
||||
- data.size() > 0
|
||||
- net_type uses a supervised loss.
|
||||
i.e. net_type::training_label_type != no_label_type.
|
||||
ensures
|
||||
- Runs the given data through the network and computes and records the loss.
|
||||
- This call does not modify network parameters. The point of
|
||||
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||
network on hold out data during training, and to allow the trainer to
|
||||
automatically adjust the learning rate when the test loss stops
|
||||
improving. It should be noted that you are not required to use
|
||||
test_one_step() at all, but if you want to do this kind of thing it is
|
||||
available.
|
||||
- You can observe the current average loss value by calling get_average_test_loss().
|
||||
- The computation will happen in another thread. Therefore, after calling
|
||||
this function you should call get_net() before you touch the net object
|
||||
from the calling thread to ensure no other threads are still accessing
|
||||
the network.
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename data_iterator,
|
||||
typename label_iterator
|
||||
>
|
||||
void test_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend,
|
||||
label_iterator lbegin
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
|
||||
- std::distance(dbegin, dend) > 0
|
||||
- net_type uses a supervised loss.
|
||||
i.e. net_type::training_label_type != no_label_type.
|
||||
ensures
|
||||
- Runs the given data through the network and computes and records the loss.
|
||||
- This call does not modify network parameters. The point of
|
||||
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||
network on hold out data during training, and to allow the trainer to
|
||||
automatically adjust the learning rate when the test loss stops
|
||||
improving. It should be noted that you are not required to use
|
||||
test_one_step() at all, but if you want to do this kind of thing it is
|
||||
available.
|
||||
- You can observe the current average loss value by calling get_average_test_loss().
|
||||
- The computation will happen in another thread. Therefore, after calling
|
||||
this function you should call get_net() before you touch the net object
|
||||
from the calling thread to ensure no other threads are still accessing
|
||||
the network.
|
||||
!*/
|
||||
|
||||
void test_one_step (
|
||||
const std::vector<input_type>& data
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- data.size() > 0
|
||||
- net_type uses an unsupervised loss.
|
||||
i.e. net_type::training_label_type == no_label_type.
|
||||
ensures
|
||||
- Runs the given data through the network and computes and records the loss.
|
||||
- This call does not modify network parameters. The point of
|
||||
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||
network on hold out data during training, and to allow the trainer to
|
||||
automatically adjust the learning rate when the test loss stops
|
||||
improving. It should be noted that you are not required to use
|
||||
test_one_step() at all, but if you want to do this kind of thing it is
|
||||
available.
|
||||
- You can observe the current average loss value by calling get_average_test_loss().
|
||||
- The computation will happen in another thread. Therefore, after calling
|
||||
this function you should call get_net() before you touch the net object
|
||||
from the calling thread to ensure no other threads are still accessing
|
||||
the network.
|
||||
!*/
|
||||
|
||||
template <
|
||||
typename data_iterator
|
||||
>
|
||||
void test_one_step (
|
||||
data_iterator dbegin,
|
||||
data_iterator dend
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- std::distance(dbegin, dend) > 0
|
||||
- net_type uses an unsupervised loss.
|
||||
i.e. net_type::training_label_type == no_label_type.
|
||||
ensures
|
||||
- Runs the given data through the network and computes and records the loss.
|
||||
- This call does not modify network parameters. The point of
|
||||
test_one_step() is two fold, to allow you to observe the accuracy of the
|
||||
network on hold out data during training, and to allow the trainer to
|
||||
automatically adjust the learning rate when the test loss stops
|
||||
improving. It should be noted that you are not required to use
|
||||
test_one_step() at all, but if you want to do this kind of thing it is
|
||||
available.
|
||||
- You can observe the current average loss value by calling get_average_test_loss().
|
||||
- The computation will happen in another thread. Therefore, after calling
|
||||
this function you should call get_net() before you touch the net object
|
||||
from the calling thread to ensure no other threads are still accessing
|
||||
the network.
|
||||
!*/
|
||||
|
||||
void set_test_iterations_without_progress_threshold (
|
||||
unsigned long thresh
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_test_iterations_without_progress_threshold() == thresh
|
||||
- #get_learning_rate_schedule().size() == 0
|
||||
- This function blocks until all threads inside the dnn_trainer have
|
||||
stopped touching the net.
|
||||
!*/
|
||||
|
||||
unsigned long get_test_iterations_without_progress_threshold (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- This object monitors the progress of training and estimates if the
|
||||
testing error is being reduced. It does this by looking at the previous
|
||||
get_test_iterations_without_progress_threshold() mini-batch results from
|
||||
test_one_step() and applying the statistical test defined by the
|
||||
running_gradient object to see if the testing error is getting smaller.
|
||||
If it isn't being reduced then get_learning_rate() is made smaller by a
|
||||
factor of get_learning_rate_shrink_factor().
|
||||
|
||||
Therefore, get_test_iterations_without_progress_threshold() should always be
|
||||
set to something sensibly large so that this test can be done with
|
||||
reasonably high confidence. Think of this test as saying "if the testing loss
|
||||
hasn't decreased for the previous get_test_iterations_without_progress_threshold()
|
||||
calls to test_one_step() then shrink the learning rate".
|
||||
!*/
|
||||
|
||||
unsigned long get_test_steps_without_progress (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- if (get_learning_rate_shrink_factor() != 1) then
|
||||
- returns an estimate of how many mini-batches have executed without us
|
||||
observing a statistically significant decrease in the testing error
|
||||
(i.e. the error on the data given to the trainer via test_one_step()
|
||||
calls).
|
||||
- else
|
||||
- returns 0
|
||||
!*/
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue
Block a user