Azyka opened a new issue, #16057: URL: https://github.com/apache/tvm/issues/16057
### Expected behavior When we have 2 duplicate nodes(which have the same inputs and outputs) of torch.mul, no matter following nodes use which output of torch.mul, the results should always be the same. Model0(duplicate torch.mul): <img width="248" alt="image" src="https://github.com/apache/tvm/assets/74590664/24e04cae-efc9-4ea9-bddd-6a47da88bb88"> Model1(non-duplicate): <img width="248" alt="image" src="https://github.com/apache/tvm/assets/74590664/ce65a109-e415-497a-b5da-a1997fc78eca"> ### Actual behavior When 2 duplicate nodes(which have the same inputs and outputs) of torch.mul are defined, torch.gt produces wrong results after tvm_opt_4. ``` ========================= tvm_opt_4 triggers assertion Not equal to tolerance rtol=1e-07, atol=0 Mismatched elements: 18838 / 38019 (49.5%) x: array([[[ True, True, False, ..., False, False, False], [False, True, False, ..., False, False, True], [ True, True, False, ..., True, True, True],... y: array([[[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False],... ========================= ========================= tvm_opt_0 does not trigger assertion ========================= ``` ### Environment - OS: Ubuntu 22.04.3 LTS (x86_64) - TVM version: 0.14.dev189 - Execution provider: cpu - ONNX opset version: 14 ### Steps to reproduce Input data file: [input_data.zip](https://github.com/apache/tvm/files/13246466/input_data.zip) Sample code: ``` import onnx import numpy as np import pickle from numpy import testing import tvm from tvm import relay import torch class Model0(torch.nn.Module): def __init__(self): super().__init__() def forward(self, *args): _args = args getitem = _args[0] getitem_1 = _args[1] mul = torch.mul(getitem, getitem_1) mul_1 = torch.mul(getitem_1, getitem) gt = torch.gt(mul, mul_1) return (gt, mul_1) model_0 = Model0() output_names_0 = ['v6_0', 'v1_0'] input_dict_0 = pickle.load(open('./0.pickle', 'rb')) inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_0.items()) torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v3_0', 'v2_0'], output_names=output_names_0, opset_version=14, do_constant_folding=False) class Model1(torch.nn.Module): def __init__(self): super().__init__() def forward(self, *args): _args = args getitem = _args[0] getitem_1 = _args[1] mul = torch.mul(getitem_1, getitem) gt = torch.gt(mul, mul) return (gt, mul) model_1 = Model1() output_names_1 = ['v4_0', 'v9_0'] input_dict_1 = pickle.load(open('./1.pickle', 'rb')) inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_1.items()) torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0', 'v2_0'], output_names=output_names_1, opset_version=14, do_constant_folding=False) onnx_model_0 = onnx.load('0.onnx') onnx_model_outputs_0 = [node.name for node in onnx_model_0.graph.output] shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()} mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, freeze_params=True) with tvm.transform.PassContext(opt_level=4): executor_0 = relay.build_module.create_executor("graph", mod_0, tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate() executor_res_0 = [tensor.numpy() for tensor in executor_0(**input_dict_0)] output_0 = dict(zip(onnx_model_outputs_0, executor_res_0)) onnx_model_1 = onnx.load('1.onnx') onnx_model_outputs_1 = [node.name for node in onnx_model_1.graph.output] shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()} mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, freeze_params=True) with tvm.transform.PassContext(opt_level=4): executor_1 = relay.build_module.create_executor("graph", mod_1, tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate() executor_res_1 = [tensor.numpy() for tensor in executor_1(**input_dict_1)] output_1 = dict(zip(onnx_model_outputs_1, executor_res_1)) output_name_dict = {'v6_0': 'v4_0'} print('=========================') try: for tensor_name_0, tensor_name_1 in output_name_dict.items(): testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1]) print("tvm_opt_4 does not trigger assertion") except AssertionError as e: print("tvm_opt_4 triggers assertion") print(e) print('=========================') shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()} mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, freeze_params=True) with tvm.transform.PassContext(opt_level=0): executor_0 = relay.build_module.create_executor("graph", mod_0, tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate() executor_res_0 = [tensor.numpy() for tensor in executor_0(**input_dict_0)] output_0 = dict(zip(onnx_model_outputs_0, executor_res_0)) shape_dict_1 = {key: val.shape for key, val in input_dict_1.items()} mod_1, params_1 = relay.frontend.from_onnx(onnx_model_1, shape_dict_1, freeze_params=True) with tvm.transform.PassContext(opt_level=0): executor_1 = relay.build_module.create_executor("graph", mod_1, tvm.cpu(), tvm.target.Target("llvm"), params_1).evaluate() executor_res_1 = [tensor.numpy() for tensor in executor_1(**input_dict_1)] output_1 = dict(zip(onnx_model_outputs_1, executor_res_1)) print('=========================') try: for tensor_name_0, tensor_name_1 in output_name_dict.items(): testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1]) print("tvm_opt_0 does not trigger assertion") except AssertionError as e: print("tvm_opt_0 triggers assertion") print(e) print('=========================') ``` 1. Download the data file and put data and code file in same directory. 2. Execute the code. ### Triage * needs-triage * frontend:onnx * flow:relay -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
