This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new caa0d59c33 [ONNX] Add more dynamism to Eyelike (#11615)
caa0d59c33 is described below
commit caa0d59c335713d29b1e63714395fc2ba3d979dc
Author: An Wang <[email protected]>
AuthorDate: Wed Jun 22 12:33:13 2022 -0700
[ONNX] Add more dynamism to Eyelike (#11615)
* add dynamism-okness to eyelike onnx importer
* add dynamism to eyelike
* add more dynamism robustness to eyelike onnx importer
* noop
---
python/tvm/relay/frontend/onnx.py | 15 +++++++++------
python/tvm/relay/op/_transform.py | 1 +
tests/python/frontend/onnx/test_forward.py | 29 +++++++++++++++++++++++------
3 files changed, 33 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 595f12d4d5..352eb99ba4 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2124,18 +2124,21 @@ class EyeLike(OnnxOpConverter):
@classmethod
def _impl_v9(cls, inputs, attr, params):
- in_checked_type = infer_type(inputs[0]).checked_type
- in_dtype = in_checked_type.dtype
- in_shape = list(get_const_tuple(in_checked_type.shape))
dtype = attr.get("dtype", None)
if dtype is None:
+ in_checked_type = infer_type(inputs[0]).checked_type
+ in_dtype = in_checked_type.dtype
dtype = in_dtype
else:
dtype = get_type(dtype)
+
+ in_shape = _op.shape_of(inputs[0])
zeros = _op.zeros(in_shape, dtype)
- dim = in_shape[0]
- indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32")
- ones = _op.full(_op.const(1), (dim,), dtype=dtype)
+
+ dim = _op.take(in_shape, _op.const(0))
+
+ indices = _op.arange(_op.const(0), dim, dtype="int32")
+ ones = _op.full(_op.const(1), _op.reshape(dim, (1,)), dtype=dtype)
k = _op.const(attr.get("k", 0), dtype="int32")
return _op.scatter_nd(zeros, _op.stack([indices, indices + k],
axis=0), ones, "update")
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index 90507ce29a..baf616a946 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -677,6 +677,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
+_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)
@script
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index c58e920ead..12292a6fb7 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5089,28 +5089,45 @@ def test_cumsum(target, dev):
@tvm.testing.parametrize_targets
def test_eyelike(target, dev):
- def verify_eyelike(indata):
+ def verify_eyelike(indata, dynamic=False):
+ node_list = []
+ eyelike_inputs = ["X"]
+ input_node_list = [
+ helper.make_tensor_value_info("X", TensorProto.FLOAT,
list(indata.shape))
+ ]
+ input_list = [indata]
+
+ if dynamic:
+ input_node_list.append(
+ helper.make_tensor_value_info("shape", TensorProto.INT64,
[len(indata.shape)])
+ )
+ input_list.append(np.asarray(indata.shape))
+ reshape_node = helper.make_node("Reshape", ["X", "shape"],
["X_dyn"])
+ eyelike_inputs[0] = "X_dyn"
+ node_list += [reshape_node]
+
node = helper.make_node(
"EyeLike",
- inputs=["X"],
+ inputs=eyelike_inputs,
outputs=["Y"],
)
+ node_list.append(node)
graph = helper.make_graph(
- [node],
+ node_list,
"eyelike_test",
- inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT,
list(indata.shape))],
+ inputs=input_node_list,
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT,
list(indata.shape))],
)
model = helper.make_model(graph, producer_name="eyelike_test")
-
verify_with_ort_with_inputs(
- model, [indata], dtype="float32", opset=9, target=target, dev=dev
+ model, input_list, dtype="float32", opset=9, target=target,
dev=dev, use_vm=True
)
input_data = np.zeros((5, 5), dtype=np.float32)
verify_eyelike(input_data)
+ verify_eyelike(input_data, True)
"""