This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d814733 Added query for cuDNN BN min. epsilon. Enabled choice of BN
impl. for ONNX import (#11380)
d814733 is described below
commit d814733f16d50b2dd51b6f3e7e3256b5e66c8026
Author: Marek Kolodziej <[email protected]>
AuthorDate: Tue Jul 10 22:36:32 2018 -0700
Added query for cuDNN BN min. epsilon. Enabled choice of BN impl. for ONNX
import (#11380)
---
python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 2b98aa0..61f342a 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -21,7 +21,6 @@
import numpy as np
from . import _translation_utils as translation_utils
from .... import symbol
-
# Method definitions for the callable objects mapped in the import_helper
module
def identity(attrs, inputs, proto_obj):
@@ -209,7 +208,10 @@ def batch_norm(attrs, inputs, proto_obj):
'is_test':
'fix_gamma'})
new_attrs = translation_utils._remove_attributes(new_attrs,
['spatial',
'consumed_inputs'])
- new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'cudnn_off': 1})
+ # Disable cuDNN BN only if epsilon from model is < than minimum cuDNN eps
(1e-5)
+ cudnn_min_eps = 1e-5
+ cudnn_off = 0 if attrs.get('epsilon', cudnn_min_eps) >= cudnn_min_eps else
1
+ new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'cudnn_off': cudnn_off})
# in test mode "fix_gamma" should be unset.
new_attrs['fix_gamma'] = not attrs.get('is_test', 1)