This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 4fec972 [Feat] Support lower PyTorch versions in dtype handling (#414)
4fec972 is described below
commit 4fec9721c6bd6862bd371cf65d3ac6a1e8eb737e
Author: Nan <[email protected]>
AuthorDate: Sat Jan 17 11:29:51 2026 +0800
[Feat] Support lower PyTorch versions in dtype handling (#414)
This PR partially adds support for older versions of PyTorch (Enables
successful `import tvm_ffi`). Previous discussion:
https://github.com/apache/tvm-ffi/issues/381
This is especially helpful in scenarios where the environment uses an
older PyTorch version and patching/hacking PyTorch is not feasible.
Verified with PyTorch 1.10.2, 1.14.0a0, and 2.0
---
python/tvm_ffi/cython/dtype.pxi | 26 +++++++++++++++-----------
1 file changed, 15 insertions(+), 11 deletions(-)
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 3b8530a..c320c63 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -180,9 +180,6 @@ if torch is not None:
torch.int64: DLDataType(0, 64, 1),
torch.long: DLDataType(0, 64, 1),
torch.uint8: DLDataType(1, 8, 1),
- torch.uint16: DLDataType(1, 16, 1),
- torch.uint32: DLDataType(1, 32, 1),
- torch.uint64: DLDataType(1, 64, 1),
torch.float16: DLDataType(2, 16, 1),
torch.half: DLDataType(2, 16, 1),
torch.float32: DLDataType(2, 32, 1),
@@ -191,15 +188,22 @@ if torch is not None:
torch.double: DLDataType(2, 64, 1),
torch.bfloat16: DLDataType(4, 16, 1),
torch.bool: DLDataType(6, 8, 1),
- torch.float8_e4m3fn: DLDataType(10, 8, 1),
- torch.float8_e4m3fnuz: DLDataType(11, 8, 1),
- torch.float8_e5m2: DLDataType(12, 8, 1),
- torch.float8_e5m2fnuz: DLDataType(13, 8, 1),
}
- if hasattr(torch, "float8_e8m0fnu"):
- TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float8_e8m0fnu] = DLDataType(14, 8,
1)
- if hasattr(torch, "float4_e2m1fn_x2"):
- TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float4_e2m1fn_x2] = DLDataType(17,
4, 2)
+
+ extra_types = [
+ ("uint16", DLDataType(1, 16, 1)),
+ ("uint32", DLDataType(1, 32, 1)),
+ ("uint64", DLDataType(1, 64, 1)),
+ ("float8_e4m3fn", DLDataType(10, 8, 1)),
+ ("float8_e4m3fnuz", DLDataType(11, 8, 1)),
+ ("float8_e5m2", DLDataType(12, 8, 1)),
+ ("float8_e5m2fnuz", DLDataType(13, 8, 1)),
+ ("float8_e8m0fnu", DLDataType(14, 8, 1)),
+ ("float4_e2m1fn_x2", DLDataType(17, 4, 2)),
+ ]
+ for attr_name, dl_dtype in extra_types:
+ if hasattr(torch, attr_name):
+ TORCH_DTYPE_TO_DL_DATA_TYPE[getattr(torch, attr_name)] = dl_dtype
def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
cdef DLDataType dl_data_type = TORCH_DTYPE_TO_DL_DATA_TYPE[torch_dtype]