szha closed pull request #11380: Add ability to query cuDNN BatchNorm min.
epsilon. Allow ONNX importer to use cuDNN BN if chosen eps >= cuDNN min. eps.
URL: https://github.com/apache/incubator-mxnet/pull/11380
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 2b98aa08feb..61f342a9ae9 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -21,7 +21,6 @@
import numpy as np
from . import _translation_utils as translation_utils
from .... import symbol
-
# Method definitions for the callable objects mapped in the import_helper
module
def identity(attrs, inputs, proto_obj):
@@ -209,7 +208,10 @@ def batch_norm(attrs, inputs, proto_obj):
'is_test':
'fix_gamma'})
new_attrs = translation_utils._remove_attributes(new_attrs,
['spatial',
'consumed_inputs'])
- new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'cudnn_off': 1})
+ # Disable cuDNN BN only if epsilon from model is < than minimum cuDNN eps
(1e-5)
+ cudnn_min_eps = 1e-5
+ cudnn_off = 0 if attrs.get('epsilon', cudnn_min_eps) >= cudnn_min_eps else
1
+ new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'cudnn_off': cudnn_off})
# in test mode "fix_gamma" should be unset.
new_attrs['fix_gamma'] = not attrs.get('is_test', 1)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services