This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b13be936a0 [DataType] Initial support of fp8 (e4m3/e5m2) (#14863)
b13be936a0 is described below

commit b13be936a056c8c58f96dc63ad1fcfaf51f02768
Author: Zihao Ye <[email protected]>
AuthorDate: Thu Jun 1 06:47:06 2023 -0700

    [DataType] Initial support of fp8 (e4m3/e5m2) (#14863)
    
    Recently NVIDIA announced official support of the fp8 data type: e4m3 and 
e5m2, the first one has 4 bits for exponent and 3 bits for mantissa while the 
second one has 5 bits for exponent and 2 bits for mantissa, and NVIDIA 
encourages using e4m3 for forward and e5m2 (larger dynamic range) for backward. 
Currently, TVM has no support for these data types, as the first step to 
support fp8, this PR adds new type codes for `e4m3_float8` and `e5m2_float8`, 
and implement legalization passes ` [...]
---
 docker/install/ubuntu_install_python_package.sh    |   3 +-
 include/tvm/runtime/data_type.h                    |  34 +++
 include/tvm/tir/op.h                               |   3 +-
 include/tvm/tir/transform.h                        |  14 ++
 python/gen_requirements.py                         |   1 +
 python/tvm/_ffi/runtime_ctypes.py                  |  22 ++
 python/tvm/contrib/nvcc.py                         |  17 ++
 python/tvm/runtime/ndarray.py                      |  19 ++
 python/tvm/tir/transform/transform.py              |  28 +++
 src/driver/driver_api.cc                           |   2 +
 src/ir/expr.cc                                     |  14 +-
 src/support/scalars.h                              |  12 +
 src/target/source/codegen_cuda.cc                  |  17 ++
 src/target/source/codegen_cuda.h                   |   5 +-
 src/tir/op/op.cc                                   |   7 +
 src/tir/transforms/dtype_conversion.cc             | 101 ++++++++
 src/tir/transforms/dtype_conversion.h              | 165 +++++++++++++
 ...6_legalize.cc => unsupported_dtype_legalize.cc} | 264 +++++++++++++--------
 tests/python/unittest/test_datatype_nv_fp8.py      | 104 ++++++++
 .../unittest/test_tir_transform_bf16_legalize.py   |   5 +-
 .../unittest/test_tir_transform_fp8_legalize.py    | 224 +++++++++++++++++
 21 files changed, 959 insertions(+), 102 deletions(-)

diff --git a/docker/install/ubuntu_install_python_package.sh 
b/docker/install/ubuntu_install_python_package.sh
index 93abac52be..41c8697f42 100755
--- a/docker/install/ubuntu_install_python_package.sh
+++ b/docker/install/ubuntu_install_python_package.sh
@@ -44,4 +44,5 @@ pip3 install --upgrade \
     junitparser==2.4.2 \
     six \
     tornado \
-    pytest-lazy-fixture
+    pytest-lazy-fixture \
+    ml_dtypes
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index f52e95c756..9fb113f56b 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -32,6 +32,7 @@
 
 namespace tvm {
 namespace runtime {
+
 /*!
  * \brief Runtime primitive data type.
  *
@@ -54,6 +55,8 @@ class DataType {
     kFloat = kDLFloat,
     kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
     kBFloat = kDLBfloat,
+    kE4M3Float = 6U,
+    kE5M2Float = 7U,
     kCustomBegin = 129
   };
   /*! \brief default constructor */
@@ -76,6 +79,9 @@ class DataType {
     if (code == kBFloat) {
       ICHECK_EQ(bits, 16);
     }
+    if (code == kE4M3Float || code == kE5M2Float) {
+      ICHECK_EQ(bits, 8);
+    }
   }
   /*! \return The type code. */
   int code() const { return static_cast<int>(data_.code); }
@@ -91,6 +97,12 @@ class DataType {
   bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
   /*! \return whether type is a float type. */
   bool is_float() const { return code() == DataType::kFloat; }
+  /*! \return whether type is a float8 type. */
+  bool is_float8() const {
+    return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
+            code() == DataType::kE5M2Float) &&
+           bits() == 8;
+  }
   /*! \return whether type is a float16 type. */
   bool is_float16() const { return is_float() && bits() == 16; }
   /*! \return whether type is a bfloat16 type. */
@@ -183,6 +195,18 @@ class DataType {
    * \return The constructed data type.
    */
   static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, 
bits, lanes); }
+  /*!
+   * \brief Construct NV float8 e4m3 datatype.
+   * \param lanes The number of lanes
+   * \return The constructed data type.
+   */
+  static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, 
lanes); }
+  /*!
+   * \brief Construct NV float8 e5m2 datatype.
+   * \param lanes The number of lanes
+   * \return The constructed data type.
+   */
+  static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, 
lanes); }
   /*!
    * \brief Construct a bool type.
    * \param lanes The number of lanes
@@ -308,6 +332,10 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode 
type_code) {
       return "handle";
     case kDLBfloat:
       return "bfloat";
+    case DataType::kE4M3Float:
+      return "e4m3_float";
+    case DataType::kE5M2Float:
+      return "e5m2_float";
     default:
       LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
   }
@@ -376,6 +404,12 @@ inline DLDataType String2DLDataType(std::string s) {
   } else if (s.substr(0, 6) == "bfloat") {
     t.code = DataType::kBFloat;
     scan = s.c_str() + 6;
+  } else if (s.substr(0, 10) == "e4m3_float") {
+    t.code = DataType::kE4M3Float;
+    scan = s.c_str() + 10;
+  } else if (s.substr(0, 10) == "e5m2_float") {
+    t.code = DataType::kE5M2Float;
+    scan = s.c_str() + 10;
   } else if (s.substr(0, 6) == "custom") {
     t.code = ParseCustomDatatype(s, &scan);
   } else {
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 0198feb3cd..3d5e589ab4 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -939,7 +939,8 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType 
value, Span span = Span())
       return LargeUIntImm(t, static_cast<int64_t>(low), 
static_cast<int64_t>(high), span);
     }
   }
-  if (t.is_float() || t.is_bfloat16()) return FloatImm(t, 
static_cast<double>(value), span);
+  if (t.is_float() || t.is_bfloat16() || t.is_float8())
+    return FloatImm(t, static_cast<double>(value), span);
   // For now, we store const scalar values of custom datatypes within doubles; 
later, during the
   // datatypes lowering pass, we will lower the value to its true 
representation in the format
   // specified by the datatype.
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index d9d68e0a8b..85f2feaa2c 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -394,12 +394,26 @@ TVM_DLL Pass NarrowDataType(int target_bits);
  */
 TVM_DLL Pass BF16ComputeLegalize();
 
+/*!
+ * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
+ *   before Ops, then add a cast back to fp8.
+ * \param promote_dtype_str The data type used for type promotion, defaults to 
float16
+ * \return The pass.
+ */
+TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
+
 /*!
  * \brief Legalize bf16 storage types to u16.
  * \return The pass.
  */
 TVM_DLL Pass BF16StorageLegalize();
 
+/*!
+ * \brief Legalize fp8 storage types to u8.
+ * \return The pass.
+ */
+TVM_DLL Pass FP8StorageLegalize();
+
 /*!
  * \brief Rewrite the pointer content type of arguments,
  *  as well as Alloc internal to the function to use
diff --git a/python/gen_requirements.py b/python/gen_requirements.py
index 1a55dccd11..1cb1ce109a 100644
--- a/python/gen_requirements.py
+++ b/python/gen_requirements.py
@@ -67,6 +67,7 @@ REQUIREMENTS_BY_PIECE: RequirementsByPieceType = [
                 "attrs",
                 "cloudpickle",
                 "decorator",
+                "ml_dtypes",
                 "numpy",
                 "psutil",
                 "scipy",
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 999f69bc34..adcc3a8e97 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -19,6 +19,11 @@
 import ctypes
 import json
 import numpy as np
+
+try:
+    import ml_dtypes
+except ImportError:
+    ml_dtypes = None
 from .base import _LIB, check_call
 
 tvm_shape_index_t = ctypes.c_int64
@@ -59,6 +64,8 @@ class DataTypeCode(object):
     FLOAT = 2
     HANDLE = 3
     BFLOAT = 4
+    E4M3Float = 6
+    E5M2Float = 7
 
 
 class DataType(ctypes.Structure):
@@ -71,6 +78,8 @@ class DataType(ctypes.Structure):
         DataTypeCode.FLOAT: "float",
         DataTypeCode.HANDLE: "handle",
         DataTypeCode.BFLOAT: "bfloat",
+        DataTypeCode.E4M3Float: "e4m3_float",
+        DataTypeCode.E5M2Float: "e5m2_float",
     }
     NUMPY2STR = {
         np.dtype(np.bool_): "bool",
@@ -97,6 +106,8 @@ class DataType(ctypes.Structure):
         "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
         "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
         "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
+        "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, 
"lanes": 1},
+        "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, 
"lanes": 1},
         "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
         "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
         "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -141,6 +152,12 @@ class DataType(ctypes.Structure):
         elif head.startswith("bfloat"):
             self.type_code = DataTypeCode.BFLOAT
             head = head[6:]
+        elif head.startswith("e4m3_float"):
+            self.type_code = DataTypeCode.E4M3Float
+            head = head[10:]
+        elif head.startswith("e5m2_float"):
+            self.type_code = DataTypeCode.E5M2Float
+            head = head[10:]
         elif head.startswith("custom"):
             # pylint: disable=import-outside-toplevel
             import tvm.runtime._ffi_api
@@ -182,6 +199,11 @@ class DataType(ctypes.Structure):
         return not self.__eq__(other)
 
 
+if ml_dtypes is not None:
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
+    DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
+
 RPC_SESS_MASK = 128
 
 
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 643ad96c02..5eb3480099 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -404,3 +404,20 @@ def have_bf16(compute_version):
         return True
 
     return False
+
+
+def have_fp8(compute_version):
+    """Whether fp8 support is provided in the specified compute capability or 
not
+
+    Parameters
+    ----------
+    compute_version : str
+        GPU capability
+    """
+    major, minor = parse_compute_version(compute_version)
+    # fp8 is suppored in Ada Lovelace (8.9) or later architectures.
+    if major == 8 and minor == 9:
+        return True
+    if major >= 9:
+        return True
+    return False
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 7e08f59644..a78c68ee67 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -19,6 +19,11 @@
 import ctypes
 import warnings
 import numpy as np
+
+try:
+    import ml_dtypes
+except ImportError:
+    ml_dtypes = None
 import tvm._ffi
 
 from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
@@ -217,6 +222,20 @@ class NDArray(NDArrayBase):
             dtype = "int8"
         if dtype == "bfloat16":
             dtype = "uint16"
+        if dtype == "e4m3_float8":
+            if ml_dtypes is not None:
+                dtype = ml_dtypes.float8_e4m3fn
+            else:
+                raise RuntimeError(
+                    "ml_dtypes is not installed, cannot convert e4m3_float8 
array to numpy."
+                )
+        if dtype == "e5m2_float8":
+            if ml_dtypes is not None:
+                dtype = ml_dtypes.float8_e5m2
+            else:
+                raise RuntimeError(
+                    "ml_dtypes is not installed, cannot convert e5m2_float8 
array to numpy."
+                )
         np_arr = np.empty(shape, dtype=dtype)
         assert np_arr.flags["C_CONTIGUOUS"]
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 9e038f618b..ffaeb85f74 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -40,6 +40,7 @@ def Apply(ftransform):
     fpass : tvm.transform.Pass
         The result pass
     """
+
     # pylint: disable=unused-argument
     def _transform(func, mod, ctx):
         return ftransform(func)
@@ -297,6 +298,22 @@ def BF16ComputeLegalize():
     return _ffi_api.BF16ComputeLegalize()  # type: ignore
 
 
+def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
+    """Legalize fp8 compute Ops.
+
+    Parameters
+    ----------
+    promote_dtype : str
+        The data type we promote fp8 to, options: float16/float32.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.FP8ComputeLegalize(promote_dtype_str)  # type: ignore
+
+
 def BF16StorageLegalize():
     """Legalize bf16 storage types to u16.
 
@@ -308,6 +325,17 @@ def BF16StorageLegalize():
     return _ffi_api.BF16StorageLegalize()  # type: ignore
 
 
+def FP8StorageLegalize():
+    """Legalize fp8 storage types to u8.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.FP8StorageLegalize()  # type: ignore
+
+
 def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: 
bool = False):
     """Replace redundant computations by new variables.
 
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e5f71c3832..cfc7fa80c7 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -210,6 +210,7 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   pass_list.push_back(tir::transform::InjectSoftwarePipeline());
   pass_list.push_back(tir::transform::LowerOpaqueBlock());
   pass_list.push_back(tir::transform::FlattenBuffer());
+  pass_list.push_back(tir::transform::FP8ComputeLegalize());
   pass_list.push_back(tir::transform::BF16ComputeLegalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
@@ -586,6 +587,7 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
   } else {
     mixed_pass_list.push_back(tir::transform::MakePackedAPI());
   }
+  mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
   mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
 
   mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 0e09568f15..fdd8c2cd8b 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -104,7 +104,8 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
 FloatImm::FloatImm(DataType dtype, double value, Span span) {
   ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
 
-  ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.code() >= 
DataType::kCustomBegin)
+  ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
+         dtype.code() >= DataType::kCustomBegin)
       << "ValueError: FloatImm supports only float, but " << dtype << " was 
supplied.";
 
   // check range for float32 and float16 since they have specified range.
@@ -119,6 +120,17 @@ FloatImm::FloatImm(DataType dtype, double value, Span 
span) {
           << "ValueError: Literal value " << value << " exceeds minimum of " 
<< dtype;
       ICHECK_LE(value, support::kMaxFloat16)
           << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
+    } else if (dtype.is_bfloat16()) {
+      ICHECK_GE(value, -support::kMaxBFloat16)
+          << "ValueError: Literal value " << value << " exceeds minimum of " 
<< dtype;
+      ICHECK_LE(value, support::kMaxBFloat16)
+          << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
+    } else if (dtype.is_float8()) {
+      double bound = (dtype.code() == DataType::kE4M3Float) ? 
support::kMaxE4M3 : support::kMaxE5M2;
+      ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " 
exceeds minimum of "
+                               << dtype;
+      ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " 
exceeds maximum of "
+                              << dtype;
     }
   }
   ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
diff --git a/src/support/scalars.h b/src/support/scalars.h
index 2fdbb001d9..2b34914565 100644
--- a/src/support/scalars.h
+++ b/src/support/scalars.h
@@ -65,6 +65,18 @@ FloatImm ValueToFloatImm(double value, int width);
 // See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
 constexpr double kMaxFloat16 = 65504.0;
 
+// 2^127 * (1 + 127/128)
+// See https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
+constexpr double kMaxBFloat16 = 3.895313892515354759047080037148786688e38;
+
+// 2^8 * (1 + 6/8)
+// See https://arxiv.org/pdf/2209.05433.pdf
+constexpr double kMaxE4M3 = 448;
+
+// 2^15 * (1 + 3/4)
+// See https://arxiv.org/pdf/2209.05433.pdf
+constexpr double kMaxE5M2 = 57344;
+
 }  // namespace support
 }  // namespace tvm
 
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index aaf6660172..ec8695a2a0 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -117,6 +117,12 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << _cuda_bfloat16_util;
   }
 
+  if (enable_fp8_) {
+    decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
+    decl_stream << "#include <cuda_fp8.h>\n";
+    decl_stream << "#endif\n\n";
+  }
+
   if (enable_warp_shuffle_) {
     decl_stream << _cuda_warp_intrinsic_util;
   }
@@ -250,6 +256,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
       fail = true;
     }
     if (!fail) return;
+  } else if (t.is_float8()) {
+    if (t.is_scalar()) {
+      os << "unsigned char";  // __nv_fp8_storage_t is an alias of unsigned 
char
+    } else if (lanes == 2) {
+      os << "unsigned short int";  // __nv_fp8x2_storage_t is an alias of 
unsigned short
+    } else if (lanes == 4) {
+      os << "unsigned int";  // __nv_fp8x4_storage_t is an alias of unsigned 
int
+    } else {
+      fail = true;
+    }
+    if (!fail) return;
   } else if (t == DataType::Bool()) {
     os << "bool";
     return;
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index bb507c1799..c6cf96d460 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -42,7 +42,8 @@ class CodeGenCUDA final : public CodeGenC {
   void Init(bool output_ssa);
   std::string Finish();
   bool need_include_path() {
-    return (enable_fp16_ || enable_bf16_ || enable_int8_ || 
need_math_constants_h_ || need_mma_h_);
+    return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || 
need_math_constants_h_ ||
+            need_mma_h_);
   }
   // override behavior
   void PrintFuncPrefix(std::ostream& os) final;
@@ -93,6 +94,8 @@ class CodeGenCUDA final : public CodeGenC {
   bool enable_fp16_{false};
   // whether enable bf16
   bool enable_bf16_{false};
+  // whether enable fp8
+  bool enable_fp8_{false};
   // whether enable int8
   bool enable_int8_{false};
   // whether enable warp shuffle intrinsics
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 4439a9c3d7..39214c4546 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -143,6 +143,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span 
span) {  // NOLINT(*)
              !rtype.is_bfloat16()) {
     // Cast int->bfloat16 when the other operand is a bfloat16
     rhs = cast(ltype, rhs);
+  } else if (!ltype.is_float8() && rtype.is_float8()) {
+    // Cast int->float8 for lhs when rhs is a float8
+    lhs = cast(rtype, lhs);
+  } else if (ltype.is_float8() && !rtype.is_float8()) {
+    // Cast int->float8 for rhs when lhs is a float8
+    rhs = cast(ltype, rhs);
   } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && 
rtype.is_uint())) {
     // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
     if (ltype.bits() < rtype.bits()) {
@@ -165,6 +171,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span 
span) {  // NOLINT(*)
       }
     }
   } else {
+    LOG(INFO) << lhs << " " << rhs;
     LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
   }
 }
diff --git a/src/tir/transforms/dtype_conversion.cc 
b/src/tir/transforms/dtype_conversion.cc
new file mode 100644
index 0000000000..de94cf6473
--- /dev/null
+++ b/src/tir/transforms/dtype_conversion.cc
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dtype_conversion.cc
+ * \brief Header file of data type conversion routines.
+ */
+#include "dtype_conversion.h"
+
+namespace tvm {
+namespace tir {
+
+PrimExpr ReinterpretAsUInt(PrimExpr value) {
+  return reinterpret(GetStorageUIntDType(value.dtype()), value);
+}
+
+DataType GetStorageUIntDType(DataType dtype) { return 
DataType::UInt(dtype.bits(), dtype.lanes()); }
+
+PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode 
round_mode) {
+  DataType src_dtype = src_value.dtype();
+  // Step 1: check dtype
+  // The lanes of src dtype and target dtype must match.
+  CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes())
+      << "The lanes for data type for source value must matches the target 
datatype.";
+  auto is_floating_point = [](DataType dtype) {
+    return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16();
+  };
+  // Both source dtype and target dtype should be floating point.
+  CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype));
+  FloatConfig src_fp = FloatConfig::FromDataType(src_value.dtype()),
+              tgt_fp = FloatConfig::FromDataType(tgt_dtype);
+  int exponent_delta = tgt_fp.exponent - src_fp.exponent;
+  int bias_delta = tgt_fp.bias - src_fp.bias;
+  int mantissa_delta = tgt_fp.mantissa - src_fp.mantissa;
+  DataType src_uint = GetStorageUIntDType(src_value.dtype()),
+           tgt_uint = GetStorageUIntDType(tgt_dtype);
+  PrimExpr src_uint_value = ReinterpretAsUInt(src_value);
+  if (mantissa_delta < 0) {
+    // use rounding
+    CHECK(round_mode == RoundingMode::kHalfToEven)
+        << "Currently we only support HalfToEven rounding mode.";
+    PrimExpr rounding_bias = ((src_uint_value >> (-mantissa_delta)) & 1) +
+                             make_const(src_uint, (int64_t(1) << 
(-mantissa_delta - 1)) - 1);
+    src_uint_value = src_uint_value + rounding_bias;
+  }
+  if (exponent_delta == 0) {
+    // number of exponent bits exactly matches
+    PrimExpr ret = src_uint_value;
+    if (mantissa_delta >= 0) {
+      ret = cast(tgt_uint, ret) << mantissa_delta;
+    } else {  // mantissa_delta < 0
+      ret = cast(tgt_uint, ret >> (-mantissa_delta));
+    }
+    if (bias_delta > 0) {
+      ret = ret + (make_const(tgt_uint, bias_delta) << tgt_fp.mantissa);
+    } else if (bias_delta < 0) {
+      ret = ret - (make_const(tgt_uint, -bias_delta) << tgt_fp.mantissa);
+    }
+    return reinterpret(tgt_dtype, ret);
+  } else {
+    // number of exponent bits mismatch.
+    PrimExpr ret_mantissa =
+        (mantissa_delta >= 0 ? (cast(tgt_uint, src_uint_value) << 
mantissa_delta)
+                             : (cast(tgt_uint, src_uint_value >> 
(-mantissa_delta)))) &
+        make_const(tgt_uint, (int64_t(1) << (tgt_fp.mantissa)) - 1);
+    PrimExpr exponent_before_delta = ((src_uint_value << 1) >> 
(src_fp.mantissa + 1));
+    PrimExpr ret_sign = cast(tgt_uint, (src_uint_value >> (src_fp.mantissa + 
src_fp.exponent)))
+                        << (tgt_fp.mantissa + tgt_fp.exponent);
+    if (bias_delta >= 0) {
+      PrimExpr ret_exponent =
+          (bias_delta > 0) ? (cast(tgt_uint, exponent_before_delta + 
bias_delta) << tgt_fp.mantissa)
+                           : (cast(tgt_uint, exponent_before_delta) << 
tgt_fp.mantissa);
+      return reinterpret(tgt_dtype, ret_mantissa | ret_exponent | ret_sign);
+    } else {  // bias_delta < 0
+      PrimExpr round_to_zero = exponent_before_delta < (-bias_delta);
+      PrimExpr ret_exponent = cast(tgt_uint, exponent_before_delta - 
(-bias_delta))
+                              << tgt_fp.mantissa;
+      return reinterpret(tgt_dtype, if_then_else(round_to_zero, 
make_const(tgt_uint, 0),
+                                                 ret_mantissa | ret_exponent | 
ret_sign));
+    }
+  }
+}
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/dtype_conversion.h 
b/src/tir/transforms/dtype_conversion.h
new file mode 100644
index 0000000000..b509abb9cd
--- /dev/null
+++ b/src/tir/transforms/dtype_conversion.h
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dtype_conversion.h
+ * \brief Header file of data type conversion routines.
+ */
+#ifndef TVM_TIR_TRANSFORMS_DTYPE_CONVERSION_H_
+#define TVM_TIR_TRANSFORMS_DTYPE_CONVERSION_H_
+
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+
+namespace tir {
+
+/*!
+ * \brief Rounding mode: https://en.wikipedia.org/wiki/Rounding
+ */
+enum class RoundingMode {
+  // Round half to nearest even
+  kHalfToEven = 0U,
+  // Round down
+  kDown = 1U,
+  // Round up
+  kUp = 2U,
+  // Round towards zero
+  kTowardsZero = 3U,
+};
+
+/*!
+ * \brief Floating point internal representation.
+ */
+class FloatConfig {
+ public:
+  /*!
+   * \brief Style of infinite number representation.
+   */
+  enum class InftyStyle {
+    // Exponent all ones, mantissa all zeros
+    kIEEE = 0U,
+    // No representation of infinity
+    kNone = 1U
+  };
+  /*!
+   * \brief Style of NaN (not-a-number) representation.
+   */
+  enum class NaNStyle {
+    // Exponent all ones, mantissa non zeros
+    // - quiet NaN : 1XXXXX...
+    // - signaling NaN : 0XXXXX...
+    kIEEE = 0U,
+    // No representation of infinity
+    kNone = 1U,
+    // Both exponent bits and mantissa bits are all ones.
+    kAllOnes = 2U,
+  };
+  // The number of exponent bits.
+  int exponent;
+  // The number of mantissa (also know as fraction in IEEE format) bits.
+  int mantissa;
+  // The exponent bias in IEEE format.
+  int bias;
+  // The representation of infinity.
+  InftyStyle infty_style;
+  // The representation of NaN (Not a Number).
+  NaNStyle nan_style;
+
+  FloatConfig(int exponent, int mantissa, int bias, InftyStyle infty_style, 
NaNStyle nan_style)
+      : exponent(exponent),
+        mantissa(mantissa),
+        bias(bias),
+        infty_style(infty_style),
+        nan_style(nan_style) {}
+
+  inline int bits() const { return mantissa + exponent + 1; }
+
+  /*!
+   * \brief Create float config from data type.
+   * \param dtype The data type, must be a floating point.
+   * \return The FloatConfig class containing internal floating point 
representation.
+   */
+  static FloatConfig FromDataType(DataType dtype) {
+    CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8())
+        << "FloatConfig is only applicable to floating point data types, got " 
<< dtype
+        << " instead.";
+    if (dtype.is_float()) {
+      // IEEE 754 binary formats
+      // Reference: https://en.wikipedia.org/wiki/Floating-point_arithmetic
+      switch (dtype.bits()) {
+        case 16:
+          return FloatConfig(5, 10, 15, InftyStyle::kIEEE, NaNStyle::kIEEE);
+        case 32:
+          return FloatConfig(8, 23, 127, InftyStyle::kIEEE, NaNStyle::kIEEE);
+        default:
+          // float64
+          return FloatConfig(11, 52, 1023, InftyStyle::kIEEE, NaNStyle::kIEEE);
+      }
+    } else if (dtype.is_bfloat16()) {
+      // bfloat16,
+      return FloatConfig(8, 7, 127, InftyStyle::kIEEE, NaNStyle::kIEEE);
+    } else {  // float8
+      // NVIDIA/Arm/Intel's FP8 formats for Deep Learning
+      // Reference: https://arxiv.org/abs/2209.05433
+      switch (dtype.code()) {
+        case DataType::kE4M3Float:
+          // E4M3 format, not consistent with IEEE-754
+          return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes);
+        default:
+          // E5M2 format, consistent with IEEE-754
+          return FloatConfig(5, 2, 15, InftyStyle::kIEEE, NaNStyle::kIEEE);
+      }
+    }
+  }
+};
+
+/*!
+ * \brief Reinterpret value as unsigned integer with equal number of bits.
+ * \param value The value to interpret.
+ * \return The reinterpreted uint value.
+ */
+PrimExpr ReinterpretAsUInt(PrimExpr value);
+
+/*!
+ * \brief Get the unsigned integer data type used as storage when the 
specified dtype is not
+ *   supported natively.
+ * \param dtype The data type.
+ * \return The uint data type, the number of bits is
+ *   the same as input dtype.
+ */
+DataType GetStorageUIntDType(DataType dtype);
+
+/*!
+ * \brief Conversion routine from value stored in one floating point data type 
to another floating
+ *   point data type.
+ * \param src_value The floating point value to be converted.
+ * \param tgt_dtype The target floating point data type.
+ * \param round_mode The rounding mode to use, defaults to kHalfToEven.
+ * \return The converted value in target floating point data type.
+ * \note Used when there is no native data type conversion implementation.
+ */
+PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype,
+                         RoundingMode round_mode = RoundingMode::kHalfToEven);
+
+}  // namespace tir
+}  // namespace tvm
+#endif  // TVM_TIR_TRANSFORMS_DTYPE_CONVERSION_H_
diff --git a/src/tir/transforms/bf16_legalize.cc 
b/src/tir/transforms/unsupported_dtype_legalize.cc
similarity index 69%
rename from src/tir/transforms/bf16_legalize.cc
rename to src/tir/transforms/unsupported_dtype_legalize.cc
index cc57735df6..be8876b815 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -18,10 +18,9 @@
  */
 
 /*!
- * \file bf16_legalize.cc
- * \brief legalize bf16 type by adding cast_to_fp32
+ * \file unsupported_dtype_legalize.cc
+ * \brief legalize bf16/fp8 type by adding cast_to_fp32
  */
-
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/op.h>
@@ -31,21 +30,24 @@
 #include <cmath>
 #include <tuple>
 
+#include "dtype_conversion.h"
+
 namespace tvm {
 namespace tir {
 
 // NOTE: do not touch buffer on function boundary
-// remap internal bf16 buffer to f32 if they meet the following condition
+// remap internal fp8/bf16 buffer to f32 if they meet the following condition
 // - constant allocation size
 // - do not have raw pointer access to the buffer
 //
 // populate the buffer_remap and var_remap accordingly.
-class BF16ComputeLegalizePlanner : public StmtExprVisitor {
+class ComputeLegalizePlanner : public StmtExprVisitor {
  public:
-  BF16ComputeLegalizePlanner(
+  ComputeLegalizePlanner(
       std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_remap,
-      std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap)
-      : buffer_remap_(buffer_remap), var_remap_(var_remap) {}
+      std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap,
+      DataType promote_dtype)
+      : buffer_remap_(buffer_remap), var_remap_(var_remap), 
promote_dtype_(promote_dtype) {}
 
   // run planning to populate buffer remap and var remap.
   void Plan(PrimFunc func) {
@@ -71,10 +73,12 @@ class BF16ComputeLegalizePlanner : public StmtExprVisitor {
     }
   }
 
+  virtual bool MatchDType(DataType dtype) const = 0;
+
   void VisitStmt_(const AllocateNode* op) final {
-    // remap all intermediate constant buffr to fp32
-    if (op->dtype.is_bfloat16() && op->ConstantAllocationSize() != 0) {
-      DataType dtype = DataType::Float(32, op->dtype.lanes());
+    // remap all intermediate constant buffer to promote data types (fp16/fp32)
+    if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) {
+      DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes());
       Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
       (*var_remap_)[op->buffer_var] = buffer_var;
     }
@@ -109,7 +113,7 @@ class BF16ComputeLegalizePlanner : public StmtExprVisitor {
     auto var_it = var_remap_->find(buf->data);
     if (var_it == var_remap_->end()) return;
 
-    Buffer new_buffer(var_it->second, DataType::Float(32, buf->dtype.lanes()), 
buf->shape,
+    Buffer new_buffer(var_it->second, 
promote_dtype_.with_lanes(buf->dtype.lanes()), buf->shape,
                       buf->strides, buf->elem_offset, buf->name, 
buf->data_alignment,
                       buf->offset_factor, buf->buffer_type, 
buf->axis_separators, buf->span);
     (*buffer_remap_)[buf] = new_buffer;
@@ -118,42 +122,68 @@ class BF16ComputeLegalizePlanner : public StmtExprVisitor 
{
   std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_remap_;
   std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap_;
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> opaque_var_access_;
+  DataType promote_dtype_;
+};
+
+class BF16ComputeLegalizePlanner : public ComputeLegalizePlanner {
+ public:
+  explicit BF16ComputeLegalizePlanner(
+      std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_remap,
+      std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap,
+      DataType promote_dtype)
+      : ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {}
+  bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); }
+};
+
+class FP8ComputeLegalizePlanner : public ComputeLegalizePlanner {
+ public:
+  explicit FP8ComputeLegalizePlanner(
+      std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_remap,
+      std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap,
+      DataType promote_dtype)
+      : ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {}
+  bool MatchDType(DataType dtype) const { return dtype.is_float8(); }
 };
 
-#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC)                       \
-  PrimExpr VisitExpr_(const OP* op) final {                       \
-    PrimExpr origin_a = PromoteBF16ToF32(this->VisitExpr(op->a)); \
-    PrimExpr origin_b = PromoteBF16ToF32(this->VisitExpr(op->b)); \
-                                                                  \
-    if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) {     \
-      return GetRef<PrimExpr>(op);                                \
-    } else {                                                      \
-      return FUNC(origin_a, origin_b);                            \
-    }                                                             \
+#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC)                      \
+  PrimExpr VisitExpr_(const OP* op) final {                      \
+    PrimExpr origin_a = PromoteToTarget(this->VisitExpr(op->a)); \
+    PrimExpr origin_b = PromoteToTarget(this->VisitExpr(op->b)); \
+                                                                 \
+    if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) {    \
+      return GetRef<PrimExpr>(op);                               \
+    } else {                                                     \
+      return FUNC(origin_a, origin_b);                           \
+    }                                                            \
   }
 
-// NOTE: Legalize the BF16 computations
+// NOTE: Legalize the FP8/BF16 computations
 // to floating point computations and only keeps the
-// bf16 storage which can further be legalized by BF16StorageLegalizer
-// BF16StorageLegalizer will be called at a much later time
+// fp8/bf16 storage which can further be legalized by FP8/BF16StorageLegalizer
+// FP8/BF16StorageLegalizer will be called at a much later time
 // point in the TIR lowering phases.
-class BF16ComputeLegalizer : public StmtExprMutator {
+class ComputeLegalizer : public StmtExprMutator {
  public:
-  PrimFunc Legalize(PrimFunc func) {
-    BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_);
-    planner.Plan(func);
+  explicit ComputeLegalizer(DataType promote_dtype) : 
promote_dtype_(promote_dtype) {}
+
+  PrimFunc LegalizeWithPlanner(PrimFunc func, ComputeLegalizePlanner* planner) 
{
+    planner->Plan(func);
     auto* n = func.CopyOnWrite();
     n->body = this->VisitStmt(std::move(n->body));
     return func;
   }
 
+  virtual PrimFunc Legalize(PrimFunc func) = 0;
+
+  virtual bool MatchDType(DataType dtype) const = 0;
+
  protected:
   PrimExpr VisitExpr_(const CastNode* op) final {
-    auto op_val = PromoteBF16ToF32(this->VisitExpr(op->value));
+    auto op_val = PromoteToTarget(this->VisitExpr(op->value));
 
-    // all casts to BF16 becomes f32
-    if (op->dtype.is_bfloat16()) {
-      return cast(DataType::Float(32, op->dtype.lanes()), op_val);
+    // all casts to matched data type (fp8/bf16) becomes f32
+    if (MatchDType(op->dtype)) {
+      return cast(promote_dtype_.with_lanes(op->dtype.lanes()), op_val);
     }
 
     if (op_val.same_as(op->value)) {
@@ -165,8 +195,8 @@ class BF16ComputeLegalizer : public StmtExprMutator {
 
   PrimExpr VisitExpr_(const SelectNode* op) final {
     PrimExpr condition = this->VisitExpr(op->condition);
-    PrimExpr true_value = PromoteBF16ToF32(this->VisitExpr(op->true_value));
-    PrimExpr false_value = PromoteBF16ToF32(this->VisitExpr(op->false_value));
+    PrimExpr true_value = PromoteToTarget(this->VisitExpr(op->true_value));
+    PrimExpr false_value = PromoteToTarget(this->VisitExpr(op->false_value));
     if (condition.same_as(op->condition) && true_value.same_as(op->true_value) 
&&
         false_value.same_as(op->false_value)) {
       return GetRef<PrimExpr>(op);
@@ -176,7 +206,7 @@ class BF16ComputeLegalizer : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const BroadcastNode* op) final {
-    PrimExpr value = PromoteBF16ToF32(this->VisitExpr(op->value));
+    PrimExpr value = PromoteToTarget(this->VisitExpr(op->value));
     if (value.same_as(op->value)) {
       return GetRef<PrimExpr>(op);
     } else {
@@ -185,7 +215,7 @@ class BF16ComputeLegalizer : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const ShuffleNode* op) final {
-    auto fexpr = [this](const PrimExpr& e) { return 
PromoteBF16ToF32(this->VisitExpr(e)); };
+    auto fexpr = [this](const PrimExpr& e) { return 
PromoteToTarget(this->VisitExpr(e)); };
     auto vectors = op->vectors.Map(fexpr);
     if (vectors.same_as(op->vectors)) {
       return GetRef<PrimExpr>(op);
@@ -200,10 +230,10 @@ class BF16ComputeLegalizer : public StmtExprMutator {
       return StmtExprMutator::VisitExpr_(op);
     }
     // update normal computations to return f32 instead.
-    auto fmutate = [this](const PrimExpr& e) { return 
PromoteBF16ToF32(this->VisitExpr(e)); };
+    auto fmutate = [this](const PrimExpr& e) { return 
PromoteToTarget(this->VisitExpr(e)); };
     Array<PrimExpr> args = op->args.Map(fmutate);
-    if (op->dtype.is_bfloat16()) {
-      return Call(DataType::Float(32, op->dtype.lanes()), op->op, args);
+    if (MatchDType(op->dtype)) {
+      return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args);
     }
     if (args.same_as(op->args)) {
       return GetRef<PrimExpr>(op);
@@ -213,8 +243,8 @@ class BF16ComputeLegalizer : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const FloatImmNode* op) final {
-    if (op->dtype.is_bfloat16()) {
-      return FloatImm(DataType::Float(32), op->value);
+    if (MatchDType(op->dtype)) {
+      return FloatImm(promote_dtype_, op->value);
     }
     return GetRef<PrimExpr>(op);
   }
@@ -231,7 +261,7 @@ class BF16ComputeLegalizer : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LetNode* op) final {
-    PrimExpr value = PromoteBF16ToF32(op->value);
+    PrimExpr value = PromoteToTarget(op->value);
     Var var = op->var;
     if (value.dtype() != op->value.dtype()) {
       var = op->var.copy_with_dtype(op->value.dtype());
@@ -261,7 +291,7 @@ class BF16ComputeLegalizer : public StmtExprMutator {
   DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=);
 
   Stmt VisitStmt_(const LetStmtNode* op) final {
-    PrimExpr value = PromoteBF16ToF32(op->value);
+    PrimExpr value = PromoteToTarget(op->value);
     Var var = op->var;
     if (value.dtype() != op->value.dtype()) {
       var = op->var.copy_with_dtype(op->value.dtype());
@@ -287,13 +317,16 @@ class BF16ComputeLegalizer : public StmtExprMutator {
     if (value.same_as(op->value) && indices.same_as(op->indices) && 
new_buf.same_as(op->buffer)) {
       return GetRef<Stmt>(op);
     } else {
-      if (new_buf->dtype.is_bfloat16()) {
-        value = CastF32ToBF16(value);
+      if (MatchDType(new_buf->dtype)) {
+        int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
+        int buffer_lanes = new_buf->dtype.lanes();
+        DataType legalized_dtype = new_buf->dtype.with_lanes(index_lanes * 
buffer_lanes);
+        value = CastTargetToDType(value, legalized_dtype);
       }
       if (value.dtype() != new_buf->dtype) {
         // this happens when buffer get rewritten to f32
-        // but values remain as bf16
-        ICHECK(value.dtype().is_bfloat16());
+        // but values remain as fp8/bf16
+        ICHECK(MatchDType(value->dtype));
         value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
       }
       return BufferStore(new_buf, value, indices);
@@ -373,41 +406,29 @@ class BF16ComputeLegalizer : public StmtExprMutator {
 
  private:
   /*!
-   * \brief promote BF16 to F32 and keep other values unchanged.
+   * \brief promote value to target datatype F16/F32 and keep other values 
unchanged.
    * \param value The input value.
    * \return The converted value.
    */
-  PrimExpr PromoteBF16ToF32(PrimExpr value) {
-    if (!value.dtype().is_bfloat16()) return value;
+  PrimExpr PromoteToTarget(PrimExpr value) {
+    if (!MatchDType(value.dtype())) return value;
     if (const CastNode* cast = value.as<CastNode>()) {
-      if (cast->value.dtype() == DataType::Float(32)) return cast->value;
+      if (cast->value.dtype() == 
promote_dtype_.with_lanes(value.dtype().lanes()))
+        return cast->value;
     }
-    DataType f32 = DataType::Float(32, value.dtype().lanes());
-    DataType u16 = DataType::UInt(16, value.dtype().lanes());
-    DataType u32 = DataType::UInt(32, value.dtype().lanes());
-    // reinterpret<f32>((cast<u32>(reinterpret<u16>(bf16_value)) << 16))
-    return reinterpret(f32, cast(u32, reinterpret(u16, value)) << 16);
+    return DTypeConversion(value, 
promote_dtype_.with_lanes(value.dtype().lanes()));
   }
 
   /*!
-   * \brief Cast value to F32 to BF16 and keep other values unchanged.
+   * \brief Cast value from promoted datatype (FP16/FP32) back to BF16/FP8 and 
keep other values
+   *   unchanged.
    * \param value The input value
    * \return The converted value.
    */
-  PrimExpr CastF32ToBF16(PrimExpr value) {
+  PrimExpr CastTargetToDType(PrimExpr value, DataType dtype) {
     if (!value.dtype().is_float()) return value;
-    ICHECK_EQ(value.dtype().bits(), 32);
-    DataType bf16 = DataType::BFloat(16, value.dtype().lanes());
-    DataType u16 = DataType::UInt(16, value.dtype().lanes());
-    DataType u32 = DataType::UInt(32, value.dtype().lanes());
-    PrimExpr u32_val = reinterpret(u32, value);
-
-    if (round_to_even_) {
-      PrimExpr rounding_bias = ((u32_val >> 16) & 1) + make_const(u32, 0x7FFF);
-      u32_val = u32_val + rounding_bias;
-    }
-    // reinterpret<bf16>((cast<u16>(reinterpret<u32>(f32_value)) >> 16))
-    return reinterpret(bf16, cast(u16, u32_val >> 16));
+    ICHECK_EQ(value.dtype(), 
this->promote_dtype_.with_lanes(value.dtype().lanes()));
+    return DTypeConversion(value, dtype);
   }
 
   Buffer GetRemappedBuffer(Buffer buf) {
@@ -418,19 +439,40 @@ class BF16ComputeLegalizer : public StmtExprMutator {
     return buf;
   }
 
-  bool round_to_even_{true};
-
+ protected:
+  DataType promote_dtype_;
   std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_remap_;
   std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
 };
 
+class BF16ComputeLegalizer : public ComputeLegalizer {
+ public:
+  BF16ComputeLegalizer() : ComputeLegalizer(DataType::Float(32)) {}
+  PrimFunc Legalize(PrimFunc func) {
+    BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_, 
promote_dtype_);
+    return LegalizeWithPlanner(func, &planner);
+  }
+  bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); }
+};
+
+class FP8ComputeLegalizer : public ComputeLegalizer {
+ public:
+  explicit FP8ComputeLegalizer(DataType promote_dtype) : 
ComputeLegalizer(promote_dtype) {}
+  PrimFunc Legalize(PrimFunc func) {
+    FP8ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_, 
promote_dtype_);
+    return LegalizeWithPlanner(func, &planner);
+  }
+  bool MatchDType(DataType dtype) const { return dtype.is_float8(); }
+};
+
 /*!
- * \brief This Pass legalizes remaining BF16 storages to u16
+ * \brief This Pass legalizes remaining FP8/BF16 storages to unsigned integers 
with equal number of
+ * bits.
  *
- * This pass needs to happens after BF16ComputeLegalizer and serves
- * as a way to support BF16 on platforms that do not have native support.
+ * This pass needs to happens after FP8/BF16ComputeLegalizer and serves
+ * as a way to support FP8/BF16 on platforms that do not have native support.
  */
-class BF16StorageLegalizer : public StmtExprMutator {
+class StorageLegalizer : public StmtExprMutator {
  public:
   PrimFunc Legalize(PrimFunc func) {
     ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after 
MakePackedAPI";
@@ -452,8 +494,8 @@ class BF16StorageLegalizer : public StmtExprMutator {
   }
 
   Stmt VisitStmt_(const AllocateNode* op) final {
-    if (op->dtype.is_bfloat16()) {
-      DataType dtype = DataType::UInt(16, op->dtype.lanes());
+    if (MatchDType(op->dtype)) {
+      DataType dtype = GetStorageUIntDType(op->dtype);
       Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
       var_remap_[op->buffer_var] = buffer_var;
       return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, 
op->body));
@@ -467,8 +509,8 @@ class BF16StorageLegalizer : public StmtExprMutator {
     // in a rare case the buffer didn't get remapped
     // because the original var is not bfloat*
     // force remap here
-    if (buf->dtype.is_bfloat16()) {
-      buf = Buffer(buf->data, DataType::UInt(16, buf->dtype.lanes()), 
buf->shape, buf->strides,
+    if (MatchDType(buf->dtype)) {
+      buf = Buffer(buf->data, GetStorageUIntDType(buf->dtype), buf->shape, 
buf->strides,
                    buf->elem_offset, buf->name, buf->data_alignment, 
buf->offset_factor,
                    buf->buffer_type, buf->axis_separators, buf->span);
       buffer_remap_[op->buffer] = buf;
@@ -506,13 +548,13 @@ class BF16StorageLegalizer : public StmtExprMutator {
   }
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
-    PrimExpr value = this->ChangeBF16ToU16(VisitExpr(op->value));
+    PrimExpr value = this->ChangeToUInt(VisitExpr(op->value));
     Buffer new_buf = GetRemappedBuffer(op->buffer);
     auto indices = op->indices.Map([this](PrimExpr expr) { return 
this->VisitExpr(expr); });
     if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) && 
value.same_as(op->value)) {
       return GetRef<Stmt>(op);
     } else {
-      if (op->value.dtype().is_bfloat16()) {
+      if (MatchDType(op->value.dtype())) {
         ICHECK(new_buf->dtype.is_uint());
       }
       return BufferStore(new_buf, value, indices);
@@ -558,8 +600,8 @@ class BF16StorageLegalizer : public StmtExprMutator {
       PrimExpr value = VisitExpr(op->args[0]);
       // sometimes the input dtype can change and we can skip.
       if (value.dtype() == op->dtype) return value;
-      if (op->dtype.is_bfloat16()) {
-        return reinterpret(DataType::UInt(16, op->dtype.lanes()), value);
+      if (MatchDType(op->dtype)) {
+        return reinterpret(GetStorageUIntDType(op->dtype), value);
       }
       if (op->args[0].same_as(value)) {
         return GetRef<PrimExpr>(op);
@@ -570,17 +612,19 @@ class BF16StorageLegalizer : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  virtual bool MatchDType(DataType dtype) const = 0;
+
  private:
   /*!
-   * \brief Change BF16 value to U16 value.
+   * \brief Change float value to uint value.
    * \param value The input value.
    * \return The converted value.
    */
-  PrimExpr ChangeBF16ToU16(PrimExpr value) {
-    if (!value.dtype().is_bfloat16()) return value;
+  PrimExpr ChangeToUInt(PrimExpr value) {
+    if (!MatchDType(value->dtype)) return value;
     auto* call = value.as<CallNode>();
     if (call && call->op.same_as(builtin::reinterpret())) {
-      return reinterpret(DataType::UInt(16, value.dtype().lanes()), 
call->args[0]);
+      return reinterpret(GetStorageUIntDType(value->dtype), call->args[0]);
     } else {
       return value;
     }
@@ -591,9 +635,9 @@ class BF16StorageLegalizer : public StmtExprMutator {
     if (var.dtype().is_handle()) {
       if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
         if (auto* elem_type = ptr_type->element_type.as<PrimTypeNode>()) {
-          if (elem_type->dtype.is_bfloat16()) {
-            Var new_var = Var(var->name_hint,
-                              PointerType(PrimType(DataType::UInt(16, 
elem_type->dtype.lanes()))));
+          if (MatchDType(elem_type->dtype)) {
+            Var new_var =
+                Var(var->name_hint, 
PointerType(PrimType(GetStorageUIntDType(elem_type->dtype))));
             var_remap_[var] = new_var;
             return new_var;
           }
@@ -611,13 +655,12 @@ class BF16StorageLegalizer : public StmtExprMutator {
     Buffer new_buf = buf;
     auto var_it = var_remap_.find(buf->data);
     if (var_it != var_remap_.end()) {
-      DataType dtype =
-          buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : 
buf->dtype;
+      DataType dtype = MatchDType(buf->dtype) ? 
GetStorageUIntDType(buf->dtype) : buf->dtype;
       new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, 
buf->elem_offset, buf->name,
                        buf->data_alignment, buf->offset_factor, 
buf->buffer_type,
                        buf->axis_separators, buf->span);
     } else {
-      ICHECK(!buf->dtype.is_bfloat16()) << "Cannot find var remap for " << buf;
+      ICHECK(!MatchDType(buf->dtype)) << "Cannot find var remap for " << buf;
     }
 
     buffer_remap_[buf] = new_buf;
@@ -629,6 +672,16 @@ class BF16StorageLegalizer : public StmtExprMutator {
   std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
 };
 
+class BF16StorageLegalizer : public StorageLegalizer {
+ public:
+  bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); }
+};
+
+class FP8StorageLegalizer : public StorageLegalizer {
+ public:
+  bool MatchDType(DataType dtype) const { return dtype.is_float8(); }
+};
+
 namespace transform {
 
 Pass BF16ComputeLegalize() {
@@ -651,6 +704,27 @@ Pass BF16StorageLegalize() {
 
 
TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
 
+Pass FP8ComputeLegalize(String promote_dtype_str) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    // TODO(tvm-team): skip if the target supports fp8
+    return 
FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f);
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);
+
+Pass FP8StorageLegalize() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    LOG(INFO) << f;
+    // TODO(tvm-team): skip if the target supports fp8
+    return FP8StorageLegalizer().Legalize(f);
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize);
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_datatype_nv_fp8.py 
b/tests/python/unittest/test_datatype_nv_fp8.py
new file mode 100644
index 0000000000..8313a97ee1
--- /dev/null
+++ b/tests/python/unittest/test_datatype_nv_fp8.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+import tvm.tir as tir
+from tvm import te
+from tvm.script import tir as T
+
+try:
+    from ml_dtypes import float8_e4m3fn as e4m3_float8, float8_e5m2 as 
e5m2_float8
+except ImportError:
+    e4m3_float8, e5m2_float8 = None, None
+
+
+def fp8_unary(dtype: str):
+    @T.prim_func
+    def func(
+        a: T.handle,
+        b: T.handle,
+        a_add_b: T.handle,
+        a_sub_b: T.handle,
+        a_mul_b: T.handle,
+        a_fp32: T.handle,
+        a_roundtrip: T.handle,
+    ) -> None:
+        A = T.match_buffer(a, [128], dtype=dtype)
+        B = T.match_buffer(b, [128], dtype=dtype)
+        A_add_B = T.match_buffer(a_add_b, [128], dtype=dtype)
+        A_sub_B = T.match_buffer(a_sub_b, [128], dtype=dtype)
+        A_mul_B = T.match_buffer(a_mul_b, [128], dtype=dtype)
+        A_fp32 = T.match_buffer(a_fp32, [128], dtype="float32")
+        A_roundtrip = T.match_buffer(a_roundtrip, [128], dtype=dtype)
+        for i in range(128):
+            with T.block("fp8_unary"):
+                vi = T.axis.spatial(128, i)
+                A_add_B[vi] = A[vi] + B[vi]
+                A_sub_B[vi] = A[vi] - B[vi]
+                A_mul_B[vi] = A[vi] * B[vi]
+                A_fp32[vi] = A[vi]
+                A_roundtrip[vi] = A_fp32[vi]
+
+    return func
+
+
+np_dtype, dtype_str = tvm.testing.parameters(
+    (e4m3_float8, "e4m3_float8"), (e5m2_float8, "e5m2_float8")
+)
+
+
+def test_create_nv_fp8_nd_array(np_dtype, dtype_str):
+    if np_dtype is None:
+        """Skip test if ml_dtypes is not installed"""
+        return
+    x = np.random.rand(128, 128).astype(np_dtype)
+    x_nd = tvm.nd.array(x)
+    assert x_nd.dtype == dtype_str
+
+
+def test_fp8_unary_op(np_dtype, dtype_str):
+    func = fp8_unary(dtype_str)
+    if not tvm.testing.device_enabled("llvm"):
+        return
+    if np_dtype is None:
+        """Skip test if ml_dtypes is not installed"""
+        return
+
+    f = tvm.build(func, target="llvm")
+    a = np.random.randn(128).astype(np_dtype)
+    b = np.random.randn(128).astype(np_dtype)
+    a_add_b = np.zeros(128).astype(np_dtype)
+    a_sub_b = np.zeros(128).astype(np_dtype)
+    a_mul_b = np.zeros(128).astype(np_dtype)
+    a_fp32 = np.zeros(128).astype(np.float32)
+    a_roundtrip = np.zeros(128).astype(np_dtype)
+    args = list(
+        map(lambda _: tvm.nd.array(_), [a, b, a_add_b, a_sub_b, a_mul_b, 
a_fp32, a_roundtrip])
+    )
+    f(*args)
+
+
+def test_nv_fp8_buffer(np_dtype, dtype_str):
+    m = te.size_var("m")
+    n = te.size_var("n")
+    A = tvm.tir.decl_buffer((m, n), dtype_str)
+    assert A.dtype == dtype_str
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py 
b/tests/python/unittest/test_tir_transform_bf16_legalize.py
index ababfd489a..20de9dc594 100644
--- a/tests/python/unittest/test_tir_transform_bf16_legalize.py
+++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py
@@ -53,12 +53,11 @@ def f32tou16(v):
     rounding_bias = (uint32_v >> tvm.tir.const(16, "uint32")) & 
tvm.tir.const(1, "uint32")
     rounding_bias += tvm.tir.const(0x7FFF, "uint32")
     uint32_v = uint32_v + rounding_bias
-    return uint32_v >> tvm.tir.const(16, "uint32")
+    return (uint32_v >> tvm.tir.const(16, "uint32")).astype("uint16")
 
 
 def f32tobf16(v):
-    uint32_v = f32tou16(v)
-    return T.reinterpret("bfloat16", uint32_v.astype("uint16"))
+    return T.reinterpret("bfloat16", f32tou16(v))
 
 
 def get_after_compute_legalize():
diff --git a/tests/python/unittest/test_tir_transform_fp8_legalize.py 
b/tests/python/unittest/test_tir_transform_fp8_legalize.py
new file mode 100644
index 0000000000..f5786808a6
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_fp8_legalize.py
@@ -0,0 +1,224 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.script
+import tvm.testing
+from tvm.script import tir as T
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+def get_before(dtype: str):
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(Aptr: T.handle(dtype), Bptr: T.handle(dtype), Dptr: 
T.handle(dtype)):
+            T.func_attr({"global_symbol": "main"})
+            A = T.decl_buffer((100,), dtype, data=Aptr)
+            B = T.decl_buffer((100,), dtype, data=Bptr)
+            D = T.decl_buffer((100,), dtype, data=Dptr)
+            C = T.decl_buffer((100,), dtype)
+            for i in T.grid(100):
+                C[i] = A[i] + B[i]
+                D[i] = T.exp(C[i])
+
+    return Before
+
+
+def promote_f8(f8_dtype: str, promote_dtype: str, v):
+    return promote_uint8(f8_dtype, promote_dtype, T.reinterpret("uint8", v))
+
+
+def cast_to_f8(f8_dtype: str, promote_dtype: str, v):
+    return T.reinterpret(f8_dtype, cast_to_uint8(f8_dtype, promote_dtype, v))
+
+
+def get_after_compute_legalize(dtype: str, promote_dtype: str):
+    @tvm.script.ir_module
+    class After:
+        @T.prim_func
+        def main(Aptr: T.handle(dtype), Bptr: T.handle(dtype), Dptr: 
T.handle(dtype)):
+            T.func_attr({"global_symbol": "main"})
+            A = T.decl_buffer((100,), dtype, data=Aptr)
+            B = T.decl_buffer((100,), dtype, data=Bptr)
+            D = T.decl_buffer((100,), dtype, data=Dptr)
+            C = T.decl_buffer((100,), promote_dtype)
+            for i in T.grid(100):
+                C[i] = promote_f8(dtype, promote_dtype, A[i]) + promote_f8(
+                    dtype, promote_dtype, B[i]
+                )
+                D[i] = cast_to_f8(dtype, promote_dtype, T.exp(C[i]))
+
+    return After
+
+
+def promote_uint8(f8_dtype: str, promote_dtype: str, v):
+    if f8_dtype == "e4m3_float8":
+        if promote_dtype == "float16":
+            mantissa = T.bitwise_and(
+                T.shift_left(T.Cast("uint16", v), T.uint16(7)), T.uint16(0x3FF)
+            )
+            exponent = T.shift_left(
+                T.Cast(
+                    "uint16",
+                    T.shift_right(T.shift_left(v, T.uint8(1)), T.uint8(4)) + 
T.uint8(8),
+                ),
+                T.uint16(10),
+            )
+            sign = T.shift_left(T.Cast("uint16", T.shift_right(v, 
T.uint8(7))), T.uint16(15))
+            return T.reinterpret("float16", 
T.bitwise_or(T.bitwise_or(mantissa, exponent), sign))
+        else:  # promote_dtype == "float32"
+            mantissa = T.bitwise_and(
+                T.shift_left(T.Cast("uint32", v), T.uint32(20)), 
T.uint32(0x7FFFFF)
+            )
+            exponent = T.shift_left(
+                T.Cast(
+                    "uint32",
+                    T.shift_right(T.shift_left(v, T.uint8(1)), T.uint8(4)) + 
T.uint8(120),
+                ),
+                T.uint32(23),
+            )
+            sign = T.shift_left(T.Cast("uint32", T.shift_right(v, 
T.uint8(7))), T.uint32(31))
+            return T.reinterpret("float32", 
T.bitwise_or(T.bitwise_or(mantissa, exponent), sign))
+    else:  # f8_dtype == "e5m2_float8"
+        if promote_dtype == "float16":
+            return T.reinterpret("float16", T.shift_left(T.Cast("uint16", v), 
T.uint16(8)))
+        else:  # promote_dtype == "float32"
+            mantissa = T.bitwise_and(
+                T.shift_left(T.Cast("uint32", v), T.uint32(21)), 
T.uint32(0x7FFFFF)
+            )
+            exponent = T.shift_left(
+                T.Cast(
+                    "uint32",
+                    T.shift_right(T.shift_left(v, T.uint8(1)), T.uint8(3)) + 
T.uint8(112),
+                ),
+                T.uint32(23),
+            )
+            sign = T.shift_left(T.Cast("uint32", T.shift_right(v, 
T.uint8(7))), T.uint32(31))
+            return T.reinterpret("float32", 
T.bitwise_or(T.bitwise_or(mantissa, exponent), sign))
+
+
+def cast_to_uint8(f8_dtype: str, promote_dtype: str, v):
+    if f8_dtype == "e4m3_float8":
+        if promote_dtype == "float16":
+            uint16_v = T.reinterpret("uint16", v)
+            rounding_bias = T.bitwise_and(
+                T.shift_right(uint16_v, T.uint16(7)),
+                T.uint16(1),
+            ) + T.uint16(0x3F)
+            uint16_v = uint16_v + rounding_bias
+            mantissa = T.bitwise_and(
+                T.Cast("uint8", T.shift_right(uint16_v, T.uint8(7))), 
T.uint8(0x7)
+            )
+            exponent_before_delta = T.shift_right(T.shift_left(uint16_v, 
T.uint16(1)), T.uint16(11))
+            round_to_zero = exponent_before_delta < T.uint16(8)
+            exponent = T.shift_left(
+                T.Cast("uint8", exponent_before_delta - T.uint16(8)),
+                T.uint8(3),
+            )
+            sign = T.shift_left(T.Cast("uint8", T.shift_right(uint16_v, 
T.uint16(15))), T.uint8(7))
+            return T.if_then_else(
+                round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, 
exponent), sign)
+            )
+        else:  # promote_dtype == "float32"
+            uint32_v = T.reinterpret("uint32", v)
+            rounding_bias = T.bitwise_and(
+                T.shift_right(uint32_v, T.uint32(20)), T.uint32(1)
+            ) + T.uint32(0x7FFFF)
+            uint32_v = uint32_v + rounding_bias
+            mantissa = T.bitwise_and(
+                T.Cast("uint8", T.shift_right(uint32_v, T.uint8(20))), 
T.uint8(0x7)
+            )
+            exponent_before_delta = T.shift_right(T.shift_left(uint32_v, 
T.uint32(1)), T.uint32(24))
+            round_to_zero = exponent_before_delta < T.uint32(120)
+            exponent = T.shift_left(
+                T.Cast("uint8", exponent_before_delta - T.uint32(120)), 
T.uint8(3)
+            )
+            sign = T.shift_left(T.Cast("uint8", T.shift_right(uint32_v, 
T.uint32(31))), T.uint8(7))
+            return T.if_then_else(
+                round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, 
exponent), sign)
+            )
+    else:  # f8_dtype == "e5m2_float8"
+        if promote_dtype == "float16":
+            uint16_v = T.reinterpret("uint16", v)
+            rounding_bias = T.bitwise_and(
+                T.shift_right(uint16_v, T.uint16(8)), T.uint16(1)
+            ) + T.uint16(0x7F)
+            uint16_v = uint16_v + rounding_bias
+            return T.Cast("uint8", T.shift_right(uint16_v, T.uint16(8)))
+        else:  # promote_dtype == "float32"
+            uint32_v = T.reinterpret("uint32", v)
+            rounding_bias = T.bitwise_and(
+                T.shift_right(uint32_v, T.uint32(21)), T.uint32(1)
+            ) + T.uint32(0xFFFFF)
+            uint32_v = uint32_v + rounding_bias
+            mantissa = T.bitwise_and(
+                T.Cast("uint8", T.shift_right(uint32_v, T.uint8(21))), 
T.uint8(0x3)
+            )
+            exponent_before_delta = T.shift_right(T.shift_left(uint32_v, 
T.uint32(1)), T.uint32(24))
+            round_to_zero = exponent_before_delta < T.uint32(112)
+            exponent = T.shift_left(
+                T.Cast("uint8", exponent_before_delta - T.uint32(112)), 
T.uint8(2)
+            )
+            sign = T.shift_left(T.Cast("uint8", T.shift_right(uint32_v, 
T.uint32(31))), T.uint8(7))
+            return T.if_then_else(
+                round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, 
exponent), sign)
+            )
+
+
+def get_after_storage_legalize(dtype: str, promote_dtype: str):
+    @tvm.script.ir_module
+    class After:
+        @T.prim_func
+        def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: 
T.handle("uint8")):
+            T.func_attr({"global_symbol": "main"})
+            A = T.decl_buffer((100,), "uint8", data=Aptr)
+            B = T.decl_buffer((100,), "uint8", data=Bptr)
+            D = T.decl_buffer((100,), "uint8", data=Dptr)
+            C = T.decl_buffer((100,), promote_dtype)
+            for i in T.grid(100):
+                C[i] = promote_uint8(dtype, promote_dtype, A[i]) + 
promote_uint8(
+                    dtype, promote_dtype, B[i]
+                )
+                D[i] = cast_to_uint8(dtype, promote_dtype, T.exp(C[i]))
+
+    return After
+
+
+dtype = tvm.testing.parameter("e4m3_float8", "e5m2_float8")
+promote_dtype = tvm.testing.parameter("float16", "float32")
+
+
+def test_fp8_compute_legalize(dtype, promote_dtype):
+    before = get_before(dtype)
+    expected = get_after_compute_legalize(dtype, promote_dtype)
+    # run the transform twice to ensure we can afford to deal
+    # with this repeative optimizations
+    after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before)
+    after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fp8_storage_legalize(dtype, promote_dtype):
+    before = get_after_compute_legalize(dtype, promote_dtype)
+    after = tvm.tir.transform.FP8StorageLegalize()(before)
+    expected = get_after_storage_legalize(dtype, promote_dtype)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to