This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 57af88f [ONNX] Use relay softmax op to convert Softmax if posssible
(#9892)
57af88f is described below
commit 57af88faaab56b347bce97076b272464ec796aeb
Author: Masahiro Masuda <[email protected]>
AuthorDate: Tue Jan 11 19:05:55 2022 +0900
[ONNX] Use relay softmax op to convert Softmax if posssible (#9892)
---
python/tvm/relay/frontend/onnx.py | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index b8bbcf8..263eb85 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1967,6 +1967,11 @@ class Softmax(OnnxOpConverter):
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
+ # Older ONNX Softmax op does not properly support inputs of dimension
> 2
+ # But we can use our softmax when the axis is -1
+ if axis == ndim - 1:
+ return _op.nn.softmax(inputs[0], axis=axis)
+
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
@@ -1974,16 +1979,12 @@ class Softmax(OnnxOpConverter):
return e / _op.sum(e, axes, keepdims=True)
@classmethod
- def _impl_v13(cls, inputs, attr, params):
+ def _impl_v13(cls, inputs, attr, _):
axis = attr.get("axis", -1)
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
- axes = [axis]
- x = inputs[0]
- m = _op.max(x, axes, keepdims=True)
- e = _op.exp(x - m)
- return e / _op.sum(e, axes, keepdims=True)
+ return _op.nn.softmax(inputs[0], axis=axis)
class LogSoftmax(OnnxOpConverter):