Cleanup cuDNN conv algorithm selection code slightly by moving it into its own function.

This commit is contained in:
Davis King 2020-08-01 13:33:39 -04:00
parent 4d18e0d0c7
commit 6c3243f766
2 changed files with 135 additions and 115 deletions

View File

@ -776,6 +776,134 @@ namespace dlib
return best_alg;
}
void tensor_conv::
select_best_algorithms (
const tensor& data,
const tensor_descriptor& dest_desc
)
{
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
forward_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&forward_best_algo));
#endif
forward_algo = forward_best_algo;
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_data_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_data_best_algo));
#endif
backward_data_algo = backward_data_best_algo;
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_filters_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_filters_best_algo));
#endif
// cuDNN 5.1 has a bug that causes
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// algorithm even for cases where cuDNN doesn't support it, leading to
// incorrect outputs. So here we check if we are in a case where winograd
// isn't supported and manually overrule
// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
// algorithm.
if (dnn_prefer_fastest_algorithms() &&
!(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
)
{
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
}
backward_filters_algo = backward_filters_best_algo;
}
void tensor_conv::
setup(
const tensor& data,
@ -863,81 +991,17 @@ namespace dlib
tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
forward_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&forward_best_algo));
#endif
forward_algo = forward_best_algo;
select_best_algorithms(data, dest_desc);
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
forward_best_algo,
(cudnnConvolutionFwdAlgo_t)forward_algo,
&forward_workspace_size_in_bytes));
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_data_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_data_best_algo));
#endif
backward_data_algo = backward_data_best_algo;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(),
@ -945,55 +1009,9 @@ namespace dlib
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
backward_data_best_algo,
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
&backward_data_workspace_size_in_bytes));
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_filters_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_filters_best_algo));
#endif
// cuDNN 5.1 has a bug that causes
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// algorithm even for cases where cuDNN doesn't support it, leading to
// incorrect outputs. So here we check if we are in a case where winograd
// isn't supported and manually overrule
// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
// algorithm.
if (dnn_prefer_fastest_algorithms() &&
!(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
)
{
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
}
backward_filters_algo = backward_filters_best_algo;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(
context(),
@ -1001,7 +1019,7 @@ namespace dlib
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
backward_filters_best_algo,
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
&backward_filters_workspace_size_in_bytes));
}
catch(...)

View File

@ -228,6 +228,8 @@ namespace dlib
int out_nr;
int out_nc;
// sets the three _algo fields.
void select_best_algorithms(const tensor& data, const tensor_descriptor& dest_desc);
int forward_algo;
int backward_data_algo;
int backward_filters_algo;