Azyka opened a new issue, #16053:
URL: https://github.com/apache/tvm/issues/16053

   ### Expected behavior
   
   When adding multiple torch.Tensor.transpose nodes before and after a node, 
like transpose(transpose(x, 0, 1).argmin(4), 0, 1) = x, the outputs of 2 models 
are supposed to be the same.
   Model0:
   <img width="187" alt="image" 
src="https://github.com/apache/tvm/assets/74590664/8e8bab18-91a9-49b9-82f0-5402458be51e";>
   
   Model1:
   <img width="265" alt="image" 
src="https://github.com/apache/tvm/assets/74590664/bafbfcbd-392c-4c0d-abaa-3ee379c25396";>
   
   
   ### Actual behavior
   
   The outputs of the 2 graphs turn out to be different after tvm_opt_4.
   ```
   =========================
   tvm_opt_4 triggers assertion
   
   Not equal to tolerance rtol=1e-07, atol=0
   
   Mismatched elements: 900 / 900 (100%)
   Max absolute difference: 55
   Max relative difference: 55.
    x: array([[[[18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
             18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
             18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,...
    y: array([[[[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,...
   =========================
   =========================
   tvm_opt_0 does not trigger assertion
   =========================
   ```
   
   ### Environment
   
   - OS: Ubuntu 22.04.3 LTS (x86_64)
   - TVM version: 0.14.dev189
   - Execution provider: cpu
   - ONNX opset version: 14
   
   ### Steps to reproduce
   Input data file:
   [input_data.zip](https://github.com/apache/tvm/files/13246218/input_data.zip)
   
   Sample code:
   ```
   import onnx
   import numpy as np
   import pickle
   from numpy import testing
   import tvm
   from tvm import relay
   import torch
   
   p0 = torch.nn.Parameter(torch.empty([18, 1, 1, 1, 60], dtype=torch.int64), 
requires_grad=False)
   
   class Model0(torch.nn.Module):
       def __init__(self):
           super().__init__()
           self.v6_0 = p0
   
       def forward(self, *args):
           _args = args
           v6_0 = self.v6_0
           getitem = _args[0]
           add = torch.add(getitem, v6_0)
           neg = torch.neg(add)
           argmin = neg.argmin(4)
           return (argmin)
   
   model_0 = Model0()
   input_dict_0 = pickle.load(open('./0.pickle', 'rb'))
   inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in 
input_dict_0.items())
   torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, 
input_names=['v5_0'], output_names=['v4_0'], opset_version=14, 
do_constant_folding=False)
   
   class Model1(torch.nn.Module):
       def __init__(self):
           super().__init__()
           self.v2_0 = p0
   
       def forward(self, *args):
           _args = args
           v2_0 = self.v2_0
           getitem = _args[0];  _args = None
           transpose = getitem.transpose(1, 0)
           transpose_1 = v2_0.transpose(1, 0)
           add = torch.add(transpose_1, transpose)
           transpose_2 = add.transpose(1, 0)
           neg = torch.neg(transpose_2)
           argmin = neg.argmin(4)
           return (argmin)
   
   model_1 = Model1()
   input_dict_1 = pickle.load(open('./1.pickle', 'rb'))
   inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in 
input_dict_1.items())
   torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, 
input_names=['v0_0'], output_names=['v10_0'], opset_version=14, 
do_constant_folding=False)
   
   onnx_model_0 = onnx.load('0.onnx')
   onnx_model_outputs_0 = [node.name for node in onnx_model_0.graph.output]
   shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()}
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       executor_res_0 = [executor_0(**input_dict_0).numpy()]
       output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))
   
   onnx_model_1 = onnx.load('1.onnx')
   onnx_model_outputs_1 = [node.name for node in onnx_model_1.graph.output]
   shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()}
   mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_1 = relay.build_module.create_executor("graph", mod_1, 
tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
       executor_res_1 = [executor_1(**input_dict_1).numpy()]
       output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
   output_name_dict = {'v4_0': 'v10_0'}
   
   print('=========================')
   try:
       for tensor_name_0, tensor_name_1 in output_name_dict.items():
           testing.assert_allclose(output_0[tensor_name_0], 
output_1[tensor_name_1])
       print("tvm_opt_4 does not trigger assertion")
   except AssertionError as e:
       print("tvm_opt_4 triggers assertion")
       print(e)
   print('=========================')
   
   shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()}
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=0):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       executor_res_0 = [executor_0(**input_dict_0).numpy()]
       output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))
   
   shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()}
   mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=0):
       executor_1 = relay.build_module.create_executor("graph", mod_1, 
tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
       executor_res_1 = [executor_1(**input_dict_1).numpy()]
       output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
   
   print('=========================')
   try:
       for tensor_name_0, tensor_name_1 in output_name_dict.items():
           testing.assert_allclose(output_0[tensor_name_0], 
output_1[tensor_name_1])
       print("tvm_opt_0 does not trigger assertion")
   except AssertionError as e:
       print("tvm_opt_0 triggers assertion")
       print(e)
   print('=========================')
   ```
   
   1. Download the data file and put data and code file in same directory.
   2. Execute the code.
   3. Notably, the mismatch can be triggered by argmax(taking place of argmin 
in code) as well.
   
   ### Triage
   
   * needs-triage
   * frontend:onnx
   * flow:relay
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to