cchung100m commented on a change in pull request #5073: [Relay][Frontend][ONNX] 
operator support NonZero
URL: https://github.com/apache/incubator-tvm/pull/5073#discussion_r394366255
 
 

 ##########
 File path: tests/python/frontend/onnx/test_forward.py
 ##########
 @@ -30,22 +30,54 @@
 import scipy
 
 
-def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, 
output_dtype='float32', opset=None):
-    """ Generic function to execute and get tvm output"""
-    target = 'llvm'
+def get_input_data_shape_dict(graph_def, input_data):
     if isinstance(input_data, list):
         input_names = {}
         shape_dict = {}
-        dtype_dict = {}
         for i, _ in enumerate(input_data):
             input_names[i] = graph_def.graph.input[i].name
             shape_dict[input_names[i]] = input_data[i].shape
-            dtype_dict[input_names[i]] = input_data[i].dtype
     else:
         input_names = graph_def.graph.input[0].name
         shape_dict = {input_names: input_data.shape}
+
+    return input_names, shape_dict
+
+
+def get_input_data_dtype_dict(graph_def, input_data):
+    if isinstance(input_data, list):
+        input_names = {}
+        dtype_dict = {}
+        for i, _ in enumerate(input_data):
+            input_names[i] = graph_def.graph.input[i].name
+            dtype_dict[input_names[i]] = input_data[i].dtype
+    else:
+        input_names = graph_def.graph.input[0].name
         dtype_dict = {input_names: input_data.dtype}
 
+    return input_names, dtype_dict
+
+
+def get_tvm_output_with_vm(graph_def, input_data, target, ctx,
+                           opset=None):
+    """ Generic function to execute and get tvm output with vm executor"""
+
+    _, shape_dict = get_input_data_shape_dict(graph_def, input_data)
+
+    mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
+
+    ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target)
+    indata = tvm.nd.array(input_data)
+    result = ex.evaluate()(indata)
+    return result.asnumpy()
+
+
+def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, 
output_dtype='float32', opset=None):
 
 Review comment:
   Hi @kazum 
   I'd like to keep the `relay.create_executor` and `relay.build` both in this 
PR. 
   
   I cannot change the `relay.build` with `relay.create_executor` directly due 
to there are many errors like below:
   
   ```
     File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2282, in 
<module>
       test_flatten()
   
     File "/tvm/tests/python/frontend/onnx/test_forward.py", line 374, in 
test_flatten
       tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
   
     File "/tvm/tests/python/frontend/onnx/test_forward.py", line 70, in 
get_tvm_output
       result = ex.evaluate()(indata)
   
     File "/tvm/python/tvm/relay/backend/vm.py", line 256, in _vm_wrapper
       return self.vm.run(*args)
   
     File "/tvm/python/tvm/runtime/vm.py", line 366, in run
       return self.invoke("main", *args, **kwargs)
   
     File "/tvm/python/tvm/runtime/vm.py", line 348, in invoke
       return self._invoke(func_name)
   
     File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
       raise get_last_ffi_error()
   
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (7) 8   ???                                 0x00007fff54230930 0x0 + 
140734604970288
     [bt] (6) 7   _ctypes.cpython-37m-darwin.so       0x00000001104dc2bf 
ffi_call_unix64 + 79
     [bt] (5) 6   libtvm.dylib                        0x0000000125071f78 
TVMFuncCall + 72
     [bt] (4) 5   libtvm.dylib                        0x00000001250a4e3f 
std::__1::__function::__func<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char,
 std::__1::char_traits<char>, std::__1::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, 
std::__1::allocator<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char,
 std::__1::char_traits<char>, std::__1::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void 
(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, 
tvm::runtime::TVMRetValue*&&) + 735
     [bt] (3) 4   libtvm.dylib                        0x00000001250a199a 
tvm::runtime::vm::VirtualMachine::RunLoop() + 7610
     [bt] (2) 3   libtvm.dylib                        0x00000001250a310f 
tvm::runtime::vm::VirtualMachine::InvokePacked(long long, 
tvm::runtime::PackedFunc const&, long long, long long, 
std::__1::vector<tvm::runtime::ObjectRef, 
std::__1::allocator<tvm::runtime::ObjectRef> > const&) + 1039
     [bt] (1) 2   libtvm.dylib                        0x000000012507b396 
std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, 
int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> 
const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int 
(*)(TVMValue*, int*, int, TVMValue*, int*), 
tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void 
(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, 
tvm::runtime::TVMRetValue*&&) + 310
     [bt] (0) 1   libtvm.dylib                        0x0000000124666af9 
dmlc::LogMessageFatal::~LogMessageFatal() + 57
     File "/tvm/src/runtime/library_module.cc", line 89
   TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: 
(((tvm_struct_get(arg0, 0, 5) == (uint8)2) && (tvm_struct_get(arg0, 0, 6) == 
(uint8)32)) && (tvm_struct_get(arg0, 0, 7) == (uint16)1)), arg0.dtype is 
expected to be float32
   
   ```
   ```
   File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2128, in 
verify_lstm
       output_dtype=['float32', 'float32', 'float32'])
   
     File "/tvm/tests/python/frontend/onnx/test_forward.py", line 69, in 
get_tvm_output
       indata = tvm.nd.array(input_data)
   
     File "/tvm/python/tvm/runtime/ndarray.py", line 487, in array
       return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
   
     File "/tvm/python/tvm/runtime/ndarray.py", line 270, in empty
       dtype = DataType(dtype)
   
     File "/tvm/python/tvm/_ffi/runtime_ctypes.py", line 101, in __init__
       raise ValueError("Do not know how to handle type %s" % type_str)
   
   ValueError: Do not know how to handle type object
   ```
   
   Maybe we can initiate another PR to the above issues and change the 
`relay.build` with `relay.create_executor`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to