This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 775e05064b [DataType] Rename FP8 dtypes to standard names (#17712)
775e05064b is described below

commit 775e05064b52c846390120d9613e7a89bb52aef7
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 6 17:35:49 2025 -0500

    [DataType] Rename FP8 dtypes to standard names (#17712)
    
    This PR renames the FP8 dtypes in TVM according to standards:
    
    * `e4m3_float8` is renamed to `float8_e4m3fn`,
    * `e5m2_float8` is renamed to `float8_e5m2`.
    
    This aligns with dtype names in PyTorch and ml_dtypes.
---
 include/tvm/runtime/data_type.h                    | 77 ++++++++++++++--------
 include/tvm/script/ir_builder/tir/ir.h             |  6 +-
 python/tvm/_ffi/runtime_ctypes.py                  | 42 ++++++------
 python/tvm/contrib/tvmjs.py                        |  8 +--
 python/tvm/relax/backend/cuda/cublas.py            | 10 +--
 python/tvm/relax/backend/rocm/hipblas.py           |  8 +--
 python/tvm/runtime/ndarray.py                      |  8 +--
 python/tvm/script/ir_builder/tir/ir.py             | 66 +++++++++----------
 python/tvm/tir/tensor_intrin/cuda.py               | 38 +++++------
 src/ir/expr.cc                                     |  7 +-
 src/relax/backend/contrib/utils.h                  | 10 +--
 src/runtime/contrib/cublas/cublas.cc               |  5 +-
 src/runtime/ndarray.cc                             |  2 +-
 src/script/ir_builder/tir/ir.cc                    | 12 ++--
 src/support/scalars.h                              |  4 +-
 src/target/llvm/codegen_llvm.cc                    |  4 +-
 src/target/source/codegen_cuda.cc                  | 16 ++---
 src/tir/op/op.cc                                   |  8 +--
 src/tir/transforms/dtype_conversion.h              |  2 +-
 .../python/codegen/test_target_codegen_cuda_fp8.py | 44 ++++++-------
 tests/python/ir/test_datatype_nv_fp8.py            |  8 ++-
 tests/python/ir/test_dtype.py                      |  9 +--
 tests/python/relax/test_codegen_cublas.py          | 16 ++---
 tests/python/relax/test_op_inspect.py              |  6 +-
 tests/python/relax/test_op_qdq.py                  | 12 ++--
 ..._tir_schedule_tensorize_ldmatrix_mma_numeric.py | 27 ++++----
 .../test_tir_transform_fp8_legalize.py             | 16 ++---
 .../python/tvmscript/test_tvmscript_printer_tir.py | 11 ++--
 28 files changed, 255 insertions(+), 227 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 76e5e3833f..65fd0c98fd 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -56,9 +56,9 @@ class DataType {
     kFloat = kDLFloat,
     kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
     kBFloat = kDLBfloat,
-    kE4M3Float = 6U,
-    kE5M2Float = 7U,
-    kFloat4E2M1Fn = 8U,
+    kFloat8_e4m3fn = 6U,
+    kFloat8_e5m2 = 7U,
+    kFloat4_e2m1fn = 8U,
     kCustomBegin = 129
   };
   /*! \brief default constructor */
@@ -85,10 +85,10 @@ class DataType {
     if (code == kBFloat) {
       ICHECK_EQ(bits, 16);
     }
-    if (code == kE4M3Float || code == kE5M2Float) {
+    if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) {
       ICHECK_EQ(bits, 8);
     }
-    if (code == kFloat4E2M1Fn) {
+    if (code == kFloat4_e2m1fn) {
       ICHECK_EQ(bits, 4);
     }
   }
@@ -126,15 +126,15 @@ class DataType {
   bool is_float() const { return code() == DataType::kFloat; }
   /*! \return whether type is a float8 type. */
   bool is_float8() const {
-    return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
-            code() == DataType::kE5M2Float) &&
+    return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
+            code() == DataType::kFloat8_e5m2) &&
            bits() == 8;
   }
   /*! \return whether type is a float4 type. */
-  bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() 
== 4; }
-  bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && 
bits() == 8); }
-  bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && 
bits() == 8); }
-  bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && 
bits() == 4); }
+  bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits() 
== 4; }
+  bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn 
&& bits() == 8); }
+  bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 && 
bits() == 8); }
+  bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn 
&& bits() == 4); }
   /*! \return whether type is a float16 type. */
   bool is_float16() const { return is_float() && bits() == 16; }
   /*! \return whether type is a bfloat16 type. */
@@ -252,19 +252,19 @@ class DataType {
    * \param lanes The number of lanes
    * \return The constructed data type.
    */
-  static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, 
lanes); }
+  static DataType NVFloat8E4M3(int lanes = 1) { return 
DataType(kFloat8_e4m3fn, 8, lanes); }
   /*!
    * \brief Construct NV float8 e5m2 datatype.
    * \param lanes The number of lanes
    * \return The constructed data type.
    */
-  static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, 
lanes); }
+  static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 
8, lanes); }
   /*!
    * \brief Construct NV float4_e2m1fn datatype.
    * \param lanes The number of lanes
    * \return The constructed data type.
    */
