Mohamad11Dab opened a new issue, #17987:
URL: https://github.com/apache/tvm/issues/17987

   Thanks for participating in the TVM community! We use https://discuss.tvm.ai 
for any general usage questions and discussions. The issue tracker is used for 
actionable items such as feature proposals discussion, roadmaps, and bug 
tracking.  You are always welcomed to post on the forum first :smile_cat:
   
   Issues that are inactive for a period of time may get closed. We adopt this 
policy so that we won't lose track of actionable issues that may fall at the 
bottom of the pile. Feel free to reopen a new one if you feel there is an 
additional problem that needs attention when an old one gets closed.
   
   ### Expected behavior
   
   Consistent Result between TVM and ONNXRuntime
   
   ### Actual behavior
   
   `––––– MISMATCH DETECTED –––––
   
   Not equal to tolerance rtol=0.01, atol=0.001
   
   Mismatched elements: 65536 / 65536 (100%)
   Max absolute difference: 0.00585903
   Max relative difference: 0.8311606
    x: array([[[[0.00119, 0.00119, 0.00119, ..., 0.00119, 0.00119, 0.00119],
            [0.00119, 0.00119, 0.00119, ..., 0.00119, 0.00119, 0.00119],
            [0.00119, 0.00119, 0.00119, ..., 0.00119, 0.00119, 0.00119],...
    y: array([[[[0.007049, 0.007049, 0.007049, ..., 0.007049, 0.007049,
             0.007049],
            [0.007049, 0.007049, 0.007049, ..., 0.007049, 0.007049,...` 
   
   ### Environment
   
   TVM: 0.17.0
   ONNXRuntime: 1.16.3
   
   ### Steps to reproduce
   
   ```python
   import torch
   import torch.nn as nn
   import onnx
   import onnxruntime as ort
   import numpy as np
   import tempfile
   import torch.nn.functional as F
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   from numpy.testing import assert_allclose
   
   
   class SimpleBugModel(nn.Module):
       def __init__(self):
           super().__init__()
   
           self.input_conv = nn.Conv2d(in_channels=3, out_channels=16, 
kernel_size=1)
           self.block5 = nn.InstanceNorm2d(num_features=16)
   
   
       def forward(self, x):
           x = self.input_conv(x)
           x=F.softplus(x)
           x = torch.tanh(x)
           x = torch.ceil(x)
           x = F.gelu(x)
           x = F.interpolate(x, scale_factor=2.0, mode='nearest')
           x = self.block5(x)
           return x
   
   
   def main():
       model = SimpleBugModel()
       model.eval()
       dummy = torch.randn(1, 3, 32, 32, dtype=torch.float32)
       with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp:
           onnx_path = tmp.name
       torch.onnx.export(
           model, dummy, onnx_path,
           opset_version=19,
           input_names=["input"],
           output_names=["output"],
       )
   
       ## run ONNX-Runtime
       ort_sess = ort.InferenceSession(onnx_path, 
providers=["CPUExecutionProvider"])
       ort_out = ort_sess.run(None, {"input": dummy.numpy()})[0]
   
       ### compile & run TVM
       onnx_model = onnx.load(onnx_path)
       shape_dict = {"input": dummy.numpy().shape}
       mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, 
freeze_params=True)
       with tvm.transform.PassContext(opt_level=4):
           lib = relay.build(mod, target="llvm", params=params)
       m = graph_executor.GraphModule(lib["default"](tvm.cpu()))
       m.set_input("input", tvm.nd.array(dummy.numpy()))
       m.run()
       tvm_out = m.get_output(0)
       tvm_out = tvm_out.numpy()
   
       try:
         assert_allclose(ort_out, tvm_out, rtol=1e-2, atol=1e-3, equal_nan=True)
       except AssertionError as e:
         print("––––– MISMATCH DETECTED –––––")
         print(e)                             # just the assertion message
       except Exception as e:
         print("––––– UNEXPECTED ERROR DURING COMPARISON –––––")
         print(f"{type(e).__name__}: {e}")
   
   
   if __name__ == "__main__":
       main()
   
   ``` 
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to