This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 90b946c  [ONNX][Converter] Add dynamic nodes support (#9380)
90b946c is described below

commit 90b946cae367b2d0a0bef6c15fee14224459c6f8
Author: Colin Y. Li <[email protected]>
AuthorDate: Wed Nov 3 11:13:41 2021 +0800

    [ONNX][Converter] Add dynamic nodes support (#9380)
    
    * [ONNX][Converter] Add support for dynamic nodes
    
    * fix lint
---
 python/tvm/contrib/target/onnx.py | 10 +++++++---
 python/tvm/relay/backend/vm.py    | 12 ++++++++----
 tests/python/contrib/test_onnx.py | 34 +++++++++++++++++++++++++++-------
 3 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/python/tvm/contrib/target/onnx.py 
b/python/tvm/contrib/target/onnx.py
index c26255f..2b142dc 100644
--- a/python/tvm/contrib/target/onnx.py
+++ b/python/tvm/contrib/target/onnx.py
@@ -70,6 +70,10 @@ def get_onnx_version():
     return onnx.__version__
 
 
+def get_node_shape(node):
+    return tuple("Any" if isinstance(i, tvm.tir.Any) else int(i) for i in 
node.shape)
+
+
 def infer_type(node):
     """A method to infer the type of a relay expression."""
     mod = tvm.IRModule.from_expr(node)
@@ -521,7 +525,7 @@ class Split(OpConverter):
         input_node = node_dict[node_entry["inputs"][0]]
         assert len(input_node) == 1, "input node can not be a Tuple"
         input_node = input_node[0]
-        shape = input_node["types"][0].concrete_shape
+        shape = get_node_shape(input_node["types"][0])
 
         indices_or_sect = attrs["indices_or_section"]
         axis = attrs["axis"]
@@ -1019,7 +1023,7 @@ class RelayToONNXConverter(ExprVisitor):
             node_type = node_entry["types"][0]
             dtype = 
onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)]
             input = onnx.helper.make_tensor_value_info(
-                node_entry["name"], dtype, shape=node_type.concrete_shape
+                node_entry["name"], dtype, shape=get_node_shape(node_type)
             )
             self._mc.add_inputs([input])
 
@@ -1030,7 +1034,7 @@ class RelayToONNXConverter(ExprVisitor):
             for node_type, output_name in zip(node_entry["types"], 
node_entry["output_names"]):
                 dtype = 
onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)]
                 output = onnx.helper.make_tensor_value_info(
-                    output_name, dtype, shape=node_type.concrete_shape
+                    output_name, dtype, shape=get_node_shape(node_type)
                 )
                 self._mc.add_outputs([output])
 
diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py
index 363ff89..1dde27f 100644
--- a/python/tvm/relay/backend/vm.py
+++ b/python/tvm/relay/backend/vm.py
@@ -275,14 +275,18 @@ class VMExecutor(Executor):
         self.mod = mod
         self.device = device
         self.target = target
-        self.executable = compile(mod, target)
-        self.vm = vm_rt.VirtualMachine(self.executable, device)
+        self.executable = None
+        self.vm = None
 
     def _make_executor(self, expr=None):
-        main = self.mod["main"]
+        if expr:
+            self.mod["main"] = expr
+
+        self.executable = compile(self.mod, self.target)
+        self.vm = vm_rt.VirtualMachine(self.executable, self.device)
 
         def _vm_wrapper(*args, **kwargs):
-            args = self._convert_args(main, args, kwargs)
+            args = self._convert_args(self.mod["main"], args, kwargs)
             return self.vm.run(*args)
 
         return _vm_wrapper
diff --git a/tests/python/contrib/test_onnx.py 
b/tests/python/contrib/test_onnx.py
index 121edc4..6f23228 100644
--- a/tests/python/contrib/test_onnx.py
+++ b/tests/python/contrib/test_onnx.py
@@ -47,12 +47,11 @@ def run_onnx(onnx_model, input_data):
     return res
 
 
-def run_relay(func, data_tuple):
+def run_relay(func, data_tuple, is_dyn=False):
     target = "llvm"
     dev = tvm.device("llvm", 0)
-    relay_res = relay.create_executor("graph", device=dev, 
target=target).evaluate(func)(
-        *data_tuple
-    )
+    kind = "graph" if not is_dyn else "vm"
+    relay_res = relay.create_executor(kind, device=dev, 
target=target).evaluate(func)(*data_tuple)
 
     result = []
     relay_res = relay_res if isinstance(relay_res, list) else [relay_res]
@@ -62,8 +61,8 @@ def run_relay(func, data_tuple):
     return result
 
 
-def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0):
-    relay_results = run_relay(relay_func, indata)
+def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0, 
is_dyn=False):
+    relay_results = run_relay(relay_func, indata, is_dyn)
     onnx_results = run_onnx(func_to_onnx(relay_func, test_name), indata)
 
     for relay_res, onnx_res in zip(relay_results, onnx_results):
@@ -111,7 +110,7 @@ def test_conv2d():
         func = relay.Function([x, w], y)
         data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
         kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
-        verify_results(func, [data, kernel], "test_conv2d", rtol=1e-5, 
atol=1e-5)
+        verify_results(func, [data, kernel], "test_conv2d", rtol=1e-5, 
atol=1e-5, is_dyn=True)
 
     dshape = (1, 32, 18, 18)
     kshape = (32, 1, 3, 3)
@@ -700,6 +699,26 @@ def test_resize():
                 verify_resize(isize, osize, method=i, coord_trans=j, 
rounding_method=k)
 
 
+def test_dyn():
+    """Dynamic unit test."""
+
+    def verify_dyn_bcast(lhs_shape, rhs_shape, dtype):
+        lhs_dyn_shape = tuple(relay.Any() for i in range(len(lhs_shape)))
+        rhs_dyn_shape = tuple(relay.Any() for i in range(len(rhs_shape)))
+        x = relay.var("x", shape=lhs_dyn_shape, dtype=dtype)
+        y = relay.var("y", shape=rhs_dyn_shape, dtype=dtype)
+        z = relay.add(x, y)
+        func = relay.Function([x, y], z)
+        lhs_data = np.random.uniform(size=lhs_shape).astype(dtype)
+        rhs_data = np.random.uniform(size=rhs_shape).astype(dtype)
+        verify_results(
+            func, [lhs_data, rhs_data], "test_dyn_bcast", rtol=1e-5, 
atol=1e-5, is_dyn=True
+        )
+
+    verify_dyn_bcast((1, 3, 32, 1), (1, 3, 1, 3), "float32")
+    verify_dyn_bcast((1, 13), (4, 3, 5, 1), "float32")
+
+
 if __name__ == "__main__":
     test_add()
     test_bias_add()
@@ -730,3 +749,4 @@ if __name__ == "__main__":
     test_round()
     test_cast()
     test_resize()
+    test_dyn()

Reply via email to