This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.3.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.3.x by this push:
new a94109a CudnnFind() usage improvements (v1.3.x) (#13123)
a94109a is described below
commit a94109a0e5d5f15659cd0e428f06aebde0f30e90
Author: Anton Chernov <[email protected]>
AuthorDate: Wed Nov 7 19:11:48 2018 +0100
CudnnFind() usage improvements (v1.3.x) (#13123)
* Add mx.context.gpu_memory_info() to python api for flexible tests.
* Add test_gluon_gpu.py:test_large_models to show cudnnFind headroom issue.
* Output model sizes tried by test_gluon_gpu.py:test_large_models.
* Fix perl interface to MXGetGPUMemoryInformation.
* Increase difficulty of test_gluon_gpu.py:test_large_models.
* Forgot a file in fix for perl.
* Modify test to pass on no-cudnn CI runner.
* Mutex algo reg updates, serialize cudnnFind calls.
* Fix for cudnnFind memory headroom issue.
* Fix cpplint.
* Respond to reviewers comments.
* Guard against improper MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE values.
* Fix potentially unassigned var.
---
CONTRIBUTORS.md | 14 +-
docs/faq/env_var.md | 4 +
include/mxnet/base.h | 14 +-
include/mxnet/c_api.h | 10 +
perl-package/AI-MXNetCAPI/mxnet.i | 9 +
python/mxnet/context.py | 24 ++
src/c_api/c_api.cc | 11 +
src/operator/nn/cudnn/cudnn_algoreg-inl.h | 66 ++--
src/operator/nn/cudnn/cudnn_convolution-inl.h | 480 ++++++++++++----------
src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 502 +++++++++++++-----------
src/storage/pooled_storage_manager.h | 30 +-
tests/python/gpu/test_gluon_gpu.py | 43 +-
12 files changed, 714 insertions(+), 493 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 8d8aeac..404f135 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -44,6 +44,11 @@ The committers are the granted write access to the project.
* [Sergey Kolychev](https://github.com/sergeykolychev)
- Sergey is original author and current maintainer of Perl5 interface.
* [Naveen Swamy](https://github.com/nswamy)
+* [Marco de Abreu](https://github.com/marcoabreu)
+ - Marco is the creator of the current MXNet CI.
+* [Carin Meier](https://github.com/gigasquid)
+ - Carin created and is the current maintainer for the Clojure interface.
+
### Become a Committer
MXNet is a opensource project and we are actively looking for new committers
@@ -153,8 +158,6 @@ List of Contributors
* [Manu Seth](https://github.com/mseth10/)
* [Calum Leslie](https://github.com/calumleslie)
* [Andre Tamm](https://github.com/andretamm)
-* [Marco de Abreu](https://github.com/marcoabreu)
- - Marco is the creator of the current MXNet CI.
* [Julian Salazar](https://github.com/JulianSlzr)
* [Meghna Baijal](https://github.com/mbaijal)
* [Tao Hu](https://github.com/dongzhuoyao)
@@ -178,3 +181,10 @@ List of Contributors
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
+* [Zhijingcheng Yu](https://github.com/jasonyu1996)
+* [Cheng-Che Lee](https://github.com/stu1130)
+* [Chaitanya Bapat](https://github.com/ChaiBapchya)
+* [LuckyPigeon](https://github.com/LuckyPigeon)
+* [Anton Chernov](https://github.com/lebeg)
+* [Denisa Roberts](https://github.com/D-Roberts)
+* [Dick Carter](https://github.com/DickJC123)
diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index 0664d79..6546752 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -58,6 +58,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
- Values: Int ```(default=5)```
- The percentage of GPU memory to reserve for things other than the GPU
array, such as kernel launch or cudnn handle space.
- If you see a strange out-of-memory error from the kernel launch, after
multiple iterations, try setting this to a larger value.
+* MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE
+ - Values: Int ```(default=2097152)```
+ - When using the naive pool type, memory allocations larger than this
threshhold are rounded up to a multiple of this value.
+ - The default was chosen to minimize global memory fragmentation within the
GPU driver. Set this to 1 to disable.
## Engine Type
diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index 75784a3..a56ca6f 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -225,11 +225,11 @@ struct Context {
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
- * \param free_mem pointer to the integer holding free GPU memory
- * \param total_mem pointer to the integer holding total GPU memory
+ * \param free_mem pointer to the uint64_t holding free GPU memory
+ * \param total_mem pointer to the uint64_t holding total GPU memory
* \return No return value
*/
- inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
+ inline static void GetGPUMemoryInformation(int dev, uint64_t *free, uint64_t
*total);
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
@@ -334,8 +334,8 @@ inline int32_t Context::GetGPUCount() {
#endif
}
-inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
- int *total_mem) {
+inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem,
+ uint64_t *total_mem) {
#if MXNET_USE_CUDA
size_t memF, memT;
@@ -354,8 +354,8 @@ inline void Context::GetGPUMemoryInformation(int dev, int
*free_mem,
e = cudaSetDevice(curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
- *free_mem = static_cast<int>(memF);
- *total_mem = static_cast<int>(memT);
+ *free_mem = static_cast<uint64_t>(memF);
+ *total_mem = static_cast<uint64_t>(memT);
#else
LOG(FATAL)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 0043996..1c2ebb8 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -439,6 +439,7 @@ MXNET_DLL int MXGetGPUCount(int* out);
/*!
* \brief get the free and total available memory on a GPU
+ * Note: Deprecated, use MXGetGPUMemoryInformation64 instead.
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
@@ -447,6 +448,15 @@ MXNET_DLL int MXGetGPUCount(int* out);
MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int
*total_mem);
/*!
+ * \brief get the free and total available memory on a GPU
+ * \param dev the GPU number to query
+ * \param free_mem pointer to the uint64_t holding free GPU memory
+ * \param total_mem pointer to the uint64_t holding total GPU memory
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem,
uint64_t *total_mem);
+
+/*!
* \brief get the MXNet library version as an integer
* \param pointer to the integer holding the version number
* \return 0 when success, -1 when failure happens
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i
b/perl-package/AI-MXNetCAPI/mxnet.i
index 2540e1b..64c1654 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -342,6 +342,15 @@ int MXEngineSetBulkSize(int bulk_size, int* out);
*/
int MXGetGPUCount(int* out);
+/*!
+ * \brief get the free and total available memory on a GPU
+ * \param dev the GPU number to query
+ * \param free_mem pointer to the uint64_t holding free GPU memory
+ * \param total_mem pointer to the uint64_t holding total GPU memory
+ * \return 0 when success, -1 when failure happens
+ */
+int MXGetGPUMemoryInformation64(int dev, uint64_t *out, uint64_t *out);
+
//-------------------------------------
// Part 1: NDArray creation and deletion
diff --git a/python/mxnet/context.py b/python/mxnet/context.py
index 61b7053..15ea990 100644
--- a/python/mxnet/context.py
+++ b/python/mxnet/context.py
@@ -258,6 +258,30 @@ def num_gpus():
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
return count.value
+def gpu_memory_info(device_id=0):
+ """Query CUDA for the free and total bytes of GPU global memory.
+
+ Parameters
+ ----------
+ device_id : int, optional
+ The device id of the GPU device.
+
+ Raises
+ ------
+ Will raise an exception on any CUDA error.
+
+ Returns
+ -------
+ (free, total) : (int, int)
+ The number of GPUs.
+
+ """
+ free = ctypes.c_uint64()
+ total = ctypes.c_uint64()
+ dev_id = ctypes.c_int(device_id)
+ check_call(_LIB.MXGetGPUMemoryInformation64(dev_id, ctypes.byref(free),
ctypes.byref(total)))
+ return (free.value, total.value)
+
def current_context():
"""Returns the current context.
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1ef3f0f..feed336 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -122,8 +122,19 @@ int MXGetGPUCount(int* out) {
API_END();
}
+// Deprecated: use MXGetGPUMemoryInformation64() instead.
int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
API_BEGIN();
+ uint64_t free_mem64 = 0UL;
+ uint64_t total_mem64 = 0UL;
+ Context::GetGPUMemoryInformation(dev, &free_mem64, &total_mem64);
+ *free_mem = static_cast<int>(free_mem64);
+ *total_mem = static_cast<int>(total_mem64);
+ API_END();
+}
+
+int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t
*total_mem) {
+ API_BEGIN();
Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
API_END();
}
diff --git a/src/operator/nn/cudnn/cudnn_algoreg-inl.h
b/src/operator/nn/cudnn/cudnn_algoreg-inl.h
index 3b59fd1..21d3a30 100644
--- a/src/operator/nn/cudnn/cudnn_algoreg-inl.h
+++ b/src/operator/nn/cudnn/cudnn_algoreg-inl.h
@@ -30,6 +30,8 @@
#include <mutex>
#include <string>
#include <vector>
+#include <functional>
+#include <utility>
#include "../../../common/cuda_utils.h"
#include "../convolution-inl.h"
#include "../deconvolution-inl.h"
@@ -65,7 +67,11 @@ class CuDNNAlgo {
template<typename ParamType>
class CuDNNAlgoReg {
public:
- bool Find(const ParamType ¶m,
+ using AlgoSetter_t = std::function<void(CuDNNAlgo<cudnnConvolutionFwdAlgo_t>
*,
+ CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *,
+ CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *)>;
+
+ void FindOrElseRegister(const ParamType ¶m,
const std::vector<TShape> &in_shape,
const std::vector<TShape> &out_shape,
cudnnDataType_t cudnn_data_type,
@@ -75,7 +81,8 @@ class CuDNNAlgoReg {
bool add_to_weight,
CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
- CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+ CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt,
+ const AlgoSetter_t &algo_setter) {
CHECK(in_shape.size() == 2 || in_shape.size() == 3);
ParamKey key{param, in_shape[0], in_shape[1], out_shape[0],
cudnn_data_type,
cudnn_forward_compute_type, cudnn_backward_compute_type,
sm_arch, add_to_weight};
@@ -85,45 +92,28 @@ class CuDNNAlgoReg {
*fwd = i->second.fwd;
*bwd = i->second.bwd;
*flt = i->second.flt;
- return true;
- }
- return false;
- }
-
- void Register(const ParamType ¶m,
- const std::vector<TShape> &in_shape,
- const std::vector<TShape> &out_shape,
- cudnnDataType_t cudnn_data_type,
- cudnnDataType_t cudnn_forward_compute_type,
- cudnnDataType_t cudnn_backward_compute_type,
- int sm_arch,
- bool add_to_weight,
- const CuDNNAlgo<cudnnConvolutionFwdAlgo_t> &fwd,
- const CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> &bwd,
- const CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> &flt) {
- CHECK(in_shape.size() == 2 || in_shape.size() == 3);
- ParamKey key{param, in_shape[0], in_shape[1], out_shape[0],
cudnn_data_type,
- cudnn_forward_compute_type, cudnn_backward_compute_type,
sm_arch, add_to_weight};
- std::lock_guard<std::mutex> guard(lock_);
- if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
- LOG(INFO) << "Running performance tests to find the best convolution "
- "algorithm, "
- "this can take a while... (setting env variable "
- "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)";
- if (reg_.size() >= 1000) {
- // Many people are very concerned about this warning, so change the
warning once.
- if (!is_warning_autotune_) {
- LOG(INFO)
- << "If you see this message in the middle of training, you are "
- "probably using bucketing. Consider setting env variable "
- "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning.";
- is_warning_autotune_ = true;
+ } else {
+ if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
+ LOG(INFO) << "Running performance tests to find the best convolution "
+ "algorithm, "
+ "this can take a while... (setting env variable "
+ "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)";
+ if (reg_.size() >= 1000) {
+ // Many people are very concerned about this warning, so change the
warning once.
+ if (!is_warning_autotune_) {
+ LOG(INFO)
+ << "If you see this message in the middle of training, you are
"
+ "probably using bucketing. Consider setting env variable "
+ "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn
tuning.";
+ is_warning_autotune_ = true;
+ }
}
}
+ // Call provided function to determine the algos- likely uses
cudnnFind() or cudnnGet()
+ algo_setter(fwd, bwd, flt);
+ // Save result so future lookups hit in this registry
+ reg_.insert(std::pair<ParamKey, CudnnAlgorithms>(key,
CudnnAlgorithms{*fwd, *bwd, *flt}));
}
- reg_[key].fwd = fwd;
- reg_[key].bwd = bwd;
- reg_[key].flt = flt;
}
static CuDNNAlgoReg *Get();
diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h
b/src/operator/nn/cudnn/cudnn_convolution-inl.h
index acdd649..4dc7ff8 100644
--- a/src/operator/nn/cudnn/cudnn_convolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h
@@ -26,6 +26,7 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_
#define MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_
+#include <mxnet/storage.h>
#include <algorithm>
#include <vector>
#include <mutex>
@@ -606,236 +607,265 @@ class CuDNNConvolutionOp {
}
}
- void SelectAlgo(const RunContext& rctx,
+ void CuDNNAlgoSetter(const RunContext& rctx,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
cudnnDataType_t cudnn_forward_compute_type,
- cudnnDataType_t cudnn_backward_compute_type) {
- if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
- cudnn_forward_compute_type,
cudnn_backward_compute_type,
- SMArch(rctx.ctx.dev_id), add_to_weight_,
- &forward_algo_, &back_algo_,
&back_algo_w_)) {
- mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
- CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
- size_t workspace_byte = static_cast<size_t>(param_.workspace *
sizeof(DType));
- #if CUDNN_MAJOR >= 7
- // Starting with cuDNNv7, the algo number returned by *Get*() is not the
entire
- // story: the notion of whether the algo ran in Tensor Core mode is not
known.
- // Since we want to report the Tensor Core mode in the verbose output,
we switch
- // to using the new *Get*_v7() call. Since the function signature of
*Get*_v7() matches
- // that of *Find*(), we can unify the find-vs-get logic by using
function pointers.
-
- // Forward Algorithm Find/Get() v7
- std::vector<cudnnConvolutionFwdAlgoPerf_t>
fwd_results(MaxForwardAlgos(s->dnn_handle_));
- int actual_fwd_algos = 0;
- auto fwd_algo_discoverer =
- param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionForwardAlgorithm_v7
- :
cudnnFindConvolutionForwardAlgorithm;
- CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
- in_desc_,
- filter_desc_,
- forward_conv_desc_,
- out_desc_,
- fwd_results.size(),
- &actual_fwd_algos,
- fwd_results.data()));
- fwd_results.resize(actual_fwd_algos);
- AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
- cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
- workspace_byte,
&forward_algo_);
-
- // Backprop-to-Filter Algorithm Find/Get() v7
- auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
- std::vector<cudnnConvolutionBwdFilterAlgoPerf_t>
bwd_filt_results(max_bwd_filt_algos);
- int actual_bwd_filter_algos = 0;
- // In cudnn v7.1.4, find() returned wgrad algos that could fail for
large c if we
- // were summing into the output (i.e. beta != 0). Get() returned OK
algos though.
- auto bwd_filter_algo_discoverer =
- param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionBackwardFilterAlgorithm_v7
- :
cudnnFindConvolutionBackwardFilterAlgorithm;
- CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+ cudnnDataType_t cudnn_backward_compute_type,
+ CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+ CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+ CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+ // Not in algo registry, must determine via *Get*() or *Find*()
+ mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+ CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+ size_t workspace_byte = static_cast<size_t>(param_.workspace *
sizeof(DType));
+#if CUDNN_MAJOR >= 7
+ // Starting with cuDNNv7, the algo number returned by *Get*() is not the
entire
+ // story: the notion of whether the algo ran in Tensor Core mode is not
known.
+ // Since we want to report the Tensor Core mode in the verbose output, we
switch
+ // to using the new *Get*_v7() call. Since the function signature of
*Get*_v7() matches
+ // that of *Find*(), we can unify the find-vs-get logic by using function
pointers.
+
+ // Forward Algorithm Find/Get() v7
+ std::vector<cudnnConvolutionFwdAlgoPerf_t>
fwd_results(MaxForwardAlgos(s->dnn_handle_));
+ int actual_fwd_algos = 0;
+ auto fwd_algo_discoverer =
+ param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionForwardAlgorithm_v7
+ :
cudnnFindConvolutionForwardAlgorithm;
+ CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+ in_desc_,
+ filter_desc_,
+ forward_conv_desc_,
+ out_desc_,
+ fwd_results.size(),
+ &actual_fwd_algos,
+ fwd_results.data()));
+ fwd_results.resize(actual_fwd_algos);
+ AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+ cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+ workspace_byte, fwd);
+
+ // Backprop-to-Filter Algorithm Find/Get() v7
+ auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+ std::vector<cudnnConvolutionBwdFilterAlgoPerf_t>
bwd_filt_results(max_bwd_filt_algos);
+ int actual_bwd_filter_algos = 0;
+ // In cudnn v7.1.4, find() returned wgrad algos that could fail for large
c if we
+ // were summing into the output (i.e. beta != 0). Get() returned OK algos
though.
+ auto bwd_filter_algo_discoverer =
+ param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionBackwardFilterAlgorithm_v7
+ :
cudnnFindConvolutionBackwardFilterAlgorithm;
+ CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+ in_desc_,
+ out_desc_,
+ back_conv_desc_w_,
+ filter_desc_,
+ bwd_filt_results.size(),
+ &actual_bwd_filter_algos,
+ bwd_filt_results.data()));
+ bwd_filt_results.resize(actual_bwd_filter_algos);
+ AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+ cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results,
"backprop-to-filter",
+ workspace_byte, flt);
+
+ // Backprop-to-Data Algorithm Find/Get() v7
+ auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+ std::vector<cudnnConvolutionBwdDataAlgoPerf_t>
bwd_data_results(max_bwd_data_algos);
+ int actual_bwd_data_algos = 0;
+ auto bwd_data_algo_discoverer =
+ param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionBackwardDataAlgorithm_v7
+ :
cudnnFindConvolutionBackwardDataAlgorithm;
+ CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+ filter_desc_,
+ out_desc_,
+ back_conv_desc_,
+ in_desc_,
+ bwd_data_results.size(),
+ &actual_bwd_data_algos,
+ bwd_data_results.data()));
+ bwd_data_results.resize(actual_bwd_data_algos);
+ AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+ cudnnConvolutionBwdDataAlgo_t>(bwd_data_results,
"backprop-to-data",
+ workspace_byte, bwd);
+#else
+ // CUDNN_MAJOR < 7
+ const int kMaxAlgos = 10;
+ int nalgo = kMaxAlgos;
+ int i = 0;
+ size_t min_memory_needs = 0;
+ // Forward Algorithm Find/Get, v6 and earlier
+ if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
+ // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
is
+ // supported. Hard-coded this since the algo find() or get() throws an
FPE.
+ fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+ } else if (!param_.cudnn_tune.value()) {
+ cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+ CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
in_desc_,
- out_desc_,
- back_conv_desc_w_,
filter_desc_,
- bwd_filt_results.size(),
- &actual_bwd_filter_algos,
- bwd_filt_results.data()));
- bwd_filt_results.resize(actual_bwd_filter_algos);
- AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
- cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results,
"backprop-to-filter",
- workspace_byte, &back_algo_w_);
-
- // Backprop-to-Data Algorithm Find/Get() v7
- auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
- std::vector<cudnnConvolutionBwdDataAlgoPerf_t>
bwd_data_results(max_bwd_data_algos);
- int actual_bwd_data_algos = 0;
- auto bwd_data_algo_discoverer =
- param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionBackwardDataAlgorithm_v7
- :
cudnnFindConvolutionBackwardDataAlgorithm;
- CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
- filter_desc_,
- out_desc_,
- back_conv_desc_,
- in_desc_,
- bwd_data_results.size(),
- &actual_bwd_data_algos,
- bwd_data_results.data()));
- bwd_data_results.resize(actual_bwd_data_algos);
- AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
- cudnnConvolutionBwdDataAlgo_t>(bwd_data_results,
"backprop-to-data",
- workspace_byte, &back_algo_);
- #else
- // CUDNN_MAJOR < 7
- const int kMaxAlgos = 10;
- int nalgo = kMaxAlgos;
- int i = 0;
- size_t min_memory_needs = 0;
- // Forward Algorithm Find/Get, v6 and earlier
- if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
- // In cuDNNv6, for kNHWC, only
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
- // supported. Hard-coded this since the algo find() or get() throws
an FPE.
- forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
- } else if (!param_.cudnn_tune.value()) {
- cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
- CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
- in_desc_,
- filter_desc_,
- forward_conv_desc_,
- out_desc_,
-
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
- workspace_byte,
- &fastest_fwd_algo));
- forward_algo_.Set(fastest_fwd_algo, false);
- } else {
- cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
- CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
- in_desc_,
- filter_desc_,
- forward_conv_desc_,
- out_desc_,
- kMaxAlgos,
- &nalgo,
- fwd_algo));
- i = 0;
- while (i < nalgo
- && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
- || (param_.cudnn_tune.value() == conv::kLimited
- && fwd_algo[i].memory > workspace_byte))) {
- ++i;
- min_memory_needs =
- (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs,
fwd_algo[i].memory);
- }
- if (i == nalgo) {
- LOG(FATAL) << nalgo << " forward algorithms with minimum memory
requirement "
- << min_memory_needs << " bytes have been tried. Workspace
size is set to "
- << workspace_byte << " bytes, please consider reducing
the batch/model size, "
- << "or increasing workspace size.";
- } else {
- forward_algo_.Set(fwd_algo[i].algo, false);
- }
+ forward_conv_desc_,
+ out_desc_,
+
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+ workspace_byte,
+ &fastest_fwd_algo));
+ fwd->Set(fastest_fwd_algo, false);
+ } else {
+ cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
+ CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
+ in_desc_,
+ filter_desc_,
+ forward_conv_desc_,
+ out_desc_,
+ kMaxAlgos,
+ &nalgo,
+ fwd_algo));
+ i = 0;
+ while (i < nalgo
+ && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+ || (param_.cudnn_tune.value() == conv::kLimited
+ && fwd_algo[i].memory > workspace_byte))) {
+ ++i;
+ min_memory_needs =
+ (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs,
fwd_algo[i].memory);
}
- // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
- if (!param_.cudnn_tune.value()) {
- cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
- CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
- in_desc_,
- out_desc_,
- back_conv_desc_w_,
- filter_desc_,
-
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
- workspace_byte,
- &fastest_bwd_filt_algo));
- back_algo_w_.Set(fastest_bwd_filt_algo, false);
+ if (i == nalgo) {
+ LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte,
"forward");
} else {
- cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
- CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
- in_desc_,
- out_desc_,
-
back_conv_desc_w_,
- filter_desc_,
- kMaxAlgos,
- &nalgo,
-
bwd_filter_algo));
- i = 0;
- while (i < nalgo
- && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
- || (param_.cudnn_tune.value() == conv::kLimited
- && bwd_filter_algo[i].memory > workspace_byte))) {
- ++i;
- min_memory_needs = (i == 0) ?
- bwd_filter_algo[i].memory :
- std::min(min_memory_needs,
bwd_filter_algo[i].memory);
- }
- if (i == nalgo) {
- LOG(FATAL) << nalgo << " backward filter algorithms with minimum
memory requirement "
- << min_memory_needs << " bytes have been tried. Workspace
size is set to "
- << workspace_byte << " bytes, please consider reducing
the batch/model size, "
- << "or increasing workspace size.";
- } else {
- back_algo_w_.Set(bwd_filter_algo[i].algo, false);
- }
+ fwd->Set(fwd_algo[i].algo, false);
}
- // Backprop-to-Data Algorithm Get(), v6 and earlier
- if (!param_.cudnn_tune.value()) {
- cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
- CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
- filter_desc_,
- out_desc_,
- back_conv_desc_,
- in_desc_,
-
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
- workspace_byte,
- &fastest_bwd_data_algo));
- back_algo_.Set(fastest_bwd_data_algo, false);
- } else {
- cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
- CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
- filter_desc_,
- out_desc_,
- back_conv_desc_,
+ }
+ // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
+ if (!param_.cudnn_tune.value()) {
+ cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
+ CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+ in_desc_,
+ out_desc_,
+ back_conv_desc_w_,
+ filter_desc_,
+
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+ workspace_byte,
+ &fastest_bwd_filt_algo));
+ flt->Set(fastest_bwd_filt_algo, false);
+ } else {
+ cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
+ CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
in_desc_,
+ out_desc_,
+ back_conv_desc_w_,
+ filter_desc_,
kMaxAlgos,
&nalgo,
- bwd_data_algo));
- i = 0;
- while (i < nalgo
- && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
- || (param_.cudnn_tune.value() == conv::kLimited
- && bwd_data_algo[i].memory > workspace_byte))) {
- ++i;
- min_memory_needs = (i == 0) ?
- bwd_data_algo[i].memory :
- std::min(min_memory_needs,
bwd_data_algo[i].memory);
- }
- if (i == nalgo) {
- LOG(FATAL) << nalgo << " backward data algorithms with minimum
memory requirement "
- << min_memory_needs << " bytes have been tried. Workspace
size is set to "
- << workspace_byte << " bytes, please consider reducing
the batch/model size, "
- << "or increasing workspace size.";
- } else {
- back_algo_.Set(bwd_data_algo[i].algo, false);
- }
+ bwd_filter_algo));
+ i = 0;
+ while (i < nalgo
+ && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
+ || (param_.cudnn_tune.value() == conv::kLimited
+ && bwd_filter_algo[i].memory > workspace_byte))) {
+ ++i;
+ min_memory_needs = (i == 0) ?
+ bwd_filter_algo[i].memory :
+ std::min(min_memory_needs,
bwd_filter_algo[i].memory);
}
- #endif // CUDNN_MAJOR < 7
-
- // Fix for issue #11241
- int cudnn_find_issue_max_features = 64 * 1024;
- if (add_to_weight_ && Features(in_shape[conv::kData]) >=
cudnn_find_issue_max_features) {
- this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+ if (i == nalgo) {
+ LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte,
"backward filter");
+ } else {
+ flt->Set(bwd_filter_algo[i].algo, false);
}
+ }
+ // Backprop-to-Data Algorithm Get(), v6 and earlier
+ if (!param_.cudnn_tune.value()) {
+ cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
+ CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+ filter_desc_,
+ out_desc_,
+ back_conv_desc_,
+ in_desc_,
+
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+ workspace_byte,
+ &fastest_bwd_data_algo));
+ bwd->Set(fastest_bwd_data_algo, false);
+ } else {
+ cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
+ CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+ filter_desc_,
+ out_desc_,
+ back_conv_desc_,
+ in_desc_,
+ kMaxAlgos,
+ &nalgo,
+ bwd_data_algo));
+ i = 0;
+ while (i < nalgo
+ && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
+ || (param_.cudnn_tune.value() == conv::kLimited
+ && bwd_data_algo[i].memory > workspace_byte))) {
+ ++i;
+ min_memory_needs = (i == 0) ?
+ bwd_data_algo[i].memory :
+ std::min(min_memory_needs, bwd_data_algo[i].memory);
+ }
+ if (i == nalgo) {
+ LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte,
"backward data");
+ } else {
+ bwd->Set(bwd_data_algo[i].algo, false);
+ }
+ }
+#endif // CUDNN_MAJOR < 7
- // An algo specification by the user may be cached here, but another
- // convolution will match only if identically specified.
- // We're caching results of *Get* as well as *Find*, but these records
- // will be held distinctly because param_.cudnn_tune is part of the key.
- CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
- cudnn_forward_compute_type,
- cudnn_backward_compute_type,
- SMArch(rctx.ctx.dev_id),
this->add_to_weight_,
- this->forward_algo_,
- this->back_algo_, this->back_algo_w_);
+ // Fix for issue #11241
+ int cudnn_find_issue_max_features = 64 * 1024;
+ if (add_to_weight_ && Features(in_shape[conv::kData]) >=
cudnn_find_issue_max_features) {
+ flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
}
+ }
+
+ void SelectAlgo(const RunContext& rctx,
+ const std::vector<TShape>& in_shape,
+ const std::vector<TShape>& out_shape,
+ cudnnDataType_t cudnn_forward_compute_type,
+ cudnnDataType_t cudnn_backward_compute_type) {
+ auto algo_setter = [&](CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+ CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+ CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+ if (param_.cudnn_tune.value() == conv::kOff) {
+ // The routine will only be calling cudnnGet, so no need to grab the
Storage lock.
+ this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+ cudnn_forward_compute_type,
+ cudnn_backward_compute_type,
+ fwd, bwd, flt);
+ } else {
+ // One potential problem is that cudnnFind() uses cudaMalloc() to
directly allocate
+ // I/O and workspace areas, and these allocations may result in an
out-of-memory
+ // error even though the StorageMangager free pool is not empty.
Ideally, cudnnFind
+ // would use MXNet's storage allocator for its I/O and workspace
areas, instead of using
+ // the area carved out by MXNET_GPU_MEM_POOL_RESERVE.
+ // To get somewhat the same effect as this, we can pre-allocate the
areas needed for the
+ // I/Os (possibly triggering a desirable
StorageManager::ReleaseAll()), followed by a
+ // DirectFree(), which makes these areas available for cudnn's
subsequent cudaMalloc().
+
+ // Allocate for x (or dx), w (or dw) and y (or dy).
+ ReserveElements({in_shape[conv::kData].Size(),
+ in_shape[conv::kWeight].Size(),
+ out_shape[conv::kOut].Size()});
+
+ // We're about to call cudnnFind so we need to quiet the system by
grabbing
+ // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate
timing
+ // measurements of the algos, and can prevent the cuda driver's proper
freeing
+ // of cudnnFind's internal temporary allocations. Grabbing the lock
might also
+ // impede other threads from launching work on the GPU.
+ std::lock_guard<std::mutex>
lock(Storage::Get()->GetMutex(Context::kGPU));
+ this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+ cudnn_forward_compute_type,
+ cudnn_backward_compute_type,
+ fwd, bwd, flt);
+ }
+ };
+
+ CuDNNConvAlgoReg::Get()->FindOrElseRegister(param_, in_shape, out_shape,
dtype_,
+ cudnn_forward_compute_type,
+ cudnn_backward_compute_type,
+ SMArch(rctx.ctx.dev_id), add_to_weight_,
+ &forward_algo_, &back_algo_,
&back_algo_w_, algo_setter);
+
// If we're allowing Tensor Core variants of the algos to be considered in
// *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
// we must change the descriptor to preclude Tensor Core. Simplest is to
@@ -872,6 +902,7 @@ class CuDNNConvolutionOp {
<< " please consider reducing batch/model size or increasing
the workspace size";
}
+
void GetTempSize(const OpContext& ctx) {
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
size_t back_size = 0, back_size_w = 0;
@@ -970,6 +1001,25 @@ class CuDNNConvolutionOp {
return c;
}
+ // Make a number of allocations and directly free them, ensuring room for an
equivalent set of
+ // cudaMalloc() calls by (say) cudnnFind(). `elements` spec the alloc size
in DTypes, not bytes.
+ void ReserveElements(const std::vector<size_t> &elements) {
+ std::vector<Storage::Handle> handles;
+ for (size_t alloc_element : elements)
+ handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType),
Context::GPU()));
+ for (auto &handle : handles)
+ Storage::Get()->DirectFree(handle);
+ }
+
+ // Log that no suitable algo was found that met the workspace constraints,
then exit.
+ void LogNoSuitableAlgoAndExit(int num_algos_tried, size_t min_memory_needs,
+ size_t workspace_byte, std::string algo_kind) {
+ LOG(FATAL) << num_algos_tried << " " << algo_kind << " with minimum memory
requirement "
+ << min_memory_needs << " bytes have been tried. Workspace size
is set to "
+ << workspace_byte << " bytes, please consider reducing the
batch/model size, "
+ << "or increasing workspace size.";
+ }
+
std::vector<int> param_stride_;
std::vector<int> param_dilate_;
std::vector<int> param_pad_;
diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
index 041bea6..c0c5650 100644
--- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
@@ -26,6 +26,7 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_DECONVOLUTION_INL_H_
#define MXNET_OPERATOR_NN_CUDNN_CUDNN_DECONVOLUTION_INL_H_
+#include <mxnet/storage.h>
#include <algorithm>
#include <vector>
#include <mutex>
@@ -538,245 +539,273 @@ class CuDNNDeconvolutionOp {
}
}
- void SelectAlgo(const RunContext& rctx,
- const std::vector<TShape>& in_shape,
- const std::vector<TShape>& out_shape,
- cudnnDataType_t cudnn_forward_compute_type,
- cudnnDataType_t cudnn_backward_compute_type) {
- if (!CuDNNDeconvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
- cudnn_forward_compute_type,
- cudnn_backward_compute_type,
- SMArch(rctx.ctx.dev_id),
add_to_weight_,
- &forward_algo_, &back_algo_,
&back_algo_w_)) {
- mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
- CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
- size_t workspace_byte = static_cast<size_t>(param_.workspace *
sizeof(DType));
- #if CUDNN_MAJOR >= 7
- // Starting with cuDNNv7, the algo number returned by *Get*() is not the
entire
- // story: the notion of whether the algo ran in Tensor Core mode is not
known.
- // Since we want to report the Tensor Core mode in the verbose output,
we switch
- // to using the new *Get*_v7() call. Since the function signature of
*Get*_v7() matches
- // that of *Find*(), we can unify the find-vs-get logic by using
function pointers.
-
- // Forward Algorithm Find/Get() v7
- std::vector<cudnnConvolutionFwdAlgoPerf_t>
fwd_results(MaxForwardAlgos(s->dnn_handle_));
- int actual_fwd_algos = 0;
- auto fwd_algo_discoverer =
- param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionForwardAlgorithm_v7
+ void CuDNNAlgoSetter(const RunContext& rctx,
+ const std::vector<TShape>& in_shape,
+ const std::vector<TShape>& out_shape,
+ cudnnDataType_t cudnn_forward_compute_type,
+ cudnnDataType_t cudnn_backward_compute_type,
+ CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+ CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+ CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+ // Not in algo registry, must determine via *Get*() or *Find*()
+ mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
+ CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+ size_t workspace_byte = static_cast<size_t>(param_.workspace *
sizeof(DType));
+#if CUDNN_MAJOR >= 7
+ // Starting with cuDNNv7, the algo number returned by *Get*() is not the
entire
+ // story: the notion of whether the algo ran in Tensor Core mode is not
known.
+ // Since we want to report the Tensor Core mode in the verbose output, we
switch
+ // to using the new *Get*_v7() call. Since the function signature of
*Get*_v7() matches
+ // that of *Find*(), we can unify the find-vs-get logic by using function
pointers.
+
+ // Forward Algorithm Find/Get() v7
+ std::vector<cudnnConvolutionFwdAlgoPerf_t>
fwd_results(MaxForwardAlgos(s->dnn_handle_));
+ int actual_fwd_algos = 0;
+ auto fwd_algo_discoverer =
+ param_.cudnn_tune.value() == deconv::kOff ?
cudnnGetConvolutionForwardAlgorithm_v7
:
cudnnFindConvolutionForwardAlgorithm;
- CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
- out_desc_,
- filter_desc_,
- back_conv_desc_, // fwd algo used to
backprop-to-data
- in_desc_,
- fwd_results.size(),
- &actual_fwd_algos,
- fwd_results.data()));
- fwd_results.resize(actual_fwd_algos);
- AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
- cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
- workspace_byte,
&forward_algo_);
-
- // Backprop-to-Filter Algorithm Find/Get() v7
- auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
- std::vector<cudnnConvolutionBwdFilterAlgoPerf_t>
bwd_filt_results(max_bwd_filt_algos);
- int actual_bwd_filter_algos = 0;
- // In cudnn v7.1.4, find() returned wgrad algos that could fail for
large c if we
- // were summing into the output (i.e. beta != 0). Get() returned OK
algos though.
- auto bwd_filter_algo_discoverer =
- param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionBackwardFilterAlgorithm_v7
- :
cudnnFindConvolutionBackwardFilterAlgorithm;
- CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
- out_desc_,
- in_desc_,
- back_conv_desc_,
- filter_desc_,
- bwd_filt_results.size(),
- &actual_bwd_filter_algos,
- bwd_filt_results.data()));
- bwd_filt_results.resize(actual_bwd_filter_algos);
- AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
- cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results,
"backprop-to-filter",
- workspace_byte,
&back_algo_w_);
-
- // Backprop-to-Data Algorithm Find/Get() v7
- auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
- std::vector<cudnnConvolutionBwdDataAlgoPerf_t>
bwd_data_results(max_bwd_data_algos);
- int actual_bwd_data_algos = 0;
- auto bwd_data_algo_discoverer =
- param_.cudnn_tune.value() == conv::kOff ?
cudnnGetConvolutionBackwardDataAlgorithm_v7
+ CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+ out_desc_,
+ filter_desc_,
+ back_conv_desc_, // fwd algo used to
backprop-to-data
+ in_desc_,
+ fwd_results.size(),
+ &actual_fwd_algos,
+ fwd_results.data()));
+ fwd_results.resize(actual_fwd_algos);
+ AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+ cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+ workspace_byte, fwd);
+
+ // Backprop-to-Filter Algorithm Find/Get() v7
+ auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+ std::vector<cudnnConvolutionBwdFilterAlgoPerf_t>
bwd_filt_results(max_bwd_filt_algos);
+ int actual_bwd_filter_algos = 0;
+ // In cudnn v7.1.4, find() returned wgrad algos that could fail for large
c if we
+ // were summing into the output (i.e. beta != 0). Get() returned OK algos
though.
+ auto bwd_filter_algo_discoverer =
+ param_.cudnn_tune.value() == deconv::kOff ?
cudnnGetConvolutionBackwardFilterAlgorithm_v7
+ :
cudnnFindConvolutionBackwardFilterAlgorithm;
+ CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+ out_desc_,
+ in_desc_,
+ back_conv_desc_,
+ filter_desc_,
+ bwd_filt_results.size(),
+ &actual_bwd_filter_algos,
+ bwd_filt_results.data()));
+ bwd_filt_results.resize(actual_bwd_filter_algos);
+ AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+ cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results,
"backprop-to-filter",
+ workspace_byte, flt);
+ // Backprop-to-Data Algorithm Find/Get() v7
+ auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+ std::vector<cudnnConvolutionBwdDataAlgoPerf_t>
bwd_data_results(max_bwd_data_algos);
+ int actual_bwd_data_algos = 0;
+ auto bwd_data_algo_discoverer =
+ param_.cudnn_tune.value() == deconv::kOff ?
cudnnGetConvolutionBackwardDataAlgorithm_v7
:
cudnnFindConvolutionBackwardDataAlgorithm;
- CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+ CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+ filter_desc_,
+ in_desc_,
+ forward_conv_desc_, // bwd algo
used in inference
+ out_desc_,
+ bwd_data_results.size(),
+ &actual_bwd_data_algos,
+ bwd_data_results.data()));
+ bwd_data_results.resize(actual_bwd_data_algos);
+ AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+ cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
+ workspace_byte, bwd);
+#else
+ // CUDNN_MAJOR < 7
+ const int kMaxAlgos = 10;
+ int nalgo = kMaxAlgos;
+ int i = 0;
+ size_t min_memory_needs = 0;
+ // Forward Algorithm Find/Get, v6 and earlier
+ if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
+ // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
is
+ // supported. Hard-coded this since the algo find() or get() throws an
FPE.
+ fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+ } else if (!param_.cudnn_tune.value()) {
+ cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+ CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
+ out_desc_,
+ filter_desc_,
+ back_conv_desc_, // fwd algo
used in dgrad
+ in_desc_,
+
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+ workspace_byte,
+ &fastest_fwd_algo));
+ fwd->Set(fastest_fwd_algo, false);
+ } else {
+ cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
+ CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
+ out_desc_,
+ filter_desc_,
+ back_conv_desc_, // fwd
algo used in dgrad
+ in_desc_,
+ kMaxAlgos,
+ &nalgo,
+ fwd_algo));
+ i = 0;
+ while (i < nalgo
+ && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+ || (param_.cudnn_tune.value() == deconv::kLimited
+ && fwd_algo[i].memory > workspace_byte))) {
+ ++i;
+ min_memory_needs = (i == 0) ?
+ fwd_algo[i].memory :
+ std::min(min_memory_needs, fwd_algo[i].memory);
+ }
+ if (i == nalgo) {
+ LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte,
+ "forward algos (for use in deconv op
backprop-to-data)");
+ } else {
+ fwd->Set(fwd_algo[i].algo, false);
+ }
+ }
+ // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
+ if (!param_.cudnn_tune.value()) {
+ cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
+ CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+ out_desc_,
+ in_desc_,
+ back_conv_desc_,
+ filter_desc_,
+
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+ workspace_byte,
+ &fastest_bwd_filt_algo));
+ flt->Set(fastest_bwd_filt_algo, false);
+ } else {
+ cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
+ CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+ out_desc_,
+ in_desc_,
+ back_conv_desc_,
+ filter_desc_,
+ kMaxAlgos,
+ &nalgo,
+ bwd_filter_algo));
+ i = 0;
+ while (i < nalgo
+ && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
+ || (param_.cudnn_tune.value() == deconv::kLimited
+ && bwd_filter_algo[i].memory > workspace_byte))) {
+ ++i;
+ min_memory_needs = (i == 0) ?
+ bwd_filter_algo[i].memory :
+ std::min(min_memory_needs,
bwd_filter_algo[i].memory);
+ }
+ if (i == nalgo) {
+ LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte,
+ "backward filter algos (for use in deconv op
backprop-to-filter)");
+ } else {
+ flt->Set(bwd_filter_algo[i].algo, false);
+ }
+ }
+ // Backprop-to-Data Algorithm Get(), v6 and earlier
+ if (!param_.cudnn_tune.value()) {
+ cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
+ CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+ filter_desc_,
+ in_desc_,
+ forward_conv_desc_, // bwd algo
used for inference
+ out_desc_,
+
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+ workspace_byte,
+ &fastest_bwd_data_algo));
+ bwd->Set(fastest_bwd_data_algo, false);
+ } else {
+ cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
+ CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
filter_desc_,
in_desc_,
forward_conv_desc_, // bwd algo
used in inference
out_desc_,
- bwd_data_results.size(),
- &actual_bwd_data_algos,
- bwd_data_results.data()));
- bwd_data_results.resize(actual_bwd_data_algos);
- AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
- cudnnConvolutionBwdDataAlgo_t>(bwd_data_results,
"backprop-to-data",
- workspace_byte,
&back_algo_);
- #else
- // CUDNN_MAJOR < 7
- const int kMaxAlgos = 10;
- int nalgo = kMaxAlgos;
- int i = 0;
- size_t min_memory_needs = 0;
- // Forward Algorithm Find/Get, v6 and earlier
- if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
- // In cuDNNv6, for kNHWC, only
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
- // supported. Hard-coded this since the algo find() or get() throws
an FPE.
- forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
- } else if (!param_.cudnn_tune.value()) {
- cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
- CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
- out_desc_,
- filter_desc_,
- back_conv_desc_, // fwd
algo used in dgrad
- in_desc_,
-
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
- workspace_byte,
- &fastest_fwd_algo));
- forward_algo_.Set(fastest_fwd_algo, false);
- } else {
- cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
- CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
- out_desc_,
- filter_desc_,
- back_conv_desc_, // fwd
algo used in dgrad
- in_desc_,
- kMaxAlgos,
- &nalgo,
- fwd_algo));
- i = 0;
- while (i < nalgo
- && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
- || (param_.cudnn_tune.value() == deconv::kLimited
- && fwd_algo[i].memory > workspace_byte))) {
- ++i;
- min_memory_needs = (i == 0) ?
- fwd_algo[i].memory :
- std::min(min_memory_needs, fwd_algo[i].memory);
- }
- if (i == nalgo) {
- LOG(FATAL) << nalgo << " forward algorithms"
- << " (for use in deconvolution operator backprop-to-data)"
- << " with minimum memory requirement " << min_memory_needs
- << " bytes have been tried. Workspace size is set to " <<
workspace_byte
- << " bytes, please consider reducing the batch/model
size,"
- << " or increasing workspace size.";
- } else {
- forward_algo_.Set(fwd_algo[i].algo, false);
- }
+ kMaxAlgos,
+ &nalgo,
+ bwd_data_algo));
+ i = 0;
+ while (i < nalgo
+ && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
+ || (param_.cudnn_tune.value() == deconv::kLimited
+ && bwd_data_algo[i].memory > workspace_byte))) {
+ ++i;
+ min_memory_needs = (i == 0) ?
+ bwd_data_algo[i].memory :
+ std::min(min_memory_needs, bwd_data_algo[i].memory);
}
- // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
- if (!param_.cudnn_tune.value()) {
- cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
- CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
- out_desc_,
- in_desc_,
- back_conv_desc_,
- filter_desc_,
-
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
- workspace_byte,
- &fastest_bwd_filt_algo));
- back_algo_w_.Set(fastest_bwd_filt_algo, false);
+ if (i == nalgo) {
+ LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte,
+ "backward data algos (for use in deconv op
forward inference)");
} else {
- cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
- CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
- out_desc_,
- in_desc_,
- back_conv_desc_,
- filter_desc_,
- kMaxAlgos,
- &nalgo,
-
bwd_filter_algo));
- i = 0;
- while (i < nalgo
- && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
- || (param_.cudnn_tune.value() == deconv::kLimited
- && bwd_filter_algo[i].memory > workspace_byte))) {
- ++i;
- min_memory_needs = (i == 0) ?
- bwd_filter_algo[i].memory :
- std::min(min_memory_needs,
bwd_filter_algo[i].memory);
- }
- if (i == nalgo) {
- LOG(FATAL) << nalgo << " backward filter algorithms"
- << " (for use in deconvolution operator
backprop-to-filter)"
- << " with minimum memory requirement " << min_memory_needs
- << " bytes have been tried. Workspace size is set to " <<
workspace_byte
- << " bytes, please consider reducing the batch/model
size,"
- << " or increasing workspace size.";
- } else {
- back_algo_w_.Set(bwd_filter_algo[i].algo, false);
- }
+ bwd->Set(bwd_data_algo[i].algo, false);
}
- // Backprop-to-Data Algorithm Get(), v6 and earlier
- if (!param_.cudnn_tune.value()) {
- cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
- CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
- filter_desc_,
- in_desc_,
- forward_conv_desc_, // bwd algo
used for inference
- out_desc_,
-
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
- workspace_byte,
- &fastest_bwd_data_algo));
- back_algo_.Set(fastest_bwd_data_algo, false);
+ }
+#endif // CUDNN_MAJOR < 7
+
+ // Fix for issue #11241
+ int cudnn_find_issue_max_features = 64 * 1024;
+ // With deconvolution, the algo sensitivity is to a large number of output
features
+ if (add_to_weight_ && Features(out_shape[deconv::kOut]) >=
cudnn_find_issue_max_features) {
+ flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+ }
+ }
+
+ void SelectAlgo(const RunContext& rctx,
+ const std::vector<TShape>& in_shape,
+ const std::vector<TShape>& out_shape,
+ cudnnDataType_t cudnn_forward_compute_type,
+ cudnnDataType_t cudnn_backward_compute_type) {
+ auto algo_setter = [&](CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+ CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+ CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+ if (param_.cudnn_tune.value() == deconv::kOff) {
+ // The routine will only be calling cudnnGet, so no need to grab the
Storage lock.
+ this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+ cudnn_forward_compute_type,
+ cudnn_backward_compute_type,
+ fwd, bwd, flt);
} else {
- cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
- CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
- filter_desc_,
- in_desc_,
- forward_conv_desc_, // bwd
algo used in inference
- out_desc_,
- kMaxAlgos,
- &nalgo,
- bwd_data_algo));
- i = 0;
- while (i < nalgo
- && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
- || (param_.cudnn_tune.value() == deconv::kLimited
- && bwd_data_algo[i].memory > workspace_byte))) {
- ++i;
- min_memory_needs = (i == 0) ?
- bwd_data_algo[i].memory :
- std::min(min_memory_needs,
bwd_data_algo[i].memory);
- }
- if (i == nalgo) {
- LOG(FATAL) << nalgo << " backward data algorithms"
- << " (for use in deconvolution operator forward
inference) with"
- << " minimum memory requirement " << min_memory_needs
- << " bytes have been tried. Workspace size is set to " <<
workspace_byte
- << " bytes, please consider reducing the batch/model
size,"
- << " or increasing workspace size.";
- } else {
- back_algo_.Set(bwd_data_algo[i].algo, false);
- }
- }
- #endif // CUDNN_MAJOR < 7
+ // One potential problem is that cudnnFind() uses cudaMalloc() to
directly allocate
+ // I/O and workspace areas, and these allocations may result in an
out-of-memory
+ // error even though the StorageMangager free pool is not empty.
Ideally, cudnnFind
+ // would use MXNet's storage allocator for its I/O and workspace
areas, instead of using
+ // the area carved out by MXNET_GPU_MEM_POOL_RESERVE.
+ // To get somewhat the same effect as this, we can pre-allocate the
areas needed for the
+ // I/Os (possibly triggering a desirable
StorageManager::ReleaseAll()), followed by a
+ // DirectFree(), which makes these areas available for cudnn's
subsequent cudaMalloc().
+
+ // Allocate for x (or dx), w (or dw) and y (or dy).
+ ReserveElements({in_shape[conv::kData].Size(),
+ in_shape[conv::kWeight].Size(),
+ out_shape[conv::kOut].Size()});
- // Fix for issue #11241
- int cudnn_find_issue_max_features = 64 * 1024;
- // With deconvolution, the algo sensitivity is to a large number of
output features
- if (add_to_weight_ && Features(out_shape[deconv::kOut]) >=
cudnn_find_issue_max_features) {
- this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+ // We're about to call cudnnFind so we need to quiet the system by
grabbing
+ // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate
timing
+ // measurements of the algos, and can prevent the cuda driver's proper
freeing
+ // of cudnnFind's internal temporary allocations. Grabbing the lock
might also
+ // impede other threads from launching work on the GPU.
+ std::lock_guard<std::mutex>
lock(Storage::Get()->GetMutex(Context::kGPU));
+ this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+ cudnn_forward_compute_type,
+ cudnn_backward_compute_type,
+ fwd, bwd, flt);
}
+ };
+
+ // An algo specification by the user may be cached here, but another
+ // convolution will match only if identically specified.
+ // We're caching results of *Get* as well as *Find*, but these records
+ // will be held distinctly because param_.cudnn_tune is part of the key.
+ CuDNNDeconvAlgoReg::Get()->FindOrElseRegister(param_, in_shape, out_shape,
dtype_,
+ cudnn_forward_compute_type,
+ cudnn_backward_compute_type,
+ SMArch(rctx.ctx.dev_id),
add_to_weight_,
+ &forward_algo_, &back_algo_,
&back_algo_w_, algo_setter);
- // An algo specification by the user may be cached here, but another
- // convolution will match only if identically specified.
- // We're caching results of *Get* as well as *Find*, but these records
- // will be held distinctly because param_.cudnn_tune is part of the key.
- CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
- cudnn_forward_compute_type,
- cudnn_backward_compute_type,
- SMArch(rctx.ctx.dev_id),
this->add_to_weight_,
- this->forward_algo_,
- this->back_algo_,
this->back_algo_w_);
- }
// If we're allowing Tensor Core variants of the algos to be considered in
// *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
// we must change the descriptor to preclude Tensor Core. Simplest is to
@@ -818,6 +847,7 @@ class CuDNNDeconvolutionOp {
<< " please consider reducing batch/model size or increasing
the workspace size";
}
+
void GetTempSize(const OpContext& ctx) {
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
size_t back_data_algo_workspace_size = 0;
@@ -921,6 +951,26 @@ class CuDNNDeconvolutionOp {
return c;
}
+ // Make a number of allocations and directly free them, ensuring room for an
equivalent set of
+ // cudaMalloc() calls by (say) cudnnFind(). `elements` spec the alloc size
in DTypes, not bytes.
+ void ReserveElements(const std::vector<size_t> &elements) {
+ std::vector<Storage::Handle> handles;
+ for (size_t alloc_element : elements)
+ handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType),
Context::GPU()));
+ for (auto &handle : handles)
+ Storage::Get()->DirectFree(handle);
+ }
+
+
+ // Log that no suitable algo was found that met the workspace constraints,
then exit.
+ void LogNoSuitableAlgoAndExit(int num_algos_tried, size_t min_memory_needs,
+ size_t workspace_byte, std::string algo_kind) {
+ LOG(FATAL) << num_algos_tried << " " << algo_kind << " with minimum memory
requirement "
+ << min_memory_needs << " bytes have been tried. Workspace size
is set to "
+ << workspace_byte << " bytes, please consider reducing the
batch/model size, "
+ << "or increasing workspace size.";
+ }
+
std::vector<int> param_stride_;
std::vector<int> param_dilate_;
diff --git a/src/storage/pooled_storage_manager.h
b/src/storage/pooled_storage_manager.h
index f3a9b16..cade8d9 100644
--- a/src/storage/pooled_storage_manager.h
+++ b/src/storage/pooled_storage_manager.h
@@ -57,6 +57,11 @@ class GPUPooledStorageManager final : public StorageManager {
GPUPooledStorageManager() {
reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5);
page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096);
+ large_alloc_round_size_ =
dmlc::GetEnv("MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE", 2 * 1024 * 1024);
+ if (large_alloc_round_size_ <= 0) {
+ LOG(FATAL) << "MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE cannot be set to a
value <= 0, found: "
+ << large_alloc_round_size_;
+ }
if (page_size_ < NDEV) {
LOG(FATAL) << "MXNET_GPU_MEM_POOL_PAGE_SIZE cannot be set to a value
smaller than " << NDEV \
<< ". Got " << page_size_ << ".";
@@ -80,7 +85,7 @@ class GPUPooledStorageManager final : public StorageManager {
private:
void DirectFreeNoLock(Storage::Handle handle) {
cudaError_t err = cudaFree(handle.dptr);
- size_t size = std::max(handle.size, page_size_);
+ size_t size = RoundAllocSize(handle.size);
// ignore unloading error, as memory has already been recycled
if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
@@ -88,12 +93,31 @@ class GPUPooledStorageManager final : public StorageManager
{
used_memory_ -= size;
}
+ // Round a value 'x' up to the next multiple of 'multiple'
+ size_t RoundToMultiple(size_t x, size_t multiple) {
+ size_t retVal = ((x + multiple - 1) / multiple) * multiple;
+ return retVal;
+ }
+
+ size_t RoundAllocSize(size_t size) {
+ // Round up small allocs to the page_size_ to consolidate the pool lookups
+ size = std::max(size, page_size_);
+ // To ensure proper freeing under some driver variants, make sure
+ // large allocs entirely occupy their slabs, which cannot then be
+ // locked by smaller permanent allocations sharing the slab.
+ if (size > large_alloc_round_size_)
+ size = RoundToMultiple(size, large_alloc_round_size_);
+ return size;
+ }
+
private:
void ReleaseAll();
// used memory
size_t used_memory_ = 0;
// page size
size_t page_size_;
+ // size that large allocations should be rounded to, for proper freeing.
+ size_t large_alloc_round_size_;
// percentage of reserved memory
int reserve_;
// number of devices
@@ -105,7 +129,7 @@ class GPUPooledStorageManager final : public StorageManager
{
void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
- size_t size = std::max(handle->size, page_size_);
+ size_t size = RoundAllocSize(handle->size);
auto&& reuse_it = memory_pool_.find(size);
if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) {
size_t free, total;
@@ -130,7 +154,7 @@ void GPUPooledStorageManager::Alloc(Storage::Handle*
handle) {
void GPUPooledStorageManager::Free(Storage::Handle handle) {
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
- size_t size = std::max(handle.size, page_size_);
+ size_t size = RoundAllocSize(handle.size);
auto&& reuse_pool = memory_pool_[size];
reuse_pool.push_back(handle.dptr);
}
diff --git a/tests/python/gpu/test_gluon_gpu.py
b/tests/python/gpu/test_gluon_gpu.py
index ac7df62..80c28d9 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -25,12 +25,14 @@ import unittest
import mxnet as mx
import numpy as np
import unittest
+import math
from nose.tools import assert_raises
from mxnet.test_utils import check_consistency, set_default_context,
assert_almost_equal
from mxnet.base import MXNetError
from mxnet import autograd
from numpy.testing import assert_allclose
+
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown,
assert_raises_cudnn_disabled
@@ -57,7 +59,7 @@ def check_rnn_layer(layer):
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
-
+@with_seed()
def check_rnn_layer_w_rand_inputs(layer):
layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
x = mx.nd.uniform(shape=(10, 16, 30))
@@ -182,7 +184,7 @@ def _check_batchnorm_result(input, num_devices=1,
cuda=False):
input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for
output in inputs2], dim=0)
assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(),
atol=1e-3, rtol=1e-3)
-
+@with_seed()
def test_sync_batchnorm():
def get_num_devices():
for i in range(100):
@@ -200,6 +202,43 @@ def test_sync_batchnorm():
num_devices=ndev, cuda=True)
@with_seed()
+def test_large_models():
+ ctx = default_context()
+ # Create model
+ net = gluon.nn.HybridSequential()
+
+ largest_num_features = 256
+ with net.name_scope():
+ net.add(nn.Conv2D(largest_num_features, 3))
+
+ net.hybridize()
+ net.initialize(mx.init.Normal(sigma=0.01), ctx=ctx)
+
+ # Compute the height (=width) of the square tensor of the given size in
bytes
+ def tensor_size(big_tensor_bytes):
+ bytes_per_float = 4
+ sz = int(math.sqrt(big_tensor_bytes / largest_num_features /
bytes_per_float))
+ return (sz // 100) * 100
+
+ # The idea is to create models with large tensors of (say) 20% of the
total memory.
+ # This in the past has given cudnnFind() trouble when it needed to
allocate similar I/O's
+ # from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by
default 5%).
+ (free_mem_bytes, total_mem_bytes) =
mx.context.gpu_memory_info(ctx.device_id)
+ start_size = tensor_size(0.20 * total_mem_bytes)
+ num_trials = 10
+ sys.stderr.write(' testing global memory of size {} ...
'.format(total_mem_bytes))
+ sys.stderr.flush()
+ for i in range(num_trials):
+ sz = start_size - 10 * i
+ (height, width) = (sz,sz)
+ sys.stderr.write(" {}x{} ".format(height,width))
+ sys.stderr.flush()
+ data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height,
width),
+ ctx=ctx, dtype="float32")
+ # Evaluate model
+ net(data_in).asnumpy()
+
+@with_seed()
def test_symbol_block_fp16():
# Test case to verify if initializing the SymbolBlock from a model with
params
# other than fp32 param dtype.