This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 57af88f  [ONNX] Use relay softmax op to convert Softmax if posssible 
(#9892)
57af88f is described below

commit 57af88faaab56b347bce97076b272464ec796aeb
Author: Masahiro Masuda <[email protected]>
AuthorDate: Tue Jan 11 19:05:55 2022 +0900

    [ONNX] Use relay softmax op to convert Softmax if posssible (#9892)
---
 python/tvm/relay/frontend/onnx.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index b8bbcf8..263eb85 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1967,6 +1967,11 @@ class Softmax(OnnxOpConverter):
         ndim = len(infer_shape(inputs[0]))
         if axis < 0:
             axis += ndim
+        # Older ONNX Softmax op does not properly support inputs of dimension 
> 2
+        # But we can use our softmax when the axis is -1
+        if axis == ndim - 1:
+            return _op.nn.softmax(inputs[0], axis=axis)
+
         axes = list(range(axis, ndim))
         x = inputs[0]
         m = _op.max(x, axes, keepdims=True)
@@ -1974,16 +1979,12 @@ class Softmax(OnnxOpConverter):
         return e / _op.sum(e, axes, keepdims=True)
 
     @classmethod
-    def _impl_v13(cls, inputs, attr, params):
+    def _impl_v13(cls, inputs, attr, _):
         axis = attr.get("axis", -1)
         ndim = len(infer_shape(inputs[0]))
         if axis < 0:
             axis += ndim
-        axes = [axis]
-        x = inputs[0]
-        m = _op.max(x, axes, keepdims=True)
-        e = _op.exp(x - m)
-        return e / _op.sum(e, axes, keepdims=True)
+        return _op.nn.softmax(inputs[0], axis=axis)
 
 
 class LogSoftmax(OnnxOpConverter):

Reply via email to