This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 91c64b7  [Fix] Fix missing int8 for torch<2.6.0 (#323)
91c64b7 is described below

commit 91c64b71a7d571de83f2edf42d1be374d2dba7c5
Author: Yichen Yan <[email protected]>
AuthorDate: Mon Dec 8 21:40:19 2025 +0800

    [Fix] Fix missing int8 for torch<2.6.0 (#323)
    
    This PR fixes the issue that int8 type is missing for `torch<2.6.0`.
---
 .../tvm_ffi/utils/_build_optional_torch_c_dlpack.py  | 20 +++++---------------
 1 file changed, 5 insertions(+), 15 deletions(-)

diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py 
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 4512caf..fd0aff4 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -72,7 +72,7 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
     case ScalarType::UInt64:
       dtype.code = DLDataTypeCode::kDLUInt;
       break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
+#if (TORCH_VERSION_MAJOR > 2) || (TORCH_VERSION_MAJOR == 2 && 
TORCH_VERSION_MINOR >= 6)
     case ScalarType::Int1:
     case ScalarType::Int2:
     case ScalarType::Int3:
@@ -80,26 +80,16 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
     case ScalarType::Int5:
     case ScalarType::Int6:
     case ScalarType::Int7:
-    case ScalarType::Char:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
 #endif
-    case ScalarType::Double:
-      dtype.code = DLDataTypeCode::kDLFloat;
-      break;
-    case ScalarType::Float:
-      dtype.code = DLDataTypeCode::kDLFloat;
-      break;
+    case ScalarType::Char:
+    case ScalarType::Short:
     case ScalarType::Int:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
     case ScalarType::Long:
       dtype.code = DLDataTypeCode::kDLInt;
       break;
-    case ScalarType::Short:
-      dtype.code = DLDataTypeCode::kDLInt;
-      break;
     case ScalarType::Half:
+    case ScalarType::Float:
+    case ScalarType::Double:
       dtype.code = DLDataTypeCode::kDLFloat;
       break;
     case ScalarType::Bool:

Reply via email to