This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new caa0d59c33 [ONNX] Add more dynamism to Eyelike (#11615)
caa0d59c33 is described below

commit caa0d59c335713d29b1e63714395fc2ba3d979dc
Author: An Wang <[email protected]>
AuthorDate: Wed Jun 22 12:33:13 2022 -0700

    [ONNX] Add more dynamism to Eyelike (#11615)
    
    * add dynamism-okness to eyelike onnx importer
    
    * add dynamism to eyelike
    
    * add more dynamism robustness to eyelike onnx importer
    
    * noop
---
 python/tvm/relay/frontend/onnx.py          | 15 +++++++++------
 python/tvm/relay/op/_transform.py          |  1 +
 tests/python/frontend/onnx/test_forward.py | 29 +++++++++++++++++++++++------
 3 files changed, 33 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 595f12d4d5..352eb99ba4 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2124,18 +2124,21 @@ class EyeLike(OnnxOpConverter):
 
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
-        in_checked_type = infer_type(inputs[0]).checked_type
-        in_dtype = in_checked_type.dtype
-        in_shape = list(get_const_tuple(in_checked_type.shape))
         dtype = attr.get("dtype", None)
         if dtype is None:
+            in_checked_type = infer_type(inputs[0]).checked_type
+            in_dtype = in_checked_type.dtype
             dtype = in_dtype
         else:
             dtype = get_type(dtype)
+
+        in_shape = _op.shape_of(inputs[0])
         zeros = _op.zeros(in_shape, dtype)
-        dim = in_shape[0]
-        indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32")
-        ones = _op.full(_op.const(1), (dim,), dtype=dtype)
+
+        dim = _op.take(in_shape, _op.const(0))
+
+        indices = _op.arange(_op.const(0), dim, dtype="int32")
+        ones = _op.full(_op.const(1), _op.reshape(dim, (1,)), dtype=dtype)
         k = _op.const(attr.get("k", 0), dtype="int32")
         return _op.scatter_nd(zeros, _op.stack([indices, indices + k], 
axis=0), ones, "update")
 
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index 90507ce29a..baf616a946 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -677,6 +677,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
 
 _reg.register_shape_func("scatter", False, elemwise_shape_func)
 _reg.register_shape_func("scatter_add", False, elemwise_shape_func)
+_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)
 
 
 @script
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index c58e920ead..12292a6fb7 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5089,28 +5089,45 @@ def test_cumsum(target, dev):
 
 @tvm.testing.parametrize_targets
 def test_eyelike(target, dev):
-    def verify_eyelike(indata):
+    def verify_eyelike(indata, dynamic=False):
+        node_list = []
+        eyelike_inputs = ["X"]
+        input_node_list = [
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, 
list(indata.shape))
+        ]
+        input_list = [indata]
+
+        if dynamic:
+            input_node_list.append(
+                helper.make_tensor_value_info("shape", TensorProto.INT64, 
[len(indata.shape)])
+            )
+            input_list.append(np.asarray(indata.shape))
+            reshape_node = helper.make_node("Reshape", ["X", "shape"], 
["X_dyn"])
+            eyelike_inputs[0] = "X_dyn"
+            node_list += [reshape_node]
+
         node = helper.make_node(
             "EyeLike",
-            inputs=["X"],
+            inputs=eyelike_inputs,
             outputs=["Y"],
         )
+        node_list.append(node)
 
         graph = helper.make_graph(
-            [node],
+            node_list,
             "eyelike_test",
-            inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, 
list(indata.shape))],
+            inputs=input_node_list,
             outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, 
list(indata.shape))],
         )
 
         model = helper.make_model(graph, producer_name="eyelike_test")
-
         verify_with_ort_with_inputs(
-            model, [indata], dtype="float32", opset=9, target=target, dev=dev
+            model, input_list, dtype="float32", opset=9, target=target, 
dev=dev, use_vm=True
         )
 
     input_data = np.zeros((5, 5), dtype=np.float32)
     verify_eyelike(input_data)
+    verify_eyelike(input_data, True)
 
 
 """

Reply via email to