This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 929effa [DLPACK] Update dlpack conversion to enable f4 (#22)
929effa is described below
commit 929effa05c65d52f1608723c96efc9fdd24de746
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Sep 18 11:40:28 2025 -0400
[DLPACK] Update dlpack conversion to enable f4 (#22)
Also migrate the python header install into python build.
---
CMakeLists.txt | 2 +-
pyproject.toml | 2 +-
python/tvm_ffi/_optional_torch_c_dlpack.py | 174 ++++++++++++++++++++++++++++-
3 files changed, 175 insertions(+), 3 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9580f20..1fcf4e8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -243,13 +243,13 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Utils/ DESTINATION
cmake/Utils)
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt DESTINATION .)
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/tvm_ffi-config.cmake
DESTINATION cmake)
+ install(FILES
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
DESTINATION include/)
endif()
########## Install the related for normal cmake library ##########
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/ffi/ DESTINATION
include/tvm/ffi/)
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include/
DESTINATION include/)
-install(FILES
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
DESTINATION include/)
install(TARGETS tvm_ffi_shared DESTINATION lib)
# ship additional dSYM files for debugging symbols on if available
if (APPLE)
diff --git a/pyproject.toml b/pyproject.toml
index 77a82d6..c7ebd39 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b3"
+version = "0.1.0b4"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 500b684..b227c46 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -126,6 +126,8 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
case ScalarType::Float4_e2m1fn_x2:
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
+ dtype.lanes = 2;
+ dtype.bits = 4;
break;
#endif
default:
@@ -282,6 +284,176 @@ static Device getATenDeviceForDLPackv1(DLDeviceType type,
c10::DeviceIndex index
}
}
+ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) {
+ ScalarType stype = ScalarType::Undefined;
+ if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
+ TORCH_CHECK(
+ dtype.lanes == 1,
+ "ATen does not support lanes != 1 for dtype code",
std::to_string(dtype.code));
+ }
+ switch (dtype.code) {
+ case DLDataTypeCode::kDLUInt:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Byte;
+ break;
+ case 16:
+ stype = ScalarType::UInt16;
+ break;
+ case 32:
+ stype = ScalarType::UInt32;
+ break;
+ case 64:
+ stype = ScalarType::UInt64;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLInt:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Char;
+ break;
+ case 16:
+ stype = ScalarType::Short;
+ break;
+ case 32:
+ stype = ScalarType::Int;
+ break;
+ case 64:
+ stype = ScalarType::Long;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kInt bits ", std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat:
+ switch (dtype.bits) {
+ case 16:
+ stype = ScalarType::Half;
+ break;
+ case 32:
+ stype = ScalarType::Float;
+ break;
+ case 64:
+ stype = ScalarType::Double;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLBfloat:
+ switch (dtype.bits) {
+ case 16:
+ stype = ScalarType::BFloat16;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLComplex:
+ switch (dtype.bits) {
+ case 32:
+ stype = ScalarType::ComplexHalf;
+ break;
+ case 64:
+ stype = ScalarType::ComplexFloat;
+ break;
+ case 128:
+ stype = ScalarType::ComplexDouble;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLBool:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Bool;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat8_e5m2:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Float8_e5m2;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat8_e5m2 bits ",
std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat8_e5m2fnuz:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Float8_e5m2fnuz;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat8_e5m2fnuz bits ",
std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat8_e4m3fn:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Float8_e4m3fn;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat8_e4m3fn bits ",
std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat8_e4m3fnuz:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Float8_e4m3fnuz;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat8_e4m3fnuz bits ",
std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat8_e8m0fnu:
+ switch (dtype.bits) {
+ case 8:
+ stype = ScalarType::Float8_e8m0fnu;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat8_e8m0fnu bits ",
std::to_string(dtype.bits));
+ }
+ break;
+ case DLDataTypeCode::kDLFloat4_e2m1fn:
+ switch (dtype.bits) {
+ case 4:
+ switch (dtype.lanes) {
+ case 2:
+ stype = ScalarType::Float4_e2m1fn_x2;
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat4_e2m1fn lanes ",
std::to_string(dtype.lanes));
+ }
+ break;
+ default:
+ TORCH_CHECK(
+ false, "Unsupported kDLFloat4_e2m1fn bits ",
std::to_string(dtype.bits));
+ }
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
+ }
+ return stype;
+}
// This function constructs a Tensor from a memory managed DLPack which
// may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
@@ -297,7 +469,7 @@ at::Tensor fromDLPackImpl(T* src,
std::function<void(void*)> deleter) {
DLTensor& dl_tensor = src->dl_tensor;
Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type,
dl_tensor.device.device_id, dl_tensor.data);
- ScalarType stype = toScalarType(dl_tensor.dtype);
+ ScalarType stype = toScalarTypeForDLPackv1(dl_tensor.dtype);
if (!dl_tensor.strides) {
return at::from_blob(