Azyka opened a new issue, #16054:
URL: https://github.com/apache/tvm/issues/16054

   ### Expected behavior
   
   When model contains an obvious div-by-zero, tvm should recognize it the stop 
the execution. And normally 2 same model should output the same for the same 
inputs.
   
   ### Actual behavior
   
   The outputs of 2 same graphs turn out to be different and have strange large 
numbers after tvm_opt_4.
   ```
   =========================
   tvm_opt_4 triggers assertion
   
   Not equal to tolerance rtol=1e-07, atol=0
   
   Mismatched elements: 21 / 25 (84%)
   Max absolute difference: 4503603923386370
   Max relative difference: 1.
    x: array([[             1,              0,              0,              0,
                        0,              0,              0,              0,
                        0,              0,              0,              0,...
    y: array([[                1,        4294967296,                 0,
              94244568770664,      137438953472,    94244568246608,
              94244569409280,                20,    94244568246608,...
   =========================
   =========================
   tvm_opt_0 does not trigger assertion
   =========================
   ```
   
   ### Environment
   
   - OS: Ubuntu 22.04.3 LTS (x86_64)
   - TVM version: 0.14.dev189
   - Execution provider: cpu
   - ONNX opset version: 14
   
   ### Steps to reproduce
   Input data file:
   [input_data.zip](https://github.com/apache/tvm/files/13246324/input_data.zip)
   
   Sample code:
   ```
   import onnx
   import numpy as np
   import pickle
   from numpy import testing
   import tvm
   from tvm import relay
   import torch
   
   class Model0(torch.nn.Module):
       def __init__(self):
           super().__init__()
   
       def forward(self, *args):
           _args = args
           getitem = _args[0]
           tril = getitem.tril(0)
           div = torch.div(tril, tril)
           to = div.to(dtype = torch.int64)
           getitem_1 = to[(slice(-13, -12, 1), slice(None, None, None))]
           return (getitem_1,)
   
   model_0 = Model0()
   output_names_0 = ['v4_0']
   input_dict_0 = pickle.load(open('./0.pickle', 'rb'))
   inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in 
input_dict_0.items())
   torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, 
input_names=['v5_0'], output_names=output_names_0, opset_version=14, 
do_constant_folding=False)
   
   class Model1(torch.nn.Module):
       def __init__(self):
           super().__init__()
   
       def forward(self, *args):
           _args = args
           getitem = _args[0]
           tril = getitem.tril(0)
           div = torch.div(tril, tril)
           to = div.to(dtype = torch.int64)
           getitem_1 = to[(slice(-13, -12, 1), slice(None, None, None))]
           return (getitem_1,)
   
   model_1 = Model1()
   output_names_1 = ['v5_0']
   input_dict_1 = pickle.load(open('./1.pickle', 'rb'))
   inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in 
input_dict_1.items())
   torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, 
input_names=['v0_0'], output_names=output_names_1, opset_version=14, 
do_constant_folding=False)
   
   onnx_model_0 = onnx.load('0.onnx')
   onnx_model_outputs_0 = [node.name for node in onnx_model_0.graph.output]
   shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()}
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       executor_res_0 = [executor_0(**input_dict_0).numpy()]
       output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))
   
   onnx_model_1 = onnx.load('1.onnx')
   onnx_model_outputs_1 = [node.name for node in onnx_model_1.graph.output]
   shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()}
   mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_1 = relay.build_module.create_executor("graph", mod_1, 
tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
       executor_res_1 = [executor_1(**input_dict_1).numpy()]
       output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
   output_name_dict = {'v4_0': 'v5_0'}
   
   print('=========================')
   try:
       for tensor_name_0, tensor_name_1 in output_name_dict.items():
           testing.assert_allclose(output_0[tensor_name_0], 
output_1[tensor_name_1])
       print("tvm_opt_4 does not trigger assertion")
   except AssertionError as e:
       print("tvm_opt_4 triggers assertion")
       print(e)
   print('=========================')
   
   shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()}
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=0):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       executor_res_0 = [executor_0(**input_dict_0).numpy()]
       output_0 = dict(zip(onnx_model_outputs_0, executor_res_0))
   
   shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()}
   mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=0):
       executor_1 = relay.build_module.create_executor("graph", mod_1, 
tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate()
       executor_res_1 = [executor_1(**input_dict_1).numpy()]
       output_1 = dict(zip(onnx_model_outputs_1, executor_res_1))
   
   print('=========================')
   try:
       for tensor_name_0, tensor_name_1 in output_name_dict.items():
           testing.assert_allclose(output_0[tensor_name_0], 
output_1[tensor_name_1])
       print("tvm_opt_0 does not trigger assertion")
   except AssertionError as e:
       print("tvm_opt_0 triggers assertion")
       print(e)
   print('=========================')
   ```
   
   1. Download the data file and put data and code file in same directory.
   2. Execute the code.
   
   ### Triage
   
   * needs-triage
   * frontend:onnx
   * flow:relay
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to