slyubomirsky opened a new pull request, #11511:
URL: https://github.com/apache/tvm/pull/11511
I encountered a quirk today when using the PyTorch importer. My local
PyTorch version is 1.5.1 and it uses Cuda (10.1 on my system), so
`torch.__version__` returns `1.5.1+cu101`.
When I used the importer I encountered the following error:
```
File "/home/sslyu/tvm/python/tvm/relay/frontend/pytorch.py", line 4064, in
from_pytorch
_run_jit_passes(graph, enable_lower_all_tuples)
File "/home/sslyu/tvm/python/tvm/relay/frontend/pytorch.py", line 3637, in
_run_jit_passes
torch._C._jit_pass_onnx_function_substitution(graph)
AttributeError: module 'torch._C' has no attribute
'_jit_pass_onnx_function_substitution'
```
That is because PT version 1.5.1 doesn't have that function. In principle,
the importer should check for that and not use that function for versions that
don't have it, which is the logic here:
https://github.com/apache/tvm/blob/119afda6344785aee5cf1729eec30624ac068f33/python/tvm/relay/frontend/pytorch.py#L3628-L3642
However, the "+cu101" in the version number causes `is_version_greater_than`
to return true. This change disregards the "+cu" code in the version number for
that comparison.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]