This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 929effa  [DLPACK] Update dlpack conversion to enable f4 (#22)
929effa is described below

commit 929effa05c65d52f1608723c96efc9fdd24de746
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Sep 18 11:40:28 2025 -0400

    [DLPACK] Update dlpack conversion to enable f4 (#22)
    
    Also migrate the python header install into python build.
---
 CMakeLists.txt                             |   2 +-
 pyproject.toml                             |   2 +-
 python/tvm_ffi/_optional_torch_c_dlpack.py | 174 ++++++++++++++++++++++++++++-
 3 files changed, 175 insertions(+), 3 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9580f20..1fcf4e8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -243,13 +243,13 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
   install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Utils/ DESTINATION 
cmake/Utils)
   install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt DESTINATION .)
   install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/tvm_ffi-config.cmake 
DESTINATION cmake)
+  install(FILES 
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
DESTINATION include/)
 endif()
 
 ########## Install the related for normal cmake library ##########
 
 install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/ffi/ DESTINATION 
include/tvm/ffi/)
 install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include/ 
DESTINATION include/)
-install(FILES 
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
DESTINATION include/)
 install(TARGETS tvm_ffi_shared  DESTINATION lib)
 # ship additional dSYM files for debugging symbols on if available
 if (APPLE)
diff --git a/pyproject.toml b/pyproject.toml
index 77a82d6..c7ebd39 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0b3"
+version = "0.1.0b4"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 500b684..b227c46 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -126,6 +126,8 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
 #if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
     case ScalarType::Float4_e2m1fn_x2:
       dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
+      dtype.lanes = 2;
+      dtype.bits = 4;
       break;
 #endif
    default:
@@ -282,6 +284,176 @@ static Device getATenDeviceForDLPackv1(DLDeviceType type, 
c10::DeviceIndex index
   }
 }
 
+ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) {
+  ScalarType stype = ScalarType::Undefined;
+  if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
+    TORCH_CHECK(
+        dtype.lanes == 1,
+        "ATen does not support lanes != 1 for dtype code", 
std::to_string(dtype.code));
+  }
+  switch (dtype.code) {
+    case DLDataTypeCode::kDLUInt:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Byte;
+          break;
+        case 16:
+          stype = ScalarType::UInt16;
+          break;
+        case 32:
+          stype = ScalarType::UInt32;
+          break;
+        case 64:
+          stype = ScalarType::UInt64;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLInt:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Char;
+          break;
+        case 16:
+          stype = ScalarType::Short;
+          break;
+        case 32:
+          stype = ScalarType::Int;
+          break;
+        case 64:
+          stype = ScalarType::Long;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kInt bits ", std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat:
+      switch (dtype.bits) {
+        case 16:
+          stype = ScalarType::Half;
+          break;
+        case 32:
+          stype = ScalarType::Float;
+          break;
+        case 64:
+          stype = ScalarType::Double;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLBfloat:
+      switch (dtype.bits) {
+        case 16:
+          stype = ScalarType::BFloat16;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLComplex:
+      switch (dtype.bits) {
+        case 32:
+          stype = ScalarType::ComplexHalf;
+          break;
+        case 64:
+          stype = ScalarType::ComplexFloat;
+          break;
+        case 128:
+          stype = ScalarType::ComplexDouble;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLBool:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Bool;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat8_e5m2:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Float8_e5m2;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLFloat8_e5m2 bits ", 
std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat8_e5m2fnuz:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Float8_e5m2fnuz;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLFloat8_e5m2fnuz bits ", 
std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat8_e4m3fn:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Float8_e4m3fn;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLFloat8_e4m3fn bits ", 
std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat8_e4m3fnuz:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Float8_e4m3fnuz;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLFloat8_e4m3fnuz bits ", 
std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat8_e8m0fnu:
+      switch (dtype.bits) {
+        case 8:
+          stype = ScalarType::Float8_e8m0fnu;
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLFloat8_e8m0fnu bits ", 
std::to_string(dtype.bits));
+      }
+      break;
+    case DLDataTypeCode::kDLFloat4_e2m1fn:
+      switch (dtype.bits) {
+        case 4:
+          switch (dtype.lanes) {
+            case 2:
+              stype = ScalarType::Float4_e2m1fn_x2;
+              break;
+            default:
+              TORCH_CHECK(
+                false, "Unsupported kDLFloat4_e2m1fn lanes ", 
std::to_string(dtype.lanes));
+          }
+          break;
+        default:
+          TORCH_CHECK(
+              false, "Unsupported kDLFloat4_e2m1fn bits ", 
std::to_string(dtype.bits));
+      }
+      break;
+    default:
+      TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
+  }
+  return stype;
+}
 
 // This function constructs a Tensor from a memory managed DLPack which
 // may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
@@ -297,7 +469,7 @@ at::Tensor fromDLPackImpl(T* src, 
std::function<void(void*)> deleter) {
 
   DLTensor& dl_tensor = src->dl_tensor;
   Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, 
dl_tensor.device.device_id, dl_tensor.data);
-  ScalarType stype = toScalarType(dl_tensor.dtype);
+  ScalarType stype = toScalarTypeForDLPackv1(dl_tensor.dtype);
 
   if (!dl_tensor.strides) {
     return at::from_blob(

Reply via email to