This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f0dc42 [Frontend][MXNet] add _npi_stack, issue #7186 (#7209)
3f0dc42 is described below
commit 3f0dc420ff9b891a79f55181886b536ecc337796
Author: insop <[email protected]>
AuthorDate: Tue Jan 5 17:17:02 2021 -0800
[Frontend][MXNet] add _npi_stack, issue #7186 (#7209)
- https://github.com/apache/tvm/issues/7186
- add MxNet stack, `_npi_stack`
-
https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.stack.html?highlight=stack
---
python/tvm/relay/frontend/mxnet.py | 9 +++++++++
tests/python/frontend/mxnet/test_forward.py | 28 ++++++++++++++++++++++++++++
2 files changed, 37 insertions(+)
diff --git a/python/tvm/relay/frontend/mxnet.py
b/python/tvm/relay/frontend/mxnet.py
index 1085e90..b272ead 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -2335,6 +2335,14 @@ def _mx_npi_concatenate(inputs, attrs):
return _op.concatenate(tuple(inputs), axis=int(axis))
+def _mx_npi_stack(inputs, attrs):
+ axis = attrs.get_str("axis", "0")
+ if axis == "None":
+ return _op.reshape(_op.stack(tuple(inputs), axis=0), (-1,))
+ else:
+ return _op.stack(tuple(inputs), axis=int(axis))
+
+
def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
@@ -2700,6 +2708,7 @@ _convert_map = {
"_npi_less_equal": _mx_compare(_op.less_equal, _rename),
"_npi_tanh": _rename(_op.tanh),
"_npi_true_divide_scalar": _binop_scalar(_op.divide),
+ "_npi_stack": _mx_npi_stack,
}
# set identity list
diff --git a/tests/python/frontend/mxnet/test_forward.py
b/tests/python/frontend/mxnet/test_forward.py
index d3be8c0..537349e 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -2012,6 +2012,34 @@ def test_forward_npi_concatenate(data_shape1,
data_shape2, axis, dtype, target,
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
[email protected](
+ "data_shape1, data_shape2, axis",
+ [
+ ((3,), (3,), 0),
+ ((3,), (3,), -1),
+ ((1, 3, 2), (1, 3, 2), 2),
+ ((1, 3, 3), (1, 3, 3), 1),
+ ((1, 3), (1, 3), 0),
+ ],
+)
[email protected]("dtype", ["float64", "float32", "int64", "int32"])
[email protected]_targets
[email protected]("kind", ["graph", "vm", "debug"])
+def test_forward_npi_stack(data_shape1, data_shape2, axis, dtype, target, ctx,
kind):
+ data_np1 = np.random.uniform(size=data_shape1).astype(dtype)
+ data_np2 = np.random.uniform(size=data_shape2).astype(dtype)
+ data1 = mx.sym.var("data1")
+ data2 = mx.sym.var("data2")
+ ref_res = mx.np.stack([mx.np.array(data_np1), mx.np.array(data_np2)],
axis=axis)
+ mx_sym = mx.sym.np.stack([data1.as_np_ndarray(), data2.as_np_ndarray()],
axis=axis)
+ mod, _ = relay.frontend.from_mxnet(
+ mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype
+ )
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(data_np1, data_np2)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+
@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2,
3, 1), (1, 8)])
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32",
"bool"])
@tvm.testing.parametrize_targets