ChijinZ commented on issue #16053:
URL: https://github.com/apache/tvm/issues/16053#issuecomment-1798412201

   This looks like a very interesting bug. I simplify the code to make it 
runable without external input files.
   
   ```python
   import onnx
   import numpy as np
   import pickle
   from numpy import testing
   import tvm
   from tvm import relay
   import torch
   
   p0 = torch.nn.Parameter(torch.empty([18, 1, 1, 1, 60], dtype=torch.int64), 
requires_grad=False)
   
   input_array = torch.round(torch.rand(18, 1, 1, 50, 1, 
dtype=torch.float64)).to(torch.int64)
   shape_dict = {"x": input_array.shape}
   
   class Model0(torch.nn.Module):
       def __init__(self):
           super().__init__()
           self.const = p0
   
       def forward(self, x):
           y = self.const
           add = torch.add(x, y)
           neg = torch.neg(add)
           argmin = neg.argmin(4)
           return (argmin)
   
   model_0 = Model0()
   torch.onnx.export(model_0, input_array, '0.onnx', verbose=False, 
input_names=['x'], output_names=['out'], opset_version=14, 
do_constant_folding=False)
   
   class Model1(torch.nn.Module):
       def __init__(self):
           super().__init__()
           self.const = p0
   
       def forward(self, x):
           y = self.const
           trans_0 = x.transpose(1, 0)
           trans_1 = y.transpose(1, 0)
           add = torch.add(trans_0, trans_1)
           add_trans = add.transpose(1, 0)
           neg = torch.neg(add_trans)
           argmin = neg.argmin(4)
           return (argmin)
   
   model_1 = Model1()
   torch.onnx.export(model_1, input_array, '1.onnx', verbose=False, 
input_names=['x'], output_names=['out'], opset_version=14, 
do_constant_folding=False)
   
   onnx_model_0 = onnx.load('0.onnx')
   onnx_model_outputs_0 = [node.name for node in onnx_model_0.graph.output]
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       executor_res_0 = [executor_0(input_array).numpy()]
       output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))
   
   onnx_model_1 = onnx.load('1.onnx')
   onnx_model_outputs_1 = [node.name for node in onnx_model_1.graph.output]
   shape_dict_1 = {"x": input_array.shape}
   mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_1 = relay.build_module.create_executor("graph", mod_1, 
tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
       executor_res_1 = [executor_1(input_array).numpy()]
       output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
   
   print('=========================')
   try:
       testing.assert_allclose(output_0["out"], output_1["out"])
       print("tvm_opt_4 does not trigger assertion")
   except AssertionError as e:
       print("tvm_opt_4 triggers assertion")
       print(e)
   print('=========================')
   
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=0):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       executor_res_0 = [executor_0(input_array).numpy()]
       output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))
   
   mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=0):
       executor_1 = relay.build_module.create_executor("graph", mod_1, 
tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
       executor_res_1 = [executor_1(input_array).numpy()]
       output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
   
   print('=========================')
   try:
       testing.assert_allclose(output_0["out"], output_1["out"])
       print("tvm_opt_0 does not trigger assertion")
   except AssertionError as e:
       print("tvm_opt_0 triggers assertion")
       print(e)
   print('=========================')
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to