This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fec972  [Feat] Support lower PyTorch versions in dtype handling (#414)
4fec972 is described below

commit 4fec9721c6bd6862bd371cf65d3ac6a1e8eb737e
Author: Nan <[email protected]>
AuthorDate: Sat Jan 17 11:29:51 2026 +0800

    [Feat] Support lower PyTorch versions in dtype handling (#414)
    
    This PR partially adds support for older versions of PyTorch (Enables
    successful `import tvm_ffi`). Previous discussion:
    https://github.com/apache/tvm-ffi/issues/381
    
    This is especially helpful in scenarios where the environment uses an
    older PyTorch version and patching/hacking PyTorch is not feasible.
    
    
    
    Verified with PyTorch 1.10.2, 1.14.0a0,  and 2.0
---
 python/tvm_ffi/cython/dtype.pxi | 26 +++++++++++++++-----------
 1 file changed, 15 insertions(+), 11 deletions(-)

diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 3b8530a..c320c63 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -180,9 +180,6 @@ if torch is not None:
         torch.int64: DLDataType(0, 64, 1),
         torch.long: DLDataType(0, 64, 1),
         torch.uint8: DLDataType(1, 8, 1),
-        torch.uint16: DLDataType(1, 16, 1),
-        torch.uint32: DLDataType(1, 32, 1),
-        torch.uint64: DLDataType(1, 64, 1),
         torch.float16: DLDataType(2, 16, 1),
         torch.half: DLDataType(2, 16, 1),
         torch.float32: DLDataType(2, 32, 1),
@@ -191,15 +188,22 @@ if torch is not None:
         torch.double: DLDataType(2, 64, 1),
         torch.bfloat16: DLDataType(4, 16, 1),
         torch.bool: DLDataType(6, 8, 1),
-        torch.float8_e4m3fn: DLDataType(10, 8, 1),
-        torch.float8_e4m3fnuz: DLDataType(11, 8, 1),
-        torch.float8_e5m2: DLDataType(12, 8, 1),
-        torch.float8_e5m2fnuz: DLDataType(13, 8, 1),
     }
-    if hasattr(torch, "float8_e8m0fnu"):
-        TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float8_e8m0fnu] = DLDataType(14, 8, 
1)
-    if hasattr(torch, "float4_e2m1fn_x2"):
-        TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float4_e2m1fn_x2] = DLDataType(17, 
4, 2)
+
+    extra_types = [
+        ("uint16", DLDataType(1, 16, 1)),
+        ("uint32", DLDataType(1, 32, 1)),
+        ("uint64", DLDataType(1, 64, 1)),
+        ("float8_e4m3fn", DLDataType(10, 8, 1)),
+        ("float8_e4m3fnuz", DLDataType(11, 8, 1)),
+        ("float8_e5m2", DLDataType(12, 8, 1)),
+        ("float8_e5m2fnuz", DLDataType(13, 8, 1)),
+        ("float8_e8m0fnu", DLDataType(14, 8, 1)),
+        ("float4_e2m1fn_x2", DLDataType(17, 4, 2)),
+    ]
+    for attr_name, dl_dtype in extra_types:
+        if hasattr(torch, attr_name):
+            TORCH_DTYPE_TO_DL_DATA_TYPE[getattr(torch, attr_name)] = dl_dtype
 
     def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
         cdef DLDataType dl_data_type = TORCH_DTYPE_TO_DL_DATA_TYPE[torch_dtype]

Reply via email to