This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ea7a1dd421603cda511865219395b6c6cf8d4860
Author: tqchen <[email protected]>
AuthorDate: Mon Apr 21 18:41:01 2025 -0400

    [FFI] DType support for dlpack v1.1
---
 ffi/include/tvm/ffi/dtype.h | 114 ++++++++++++++++++++++++++++++++++----------
 ffi/tests/cpp/test_dtype.cc |  27 ++++++++++-
 2 files changed, 115 insertions(+), 26 deletions(-)

diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h
index 804edacd46..257b6bc158 100644
--- a/ffi/include/tvm/ffi/dtype.h
+++ b/ffi/include/tvm/ffi/dtype.h
@@ -43,12 +43,7 @@ namespace ffi {
  *
  * TOTO(tvm-team): update to latest DLPack types.
  */
-enum DLExtDataTypeCode {
-  kDLExtFloat8_e4m3fn = 6,
-  kDLExtFloat8_e5m2 = 7,
-  kDLExtFloat4_e2m1fn = 8,
-  kDLExtCustomBegin = 129
-};
+enum DLExtDataTypeCode { kDLExtCustomBegin = 129 };
 
 namespace details {
 /*!
@@ -121,15 +116,47 @@ inline void PrintDLDataTypeCodeAsStr(std::ostream& os, 
DLDataTypeCode type_code)
       os << "bfloat";
       break;
     }
-    case kDLExtFloat8_e4m3fn: {
+    case kDLFloat8_e3m4: {
+      os << "float8_e3m4";
+      break;
+    }
+    case kDLFloat8_e4m3: {
+      os << "float8_e4m3";
+      break;
+    }
+    case kDLFloat8_e4m3b11fnuz: {
+      os << "float8_e4m3b11fnuz";
+      break;
+    }
+    case kDLFloat8_e4m3fn: {
       os << "float8_e4m3fn";
       break;
     }
-    case kDLExtFloat8_e5m2: {
+    case kDLFloat8_e4m3fnuz: {
+      os << "float8_e4m3fnuz";
+      break;
+    }
+    case kDLFloat8_e5m2: {
       os << "float8_e5m2";
       break;
     }
-    case kDLExtFloat4_e2m1fn: {
+    case kDLFloat8_e5m2fnuz: {
+      os << "float8_e5m2fnuz";
+      break;
+    }
+    case kDLFloat8_e8m0fnu: {
+      os << "float8_e8m0fnu";
+      break;
+    }
+    case kDLFloat6_e2m3fn: {
+      os << "float6_e2m3fn";
+      break;
+    }
+    case kDLFloat6_e3m2fn: {
+      os << "float6_e3m2fn";
+      break;
+    }
+    case kDLFloat4_e2m1fn: {
       os << "float4_e2m1fn";
       break;
     }
@@ -164,8 +191,7 @@ inline std::ostream& operator<<(std::ostream& os, 
DLDataType dtype) {  // NOLINT
   details::PrintDLDataTypeCodeAsStr(os, 
static_cast<DLDataTypeCode>(dtype.code));
   if (dtype.code == kDLOpaqueHandle) return os;
   int16_t lanes = static_cast<int16_t>(dtype.lanes);
-  if (dtype.code != kDLExtFloat8_e4m3fn && dtype.code != kDLExtFloat8_e5m2 &&
-      dtype.code != kDLExtFloat4_e2m1fn) {
+  if (dtype.code < kDLFloat8_e3m4) {
     os << static_cast<int>(dtype.bits);
   }
   if (lanes > 1) {
@@ -223,22 +249,60 @@ inline DLDataType StringToDLDataType(const std::string& 
str) {
     return dtype;
   };
 
-  if (str.substr(0, 3) == "int") {
+  if (str.compare(0, 3, "int") == 0) {
     dtype.code = kDLInt;
     scan = str.c_str() + 3;
-  } else if (str.substr(0, 4) == "uint") {
+  } else if (str.compare(0, 4, "uint") == 0) {
     dtype.code = kDLUInt;
     scan = str.c_str() + 4;
-  } else if (str.substr(0, 13) == "float4_e2m1fn") {
-    return parse_float(str, 13, DLExtDataTypeCode::kDLExtFloat4_e2m1fn, 4);
-  } else if (str.substr(0, 13) == "float8_e4m3fn") {
-    return parse_float(str, 13, DLExtDataTypeCode::kDLExtFloat8_e4m3fn, 8);
-  } else if (str.substr(0, 11) == "float8_e5m2") {
-    return parse_float(str, 11, DLExtDataTypeCode::kDLExtFloat8_e5m2, 8);
-  } else if (str.substr(0, 5) == "float") {
-    dtype.code = kDLFloat;
-    scan = str.c_str() + 5;
-  } else if (str.substr(0, 6) == "handle") {
+  } else if (str.compare(0, 5, "float") == 0) {
+    if (str.compare(5, 2, "8_") == 0) {
+      if (str.compare(7, 4, "e3m4") == 0) {
+        return parse_float(str, 11, kDLFloat8_e3m4, 8);
+      } else if (str.compare(7, 4, "e4m3") == 0) {
+        if (str.compare(11, 7, "b11fnuz") == 0) {
+          return parse_float(str, 18, kDLFloat8_e4m3b11fnuz, 8);
+        } else if (str.compare(11, 2, "fn") == 0) {
+          if (str.compare(13, 2, "uz") == 0) {
+            return parse_float(str, 15, kDLFloat8_e4m3fnuz, 8);
+          } else {
+            return parse_float(str, 13, kDLFloat8_e4m3fn, 8);
+          }
+        } else {
+          return parse_float(str, 11, kDLFloat8_e4m3, 8);
+        }
+      } else if (str.compare(7, 8, "e5m2fnuz") == 0) {
+        return parse_float(str, 15, kDLFloat8_e5m2fnuz, 8);
+      } else if (str.compare(7, 4, "e5m2") == 0) {
+        return parse_float(str, 11, kDLFloat8_e5m2, 8);
+      } else if (str.compare(7, 7, "e8m0fnu") == 0) {
+        return parse_float(str, 14, kDLFloat8_e8m0fnu, 8);
+      } else {
+        TVM_FFI_THROW(ValueError) << "unknown float8 type `" << str << '`';
+        TVM_FFI_UNREACHABLE();
+      }
+    } else if (str.compare(5, 2, "6_") == 0) {
+      if (str.compare(7, 6, "e2m3fn") == 0) {
+        return parse_float(str, 13, kDLFloat6_e2m3fn, 6);
+      } else if (str.compare(7, 6, "e3m2fn") == 0) {
+        return parse_float(str, 13, kDLFloat6_e3m2fn, 6);
+      } else {
+        TVM_FFI_THROW(ValueError) << "unknown float6 type `" << str << '`';
+        TVM_FFI_UNREACHABLE();
+      }
+    } else if (str.compare(5, 2, "4_") == 0) {
+      // kFloat4_e2m1fn
+      if (str.compare(7, 6, "e2m1fn") == 0) {
+        return parse_float(str, 13, kDLFloat4_e2m1fn, 4);
+      } else {
+        TVM_FFI_THROW(ValueError) << "unknown float4 type `" << str << '`';
+        TVM_FFI_UNREACHABLE();
+      }
+    } else {
+      dtype.code = kDLFloat;
+      scan = str.c_str() + 5;
+    }
+  } else if (str.compare(0, 6, "handle") == 0) {
     dtype.code = kDLOpaqueHandle;
     dtype.bits = 64;  // handle uses 64 bit by default.
     scan = str.c_str() + 6;
@@ -247,11 +311,11 @@ inline DLDataType StringToDLDataType(const std::string& 
str) {
     dtype.bits = 1;
     dtype.lanes = 1;
     return dtype;
-  } else if (str.substr(0, 6) == "bfloat") {
+  } else if (str.compare(0, 6, "bfloat") == 0) {
     dtype.code = kDLBfloat;
     dtype.bits = 16;
     scan = str.c_str() + 6;
-  } else if (str.substr(0, 6) == "custom") {
+  } else if (str.compare(0, 6, "custom") == 0) {
     dtype.code = details::ParseCustomDataTypeCode(str, &scan);
   } else {
     scan = str.c_str();
diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc
index ad769b740a..3e3e43430e 100644
--- a/ffi/tests/cpp/test_dtype.cc
+++ b/ffi/tests/cpp/test_dtype.cc
@@ -44,11 +44,36 @@ TEST(DType, StringConversion) {
   EXPECT_EQ(StringToDLDataType("bfloat16x2"), dtype);
 
   // test float8
-  dtype = DLDataType{kDLExtFloat8_e4m3fn, 8, 2};
+  dtype = DLDataType{kDLFloat8_e4m3fn, 8, 2};
   EXPECT_EQ(DLDataTypeToString(dtype), "float8_e4m3fnx2");
   EXPECT_EQ(StringToDLDataType("float8_e4m3fnx2"), dtype);
 }
 
+TEST(DType, StringConversionAllDLPackTypes) {
+  std::vector<std::pair<DLDataType, std::string>> test_cases = {
+      {DLDataType{kDLFloat, 32, 1}, "float32"},
+      {DLDataType{kDLInt, 16, 1}, "int16"},
+      {DLDataType{kDLUInt, 16, 1}, "uint16"},
+      {DLDataType{kDLBfloat, 16, 1}, "bfloat16"},
+      {DLDataType{kDLFloat8_e3m4, 8, 1}, "float8_e3m4"},
+      {DLDataType{kDLFloat8_e4m3, 8, 1}, "float8_e4m3"},
+      {DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}, "float8_e4m3b11fnuz"},
+      {DLDataType{kDLFloat8_e4m3fn, 8, 1}, "float8_e4m3fn"},
+      {DLDataType{kDLFloat8_e4m3fnuz, 8, 1}, "float8_e4m3fnuz"},
+      {DLDataType{kDLFloat8_e5m2, 8, 1}, "float8_e5m2"},
+      {DLDataType{kDLFloat8_e5m2fnuz, 8, 1}, "float8_e5m2fnuz"},
+      {DLDataType{kDLFloat8_e8m0fnu, 8, 1}, "float8_e8m0fnu"},
+      {DLDataType{kDLFloat6_e2m3fn, 6, 1}, "float6_e2m3fn"},
+      {DLDataType{kDLFloat6_e3m2fn, 6, 1}, "float6_e3m2fn"},
+      {DLDataType{kDLFloat4_e2m1fn, 4, 1}, "float4_e2m1fn"},
+  };
+
+  for (const auto& [dtype, str] : test_cases) {
+    EXPECT_EQ(DLDataTypeToString(dtype), str);
+    EXPECT_EQ(StringToDLDataType(str), dtype);
+  }
+}
+
 TEST(DataType, AnyConversion) {
   AnyView view0;
   EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);

Reply via email to