shaoyuyoung opened a new issue, #17081:
URL: https://github.com/apache/tvm/issues/17081

   #### Description
   Here is a single op: `Cast`
   
![image](https://github.com/apache/tvm/assets/100203773/ff67783a-4158-4545-946c-a77da40eb245)
   
   In TVM, when it accepts **NaN** value, it outputs **False**.
   
   However, in `PyTorch`, it outputs **True**.
   
   In Pytorch and ONNX, `Cast` would cast the `Nonzero value` to **False**, the 
others to **True**.
   The evidence is here: 
https://onnx.ai/onnx/operators/onnx__Cast.html#l-onnx-doc-cast
   
![image](https://github.com/apache/tvm/assets/100203773/cc7a71a6-c707-48b6-a3e8-8f86ac2c06de)
   
   I am unsure how the `Cast` op is defined in TVM. But if it is different from 
other frameworks/compilers (e.g., Pytorch & ONNX), the final results would be 
inconsistent with other frameworks/compilers in complex scenarios (i.e., a 
model containing more ops).
   
   
   #### Code to repro
   ```python
   import pickle
   import torch
   import torch.nn as nn
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   import numpy as np
   import onnx
   import numpy.testing as npt
   
   
   class Model(nn.Module):
       def __init__(self):
           super(Model, self).__init__()
   
       def forward(self, input_tensor):
           cast_output = input_tensor.to(torch.bool)
   
           return cast_output
   
   
   model = Model()
   input_tensor = torch.tensor([float('nan')])
   
   torch_output = model(input_tensor).numpy()
   
   torch.onnx.export(
       model,
       input_tensor,
       "test.onnx",
       input_names=["input"],
       output_names=["output"],
       opset_version=14,
       do_constant_folding=True,
   )
   onnx_model = onnx.load("test.onnx")
   
   target = "llvm"
   
   shape_dict = {"input": input_tensor.shape}
   
   mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
   
   dev = tvm.cpu(0)
   with tvm.transform.PassContext(opt_level=4):
       executor = relay.build_module.create_executor(
           "graph", mod, dev, target, params
       ).evaluate()
   
   inputs = {"input": tvm.nd.array(input_tensor.numpy())}
   
   tvm_output = executor(**inputs).numpy()
   
   npt.assert_allclose(torch_output, tvm_output, rtol=1e-5, atol=1e-8)
   ```
   
   #### Error log
   ```
   AssertionError: 
   Not equal to tolerance rtol=1e-05, atol=1e-08
   
   Mismatched elements: 1 / 1 (100%)
    x: array([ True])
    y: array([False])
   ```
   
   #### Environment & Version
   ubuntu 20
   TVM d1ac1c0202b3d8cb2af268ce79c2ac710554152b
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to