KellenSunderland commented on a change in pull request #15812: cuDNN support
cleanup
URL: https://github.com/apache/incubator-mxnet/pull/15812#discussion_r313127388
##########
File path: src/operator/nn/cudnn/cudnn_convolution-inl.h
##########
@@ -577,29 +427,19 @@ class CuDNNConvolutionOp {
oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW);
}
// Set "allow tensor core" flag in convolution descriptors, if available.
- #if CUDNN_MAJOR >= 7
- cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH
- : CUDNN_DEFAULT_MATH;
- #if CUDNN_VERSION >= 7200
- if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
- (DataType<DType>::kFlag != kFloat16))
- math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
- #endif
- CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type));
- CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type));
- CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type));
- CUDNN_CALL(cudnnSetConvolutionGroupCount(forward_conv_desc_,
param_.num_group));
- CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_,
param_.num_group));
- CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_,
param_.num_group));
- #endif
-
- #if CUDNN_MAJOR <= 6
- dshape[1] /= param_.num_group;
- oshape[1] /= param_.num_group;
- #endif
- weight_offset_ = wshape.Size();
- data_offset_ = dstride[1] * dshape[1];
- out_offset_ = ostride[1] * oshape[1];
+ cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH
+ : CUDNN_DEFAULT_MATH;
+#if CUDNN_VERSION >= 7200
+ if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
+ (DataType<DType>::kFlag != kFloat16))
+ math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
Review comment:
Non-blocking for this PR, but this change reminded me that I believe there's
a bug here. IIRC the problem was that math_type can be reset but auto-tuning.
Issue I opened a while ago that I haven't had time to follow up on:
https://github.com/apache/incubator-mxnet/issues/14684
"Auto-tuning is overwriting the math mode at convolution tuning time.
Probably the right thing to do when implementing TCs but it's preventing the
conversion math type from being used. We'll have to think about the long-term
fix for this, but I've currently commented out the math type reset locally and
I'm trying to verify this cudnn feature provides a significant speedup before
moving forward. "
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services