This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new feb104393c [TIR][CUDA] Add native FP8 support to codegen (#16548)
feb104393c is described below

commit feb104393cde1347a47d5b30d8f0d0f0defcdf06
Author: Chris Sullivan <[email protected]>
AuthorDate: Fri Mar 15 08:21:53 2024 -0700

    [TIR][CUDA] Add native FP8 support to codegen (#16548)
    
    * [TIR][CUDA] Add native FP8 support to codegen
    
    Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide 
explicit type conversions that target hardware native conversion ops.
    
    * Conditionally run Storage and Compute legalization for targets that don't 
support FP8. This could be changed to only support conversion operators and do 
legalization on any compute operations other than builtin wmma calls.
    
    * Implement support for float16x4 (half4) for use with e4m3_float8x4 
(__nv_fp8x4_e4m3)
    
    * Add test for e4m3 <-> half conversion which lowers to ptx intrins.
    
    * Introduce half4 and support native fp8 vector types (1, 2, 4), and
    conversion between float and half vector types with equal lanes
    
    * Only cast to half2 for vector loads/stores of non native half struct 
types (lanes > 4).
    
    * Test e4m3 x4 vector quant/dequant
    
    ---------
    
    Co-authored-by: Joseph McMahan <[email protected]>
---
 include/tvm/tir/transform.h                        |   6 +-
 python/tvm/contrib/nvcc.py                         |   3 +
 src/driver/driver_api.cc                           |   5 +-
 src/target/llvm/codegen_llvm.cc                    |   2 +
 src/target/source/codegen_cuda.cc                  | 113 ++-
 src/target/source/literal/cuda_half_t.h            |  42 ++
 src/tir/transforms/unsupported_dtype_legalize.cc   |  28 +-
 .../python/codegen/test_target_codegen_cuda_fp8.py | 803 +++++++++++++++++++++
 8 files changed, 957 insertions(+), 45 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 934c2756f6..e219cc6846 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -398,6 +398,7 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
 /*!
  * \brief Legalize bf16 compute Ops. Add a cast to fp32
  *   before Ops, then add a cast back to bf16.
+ * \param target The target used for checking native bf16 support
  * \return The pass.
  */
 TVM_DLL Pass BF16ComputeLegalize();
@@ -405,10 +406,11 @@ TVM_DLL Pass BF16ComputeLegalize();
 /*!
  * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
  *   before Ops, then add a cast back to fp8.
+ * \param target The target used for checking native fp8 support
  * \param promote_dtype_str The data type used for type promotion, defaults to 
float16
  * \return The pass.
  */
-TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
+TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = 
"float16");
 
 /*!
  * \brief Legalize bf16 storage types to u16.
@@ -420,7 +422,7 @@ TVM_DLL Pass BF16StorageLegalize();
  * \brief Legalize fp8 storage types to u8.
  * \return The pass.
  */
-TVM_DLL Pass FP8StorageLegalize();
+TVM_DLL Pass FP8StorageLegalize(Target target);
 
 /*!
  * \brief Inline calls to private functions
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index d203007dd1..b1f042c1a5 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -270,6 +270,7 @@ def callback_libdevice_path(arch):
         return ""
 
 
+@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version")
 def get_target_compute_version(target=None):
     """Utility function to get compute capability of compilation target.
 
@@ -406,6 +407,7 @@ def have_cudagraph():
         return False
 
 
+@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16")
 def have_bf16(compute_version):
     """Either bf16 support is provided in the compute capability or not
 
@@ -421,6 +423,7 @@ def have_bf16(compute_version):
     return False
 
 
+@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8")
 def have_fp8(compute_version):
     """Whether fp8 support is provided in the specified compute capability or 
not
 
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index bdadb6db0f..33b4514e6b 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -216,7 +216,6 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   pass_list.push_back(tir::transform::TransformMmaBufferLayout());
   pass_list.push_back(tir::transform::LowerOpaqueBlock());
   pass_list.push_back(tir::transform::FlattenBuffer());
-  pass_list.push_back(tir::transform::FP8ComputeLegalize());
   pass_list.push_back(tir::transform::BF16ComputeLegalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
@@ -570,6 +569,8 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
 
   Array<Pass> mixed_pass_list;
 
+  mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
+
   // VerifyVTCMLimit must occur before LowerVtcmAlloc
   mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
   // LowerVtcmAlloc must occur after any transformations that modify memory 
allocation locations
@@ -619,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
   } else {
     mixed_pass_list.push_back(tir::transform::MakePackedAPI());
   }
-  mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
+  mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
   mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
 
   mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bba1488274..8fe740dad1 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -586,6 +586,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
       default:
         LOG(FATAL) << "do not support " << dtype;
     }
+  } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == 
DataType::kE5M2Float) {
+    etype = llvm::Type::getInt8Ty(*ctx);
   }
   if (!dtype.is_scalar()) {
 #if TVM_LLVM_VERSION >= 110
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 15905b0304..d352616f55 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -41,6 +41,31 @@
 namespace tvm {
 namespace codegen {
 
+std::string GetFP8Type(DataType type) {
+  std::stringstream stream;
+  int32_t lanes = type.lanes();
+  std::string vec;
+  if (type.is_scalar()) {
+    vec = "";
+  } else if (lanes == 2) {
+    vec = "_2";
+  } else if (lanes == 4) {
+    vec = "_4";
+  } else if (lanes == 8) {
+    vec = "_8";
+  } else {
+    LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for 
FP8";
+  }
+  if (type.code() == DataType::kE4M3Float) {
+    stream << "fp8_e4" << vec << "_t";
+  } else if (type.code() == DataType::kE5M2Float) {
+    stream << "fp8_e5" << vec << "_t";
+  } else {
+    LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
+  }
+  return stream.str();
+}
+
 CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
 
 void CodeGenCUDA::Init(bool output_ssa) {
@@ -121,8 +146,15 @@ std::string CodeGenCUDA::Finish() {
   if (enable_fp8_) {
     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
     decl_stream << "#include <cuda_fp8.h>\n";
+    decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
+    decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
+    decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
+    decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
+    decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
+    decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
     decl_stream << "#endif\n\n";
   }
+  declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
 
   if (enable_warp_shuffle_) {
     decl_stream << _cuda_warp_intrinsic_util;
@@ -214,17 +246,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
         if (t.is_scalar()) {
           os << "half";
         } else if (lanes <= 8) {
-          // Emit CUDA code to access fp16 vector elements.
-          //
-          // half4 is stored as uint2
-          //
-          // h4.x is emitted as *(half2*)(&(u2.x)).x
-          // h4.y is emitted as *(half2*)(&(u2.x)).y
-          // h4.z is emitted as *(half2*)(&(u2.y)).x
-          // h4.w is emitted as *(half2*)(&(u2.y)).y
-          //
-          ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
-          os << "uint" << lanes / 2;
+          ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for 
half type";
+          if (lanes <= 4) {
+            os << "half" << lanes;
+          } else {
+            os << "uint" << lanes / 2;
+          }
         } else {
           fail = true;
         }
@@ -271,16 +298,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
     }
     if (!fail) return;
   } else if (t.is_float8()) {
-    if (t.is_scalar()) {
-      os << "unsigned char";  // __nv_fp8_storage_t is an alias of unsigned 
char
-    } else if (lanes == 2) {
-      os << "unsigned short int";  // __nv_fp8x2_storage_t is an alias of 
unsigned short
-    } else if (lanes == 4) {
-      os << "unsigned int";  // __nv_fp8x4_storage_t is an alias of unsigned 
int
-    } else {
-      fail = true;
-    }
-    if (!fail) return;
+    enable_fp8_ = true;
+    os << GetFP8Type(t);
+    return;
   } else if (t == DataType::Bool()) {
     os << "bool";
     return;
@@ -446,7 +466,7 @@ void CodeGenCUDA::PrintVecConstructor(DataType t, 
std::ostream& os) {
 
 void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr 
lhs, PrimExpr rhs,
                                    std::ostream& os) {  // NOLINT(*)
-  // Delcare the result.
+  // Declare the result.
   std::string sret = name_supply_->FreshName("_");
   this->PrintIndent();
   this->PrintType(t, stream);
@@ -497,7 +517,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, 
DataType t, int i,
       os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
     }
   } else if (t.is_float16()) {
-    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i 
% 2];
+    if (t.lanes() <= 4) {
+      os << vec << "." << access[i];
+    } else {
+      os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2];
+    }
   } else if (t.is_bfloat16()) {
     os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2];
   } else if (t.lanes() > 4 && t.lanes() <= 8) {
@@ -543,8 +567,13 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& 
vec, DataType t, int i,
       stream << "(" << value << " << " << i % 4 * 8 << ");\n";
     }
   } else if (t.is_float16()) {
-    stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2] << " = "
-           << value << ";\n";
+    if (t.lanes() <= 4) {
+      stream << vec << "." << access[i] << " = " << value << ";\n";
+    } else {
+      stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2] << " = "
+             << value << ";\n";
+    }
+
   } else if (t.is_bfloat16()) {
     stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" 
<< access[i % 2]
            << " = " << value << ";\n";
@@ -648,6 +677,16 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, 
std::ostream& os) {
   // Emit simple C-style type conversion.
   if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
 
+  if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == 
DataType::kE5M2Float ||
+      from_ty.code() == DataType::kE4M3Float || from_ty.code() == 
DataType::kE5M2Float) {
+    std::ostringstream val;
+    val << "(";
+    PrintType(target_ty, val);
+    val << ")(" << PrintExpr(op->value) << ")";
+    os << val.str();
+    return;
+  }
+
   // We could emit make_float4 like calls, but the emitted code looks
   // too compact to read. Emit this as vectorized unary ops.
   std::string sret = name_supply_->FreshName("_");
@@ -1194,9 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, 
std::ostream& os) {  // NO
     std::string v = PrintExpr(op->value);
     PrintVecConstructor(op->dtype, os);
     os << '(';
-    for (int i = 0; i < lanes / 2; ++i) {
-      if (i != 0) os << ", ";
-      os << "__pack_half2(" << v << ", " << v << ")";
+    if (lanes <= 4) {
+      for (int i = 0; i < lanes / 2; ++i) {
+        if (i != 0) os << ", ";
+        os << v << ", " << v;
+      }
+    } else {
+      for (int i = 0; i < lanes / 2; ++i) {
+        if (i != 0) os << ", ";
+        os << "__pack_half2(" << v << ", " << v << ")";
+      }
     }
     os << ')';
     return;
@@ -1448,15 +1494,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int 
i, const std::string& val
       PrintVecConstructor(t, os);
       os << '(';
     }
-    if (i % 2 == 0) {
-      os << "__pack_half2(" << value;
+    if (i == t.lanes() - 1) {
+      os << value << ")";
     } else {
-      os << "," << value << ")";
-      if (i != t.lanes() - 1) {
-        os << ",";
-      } else {
-        os << ")";
-      }
+      os << value << ",";
     }
     return;
   }
diff --git a/src/target/source/literal/cuda_half_t.h 
b/src/target/source/literal/cuda_half_t.h
index 67471daf82..bf3e83928e 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -24,6 +24,8 @@
 #ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
 #define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
 
+#include <string>
+
 static constexpr const char* _cuda_half_t_def = R"(
 typedef unsigned short uint16_t;
 typedef unsigned char uint8_t;
@@ -379,4 +381,44 @@ static constexpr const char* _cuda_warp_intrinsic_util = 
R"(
 
 )";
 
+void declare_vector_type_extensions(std::ostringstream& stream, bool 
enable_fp16, bool enable_fp8) {
+  if (enable_fp16 || enable_fp8) {
+    stream << R"(
+struct __align__(8) half4 {
+  __half x, y, z, w;
+  __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), 
w(__half(0)) {}
+  __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), 
y(y), z(z), w(w) {}
+)";
+    if (enable_fp8) {
+      stream << R"(
+  __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
+    __nv_fp8x2_e4m3 lo_part, hi_part;
+    lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
+    hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 
0xFFFF);
+    __half2 lo_half2 = static_cast<__half2>(lo_part);
+    __half2 hi_half2 = static_cast<__half2>(hi_part);
+    x = reinterpret_cast<__half*>(&lo_half2)[0];
+    y = reinterpret_cast<__half*>(&lo_half2)[1];
+    z = reinterpret_cast<__half*>(&hi_half2)[0];
+    w = reinterpret_cast<__half*>(&hi_half2)[1];
+  }
+  __host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
+    __nv_fp8x4_e4m3 result;
+    __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
+    __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+    __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
+    result.__x =
+        (static_cast<__uint32_t>(lo_part.__x) | 
(static_cast<__uint32_t>(hi_part.__x) << 16));
+    return result;
+  })";
+    }
+    stream << R"(
+};
+__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
+    return half4(x, y, z, w);
+}
+)";
+  }
+}
+
 #endif  // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc 
b/src/tir/transforms/unsupported_dtype_legalize.cc
index 030dbd01ba..c037879074 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -693,6 +693,20 @@ class FP8StorageLegalizer : public StorageLegalizer {
 
 namespace transform {
 
+bool CheckDataTypeSupport(const Target& target, const std::string& 
support_func_name) {
+  bool has_native_support = false;
+  if (target->kind->name == "cuda") {
+    if (const PackedFunc* get_cv =
+            
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
+      std::string compute_version = (*get_cv)(target);
+      if (const PackedFunc* check_support = 
tvm::runtime::Registry::Get(support_func_name)) {
+        has_native_support = (*check_support)(compute_version);
+      }
+    }
+  }
+  return has_native_support;
+}
+
 Pass BF16ComputeLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     // TODO(tvm-team): skip if the target supports bf16
@@ -713,9 +727,11 @@ Pass BF16StorageLegalize() {
 
 
TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
 
-Pass FP8ComputeLegalize(String promote_dtype_str) {
+Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    // TODO(tvm-team): skip if the target supports fp8
+    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
+      return f;
+    }
     return 
FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
@@ -723,9 +739,11 @@ Pass FP8ComputeLegalize(String promote_dtype_str) {
 
 
TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);
 
-Pass FP8StorageLegalize() {
-  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    // TODO(tvm-team): skip if the target supports fp8
+Pass FP8StorageLegalize(Target target) {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
+      return f;
+    }
     return FP8StorageLegalizer().Legalize(f);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {});
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py 
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
new file mode 100644
index 0000000000..dade970418
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -0,0 +1,803 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import pytest
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+from typing import List, Tuple
+from tvm import DataType, DataTypeCode, IRModule
+from tvm import dlight as dl
+from tvm import relax, te, tir, topi
+from tvm.relax.frontend import nn
+from tvm.runtime import NDArray
+from tvm.target import Target
+from tvm.topi.utils import get_const_tuple
+
+
[email protected]_cuda_compute_version(9)
+def test_e4m3_conversions():
+    dtype = "e4m3_float8"
+
+    @T.prim_func
+    def add(
+        A: T.Buffer((64,), dtype),
+        B: T.Buffer((64,), dtype),
+        C: T.Buffer((64,), dtype),
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i in range(64):
+            with T.block("C"):
+                v_i = T.axis.spatial(64, i)
+                T.reads(A[v_i], B[v_i])
+                T.writes(C[v_i])
+                C[v_i] = T.Cast(dtype, T.Cast("float16", A[v_i]) + 
T.Cast("float16", B[v_i]))
+
+    sch = tvm.tir.Schedule(add)
+    block = sch.get_block("C")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 32])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+
+    target = "cuda"
+    fadd = tvm.build(sch.mod, target=target)
+
+    cuda_src = fadd.imported_modules[0].get_source()
+    assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in 
generated CUDA"
+
+    dev = tvm.device(target, 0)
+
+    numpytype = "float8_e4m3fn"
+    a = tvm.nd.array(np.random.uniform(low=0, high=5, 
size=64).astype(numpytype), dev)
+    b = tvm.nd.array(np.random.uniform(low=0, high=5, 
size=64).astype(numpytype), dev)
+    c = tvm.nd.array(np.zeros(64, dtype=numpytype), dev)
+    fadd(a, b, c)
+
+    tvm.testing.assert_allclose(
+        c.numpy().astype("float16"), (a.numpy() + b.numpy()).astype("float16")
+    )
+
+
[email protected]_cuda_compute_version(9)
+def test_e4m3_packing():
+    length = 64
+    vector_length = 4
+    native_dtype, packed_dtype = ("e4m3_float8x4", "uint32")
+
+    @T.prim_func
+    def add(
+        A: T.Buffer((length,), native_dtype),
+        R: T.Buffer((length,), packed_dtype),
+        B: T.Buffer((length,), native_dtype),
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i in range(length):
+            with T.block("R"):
+                v_i = T.axis.spatial(length, i)
+                T.reads(A[v_i])
+                T.writes(R[v_i])
+                R[v_i] = T.reinterpret(packed_dtype, A[v_i])
+        for i in range(length):
+            with T.block("B"):
+                v_i = T.axis.spatial(length, i)
+                T.reads(R[v_i])
+                T.writes(B[v_i])
+                B[v_i] = T.reinterpret(native_dtype, R[v_i])
+
+    sch = tvm.tir.Schedule(add)
+    block = sch.get_block("R")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 32])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+    block = sch.get_block("B")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 32])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+
+    target = "cuda"
+    f = tvm.build(sch.mod, target=target)
+    dev = tvm.device(target, 0)
+
+    numpytype = "float8_e4m3fn"
+    np_shape = (length, vector_length)
+    a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+    a = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev)
+    r = tvm.nd.empty(shape=(length,), dtype=packed_dtype, device=dev)
+    b = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev)
+    a.copyfrom(a_np)
+    f(a, r, b)
+    tvm.testing.assert_allclose(a.numpy().astype("float16"), 
b.numpy().astype("float16"))
+
+
+native_dtype, promoted_dtype = tvm.testing.parameters(
+    ("e4m3_float8", "float32"),
+    ("e4m3_float8", "float16"),
+    ("e4m3_float8x2", "float32x2"),
+    ("e4m3_float8x2", "float16x2"),
+    ("e4m3_float8x4", "float32x4"),
+    # Supported via half4 vector type extension in codegen
+    ("e4m3_float8x4", "float16x4"),
+)
+
+
[email protected]_cuda_compute_version(9)
+def test_e4m3_vector_conversions(native_dtype, promoted_dtype):
+    vector_length = 64
+
+    @T.prim_func
+    def add(
+        A: T.Buffer((vector_length,), native_dtype),
+        B: T.Buffer((vector_length,), native_dtype),
+        C: T.Buffer((vector_length,), native_dtype),
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i in range(vector_length):
+            with T.block("C"):
+                v_i = T.axis.spatial(vector_length, i)
+                T.reads(A[v_i], B[v_i])
+                T.writes(C[v_i])
+                C[v_i] = T.Cast(
+                    native_dtype, T.Cast(promoted_dtype, A[v_i]) + 
T.Cast(promoted_dtype, B[v_i])
+                )
+
+    sch = tvm.tir.Schedule(add)
+    block = sch.get_block("C")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 32])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+
+    target = "cuda"
+    fadd = tvm.build(sch.mod, target=target)
+    cuda_src = fadd.imported_modules[0].get_source()
+    dev = tvm.device(target, 0)
+
+    numpytype = "float8_e4m3fn"
+    if "x" in native_dtype:
+        lanes = int(native_dtype.split("x")[-1])
+    else:
+        lanes = 1
+
+    if "x" in promoted_dtype:
+        promoted_base_dtype = promoted_dtype.split("x")[0]
+    else:
+        promoted_base_dtype = promoted_dtype
+
+    np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,)
+    a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+    a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+    a.copyfrom(a_np)
+    b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
+    b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+    b.copyfrom(b_np)
+    c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
+    fadd(a, b, c)
+
+    tvm.testing.assert_allclose(
+        c.numpy().astype(promoted_base_dtype), (a_np + 
b_np).astype(promoted_base_dtype)
+    )
+
+
+bcast_length = tvm.testing.parameter(2, 4, 6, 8)
+
+
[email protected]_cuda_compute_version(8)
+def test_half_broadcast(bcast_length):
+    dtype = "float16"
+
+    @T.prim_func
+    def vector_broadcast(a: T.Buffer[(), dtype], vec: 
T.Buffer[(bcast_length,), dtype]):
+        for t in range(1):
+            with T.block("broadcast"):
+                vec[0:bcast_length] = T.broadcast(a[()], bcast_length)
+
+    sch = tvm.tir.Schedule(vector_broadcast)
+    block = sch.get_block("broadcast")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 1])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+
+    target = "cuda"
+    func = tvm.build(sch.mod, target=target)
+    dev = tvm.device(target, 0)
+
+    a_np = np.random.uniform(low=0, high=4, size=()).astype(dtype)
+    a = tvm.nd.array(a_np, device=dev)
+    b = tvm.nd.empty((bcast_length,), dtype=dtype, device=dev)
+
+    func(a, b)
+
+    b_np = np.full((bcast_length,), a_np)
+
+    tvm.testing.assert_allclose(b.numpy(), b_np)
+
+
+vector_length = tvm.testing.parameter(2, 4)
+
+
[email protected]_cuda_compute_version(8)
+def test_half_misaligned_vector_load(vector_length):
+    dtype = "float16"
+    vec_dtype = dtype + "x" + str(vector_length)
+    length = 256
+
+    @T.prim_func
+    def vector_load(
+        A: T.Buffer[(length,), dtype], B: T.Buffer[(length // vector_length,), 
vec_dtype]
+    ):
+        for b in T.thread_binding(1, thread="blockIdx.x"):
+            for i in T.thread_binding(length // vector_length, 
thread="threadIdx.x"):
+                vec_index = T.ramp((i + 1) * vector_length - 1, -1, 
vector_length)
+                B[i] = A[vec_index]
+
+    target = "cuda"
+    f = tvm.build(vector_load, target=target)
+
+    dev = tvm.device(target, 0)
+    a_np = np.random.uniform(low=0, high=1, size=(length,)).astype(dtype)
+    a = tvm.nd.array(a_np, device=dev)
+
+    b = tvm.nd.empty((length // vector_length,), dtype=vec_dtype, device=dev)
+
+    f(a, b)
+
+    b_np = np.empty((length // vector_length, vector_length), dtype=dtype)
+
+    for i in range(length // vector_length):
+        start_index = (i + 1) * vector_length - 1
+        b_np[i, :] = a_np[start_index - vector_length + 1 : start_index + 
1][::-1]
+
+    tvm.testing.assert_allclose(b.numpy(), b_np)
+
+
[email protected]_cuda_compute_version(8)
+def test_half4_vector_add():
+    dtype = "float16"
+    length = 64
+    vector_length = 4
+    vec_dtype = dtype + "x" + str(vector_length)
+
+    @T.prim_func
+    def add(
+        A: T.Buffer((length,), vec_dtype),
+        B: T.Buffer((length,), vec_dtype),
+        C: T.Buffer((length,), vec_dtype),
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i in range(length):
+            with T.block("C"):
+                v_i = T.axis.spatial(length, i)
+                T.reads(A[v_i], B[v_i])
+                T.writes(C[v_i])
+                C[v_i] = A[v_i] + B[v_i]
+
+    sch = tvm.tir.Schedule(add)
+    block = sch.get_block("C")
+    b = sch.get_loops(block)
+    bx, tx = sch.split(b[0], factors=[None, 32])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+
+    target = "cuda"
+    fadd = tvm.build(sch.mod, target=target)
+    dev = tvm.device(target, 0)
+
+    a_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype)
+    a = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
+    a.copyfrom(a_np)
+    b_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype)
+    b = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
+    b.copyfrom(b_np)
+    c = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
+
+    fadd(a, b, c)
+    c_expected = a_np + b_np
+    tvm.testing.assert_allclose(c.numpy(), c_expected, atol=1e-5, rtol=1e-5)
+
+
+class BaseFP8E4M3QuantScaleOnly:
+    @classmethod
+    def create_quantize_func(
+        cls,
+        weight_shape,
+        model_dtype,
+        quantize_dtype,
+        storage_dtype,
+        group_size,
+        num_elem_per_storage,
+        max_int_value,
+        axis,
+        output_transpose,
+    ) -> IRModule:
+        if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
+            quantize_func = cls.quantize_fp8x4_e4m3
+        else:
+            assert NotImplementedError()
+
+        bb = relax.BlockBuilder()  # pylint: disable=invalid-name
+        weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, 
model_dtype))
+        compute_scale, compute_quantize, compute_transpose = quantize_func(
+            weight_shape,
+            model_dtype,
+            quantize_dtype,
+            storage_dtype,
+            group_size,
+            num_elem_per_storage,
+            max_int_value,
+            axis,
+            output_transpose,
+        )
+        with bb.function(name="main", params=[weight_var]):
+            with bb.dataflow():
+                lv_scale = bb.emit_te(compute_scale, weight_var)
+                lv_quantized_weight = compute_quantize(bb, (weight_var, 
lv_scale))
+                if compute_transpose:
+                    lv_output = bb.emit_te(compute_transpose, 
lv_quantized_weight, lv_scale)
+                    lv_quantized_weight = lv_output[0]
+                    lv_scale = lv_output[1]
+                tuple_output = bb.emit((lv_quantized_weight, lv_scale))
+                gv = bb.emit_output(tuple_output)
+            bb.emit_func_output(gv)
+        return bb.finalize()
+
+    @classmethod
+    def create_dequantize_func(
+        cls,
+        packed_weight_shape,
+        scale_shape,
+        dequantized_shape,
+        model_dtype,
+        quantize_dtype,
+        storage_dtype,
+        group_size,
+        num_elem_per_storage,
+        axis,
+    ) -> IRModule:
+        if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
+            dequantize_func = cls.dequantize_fp8x4_e4m3
+        else:
+            assert NotImplementedError()
+
+        bb = relax.BlockBuilder()  # pylint: disable=invalid-name
+        packed_weight_var = relax.Var(
+            "weight", relax.TensorStructInfo(packed_weight_shape, 
storage_dtype)
+        )
+        scale_var = relax.Var("scale", relax.TensorStructInfo(scale_shape, 
model_dtype))
+        compute_dequantize = dequantize_func(
+            packed_weight_shape,
+            scale_shape,
+            dequantized_shape,
+            model_dtype,
+            quantize_dtype,
+            storage_dtype,
+            group_size,
+            num_elem_per_storage,
+            axis,
+        )
+        with bb.function(name="main", params=[packed_weight_var, scale_var]):
+            with bb.dataflow():
+                lv = compute_dequantize(bb, (packed_weight_var, scale_var))
+                gv = bb.emit_output(lv)
+            bb.emit_func_output(gv)
+        return bb.finalize()
+
+    @classmethod
+    def quantize_fp8x4_e4m3(  # pylint: disable=too-many-locals
+        cls,
+        weight_shape: List[tir.PrimExpr],
+        model_dtype,
+        quantize_dtype,
+        storage_dtype,
+        group_size,
+        num_elem_per_storage,
+        max_int_value,
+        axis: int = -1,
+        output_transpose: bool = False,
+    ) -> Tuple[te.Tensor, te.Tensor]:
+        """Group quantization for weight tensor, defined in tensor 
expression."""
+        max_int = tir.const(max_int_value, model_dtype)
+        shape = weight_shape  # pylint: disable=invalid-name
+        axis = axis if axis >= 0 else len(shape) + axis
+        k = shape[axis]
+        quantize_dtype = DataType(quantize_dtype)
+        # compute scale per group
+        r = te.reduce_axis((0, group_size), name="r")  # pylint: 
disable=invalid-name
+        num_group = tir.ceildiv(k, group_size)
+        # (4096, 4096) -> quantize axis = 0, group size = 32 -> (128, 4096)
+        # for channel quant group_size = 4096 -> (1, 4096)
+        scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :])
+
+        def compute_scale(weight: te.Tensor):
+            min_scaling_factor = tir.const(1.0 / (max_int_value * 512.0), 
model_dtype)
+            max_abs = te.compute(
+                shape=scale_shape,
+                fcompute=lambda *idx: te.max(
+                    tir.if_then_else(
+                        idx[axis] * group_size + r < k,
+                        te.abs(weight(*idx[:axis], idx[axis] * group_size + r, 
*idx[axis + 1 :])),
+                        te.min_value(model_dtype),
+                    ),
+                    axis=r,
+                ),
+                name="max_abs_value",
+            )
+            scale = te.compute(
+                scale_shape,
+                lambda *idx: te.max(
+                    max_abs(*idx).astype(model_dtype) / max_int, 
min_scaling_factor
+                ),
+                name="scale",
+            )
+            return scale
+
+        def compute_quantize_weight(bb: relax.BlockBuilder, args: 
relax.expr.Expr):
+            # compute scaled weight
+            packed_shape = (weight_shape[0], weight_shape[1] // 
num_elem_per_storage)
+            quant = cls.quant_and_pack_fp8x4_e4m3_sm90(
+                weight_shape,
+                packed_shape,
+                scale_shape,
+                group_size,
+                axis,
+                model_dtype,
+                storage_dtype,
+                quantize_dtype,
+            )
+            # quant.show()
+
+            global_var = bb.add_func(quant, "quantized_weight")
+            lv_quantized_weight = bb.emit(
+                relax.call_tir(
+                    global_var, args, relax.TensorStructInfo(packed_shape, 
storage_dtype)
+                )
+            )
+            return lv_quantized_weight
+
+        compute_transpose = None
+        if output_transpose:
+
+            def compute_transpose(quantized_weight: te.Tensor, scale: 
te.Tensor):
+                if len(quantized_weight.shape) != 2 or len(scale.shape) != 2:
+                    raise ValueError(
+                        "Does not support transpose output quantized weight 
with ndim != 2"
+                    )
+
+                quantized_weight = topi.transpose(quantized_weight)
+                scale = topi.transpose(scale)
+                return quantized_weight, scale
+
+        return compute_scale, compute_quantize_weight, compute_transpose
+
+    @classmethod
+    def dequantize_fp8x4_e4m3(  # pylint: disable=too-many-locals
+        cls,
+        packed_weight_shape: List[tir.PrimExpr],
+        scale_shape,
+        dequant_shape,
+        model_dtype,
+        quantize_dtype,
+        storage_dtype,
+        group_size,
+        num_elem_per_storage,
+        axis: int = -1,
+    ) -> Tuple[te.Tensor, te.Tensor]:
+        """Group quantization for weight tensor, defined in tensor 
expression."""
+        axis = axis if axis >= 0 else len(shape) + axis
+
+        def compute_dequantize_weight(bb: relax.BlockBuilder, args: 
relax.expr.Expr):
+            dequant = cls.dequant_fp8x4_e4m3_sm90(
+                packed_weight_shape,
+                scale_shape,
+                dequant_shape,
+                group_size,
+                axis,
+                model_dtype,
+                storage_dtype,
+                quantize_dtype,
+            )
+
+            global_var = bb.add_func(dequant, "dequantize_weight")
+            lv_dequantized_weight = bb.emit(
+                relax.call_tir(global_var, args, 
relax.TensorStructInfo(dequant_shape, model_dtype))
+            )
+            return lv_dequantized_weight
+
+        return compute_dequantize_weight
+
+    @classmethod
+    def quant_and_pack_fp8x4_e4m3_sm90(
+        cls,
+        weight_shape,
+        packed_shape,
+        scale_shape,
+        group_size,
+        axis,
+        model_dtype,
+        storage_dtype,
+        quantized_dtype,
+    ):
+        vector_length = 4
+        vec_quantized_dtype = f"{quantized_dtype}x{vector_length}"
+        vec_model_dtype = f"{model_dtype}x{vector_length}"
+        num_elem_per_storage = vector_length
+        # TODO(csullivan) assert on storage dtype / quantize type bytes == 
vector length
+        assert (
+            group_size % vector_length == 0
+        ), f"Number of elements in a group must be divisible by fp8 vector 
length {vector_length}"
+
+        @T.prim_func(private=True)
+        def quant_pack(
+            A: T.Buffer(weight_shape, model_dtype),
+            scale: T.Buffer(scale_shape, model_dtype),
+            compute: T.Buffer(
+                packed_shape,
+                storage_dtype,
+            ),
+        ):
+            # with T.block("root"):
+            # test = T.alloc_buffer(1, dtype=vec_model_dtype, scope="local")
+            for i0, i1 in T.grid(
+                T.int64(weight_shape[0]), T.int64(weight_shape[1] // 
vector_length)
+            ):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(
+                        A[v_i0, v_i1 : v_i1 + vector_length],
+                        scale[v_i0, v_i1 * T.int64(vector_length) // 
T.int64(group_size)],
+                    )
+                    T.writes(compute[v_i0, v_i1 * vector_length])
+                    compute[v_i0, v_i1] = T.reinterpret(
+                        storage_dtype,
+                        T.Cast(
+                            vec_quantized_dtype,
+                            A[v_i0, T.ramp(v_i1 * vector_length, 1, 
vector_length)]
+                            / scale[v_i0, v_i1 * T.int64(vector_length) // 
T.int64(group_size)],
+                        ),
+                    )
+
+        return quant_pack
+
+    @classmethod
+    def dequant_fp8x4_e4m3_sm90(
+        cls,
+        packed_weight_shape,
+        scale_shape,
+        out_shape,
+        group_size,
+        axis,
+        model_dtype,
+        storage_dtype,
+        quantized_dtype,
+    ):
+        vector_length = 4
+        vec_quantized_dtype = f"{quantized_dtype}x{vector_length}"
+        vec_model_dtype = f"{model_dtype}x{vector_length}"
+        num_elem_per_storage = vector_length
+
+        @T.prim_func
+        def dequant(
+            packed_weight: T.Buffer(packed_weight_shape, storage_dtype),
+            scale: T.Buffer(scale_shape, model_dtype),
+            dequantize: T.Buffer(out_shape, model_dtype),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1 in T.grid(T.int64(packed_weight_shape[0]), 
T.int64(packed_weight_shape[1])):
+                with T.block("dequantize"):
+                    v_i0 = T.axis.spatial(T.int64(packed_weight_shape[0]), i0)
+                    v_i1 = T.axis.spatial(T.int64(packed_weight_shape[1]), i1)
+                    T.reads(
+                        packed_weight[v_i0, v_i1],
+                        scale[v_i0, v_i1 * T.int64(vector_length) // 
T.int64(group_size)],
+                    )
+
+                    dequantize[v_i0, T.ramp(v_i1 * vector_length, 1, 
vector_length)] = T.Cast(
+                        vec_model_dtype,
+                        T.reinterpret(vec_quantized_dtype, packed_weight[v_i0, 
v_i1]),
+                    ) * T.Broadcast(
+                        scale[v_i0, v_i1 * T.int64(vector_length) // 
T.int64(group_size)],
+                        vector_length,
+                    )
+
+        return dequant
+
+    @classmethod
+    def compile_quant_and_dequant_by_scale(
+        cls,
+        weight_shape,
+        scales_shape,
+        quant_weight_shape,
+        model_dtype,
+        quantize_dtype,
+        storage_dtype,
+        group_size,
+        num_el_per_storage,
+        max_int_value,
+        axis,
+        target_str,
+        dev,
+    ):
+        quant_mod = cls.create_quantize_func(
+            weight_shape,
+            model_dtype,
+            quantize_dtype,
+            storage_dtype,
+            group_size,
+            num_el_per_storage,
+            max_int_value,
+            axis,
+            output_transpose=False,
+        )
+        # quant_mod.show()
+
+        target = tvm.target.Target(target_str)
+        with target:
+            quant_mod = dl.ApplyDefaultSchedule(
+                dl.gpu.Reduction(),
+                dl.gpu.GeneralReduction(),
+                dl.gpu.Fallback(),
+            )(quant_mod)
+        ex_1 = relax.build(quant_mod, target=target)
+        vm_1 = relax.VirtualMachine(ex_1, dev)
+
+        dequant_mod = cls.create_dequantize_func(
+            quant_weight_shape,
+            scales_shape,
+            weight_shape,
+            model_dtype,
+            quantize_dtype,
+            storage_dtype,
+            group_size,
+            num_el_per_storage,
+            axis,
+        )
+        # dequant_mod.show()
+
+        with target:
+            dequant_mod = dl.ApplyDefaultSchedule(
+                dl.gpu.Reduction(),
+                dl.gpu.GeneralReduction(),
+                dl.gpu.Fallback(),
+            )(dequant_mod)
+        dequant_mod.show()
+
+        ex_2 = relax.build(dequant_mod, target=target)
+        vm_2 = relax.VirtualMachine(ex_2, dev)
+
+        def print_cuda(target, mod, name=None):
+            if name:
+                mod = mod[name]
+            f = tvm.build(mod, target=target)
+            cuda_src = f.imported_modules[0].get_source()
+            print(cuda_src)
+
+        print_cuda(target, dequant_mod, name="dequant")
+
+        return vm_1["main"], vm_2["main"]
+
+
+class TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly):
+    # weight_shape = tvm.testing.parameter((32000, 4096), (4096, 14336))
+    weight_shape = tvm.testing.parameter((128, 256), (128, 64))
+
+    @tvm.testing.fixture
+    def group_size(self):
+        return 64
+
+    @tvm.testing.fixture
+    def axis(self):
+        return 1
+
+    @tvm.testing.fixture
+    def model_dtype(self):
+        return "float16"
+
+    @tvm.testing.fixture
+    def storage_dtype(self):
+        return "uint32"
+
+    @tvm.testing.fixture
+    def quantize_dtype(self):
+        return "e4m3_float8"
+
+    @tvm.testing.fixture
+    def num_el_per_storage(self):
+        return 4
+
+    @tvm.testing.fixture
+    def max_int_value(self):
+        return 448
+
+    @tvm.testing.fixture
+    def target_str(self):
+        return "cuda"
+
+    @tvm.testing.fixture
+    def scale_shape(self, weight_shape, group_size, axis):
+        return [
+            (d + group_size - 1) // group_size if axis == i else d
+            for i, d in enumerate(weight_shape)
+        ]
+
+    @tvm.testing.fixture
+    def quant_weight_shape(self, weight_shape, num_el_per_storage, axis):
+        return [
+            (d + num_el_per_storage - 1) // num_el_per_storage if axis == i 
else d
+            for i, d in enumerate(weight_shape)
+        ]
+
+    @tvm.testing.fixture
+    def compiled_functions(
+        self,
+        weight_shape,
+        scale_shape,
+        quant_weight_shape,
+        model_dtype,
+        quantize_dtype,
+        storage_dtype,
+        group_size,
+        num_el_per_storage,
+        max_int_value,
+        axis,
+        target_str,
+    ):
+        dev = tvm.device(target_str, 0)
+        return self.compile_quant_and_dequant_by_scale(
+            weight_shape,
+            scale_shape,
+            quant_weight_shape,
+            model_dtype,
+            quantize_dtype,
+            storage_dtype,
+            group_size,
+            num_el_per_storage,
+            max_int_value,
+            axis,
+            target_str,
+            dev,
+        )
+
+    @tvm.testing.requires_cuda_compute_version(9)
+    def test_main(self, weight_shape, model_dtype, target_str, 
compiled_functions):
+        quant, dequant = compiled_functions
+        dev = tvm.device(target_str, 0)
+
+        weight_np = np.random.uniform(-100, 100, 
weight_shape).astype(model_dtype)
+        weight = tvm.nd.array(weight_np, device=dev)
+        quant_weight, scales = quant(weight)
+        quant_weight_np, scales_np = quant_weight.numpy(), scales.numpy()
+
+        dequant_weight = dequant(quant_weight, scales)
+        dequant_weight_np = dequant_weight.numpy()
+        tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, 
rtol=5e-2)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to