This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f5ab57d34 [Frontend][ONNX] Fix softmax converter when input shape is 
dynamic (#11507)
4f5ab57d34 is described below

commit 4f5ab57d348e97b707d0707f9272cebe03a79777
Author: ChunPing Chung <[email protected]>
AuthorDate: Fri Jun 3 00:28:38 2022 +0800

    [Frontend][ONNX] Fix softmax converter when input shape is dynamic (#11507)
    
    * [Frontend][ONNX] Fix softmax converter when input shape is dynamic
    
    * [Frontend][ONNX] mark dynamic softmax tests as xfailed with cuda
---
 python/tvm/relay/frontend/onnx.py          |  2 ++
 tests/python/frontend/onnx/test_forward.py | 37 ++++++++++++++++++++++++------
 2 files changed, 32 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 30e8188a83..997aa6240e 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2420,6 +2420,8 @@ class Softmax(OnnxOpConverter):
             axis += ndim
         if axis == 0:
             reshape_shape = [-1]
+        elif axis == ndim - 1:
+            return _op.nn.softmax(inputs[0], axis=axis)
         else:
             axis_val = [in_shape[i] for i in range(axis)]
             reshape_shape = [np.prod(axis_val)] + [-1]
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index dbc5147e20..c4cd93aa7d 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1589,26 +1589,45 @@ def test_upsample3d_trilinear(target, dev):
     tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
 
+# TODO: Fix softmax with dynamic input on cuda and enable this test
[email protected]_failing_targets("cuda")
 @tvm.testing.parametrize_targets
 def test_softmax(target, dev):
-    def verify_softmax(inshape, axis):
+    def verify_softmax(inshape, axis, opset=None, dynamic=False):
         opname = "Softmax"
-        indata = np.random.uniform(size=inshape).astype(np.float32)
         outshape = inshape
-        y = helper.make_node(opname, ["in"], ["out"])
+        node_list = []
+        input_node_list = [helper.make_tensor_value_info("in", 
TensorProto.FLOAT, list(inshape))]
+        output_node_list = [helper.make_tensor_value_info("out", 
TensorProto.FLOAT, list(outshape))]
+        input_list = [np.random.uniform(size=inshape).astype(np.float32)]
+        softmax_inputs = ["in"]
+
+        if dynamic:
+            input_node_list.append(
+                helper.make_tensor_value_info("shape", TensorProto.INT64, 
[len(inshape)])
+            )
+            input_list.append(np.asarray(inshape))
+            reshape_node = helper.make_node("Reshape", ["in", "shape"], 
["dynamic_in"])
+            softmax_inputs[0] = "dynamic_in"
+            node_list += [reshape_node]
+
+        y = helper.make_node(opname, softmax_inputs, ["out"])
         if axis is not None:
             axis_attr = helper.make_attribute("axis", axis)
             y.attribute.append(axis_attr)
+        node_list.append(y)
 
         graph = helper.make_graph(
-            [y],
+            node_list,
             opname + "_test",
-            inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(indata.shape))],
-            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(outshape))],
+            inputs=input_node_list,
+            outputs=output_node_list,
         )
 
         model = helper.make_model(graph, producer_name=opname + "_test")
-        verify_with_ort_with_inputs(model, [indata], target=target, dev=dev)
+        verify_with_ort_with_inputs(
+            model, input_list, use_vm=True, opset=opset, target=target, dev=dev
+        )
 
     verify_softmax((1, 10), None)
     verify_softmax((1, 10), 1)
@@ -1616,6 +1635,10 @@ def test_softmax(target, dev):
     verify_softmax((1, 2, 3, 10), 2)
     verify_softmax((1, 2, 3, 4, 10), 3)
     verify_softmax((1, 2, 3, 4, 10), 4)
+    verify_softmax((1, 10), -1, dynamic=True)
+    verify_softmax((1, 2, 3, 10), -1, dynamic=True)
+    verify_softmax((1, 10), -1, opset=8, dynamic=True)
+    verify_softmax((1, 2, 3, 10), -1, opset=8, dynamic=True)
 
 
 @tvm.testing.parametrize_targets

Reply via email to