diff --git a/dlib/svm/rbf_network.h b/dlib/svm/rbf_network.h index 1323decec..97cc20652 100644 --- a/dlib/svm/rbf_network.h +++ b/dlib/svm/rbf_network.h @@ -108,14 +108,13 @@ namespace dlib typedef typename decision_function::scalar_vector_type scalar_vector_type; // make sure requires clause is not broken - DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1, "\tdecision_function rbf_network_trainer::train(x,y)" << "\n\t invalid inputs were given to this function" << "\n\t x.nr(): " << x.nr() << "\n\t y.nr(): " << y.nr() << "\n\t x.nc(): " << x.nc() << "\n\t y.nc(): " << y.nc() - << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false") ); // first run all the sampes through a kcentroid object to find the rbf centers diff --git a/dlib/svm/rbf_network_abstract.h b/dlib/svm/rbf_network_abstract.h index 7dd6487cb..e3cad3427 100644 --- a/dlib/svm/rbf_network_abstract.h +++ b/dlib/svm/rbf_network_abstract.h @@ -27,8 +27,7 @@ namespace dlib - get_tolerance() == 0.01 WHAT THIS OBJECT REPRESENTS - This object implements a trainer for radial basis function network for - solving binary classification problems. + This object implements a trainer for an radial basis function network. The implementation of this algorithm follows the normal RBF training process. For more details see the code or the Wikipedia article @@ -94,7 +93,13 @@ namespace dlib ) const /*! requires - - is_binary_classification_problem(x,y) == true + - in_sample_vector_type == a matrix or something convertable to a matrix + via vector_to_matrix() + - in_scalar_vector_type == a matrix or something convertable to a matrix + via vector_to_matrix() + - x.nr() > 1 + - x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1 + (i.e. x and y are both column vectors of the same length) ensures - trains a RBF network given the training samples in x and labels in y.