This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4f5ab57d34 [Frontend][ONNX] Fix softmax converter when input shape is
dynamic (#11507)
4f5ab57d34 is described below
commit 4f5ab57d348e97b707d0707f9272cebe03a79777
Author: ChunPing Chung <[email protected]>
AuthorDate: Fri Jun 3 00:28:38 2022 +0800
[Frontend][ONNX] Fix softmax converter when input shape is dynamic (#11507)
* [Frontend][ONNX] Fix softmax converter when input shape is dynamic
* [Frontend][ONNX] mark dynamic softmax tests as xfailed with cuda
---
python/tvm/relay/frontend/onnx.py | 2 ++
tests/python/frontend/onnx/test_forward.py | 37 ++++++++++++++++++++++++------
2 files changed, 32 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 30e8188a83..997aa6240e 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2420,6 +2420,8 @@ class Softmax(OnnxOpConverter):
axis += ndim
if axis == 0:
reshape_shape = [-1]
+ elif axis == ndim - 1:
+ return _op.nn.softmax(inputs[0], axis=axis)
else:
axis_val = [in_shape[i] for i in range(axis)]
reshape_shape = [np.prod(axis_val)] + [-1]
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index dbc5147e20..c4cd93aa7d 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1589,26 +1589,45 @@ def test_upsample3d_trilinear(target, dev):
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
+# TODO: Fix softmax with dynamic input on cuda and enable this test
[email protected]_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_softmax(target, dev):
- def verify_softmax(inshape, axis):
+ def verify_softmax(inshape, axis, opset=None, dynamic=False):
opname = "Softmax"
- indata = np.random.uniform(size=inshape).astype(np.float32)
outshape = inshape
- y = helper.make_node(opname, ["in"], ["out"])
+ node_list = []
+ input_node_list = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(inshape))]
+ output_node_list = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outshape))]
+ input_list = [np.random.uniform(size=inshape).astype(np.float32)]
+ softmax_inputs = ["in"]
+
+ if dynamic:
+ input_node_list.append(
+ helper.make_tensor_value_info("shape", TensorProto.INT64,
[len(inshape)])
+ )
+ input_list.append(np.asarray(inshape))
+ reshape_node = helper.make_node("Reshape", ["in", "shape"],
["dynamic_in"])
+ softmax_inputs[0] = "dynamic_in"
+ node_list += [reshape_node]
+
+ y = helper.make_node(opname, softmax_inputs, ["out"])
if axis is not None:
axis_attr = helper.make_attribute("axis", axis)
y.attribute.append(axis_attr)
+ node_list.append(y)
graph = helper.make_graph(
- [y],
+ node_list,
opname + "_test",
- inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT,
list(indata.shape))],
- outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT,
list(outshape))],
+ inputs=input_node_list,
+ outputs=output_node_list,
)
model = helper.make_model(graph, producer_name=opname + "_test")
- verify_with_ort_with_inputs(model, [indata], target=target, dev=dev)
+ verify_with_ort_with_inputs(
+ model, input_list, use_vm=True, opset=opset, target=target, dev=dev
+ )
verify_softmax((1, 10), None)
verify_softmax((1, 10), 1)
@@ -1616,6 +1635,10 @@ def test_softmax(target, dev):
verify_softmax((1, 2, 3, 10), 2)
verify_softmax((1, 2, 3, 4, 10), 3)
verify_softmax((1, 2, 3, 4, 10), 4)
+ verify_softmax((1, 10), -1, dynamic=True)
+ verify_softmax((1, 2, 3, 10), -1, dynamic=True)
+ verify_softmax((1, 10), -1, opset=8, dynamic=True)
+ verify_softmax((1, 2, 3, 10), -1, opset=8, dynamic=True)
@tvm.testing.parametrize_targets