mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Use a cache to avoid calls to the cuDNN algorithm selection routines.
This commit is contained in:
parent
8910445a7a
commit
2c70aad12c
@ -8,6 +8,8 @@
|
||||
#include "cudnn_dlibapi.h"
|
||||
#include "tensor.h"
|
||||
#include <cudnn.h>
|
||||
#include <tuple>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -782,6 +784,22 @@ namespace dlib
|
||||
const tensor_descriptor& dest_desc
|
||||
)
|
||||
{
|
||||
// Calling the cuDNN "find the best algorithm" functions are really slow. So we keep a
|
||||
// cache that tells us what method was best for a particular configuration.
|
||||
thread_local std::map<std::tuple<int,int,int,int,long,long>,
|
||||
std::tuple<int,int,int>> config_to_algo_cache;
|
||||
|
||||
// If we have already found good algorithms for this setting then just pull them from
|
||||
// the cache.
|
||||
const auto cache_key = std::make_tuple(stride_y, stride_x, padding_y, padding_x, filters_nr, filters_nc);
|
||||
const auto iter = config_to_algo_cache.find(cache_key);
|
||||
if (iter != config_to_algo_cache.end())
|
||||
{
|
||||
std::tie(forward_algo, backward_data_algo, backward_filters_algo) = iter->second;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Pick which forward algorithm we will use and allocate the necessary
|
||||
// workspace buffer.
|
||||
cudnnConvolutionFwdAlgo_t forward_best_algo;
|
||||
@ -902,6 +920,10 @@ namespace dlib
|
||||
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
|
||||
}
|
||||
backward_filters_algo = backward_filters_best_algo;
|
||||
|
||||
|
||||
// Save this algorithm selection in the cache
|
||||
config_to_algo_cache[cache_key] = std::make_tuple(forward_algo, backward_data_algo, backward_filters_algo);
|
||||
}
|
||||
|
||||
void tensor_conv::
|
||||
@ -916,7 +938,12 @@ namespace dlib
|
||||
{
|
||||
DLIB_CASSERT(data.k() == filters.k());
|
||||
|
||||
const bool non_data_params_unchanged =
|
||||
// if the last call to setup gave the same exact settings then don't do
|
||||
// anything.
|
||||
if (data_num_samples == data.num_samples() &&
|
||||
data_k == data.k() &&
|
||||
data_nr == data.nr() &&
|
||||
data_nc == data.nc() &&
|
||||
stride_y_ == stride_y &&
|
||||
stride_x_ == stride_x &&
|
||||
padding_y_ == padding_y &&
|
||||
@ -924,15 +951,7 @@ namespace dlib
|
||||
filters_num_samples == filters.num_samples() &&
|
||||
filters_k == filters.k() &&
|
||||
filters_nr == filters.nr() &&
|
||||
filters_nc == filters.nc();
|
||||
|
||||
// if the last call to setup gave the same exact settings then don't do
|
||||
// anything.
|
||||
if (non_data_params_unchanged &&
|
||||
data_num_samples == data.num_samples() &&
|
||||
data_k == data.k() &&
|
||||
data_nr == data.nr() &&
|
||||
data_nc == data.nc()
|
||||
filters_nc == filters.nc()
|
||||
)
|
||||
{
|
||||
return;
|
||||
@ -995,16 +1014,7 @@ namespace dlib
|
||||
tensor_descriptor dest_desc;
|
||||
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
|
||||
|
||||
// Ask cuDNN what algorithms are best to use. We always do this on the first call
|
||||
// to setup(). Then if something other than the size of the input tensor changes we
|
||||
// also ask cuDNN what to use. Note that in newer versions of cuDNN, asking for the
|
||||
// best algorithm is a relatively slow thing. So it's important we don't do it
|
||||
// unnecessarily.
|
||||
if (!selected_algos || !non_data_params_unchanged)
|
||||
{
|
||||
selected_algos = true;
|
||||
select_best_algorithms(data, dest_desc);
|
||||
}
|
||||
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
|
||||
context(),
|
||||
|
@ -233,8 +233,6 @@ namespace dlib
|
||||
int forward_algo;
|
||||
int backward_data_algo;
|
||||
int backward_filters_algo;
|
||||
// true if select_best_algorithms has been called at least once.
|
||||
bool selected_algos = false;
|
||||
|
||||
size_t forward_workspace_size_in_bytes;
|
||||
size_t backward_data_workspace_size_in_bytes;
|
||||
|
Loading…
Reference in New Issue
Block a user