ntcmp2u opened a new issue, #15907:
URL: https://github.com/apache/tvm/issues/15907

   I use TVM to optimize an ONNX model but the output of the optmized model is 
inconsistent. 
   
   The poc code as the follow:
   
   ```python
   import tvm
   import onnx
   from tvm import relay
   import numpy as np
   from numpy import testing
   import pickle
   
   model_input = 
np.array([[6.71610641,3.85466981,6.40658808,5.93880939,4.22816277,4.05935574,4.05884647,6.85629654,4.35045576,4.98825026,4.79773617,5.90505457,4.07920122,3.91599631,5.48346233,4.68261051,5.96164513,4.04757214,6.40736198,5.78212643,3.97077084,5.38297272,5.98323917,3.33484483,5.0777626],[5.78360128,3.50137734,6.22678471,6.58975887,5.86054516,6.28650713,4.02936316,6.55729437,4.83005142,5.48753119,6.6756916,3.1908493,5.902318,3.63500333,6.6633153,4.03696537,4.59259987,5.29771805,5.47530413,6.98861551,5.21775055,6.27956057,6.06063652,6.99890566,6.74760342],[5.15759087,5.79957485,3.87717772,3.98443675,4.08592176,3.68012309,3.56660628,6.5352087,6.32946777,5.21131229,4.83698702,3.24459314,5.40390682,3.10241055,6.76744366,4.29461193,5.75841236,5.95449638,3.74979687,4.88354635,6.29391479,5.12002468,4.30822229,5.12105751,3.2507937],[3.6196816,4.79808092,4.01847553,3.04140043,6.1275444,6.30466843,3.18326187,4.40191364,6.87912273,5.95551109,6.82588673,5.95711136,5.95687199,6.3148
 
0217,3.16748881,4.01974392,4.79580688,4.65139341,4.1572628,6.83029556,6.77638483,4.56482315,6.59405708,6.15137196,3.67091894],[6.05837917,6.34490871,6.82791471,6.16849041,3.10806918,5.42449474,5.17070627,3.94981074,4.28552246,4.36995983,6.1662178,4.59585524,3.32610178,6.27614117,4.7198143,4.82491684,5.11868238,5.589221,5.05218983,4.47221899,6.56419563,3.33831954,4.76141453,5.47758198,6.76508904],[5.27446032,4.80305195,6.86809921,5.71008205,4.43950748,4.55266809,6.0204134,4.15397835,5.40533447,4.21241283,5.2563076,5.31718588,5.03535795,5.61827183,5.727705,5.56590462,3.99580646,3.07781434,6.91844368,4.20229912,6.46835232,6.02725315,6.19255781,5.68293953,3.91823912],[5.54070759,5.55740356,4.43334723,5.79145813,4.09821415,4.99966526,6.07888222,5.35138416,5.53214931,5.38911247,6.39173222,3.72211218,4.38897705,6.03391933,5.74040985,5.49542809,6.62584782,4.33855057,6.8596487,6.41575336,6.35653496,3.90258241,6.14905071,4.76633644,5.97037888],[6.64982224,6.40295553,4.78704929,3.03489995,5.26
 
092148,3.17754674,3.61741209,5.90558434,3.44067335,4.10587311,5.2109766,4.80613804,6.83341599,5.98525429,6.70596409,4.64003181,4.31005573,3.14922762,6.44953346,6.32083702,4.5173068,4.73698139,5.2820859,5.57041645,3.72178435],[6.80869818,3.77339244,3.81124735,5.25667477,6.97375488,5.42400169,3.45591283,4.84353495,3.8287816,5.41924429,3.18944407,4.95114517,3.18213964,3.20189476,5.98147392,4.49553967,6.49250174,5.02359581,5.37200642,4.57207108,5.9689579,4.29011774,6.34126377,4.39409637,5.17850113],[4.48860836,3.66642022,4.86899662,3.43860793,6.94423056,4.19769287,3.77650118,4.41444111,6.10652637,6.75893021,6.0855093,5.83897305,3.97421741,4.10853386,3.62891579,5.17287302,6.16482878,5.45052338,5.91892958,3.24301076,6.47304344,5.24120235,6.28845692,3.56531596,5.53275776],[3.08167267,5.02443886,4.67914724,4.21279573,6.79280615,6.10375261,6.02019453,6.13239574,6.71213055,4.50908089,4.03038025,6.02951813,5.82834053,6.15195274,3.17172861,5.51577044,3.96025443,5.51473904,6.00128365,6.34118938,
 
4.94945621,3.03591824,5.76955414,6.09682941,5.2738781],[4.9579711,4.0002923,3.25455737,3.34851503,6.37354565,6.01042414,4.52113342,5.26745939,6.26920605,6.16422987,3.1255343,6.43074322,3.16986299,4.44038916,4.80228901,3.59629822,4.95137262,6.40821123,4.88515186,4.09138393,5.46744061,5.57049751,4.30876446,6.48064327,4.49197578],[5.54204559,3.35072994,5.81306267,5.0076189,6.09887314,4.83277607,3.76282811,5.04093981,6.85498047,6.33627892,4.55758762,6.29734421,3.90469551,6.35388994,4.73173046,5.31449032,3.06976533,3.54300165,3.2381289,3.11252737,3.15451765,6.34363556,6.97934818,6.39008713,6.97971106]])
   
   onnx_model_0 = onnx.load('./model_0.onnx')
   input_dict_0 = {"v5_0": model_input}
   shape_dict_0 = {key: val.shape for key, val in input_dict_0.items()}
   mod_0, params_0 = relay.frontend.from_onnx(onnx_model_0, shape_dict_0, 
freeze_params=True)
   with tvm.transform.PassContext(opt_level=4):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       output_0 = executor_0(**input_dict_0)
   
   with tvm.transform.PassContext(opt_level=4):
       executor_0 = relay.build_module.create_executor("graph", mod_0, 
tvm.cpu(), tvm.target.Target("llvm"), params_0).evaluate()
       output_1 = executor_0(**input_dict_0)
   
   print("============")
   try:
       testing.assert_allclose(output_0.numpy(), output_1.numpy(), 
equal_nan=True)
   except BaseException as e:
       print("tvm triggers the assertion failure")
       print(e)
   print("============")
   
   import onnxruntime as ort
   
   sess_options = ort.SessionOptions()
   sess_options.graph_optimization_level = 
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
   sess_0 = ort.InferenceSession('./model_0.onnx', 
providers=['CPUExecutionProvider'], sess_options=sess_options)
   res_0 = sess_0.run(["v4_0"], input_dict_0)
   
   session_options = ort.SessionOptions()
   sess_options.graph_optimization_level = 
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
   sess_1 = ort.InferenceSession('./model_0.onnx', 
providers=['CPUExecutionProvider'], sess_options=sess_options)
   res_1 = sess_1.run(["v4_0"], input_dict_0)
   
   
   print("============")
   testing.assert_allclose(res_0, res_1, equal_nan=True)
   print("onnxruntime did not trigger the assertion failure")
   print("============")
   ```
   
   The model file and the test code are attached at 