-  static DataType NVFloat4E2M1FN(int lanes = 1) { return 
DataType(kFloat4E2M1Fn, 4, lanes); }
+  static DataType NVFloat4E2M1FN(int lanes = 1) { return 
DataType(kFloat4_e2m1fn, 4, lanes); }
   /*!
    * \brief Construct a bool type.
    * \param lanes The number of lanes.
@@ -393,11 +393,11 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode 
type_code) {
       return "handle";
     case kDLBfloat:
       return "bfloat";
-    case DataType::kE4M3Float:
-      return "e4m3_float";
-    case DataType::kE5M2Float:
-      return "e5m2_float";
-    case DataType::kFloat4E2M1Fn:
+    case DataType::kFloat8_e4m3fn:
+      return "float8_e4m3fn";
+    case DataType::kFloat8_e5m2:
+      return "float8_e5m2";
+    case DataType::kFloat4_e2m1fn:
       return "float4_e2m1fn";
     default:
       LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
@@ -420,7 +420,10 @@ inline std::ostream& operator<<(std::ostream& os, 
DLDataType t) {  // NOLINT(*)
   }
   if (t.code == kTVMOpaqueHandle) return os;
   int16_t lanes = static_cast<int16_t>(t.lanes);
-  os << static_cast<int>(t.bits);
+  if (t.code != DataType::kFloat8_e4m3fn && t.code != DataType::kFloat8_e5m2 &&
+      t.code != DataType::kFloat4_e2m1fn) {
+    os << static_cast<int>(t.bits);
+  }
   if (lanes > 1) {
     os << 'x' << lanes;
   } else if (lanes < -1) {
@@ -458,7 +461,7 @@ inline DLDataType String2DLDataType(std::string s) {
     scan = s.c_str() + 4;
   } else if (s.substr(0, 13) == "float4_e2m1fn") {
     // Avoid being treated as "float"
-    t.code = DataType::kFloat4E2M1Fn;
+    t.code = DataType::kFloat4_e2m1fn;
     t.bits = 4;
     scan = s.c_str() + 13;
     char* endpt = nullptr;
@@ -468,6 +471,30 @@ inline DLDataType String2DLDataType(std::string s) {
     }
     ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
     return t;
+  } else if (s.substr(0, 13) == "float8_e4m3fn") {
+    // Avoid being treated as "float"
+    t.code = DataType::kFloat8_e4m3fn;
+    t.bits = 8;
+    scan = s.c_str() + 13;
+    char* endpt = nullptr;
+    if (*scan == 'x') {
+      t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
+      scan = endpt;
+    }
+    ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
+    return t;
+  } else if (s.substr(0, 11) == "float8_e5m2") {
+    // Avoid being treated as "float"
+    t.code = DataType::kFloat8_e5m2;
+    t.bits = 8;
+    scan = s.c_str() + 11;
+    char* endpt = nullptr;
+    if (*scan == 'x') {
+      t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
+      scan = endpt;
+    }
+    ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
+    return t;
   } else if (s.substr(0, 5) == "float") {
     t.code = kDLFloat;
     scan = s.c_str() + 5;
@@ -484,14 +511,6 @@ inline DLDataType String2DLDataType(std::string s) {
     t.code = DataType::kBFloat;
     t.bits = 16;
     scan = s.c_str() + 6;
-  } else if (s.substr(0, 10) == "e4m3_float") {
-    t.code = DataType::kE4M3Float;
-    t.bits = 8;
-    scan = s.c_str() + 10;
-  } else if (s.substr(0, 10) == "e5m2_float") {
-    t.code = DataType::kE5M2Float;
-    t.bits = 8;
-    scan = s.c_str() + 10;
   } else if (s.substr(0, 6) == "custom") {
     t.code = ParseCustomDatatype(s, &scan);
   } else {
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index e78e0d51fd..e60a3859ac 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -502,10 +502,10 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, 
DataType::Int);
   TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32));              \
   TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
 
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, 
DataType::NVFloat8E4M3);
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, 
DataType::NVFloat8E5M2);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, 
DataType::NVFloat8E4M3);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, 
DataType::NVFloat8E5M2);
 
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, 
DataType::NVFloat4E2M1FN);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, 
DataType::NVFloat4E2M1FN);
 
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 3f4ceadd1d..317bd6bead 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -66,9 +66,9 @@ class DataTypeCode(object):
     FLOAT = 2
     HANDLE = 3
     BFLOAT = 4
-    E4M3Float = 6
-    E5M2Float = 7
-    FLOAT4E2M1FN = 8
+    Float8E4M3FN = 6
+    Float8E5M2 = 7
+    Float4E2M1FN = 8
 
 
 class DataType(ctypes.Structure):
@@ -81,9 +81,9 @@ class DataType(ctypes.Structure):
         DataTypeCode.FLOAT: "float",
         DataTypeCode.HANDLE: "handle",
         DataTypeCode.BFLOAT: "bfloat",
-        DataTypeCode.E4M3Float: "e4m3_float",
-        DataTypeCode.E5M2Float: "e5m2_float",
-        DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
+        DataTypeCode.Float8E4M3FN: "float8_e4m3fn",
+        DataTypeCode.Float8E5M2: "float8_e5m2",
+        DataTypeCode.Float4E2M1FN: "float4_e2m1fn",
     }
     NUMPY2STR = {
         np.dtype(np.bool_): "bool",
@@ -112,9 +112,9 @@ class DataType(ctypes.Structure):
         "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
         "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
         "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
-        "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, 
"lanes": 1},
-        "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, 
"lanes": 1},
-        "float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, 
"lanes": 1},
+        "float8_e4m3fn": {"type_code": DataTypeCode.Float8E4M3FN, "bits": 8, 
"lanes": 1},
+        "float8_e5m2": {"type_code": DataTypeCode.Float8E5M2, "bits": 8, 
"lanes": 1},
+        "float4_e2m1fn": {"type_code": DataTypeCode.Float4E2M1FN, "bits": 4, 
"lanes": 1},
         "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
         "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
         "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -157,9 +157,17 @@ class DataType(ctypes.Structure):
             head = head[4:]
         elif head.startswith("float4_e2m1fn"):
             # Avoid being treated as "float"
-            self.type_code = DataTypeCode.FLOAT4E2M1FN
+            self.type_code = DataTypeCode.Float4E2M1FN
             bits = 4
             head = ""
+        elif head.startswith("float8_e4m3fn"):
+            self.type_code = DataTypeCode.Float8E4M3FN
+            bits = 8
+            head = ""
+        elif head.startswith("float8_e5m2"):
+            self.type_code = DataTypeCode.Float8E5M2
+            bits = 8
+            head = ""
         elif head.startswith("float"):
             self.type_code = DataTypeCode.FLOAT
             head = head[5:]
@@ -170,12 +178,6 @@ class DataType(ctypes.Structure):
         elif head.startswith("bfloat"):
             self.type_code = DataTypeCode.BFLOAT
             head = head[6:]
-        elif head.startswith("e4m3_float"):
-            self.type_code = DataTypeCode.E4M3Float
-            head = head[10:]
-        elif head.startswith("e5m2_float"):
-            self.type_code = DataTypeCode.E5M2Float
-            head = head[10:]
         elif head.startswith("custom"):
             # pylint: disable=import-outside-toplevel
             import tvm.runtime._ffi_api
@@ -204,7 +206,9 @@ class DataType(ctypes.Structure):
 
             type_name = "custom[%s]" % 
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
         if self.type_code in [
-            DataTypeCode.FLOAT4E2M1FN,
+            DataTypeCode.Float8E4M3FN,
+            DataTypeCode.Float8E5M2,
+            DataTypeCode.Float4E2M1FN,
         ]:
             x = type_name
         else:
@@ -243,8 +247,8 @@ class DataType(ctypes.Structure):
 
 if ml_dtypes is not None:
     DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
-    DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
-    DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
     DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
 
 RPC_SESS_MASK = 128
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 9bff724df7..5a5a3a1c80 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -369,19 +369,19 @@ def load_ndarray_cache(cachepath: str, device: 
tvm.runtime.Device):
             arr = tvm.nd.empty(shape, dtype, device=device)
             assert offset + nbytes <= len(raw_data)
             buffer_source = raw_data[offset : offset + nbytes]
-            if dtype == "e4m3_float8":
+            if dtype == "float8_e4m3fn":
                 if ml_dtypes is not None:
                     dtype = ml_dtypes.float8_e4m3fn
                 else:
                     raise RuntimeError(
-                        "ml_dtypes is not installed, cannot convert 
e4m3_float8 array to numpy."
+                        "ml_dtypes is not installed, cannot convert 
float8_e4m3fn array to numpy."
                     )
-            if dtype == "e5m2_float8":
+            if dtype == "float8_e5m2":
                 if ml_dtypes is not None:
                     dtype = ml_dtypes.float8_e5m2
                 else:
                     raise RuntimeError(
-                        "ml_dtypes is not installed, cannot convert 
e5m2_float8 array to numpy."
+                        "ml_dtypes is not installed, cannot convert 
float8_e5m2 array to numpy."
                     )
             if encode_format == "f32-to-bf16" and dtype == "float32":
                 data = np.frombuffer(buffer_source, 
dtype="uint16").reshape(shape)
diff --git a/python/tvm/relax/backend/cuda/cublas.py 
b/python/tvm/relax/backend/cuda/cublas.py
index 287b18b440..6828381e68 100644
--- a/python/tvm/relax/backend/cuda/cublas.py
+++ b/python/tvm/relax/backend/cuda/cublas.py
@@ -27,18 +27,18 @@ from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import (
-    make_matmul_pattern,
     make_matmul_dequantize_pattern,
     make_matmul_multiply_pattern,
+    make_matmul_pattern,
 )
 from ..utils import has_leaking_intermediate_variables
 
 
 def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
     """Check if dtypes in the given workload are supported by cuBLAS BYOC."""
-    if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
-        # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
-        return out_dtype != "e5m2_float8"
+    if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
+        # The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn'
+        return out_dtype != "float8_e5m2"
     return (
         (lhs_dtype == "float16" and rhs_dtype == "float16")
         or (lhs_dtype == "float32" and rhs_dtype == "float32")
@@ -83,7 +83,7 @@ def _check_matmul(context: PatternCheckContext) -> bool:
         if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or 
rhs_shape[-1] % 4 != 0:
             # Rows number must be multiples of 4 for IGEMM
             return False
-    elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+    elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
         matmul_rhs_var = matmul_call.args[1]
         rhs_transposed = False
         if matmul_rhs_var in context.matched_bindings:
diff --git a/python/tvm/relax/backend/rocm/hipblas.py 
b/python/tvm/relax/backend/rocm/hipblas.py
index c0accc1473..63c72b660d 100644
--- a/python/tvm/relax/backend/rocm/hipblas.py
+++ b/python/tvm/relax/backend/rocm/hipblas.py
@@ -30,9 +30,9 @@ from ..utils import has_leaking_intermediate_variables
 
 def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):  # pylint: 
disable=unused-argument
     """Check if dtypes in the given workload are supported by hipblas BYOC."""
-    if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
-        # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
-        # return out_dtype != "e5m2_float8"
+    if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
+        # The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn'
+        # return out_dtype != "float8_e5m2"
         return False
     return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
         lhs_dtype == "int8" and rhs_dtype == "int8"
@@ -61,7 +61,7 @@ def _check_matmul(context: PatternCheckContext) -> bool:
 
     if lhs_dtype == "int8" and rhs_dtype == "int8":
         return False
-    elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+    elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
         return False
 
     lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 47fcccf52b..d55334a154 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -249,19 +249,19 @@ class NDArray(NDArrayBase):
             dtype = "int8"
         if dtype == "bfloat16":
             dtype = "uint16"
-        if dtype == "e4m3_float8":
+        if dtype == "float8_e4m3fn":
             if ml_dtypes is not None:
                 dtype = ml_dtypes.float8_e4m3fn
             else:
                 raise RuntimeError(
-                    "ml_dtypes is not installed, cannot convert e4m3_float8 
array to numpy."
+                    "ml_dtypes is not installed, cannot convert float8_e4m3fn 
array to numpy."
                 )
-        if dtype == "e5m2_float8":
+        if dtype == "float8_e5m2":
             if ml_dtypes is not None:
                 dtype = ml_dtypes.float8_e5m2
             else:
                 raise RuntimeError(
-                    "ml_dtypes is not installed, cannot convert e5m2_float8 
array to numpy."
+                    "ml_dtypes is not installed, cannot convert float8_e5m2 
array to numpy."
                 )
         if dtype == "float4_e2m1fn":
             if ml_dtypes is not None:
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index c35df7a093..2fce022da3 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1442,27 +1442,27 @@ float16x64 = func_gen(("Float16x64"))
 float32x64 = func_gen(("Float32x64"))
 float64x64 = func_gen(("Float64x64"))
 
-e4m3_float8 = func_gen(("E4M3Float8"))
-e4m3_float8x4 = func_gen(("E4M3Float8x4"))
-e4m3_float8x8 = func_gen(("E4M3Float8x8"))
-e4m3_float8x16 = func_gen(("E4M3Float8x16"))
-e4m3_float8x32 = func_gen(("E4M3Float8x32"))
-e4m3_float8x64 = func_gen(("E4M3Float8x64"))
-
-e5m2_float8 = func_gen(("E5M2Float8"))
-e5m2_float8x4 = func_gen(("E5M2Float8x4"))
-e5m2_float8x8 = func_gen(("E5M2Float8x8"))
-e5m2_float8x16 = func_gen(("E5M2Float8x16"))
-e5m2_float8x32 = func_gen(("E5M2Float8x32"))
-e5m2_float8x64 = func_gen(("E5M2Float8x64"))
-
-float4_e2m1fn = func_gen(("Float4E2M1fn"))
-float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2"))
-float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4"))
-float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8"))
-float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16"))
-float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32"))
-float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64"))
+float8_e4m3fn = func_gen(("Float8E4M3FN"))
+float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4"))
+float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8"))
+float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16"))
+float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32"))
+float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64"))
+
+float8_e5m2 = func_gen(("Float8E5M2"))
+float8_e5m2x4 = func_gen(("Float8E5M2x4"))
+float8_e5m2x8 = func_gen(("Float8E5M2x8"))
+float8_e5m2x16 = func_gen(("Float8E5M2x16"))
+float8_e5m2x32 = func_gen(("Float8E5M2x32"))
+float8_e5m2x64 = func_gen(("Float8E5M2x64"))
+
+float4_e2m1fn = func_gen(("Float4E2M1FN"))
+float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2"))
+float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4"))
+float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8"))
+float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16"))
+float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32"))
+float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64"))
 
 
 # pylint: enable=invalid-name
@@ -2011,39 +2011,39 @@ __all__ = [
     "uint16x64",
     "uint32x64",
     "uint64x64",
-    "e4m3_float8",
-    "e5m2_float8",
+    "float8_e4m3fn",
+    "float8_e5m2",
     "float4_e2m1fn",
     "float16",
     "float32",
     "float64",
     "float4_e2m1fnx2",
-    "e4m3_float8x4",
-    "e5m2_float8x4",
+    "float8_e4m3fnx4",
+    "float8_e5m2x4",
     "float4_e2m1fnx4",
     "float16x4",
     "float32x4",
     "float64x4",
-    "e4m3_float8x8",
-    "e5m2_float8x8",
+    "float8_e4m3fnx8",
+    "float8_e5m2x8",
     "float4_e2m1fnx8",
     "float16x8",
     "float32x8",
     "float64x8",
-    "e4m3_float8x16",
-    "e5m2_float8x16",
+    "float8_e4m3fnx16",
+    "float8_e5m2x16",
     "float4_e2m1fnx16",
     "float16x16",
     "float32x16",
     "float64x16",
-    "e4m3_float8x32",
-    "e5m2_float8x32",
+    "float8_e4m3fnx32",
+    "float8_e5m2x32",
     "float4_e2m1fnx32",
     "float16x32",
     "float32x32",
     "float64x32",
-    "e4m3_float8x64",
-    "e5m2_float8x64",
+    "float8_e4m3fnx64",
+    "float8_e5m2x64",
     "float4_e2m1fnx64",
     "float16x64",
     "float32x64",
diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index e1ff18bc8f..57b1c3b873 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,13 +16,13 @@
 # under the License.
 # pylint: disable=invalid-name,missing-function-docstring,unused-variable
 """Intrinsics for tensorization on NVIDIA GPU."""
-from typing import Dict, Optional, Tuple, Literal
+from typing import Dict, Literal, Optional, Tuple
 
 from tvm._ffi import register_func
 from tvm.runtime import convert
 from tvm.script import tir as T
-from tvm.tir.function import PrimFunc
 from tvm.tir import Cast, IntImm, TensorIntrin
+from tvm.tir.function import PrimFunc
 
 
 def shared_16x16_to_ldmatrix_32x8_layout(i, j):
@@ -123,7 +123,7 @@ def get_ldmatrix_intrin(
             matrix_name == "B" or not transposed
         ), "Now only B matrix can be transposed for int8 matmul"
         assert k_dim == 32 and (
-            dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8"
+            dtype == "int8" or dtype == "float8_e4m3fn" or dtype == 
"float8_e5m2"
         ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
 
         if matrix_name == "B" and not transposed:
@@ -261,25 +261,25 @@ LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans"
 TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, 
"int8", "B", True))
 
 LDMATRIX_e4m3_A_INTRIN = "mma_ldmatrix_e4m3_a"
-TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, 
"e4m3_float8", "A", False))
+TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, 
"float8_e4m3fn", "A", False))
 
 LDMATRIX_e4m3_B_INTRIN = "mma_ldmatrix_e4m3_b"
-TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, 
"e4m3_float8", "B", False))
+TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, 
"float8_e4m3fn", "B", False))
 
 LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans"
 TensorIntrin.register(
-    LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", 
True)
+    LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "float8_e4m3fn", 
"B", True)
 )
 
 LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a"
-TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, 
"e5m2_float8", "A", False))
+TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, 
"float8_e5m2", "A", False))
 
 LDMATRIX_e5m2_B_INTRIN = "mma_ldmatrix_e5m2_b"
-TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, 
"e5m2_float8", "B", False))
+TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, 
"float8_e5m2", "B", False))
 
 LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans"
 TensorIntrin.register(
-    LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", 
True)
+    LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "float8_e5m2", "B", 
True)
 )
 
 
@@ -315,8 +315,8 @@ def get_mma_intrin(
         "float32": "fp32",
         "int8": "int8",
         "int32": "int32",
-        "e4m3_float8": "e4m3",
-        "e5m2_float8": "e5m2",
+        "float8_e4m3fn": "e4m3",
+        "float8_e5m2": "e5m2",
     }
     a_dtype_abbrv = dtype_abbrv[a_dtype]
     b_dtype_abbrv = dtype_abbrv[b_dtype]
@@ -522,25 +522,25 @@ TensorIntrin.register(
 MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32"
 TensorIntrin.register(
     MMA_e5m2e5m2f32_INTRIN,
-    *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False),
+    *get_mma_intrin(32, "float8_e5m2", "float8_e5m2", "float32", False, False),
 )
 
 MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b"
 TensorIntrin.register(
     MMA_e5m2e5m2f32_TRANS_B_INTRIN,
-    *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True),
+    *get_mma_intrin(32, "float8_e5m2", "float8_e5m2", "float32", False, True),
 )
 
 MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32"
 TensorIntrin.register(
     MMA_e4m3e4m3f32_INTRIN,
-    *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False),
+    *get_mma_intrin(32, "float8_e4m3fn", "float8_e4m3fn", "float32", False, 
False),
 )
 
 MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b"
 TensorIntrin.register(
     MMA_e4m3e4m3f32_TRANS_B_INTRIN,
-    *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True),
+    *get_mma_intrin(32, "float8_e4m3fn", "float8_e4m3fn", "float32", False, 
True),
 )
 
 
@@ -705,7 +705,7 @@ TensorIntrin.register(
 def get_mma_intrin_group(
     load_scope: Literal["shared", "shared.dyn"],
     store_scope: Literal["global", "shared", "shared.dyn"],
-    in_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"],
+    in_dtype: Literal["float16", "int8", "float8_e4m3fn", "float8_e5m2"],
     out_dtype: Literal["float16", "float32", "int32"],
     trans_a: bool,
     trans_b: bool,
@@ -752,7 +752,7 @@ def get_mma_intrin_group(
     """
     assert load_scope in ["shared", "shared.dyn"]
     assert store_scope in ["global", "shared", "shared.dyn"]
-    assert in_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"]
+    assert in_dtype in ["float16", "int8", "float8_e4m3fn", "float8_e5m2"]
     assert out_dtype in ["float16", "float32", "int32"]
 
     shape = "16x16"
@@ -761,8 +761,8 @@ def get_mma_intrin_group(
         "float16": "f16",
         "float32": "f32",
         "int8": "i8",
-        "e4m3_float8": "e4m3",
-        "e5m2_float8": "e5m2",
+        "float8_e4m3fn": "e4m3",
+        "float8_e5m2": "e5m2",
         "int32": "i32",
     }
     a_dtype = dtype_mapping[in_dtype]
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 766abf3483..8f188e95f0 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -132,15 +132,16 @@ FloatImm::FloatImm(DataType dtype, double value, Span 
span) {
       ICHECK_LE(value, support::kMaxBFloat16)
           << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
     } else if (dtype.is_float8()) {
-      double bound = (dtype.code() == DataType::kE4M3Float) ? 
support::kMaxE4M3 : support::kMaxE5M2;
+      double bound =
+          (dtype.code() == DataType::kFloat8_e4m3fn) ? support::kMaxE4M3FN : 
support::kMaxE5M2;
       ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " 
exceeds minimum of "
                                << dtype;
       ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " 
exceeds maximum of "
                               << dtype;
     } else if (dtype.is_float4()) {
-      ICHECK_GE(value, -support::kMaxE2M1)
+      ICHECK_GE(value, -support::kMaxE2M1FN)
           << "ValueError: Literal value " << value << " exceeds minimum of " 
<< dtype;
-      ICHECK_LE(value, support::kMaxE2M1)
+      ICHECK_LE(value, support::kMaxE2M1FN)
           << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
     }
   }
diff --git a/src/relax/backend/contrib/utils.h 
b/src/relax/backend/contrib/utils.h
index aa3928ce02..e63e99548c 100644
--- a/src/relax/backend/contrib/utils.h
+++ b/src/relax/backend/contrib/utils.h
@@ -72,10 +72,12 @@ inline std::string DType2String(const tvm::DataType dtype) {
   std::ostringstream os;
   if (dtype.is_float()) {
     os << "float";
-  } else if (dtype.is_e4m3_float8()) {
-    os << "e4m3_float";
-  } else if (dtype.is_e5m2_float8()) {
-    os << "e5m2_float";
+  } else if (dtype.is_float8_e4m3fn()) {
+    return "float8_e4m3fn";
+  } else if (dtype.is_float8_e5m2()) {
+    return "float8_e5m2";
+  } else if (dtype.is_float4_e2m1fn()) {
+    return "float4_e2m1fn";
   } else if (dtype.is_int()) {
     os << "int";
   } else if (dtype.is_uint()) {
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index c9a01fc24e..ba01f791d9 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -164,8 +164,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
     ab_type = CUDA_R_16F;
   } else if (TypeMatch(A->dtype, kDLInt, 8)) {
     ab_type = CUDA_R_8I;
-  } else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) {
-    ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8));
+  } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
+    ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8));
     ab_type = CUDA_R_8F_E4M3;
   }
 
@@ -217,7 +217,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
   int N = RowCount(A, transa, batch_offset_A);
   int K = ColumnCount(A, transa, batch_offset_A);
   bool use_batched_gemm = A->ndim > 2 || B->ndim > 2;
-
   // If A is batched but B is not, flatten all non-reduction axes of A to use 
the regular GEMM.
   // This trick is only applicable if batch axes and the other spatial axis (M 
or N) are
   // adjacent in both the input and the output matrix. In particular, if A is 
of shape (M, K)
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index d876065325..5a328413a1 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) {
       return;
     else if (dtype.bits == 4 && dtype.code == kDLInt)
       return;
-    else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn)
+    else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn)
       return;
     else
       ICHECK_EQ(dtype.bits % 8, 0);
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index a73c9cb5b4..a75a357810 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -752,13 +752,13 @@ 
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
 TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
 TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
 
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2);
 
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
diff --git a/src/support/scalars.h b/src/support/scalars.h
index b229a6b338..adc449ffd6 100644
--- a/src/support/scalars.h
+++ b/src/support/scalars.h
@@ -63,14 +63,14 @@ constexpr double kMaxBFloat16 = 
3.895313892515354759047080037148786688e38;
 
 // 2^8 * (1 + 6/8)
 // See https://arxiv.org/pdf/2209.05433.pdf
-constexpr double kMaxE4M3 = 448;
+constexpr double kMaxE4M3FN = 448;
 
 // 2^15 * (1 + 3/4)
 // See https://arxiv.org/pdf/2209.05433.pdf
 constexpr double kMaxE5M2 = 57344;
 
 // 2^2 * (1 + 1/2)
-constexpr double kMaxE2M1 = 6.0;
+constexpr double kMaxE2M1FN = 6.0;
 
 }  // namespace support
 }  // namespace tvm
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 9c2ce0bbb2..ead0bdff3c 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -579,9 +579,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
       default:
         LOG(FATAL) << "do not support " << dtype;
     }
-  } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == 
DataType::kE5M2Float) {
+  } else if (dtype.code() == DataType::kFloat8_e4m3fn || dtype.code() == 
DataType::kFloat8_e5m2) {
     etype = llvm::Type::getInt8Ty(*ctx);
-  } else if (dtype.code() == DataType::kFloat4E2M1Fn) {
+  } else if (dtype.code() == DataType::kFloat4_e2m1fn) {
     etype = llvm::Type::getIntNTy(*ctx, 4);
   }
   if (!dtype.is_scalar()) {
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 20b29750dc..35973776c8 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -60,9 +60,9 @@ std::string GetFP8Type(DataType type) {
   }
   stream << "__nv_fp8";
   std::string suffix;
-  if (type.code() == DataType::kE4M3Float) {
+  if (type.code() == DataType::kFloat8_e4m3fn) {
     suffix = "_e4m3";
-  } else if (type.code() == DataType::kE5M2Float) {
+  } else if (type.code() == DataType::kFloat8_e5m2) {
     suffix = "_e5m2";
   } else {
     LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
@@ -86,7 +86,7 @@ std::string GetFP4Type(DataType type) {
   }
   stream << "__nv_fp4";
   std::string suffix;
-  if (type.code() == DataType::kFloat4E2M1Fn) {
+  if (type.code() == DataType::kFloat4_e2m1fn) {
     suffix = "_e2m1";
   } else {
     LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
@@ -159,9 +159,7 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#endif\n\n";
 
     decl_stream << "#include <cuda.h>\n";
-    decl_stream << "#if (CUDA_VERSION <12080)\n";
     decl_stream << _cuda_half_util;
-    decl_stream << "#endif\n";
   }
 
   if (enable_bf16_) {
@@ -734,9 +732,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, 
std::ostream& os) {
   // Emit simple C-style type conversion.
   if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
 
-  if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == 
DataType::kE5M2Float ||
-      target_ty.code() == DataType::kFloat4E2M1Fn || from_ty.code() == 
DataType::kE4M3Float ||
-      from_ty.code() == DataType::kE5M2Float || from_ty.code() == 
DataType::kFloat4E2M1Fn) {
+  if (target_ty.code() == DataType::kFloat8_e4m3fn || target_ty.code() == 
DataType::kFloat8_e5m2 ||
+      target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() == 
DataType::kFloat8_e4m3fn ||
+      from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == 
DataType::kFloat4_e2m1fn) {
     std::ostringstream val;
     val << "(";
     PrintType(target_ty, val);
@@ -1508,7 +1506,7 @@ inline void PrintConst(const FloatImmNode* op, 
std::ostream& os, CodeGenCUDA* p)
     os << '(' << std::scientific << op->value << 'f' << ')';
     return;
   }
-  // Type code is kE5M2Float or kE4M4Float
+  // Type code is kFloat8_e5m2 or kE4M4Float
   if (op->dtype.is_float8() || op->dtype.is_float4()) {
     p->PrintType(op->dtype, os);
     os << '(' << std::scientific << op->value << 'f' << ')';
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 3dab634f16..63c82d1d6c 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -273,9 +273,9 @@ PrimExpr max_value(const DataType& dtype, Span span) {
     return FloatImm(dtype, std::numeric_limits<float>::max(), span);
   } else if (dtype.is_float8()) {
     // according to https://arxiv.org/pdf/2209.05433.pdf
-    if (dtype.code() == DataType::TypeCode::kE5M2Float) {
+    if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) {
       return FloatImm(dtype, 57344.0, span);
-    } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
+    } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) {
       return FloatImm(dtype, 448.0, span);
     }
   } else if (dtype.is_float4()) {
@@ -316,9 +316,9 @@ PrimExpr min_value(const DataType& dtype, Span span) {
     return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
   } else if (dtype.is_float8()) {
     // according to https://arxiv.org/pdf/2209.05433.pdf
-    if (dtype.code() == DataType::TypeCode::kE5M2Float) {
+    if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) {
       return FloatImm(dtype, -57344.0, span);
-    } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
+    } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) {
       return FloatImm(dtype, -448.0, span);
     }
   } else if (dtype.is_float4()) {
diff --git a/src/tir/transforms/dtype_conversion.h 
b/src/tir/transforms/dtype_conversion.h
index 8edbf1bc1e..a0ed6b5f6d 100644
--- a/src/tir/transforms/dtype_conversion.h
+++ b/src/tir/transforms/dtype_conversion.h
@@ -121,7 +121,7 @@ class FloatConfig {
       // NVIDIA/Arm/Intel's FP8 formats for Deep Learning
       // Reference: https://arxiv.org/abs/2209.05433
       switch (dtype.code()) {
-        case DataType::kE4M3Float:
+        case DataType::kFloat8_e4m3fn:
           // E4M3 format, not consistent with IEEE-754
           return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes);
         default:
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py 
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index d94153003c..b91efd6192 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -16,25 +16,24 @@
 # under the License.
 
 import sys
+from typing import List, Tuple
+
+import numpy as np
 import pytest
 
 import tvm
-from tvm.script import tir as T
-import numpy as np
 import tvm.testing
-
-
-from typing import List, Tuple
 from tvm import DataType, DataTypeCode, IRModule
 from tvm import dlight as dl
 from tvm import relax, te, tir, topi
 from tvm.relax.frontend import nn
 from tvm.runtime import NDArray
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 from tvm.target import Target
 from tvm.topi.utils import get_const_tuple
 
-from tvm.script import ir as I, relax as R, tir as T
-
 try:
     import ml_dtypes
 except ImportError:
@@ -43,7 +42,7 @@ except ImportError:
 
 @tvm.testing.requires_cuda_compute_version(8, 9)
 def test_e4m3_conversions():
-    dtype = "e4m3_float8"
+    dtype = "float8_e4m3fn"
 
     @T.prim_func
     def add(
@@ -90,7 +89,7 @@ def test_e4m3_conversions():
 def test_e4m3_packing():
     length = 64
     vector_length = 4
-    native_dtype, packed_dtype = ("e4m3_float8x4", "uint32")
+    native_dtype, packed_dtype = ("float8_e4m3fnx4", "uint32")
 
     @T.prim_func
     def add(
@@ -141,13 +140,13 @@ def test_e4m3_packing():
 
 
 native_dtype, promoted_dtype = tvm.testing.parameters(
-    ("e4m3_float8", "float32"),
-    ("e4m3_float8", "float16"),
-    ("e4m3_float8x2", "float32x2"),
-    ("e4m3_float8x2", "float16x2"),
-    ("e4m3_float8x4", "float32x4"),
+    ("float8_e4m3fn", "float32"),
+    ("float8_e4m3fn", "float16"),
+    ("float8_e4m3fnx2", "float32x2"),
+    ("float8_e4m3fnx2", "float16x2"),
+    ("float8_e4m3fnx4", "float32x4"),
     # Supported via half4 vector type extension in codegen
-    ("e4m3_float8x4", "float16x4"),
+    ("float8_e4m3fnx4", "float16x4"),
 )
 
 
@@ -343,7 +342,7 @@ class BaseFP8E4M3QuantScaleOnly:
         axis,
         output_transpose,
     ) -> IRModule:
-        if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
+        if DataType(quantize_dtype).type_code == DataTypeCode.Float8E4M3FN:
             quantize_func = cls.quantize_fp8x4_e4m3
         else:
             assert NotImplementedError()
@@ -387,7 +386,7 @@ class BaseFP8E4M3QuantScaleOnly:
         num_elem_per_storage,
         axis,
     ) -> IRModule:
-        if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
+        if DataType(quantize_dtype).type_code == DataTypeCode.Float8E4M3FN:
             dequantize_func = cls.dequantize_fp8x4_e4m3
         else:
             assert NotImplementedError()
@@ -732,7 +731,7 @@ class 
TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly):
 
     @tvm.testing.fixture
     def quantize_dtype(self):
-        return "e4m3_float8"
+        return "float8_e4m3fn"
 
     @tvm.testing.fixture
     def num_el_per_storage(self):
@@ -807,7 +806,7 @@ class 
TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly):
 
 
 @tvm.testing.requires_cuda_compute_version(8, 9)
[email protected]("dtype", ["e5m2_float8", "e4m3_float8"])
[email protected]("dtype", ["float8_e5m2", "float8_e4m3fn"])
 def test_const(dtype):
     @T.prim_func
     def func(A: T.Buffer((4,), dtype)) -> None:
@@ -822,7 +821,7 @@ def test_const(dtype):
 
 
 @tvm.testing.requires_cuda_compute_version(8, 9)
[email protected]("dtype", ["e5m2_float8", "e4m3_float8"])
[email protected]("dtype", ["float8_e5m2", "float8_e4m3fn"])
 @pytest.mark.parametrize("vec_len", [2, 4, 8, 16])
 def test_copy(dtype, vec_len):
     @T.prim_func
@@ -867,7 +866,7 @@ def test_moe_gemv_shfl_down_illegal_instr():
         @T.prim_func(private=True)
         def moe_dequantize_gemv(
             x_handle: T.handle,
-            w: T.Buffer((num_experts, spatial_size, reduce_size), 
"e4m3_float8"),
+            w: T.Buffer((num_experts, spatial_size, reduce_size), 
"float8_e4m3fn"),
             scale: T.Buffer((1,), "float16"),
             indptr: T.Buffer((1, 2), "int32"),
             o: T.Buffer((2, spatial_size), "float16"),
@@ -905,7 +904,7 @@ def test_moe_gemv_shfl_down_illegal_instr():
         def main(
             x: R.Tensor(("num_seq", reduce_size), dtype="float16"),
             indptr: R.Tensor((1, 2), dtype="int32"),
-            weight: R.Tensor((num_experts, spatial_size, reduce_size), 
dtype="e4m3_float8"),
+            weight: R.Tensor((num_experts, spatial_size, reduce_size), 
dtype="float8_e4m3fn"),
             scale: R.Tensor((1,), dtype="float32"),
         ) -> R.Tensor((2, spatial_size), dtype="float16"):
             num_seq = T.int64()
@@ -965,4 +964,5 @@ def test_moe_gemv_shfl_down_illegal_instr():
 
 
 if __name__ == "__main__":
+    # test_half_broadcast(6)
     tvm.testing.main()
diff --git a/tests/python/ir/test_datatype_nv_fp8.py 
b/tests/python/ir/test_datatype_nv_fp8.py
index 8313a97ee1..b812c70ab6 100644
--- a/tests/python/ir/test_datatype_nv_fp8.py
+++ b/tests/python/ir/test_datatype_nv_fp8.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+
 import tvm
 import tvm.testing
 import tvm.tir as tir
@@ -22,9 +23,10 @@ from tvm import te
 from tvm.script import tir as T
 
 try:
-    from ml_dtypes import float8_e4m3fn as e4m3_float8, float8_e5m2 as 
e5m2_float8
+    from ml_dtypes import float8_e4m3fn as float8_e4m3fn
+    from ml_dtypes import float8_e5m2 as float8_e5m2
 except ImportError:
-    e4m3_float8, e5m2_float8 = None, None
+    float8_e4m3fn, float8_e5m2 = None, None
 
 
 def fp8_unary(dtype: str):
@@ -58,7 +60,7 @@ def fp8_unary(dtype: str):
 
 
 np_dtype, dtype_str = tvm.testing.parameters(
-    (e4m3_float8, "e4m3_float8"), (e5m2_float8, "e5m2_float8")
+    (float8_e4m3fn, "float8_e4m3fn"), (float8_e5m2, "float8_e5m2")
 )
 
 
diff --git a/tests/python/ir/test_dtype.py b/tests/python/ir/test_dtype.py
index 77cd1d7e4b..988e360748 100644
--- a/tests/python/ir/test_dtype.py
+++ b/tests/python/ir/test_dtype.py
@@ -15,22 +15,23 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test data type related API"""
+import pytest
+
 import tvm
-from tvm import DataType
 import tvm.testing
-import pytest
+from tvm import DataType
 
 
 @pytest.mark.parametrize(
     "dtype_str, expected_size",
-    [("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)],
+    [("float32", 4), ("float32x4", 16), ("float8_e5m2x4", 4), ("uint8", 1)],
 )
 def test_dtype_itemsize(dtype_str, expected_size):
     dtype = DataType(dtype_str)
     assert dtype.itemsize() == expected_size
 
 
[email protected]("dtype_str", [("int32xvscalex4")])
[email protected]("dtype_str", ["int32xvscalex4"])
 def test_dtype_itemmize_error(dtype_str):
     with pytest.raises(ValueError):
         size = DataType(dtype_str).itemsize()
diff --git a/tests/python/relax/test_codegen_cublas.py 
b/tests/python/relax/test_codegen_cublas.py
index 2fbff8433b..c5514e2727 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -315,7 +315,7 @@ def test_matmul_fp8_offload(
     transpose_y,
     out_dtype,
 ):
-    in_dtype = "e4m3_float8"
+    in_dtype = "float8_e4m3fn"
     mod = get_relax_matmul_module(
         x_shape,
         y_shape,
@@ -342,7 +342,7 @@ def test_matmul_fp8_offload(
 def test_matmul_fp8_dequantize_offload():
     x_shape = (10, 32)
     y_shape = (64, 32)
-    in_dtype = "e4m3_float8"
+    in_dtype = "float8_e4m3fn"
     mod = get_relax_matmul_dequantize_module(
         x_shape,
         y_shape,
@@ -369,7 +369,7 @@ def test_matmul_fp8_multiply_offload():
     x_shape = (10, 32)
     y_shape = (64, 32)
     z_shape = (1,)
-    in_dtype, acc_dtype = ("e4m3_float8", "float32")
+    in_dtype, acc_dtype = ("float8_e4m3fn", "float32")
 
     mod = get_relax_matmul_multiply_module(
         x_shape,
@@ -397,8 +397,8 @@ def test_matmul_fp8_multiply_offload():
     "M, N, K, out_dtype, transposed_y, partition_done",
     [
         (15, 64, 32, "float32", True, True),
-        (15, 64, 32, "e4m3_float8", True, True),
-        (15, 64, 32, "e5m2_float8", True, False),
+        (15, 64, 32, "float8_e4m3fn", True, True),
+        (15, 64, 32, "float8_e5m2", True, False),
         (16, 32, 60, "float32", True, False),
         (16, 30, 64, "float32", True, False),
         (16, 8, 16, "float16", True, True),
@@ -407,7 +407,7 @@ def test_matmul_fp8_multiply_offload():
 )
 def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y, 
partition_done):
     mod = get_relax_matmul_module(
-        (M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=transposed_y
+        (M, K), (N, K), "float8_e4m3fn", out_dtype, transposed_y=transposed_y
     )
     mod = partition_for_cublas(mod)
     func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
@@ -426,7 +426,7 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, 
scale, zp, num_bindings
     mod = get_relax_matmul_dequantize_module(
         (M, K),
         (N, K),
-        "e4m3_float8",
+        "float8_e4m3fn",
         "float16",
         transposed_y=True,
         scale_const=scale,
@@ -443,7 +443,7 @@ def test_cublas_partition_fp8_matmul_multiply():
         (M, K),
         (N, K),
         (1,),
-        "e4m3_float8",
+        "float8_e4m3fn",
         "float32",
         "float16",
         transposed_y=True,
diff --git a/tests/python/relax/test_op_inspect.py 
b/tests/python/relax/test_op_inspect.py
index 18d7a88f05..ca4b0fc440 100644
--- a/tests/python/relax/test_op_inspect.py
+++ b/tests/python/relax/test_op_inspect.py
@@ -21,10 +21,10 @@ import numpy as np
 import pytest
 
 import tvm.testing
-
 from tvm import relax
 from tvm.ir import Op
-from tvm.script import ir as I, relax as R
+from tvm.script import ir as I
+from tvm.script import relax as R
 
 # Parameterization for reading dtype of DLTensor.  Chosen to have
 # multiple distinct type codes, number of lanes, and widths.
@@ -34,7 +34,7 @@ dtype = tvm.testing.parameter(
     "float32",
     "float32x4",
     "bfloat",
-    "e4m3_float8",
+    "float8_e4m3fn",
 )
 shape = tvm.testing.parameter(
     [],
diff --git a/tests/python/relax/test_op_qdq.py 
b/tests/python/relax/test_op_qdq.py
index 8b2d499041..d773a6c7d2 100644
--- a/tests/python/relax/test_op_qdq.py
+++ b/tests/python/relax/test_op_qdq.py
@@ -68,17 +68,17 @@ def test_qdq_op_infer_struct_info_symbolic():
     )
 
 
-def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
+def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic():
     bb = relax.BlockBuilder()
     n = tir.Var("n", "int64")
     x = relax.Var("x", R.Tensor((n, 3), "float32"))
-    dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8"))
+    dx = relax.Var("dx", R.Tensor((n, 3), "float8_e4m3fn"))
     s = relax.Var("s", R.Tensor([3], "float32"))
     zp = relax.Var("zp", R.Tensor([3], "float16"))
     _check_inference(
         bb,
-        relax.op.quantize(x, s, zp, 1, "e4m3_float8"),
-        relax.TensorStructInfo((n, 3), "e4m3_float8"),
+        relax.op.quantize(x, s, zp, 1, "float8_e4m3fn"),
+        relax.TensorStructInfo((n, 3), "float8_e4m3fn"),
     )
     _check_inference(
         bb,
@@ -87,8 +87,8 @@ def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
     )
 
 
-def test_qdq_e5m2_float8_op_infer_struct_info_symbolic():
-    dtype = "e5m2_float8"
+def test_qdq_float8_e5m2_op_infer_struct_info_symbolic():
+    dtype = "float8_e5m2"
     bb = relax.BlockBuilder()
     n = tir.Var("n", "int64")
     x = relax.Var("x", R.Tensor((n, 3), "float32"))
diff --git 
a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py 
b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
index fe9998bc79..5a80e3e4f6 100644
--- 
a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
+++ 
b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
@@ -17,21 +17,24 @@
 # pylint: disable=missing-docstring
 import numpy as np
 import pytest
+
 import tvm
 import tvm.testing
 from tvm import te
 from tvm.testing.tir import mma_schedule
 from tvm.tir.tensor_intrin.cuda import (
+    LDMATRIX_e4m3_A_INTRIN,
+    LDMATRIX_e4m3_B_TRANS_INTRIN,
+    LDMATRIX_e5m2_A_INTRIN,
+    LDMATRIX_e5m2_B_TRANS_INTRIN,
     LDMATRIX_f16_A_INTRIN,
     LDMATRIX_f16_B_INTRIN,
     LDMATRIX_f16_B_TRANS_INTRIN,
     LDMATRIX_i8_A_INTRIN,
-    LDMATRIX_i8_B_TRANS_INTRIN,
     LDMATRIX_i8_B_INTRIN,
-    LDMATRIX_e4m3_A_INTRIN,
-    LDMATRIX_e4m3_B_TRANS_INTRIN,
-    LDMATRIX_e5m2_A_INTRIN,
-    LDMATRIX_e5m2_B_TRANS_INTRIN,
+    LDMATRIX_i8_B_TRANS_INTRIN,
+    MMA_e4m3e4m3f32_TRANS_B_INTRIN,
+    MMA_e5m2e5m2f32_TRANS_B_INTRIN,
     MMA_f16f16f16_INTRIN,
     MMA_f16f16f16_TRANS_B_INTRIN,
     MMA_f16f16f32_INTRIN,
@@ -41,8 +44,6 @@ from tvm.tir.tensor_intrin.cuda import (
     MMA_fill_16x16_i32_INTRIN,
     MMA_i8i8i32_INTRIN,
     MMA_i8i8i32_TRANS_B_INTRIN,
-    MMA_e5m2e5m2f32_TRANS_B_INTRIN,
-    MMA_e4m3e4m3f32_TRANS_B_INTRIN,
     MMA_store_16x16_f16_global_INTRIN,
     MMA_store_16x16_f32_global_INTRIN,
     MMA_store_16x16_i32_global_INTRIN,
@@ -132,10 +133,10 @@ def run_test(
         else:
             b_np = np.random.normal(size=(K, N)).astype("float16")
             c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype(out_dtype)
-    elif in_dtype in ["e4m3_float8", "e5m2_float8"]:
+    elif in_dtype in ["float8_e4m3fn", "float8_e5m2"]:
         typemap = {
-            "e4m3_float8": "float8_e4m3fn",
-            "e5m2_float8": "float8_e5m2",
+            "float8_e4m3fn": "float8_e4m3fn",
+            "float8_e5m2": "float8_e5m2",
         }
         a_np = (
             np.random.uniform(low=-5, high=5, size=(M * K))
@@ -174,7 +175,7 @@ def run_test(
 
     f(a, b, c)
 
-    if out_dtype != "float16" and in_dtype not in ["e4m3_float8", 
"e5m2_float8"]:
+    if out_dtype != "float16" and in_dtype not in ["float8_e4m3fn", 
"float8_e5m2"]:
         # The numpy reference is computed with fp32 precision (otherwise too 
slow).
         # So there is non-trivial accuracy difference if TVM result is 
computed with fp16 accumulation.
         tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
@@ -384,7 +385,7 @@ def test_e4m3e4m3f32_m16n16k32():
         )
 
     k_inner = 32
-    in_dtype = "e4m3_float8"
+    in_dtype = "float8_e4m3fn"
     out_dtype = "float32"
     i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 
2, 2]
 
@@ -427,7 +428,7 @@ def test_e5m2e5m2f32_m16n16k32():
         )
 
     k_inner = 32
-    in_dtype = "e5m2_float8"
+    in_dtype = "float8_e5m2"
     out_dtype = "float32"
     i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 
2, 2]
 
diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py 
b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
index e1f487c572..0b10fe5c21 100644
--- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
@@ -17,8 +17,8 @@
 import tvm
 import tvm.script
 import tvm.testing
-from tvm.target import Target
 from tvm.script import tir as T
+from tvm.target import Target
 from tvm.tir.transform.transform import BindTarget
 
 # pylint: disable=no-member,invalid-name,unused-variable
@@ -69,7 +69,7 @@ def get_after_compute_legalize(dtype: str, promote_dtype: 
str):
 
 
 def promote_uint8(f8_dtype: str, promote_dtype: str, v):
-    if f8_dtype == "e4m3_float8":
+    if f8_dtype == "float8_e4m3fn":
         if promote_dtype == "float16":
             mantissa = T.bitwise_and(
                 T.shift_left(T.Cast("uint16", v), T.uint16(7)), T.uint16(0x3FF)
@@ -96,7 +96,7 @@ def promote_uint8(f8_dtype: str, promote_dtype: str, v):
             )
             sign = T.shift_left(T.Cast("uint32", T.shift_right(v, 
T.uint8(7))), T.uint32(31))
             return T.reinterpret("float32", 
T.bitwise_or(T.bitwise_or(mantissa, exponent), sign))
-    else:  # f8_dtype == "e5m2_float8"
+    else:  # f8_dtype == "float8_e5m2"
         if promote_dtype == "float16":
             return T.reinterpret("float16", T.shift_left(T.Cast("uint16", v), 
T.uint16(8)))
         else:  # promote_dtype == "float32"
@@ -115,7 +115,7 @@ def promote_uint8(f8_dtype: str, promote_dtype: str, v):
 
 
 def cast_to_uint8(f8_dtype: str, promote_dtype: str, v):
-    if f8_dtype == "e4m3_float8":
+    if f8_dtype == "float8_e4m3fn":
         if promote_dtype == "float16":
             uint16_v = T.reinterpret("uint16", v)
             rounding_bias = T.bitwise_and(
@@ -154,7 +154,7 @@ def cast_to_uint8(f8_dtype: str, promote_dtype: str, v):
             return T.if_then_else(
                 round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, 
exponent), sign)
             )
-    else:  # f8_dtype == "e5m2_float8"
+    else:  # f8_dtype == "float8_e5m2"
         if promote_dtype == "float16":
             uint16_v = T.reinterpret("uint16", v)
             rounding_bias = T.bitwise_and(
@@ -201,12 +201,12 @@ def get_after_storage_legalize(dtype: str, promote_dtype: 
str):
     return After
 
 
-dtype = tvm.testing.parameter("e4m3_float8", "e5m2_float8")
+dtype = tvm.testing.parameter("float8_e4m3fn", "float8_e5m2")
 promote_dtype = tvm.testing.parameter("float16", "float32")
 
 
 def test_fp8_compute_legalize(dtype, promote_dtype):
-    target = Target("cuda")
+    target = Target("nvidia/nvidia-a100")
     before = BindTarget(target)(get_before(dtype))
     expected = BindTarget(target)(get_after_compute_legalize(dtype, 
promote_dtype))
     # run the transform twice to ensure we can afford to deal
@@ -217,7 +217,7 @@ def test_fp8_compute_legalize(dtype, promote_dtype):
 
 
 def test_fp8_storage_legalize(dtype, promote_dtype):
-    target = Target("cuda")
+    target = Target("nvidia/nvidia-a100")
     before = BindTarget(target)(get_after_compute_legalize(dtype, 
promote_dtype))
     after = tvm.tir.transform.FP8StorageLegalize()(before)
     expected = BindTarget(target)(get_after_storage_legalize(dtype, 
promote_dtype))
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index b7ba57fa93..943ba54060 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -17,6 +17,7 @@
 # pylint: disable=missing-docstring
 
 import re
+
 import pytest
 
 import tvm.testing
@@ -917,23 +918,23 @@ def func():
     _assert_print(func, expected_output)
 
 
[email protected]("dtype", ["e4m3_float8", "e5m2_float8"])
[email protected]("dtype", ["float8_e4m3fn", "float8_e5m2"])
 def test_float8(dtype):
     from tvm.script import tir as T
 
     def get_func(dtype):
-        if dtype == "e4m3_float8":
+        if dtype == "float8_e4m3fn":
 
             @T.prim_func
             def func():
-                T.evaluate(T.e4m3_float8(0.0))
+                T.evaluate(T.float8_e4m3fn(0.0))
 
             return func
-        elif dtype == "e5m2_float8":
+        elif dtype == "float8_e5m2":
 
             @T.prim_func
             def func():
-                T.evaluate(T.e5m2_float8(0.0))
+                T.evaluate(T.float8_e5m2(0.0))
 
             return func
 

Reply via email to