[report_bug.zip](https://github.com/apache/tvm/files/12856618/report_bug.zip).
   
   ### Expected behavior
   
   Obviously, ``output_0`` and ``output_1`` should be identical because they 
are the ouputs of the same model and the same inputs.
   
   ### Actual behavior
   
    ``output_0`` and ``output_1`` are different and they appear to be unstable 
values during my several trials. In contrast, onnxruntime performs consistently 
each run. The stdout of my several trials are as follows:
   
   ```
   ============
   tvm triggers the assertion failure
   
   Not equal to tolerance rtol=1e-07, atol=0
   
   Mismatched elements: 21 / 25 (84%)
   Max absolute difference: 140577680982095
   Max relative difference: 1.40577681e+14
    x: array([[              1,  93883951470360,    137438953472,
                        49,  93883950590848,  93883950359888,
                         0,  93883951470408,     17179869184,...
    y: array([[              1,  93883951603280, 140577771430160,
                         2,  93883947480464,               1,
            93883690123272,  93883950380544,               1,...
   ============
   ============
   onnxruntime did not trigger the assertion failure
   ============
   ```
   
   ```
   ============
   tvm triggers the assertion failure
   
   Not equal to tolerance rtol=1e-07, atol=0
   
   Mismatched elements: 19 / 25 (76%)
   Max absolute difference: 8313387590019572390
   Max relative difference: 4.2949673e+09
    x: array([[             1, 93921710608472,   137438953472,          73728,
                       96,             64, 93933299986265, 93921710608520,
              17179869184,          81920, 93921710368576, 93921710761760,...
    y: array([[                  1,      93921710281360,                  32,
                            48,      93933293032941,                  12,
                93922344828927, 8313481511730180910,                  80,...
   ============
   ============
   onnxruntime did not trigger the assertion failure
   ============
   ```
   
   ```
   ============
   tvm triggers the assertion failure
   
   Not equal to tolerance rtol=1e-07, atol=0
   
   Mismatched elements: 23 / 25 (92%)
   Max absolute difference: 15022847525276686
   Max relative difference: 4.23258643e+12
    x: array([[              1,  94888918231384,    137438953472,
                       240, 139675352272960,    236223201563,
                        55,  94888918231432,     17179869184,...
    y: array([[                1,    94888917322064,   139675377593892,
              94888918299624,       17179869184,    94888917338808,
              94888712470529,               145,    94888918934400,...
   ============
   ============
   onnxruntime did not trigger the assertion failure
   ============
   ```
   
   It worths to note that when setting ``opt_level`` to ``0``, the tvm perform 
consistently everytime. So I assume that this is introduced by the tvm 
optimization.
   
   ### Environment
   
   OS: Ubuntu 22.04 LTS (Linux jin-pc 6.2.0-26-generic #26~22.04.1-Ubuntu SMP 
PREEMPT_DYNAMIC Thu Jul 13 16:27:29 UTC 2 x86_64 x86_64 x86_64 GNU/Linux)
   TVM: 0.14.dev226
   
   ### Steps to reproduce
   
   unzip the attached zip file and use ``python3 test.py`` to execute the PoC.
   
   ### Triage
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to