This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fd9c091097 [CUTLASS] Add GeMM kernels for Blackwell GPUs (#18033)
fd9c091097 is described below

commit fd9c091097ea3a6b421a0e526185f017e6931141
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Jun 6 06:43:39 2025 -0400

    [CUTLASS] Add GeMM kernels for Blackwell GPUs (#18033)
    
    This PR introduces CUTLASS gemm kernels, groupwise-scaled gemm
    kernels and group gemm kernels for Blackwell GPUs.
    
    Files are reorganized a bit so that the exposed global functions
    are now architecture agnostic.  Prior to this PR, our global
    function names for CUTLASS kernels usually end with `"_sm90"`,
    which brings extra complexity when the frontend compiler decides
    to dispatch kernels when there are multiple supported architectures,
    such as Hopper and Blackwell.
    
    Therefore, this PR renames those global function so that the
    function names are arch agnostic. During the build time, only
    the kernels that the specific architecture supports will be built.
---
 3rdparty/cutlass                                   |   2 +-
 cmake/modules/contrib/CUTLASS.cmake                |  16 +-
 .../{fp16_group_gemm.cu => fp16_group_gemm.cuh}    |  51 ++---
 ...runner.cuh => fp16_group_gemm_runner_sm100.cuh} |  79 +++++---
 ..._runner.cuh => fp16_group_gemm_runner_sm90.cuh} |  10 +-
 .../contrib/cutlass/fp16_group_gemm_sm100.cu       |  54 +++++
 ...{fp16_group_gemm.cu => fp16_group_gemm_sm90.cu} |  53 +++--
 .../contrib/cutlass/fp8_blockwise_scaled_gemm.cu   | 164 ---------------
 .../{fp8_group_gemm.cu => fp8_group_gemm_sm90.cu}  |   2 +-
 .../contrib/cutlass/fp8_groupwise_scaled_gemm.cuh  | 172 ++++++++++++++++
 .../fp8_groupwise_scaled_gemm_runner_sm100.cuh     | 155 +++++++++++++++
 ...h => fp8_groupwise_scaled_gemm_runner_sm90.cuh} |  53 +----
 .../cutlass/fp8_groupwise_scaled_gemm_sm100.cu     |  77 ++++++++
 .../cutlass/fp8_groupwise_scaled_gemm_sm90.cu      |  77 ++++++++
 ...p8_groupwise_scaled_group_gemm_runner_sm100.cuh | 220 +++++++++++++++++++++
 .../fp8_groupwise_scaled_group_gemm_sm100.cu       |  93 +++++++++
 src/target/tag.cc                                  |   2 +
 tests/python/contrib/test_cutlass_gemm.py          |  32 ++-
 18 files changed, 1001 insertions(+), 311 deletions(-)

diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index afa1772203..ad7b2f5e84 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008
+Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e
diff --git a/cmake/modules/contrib/CUTLASS.cmake 
b/cmake/modules/contrib/CUTLASS.cmake
index d11777e851..b74ce4c8df 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -58,19 +58,27 @@ if(USE_CUDA AND USE_CUTLASS)
   set(TVM_CUTLASS_RUNTIME_SRCS "")
 
   if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
-    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp16_group_gemm.cu)
-    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_group_gemm.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu)
     list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_gemm.cu)
-    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu)
+  endif()
+  if (CMAKE_CUDA_ARCHITECTURES MATCHES "100a")
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu)
+    list(APPEND TVM_CUTLASS_RUNTIME_SRCS 
src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu)
   endif()
   if(TVM_CUTLASS_RUNTIME_SRCS)
     add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
-    target_compile_options(tvm_cutlass_objs PRIVATE 
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
+    target_compile_options(tvm_cutlass_objs PRIVATE 
$<$<COMPILE_LANGUAGE:CUDA>:-lineinfo --expt-relaxed-constexpr>)
     target_include_directories(tvm_cutlass_objs PRIVATE
       ${CUTLASS_DIR}/include
       
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
     )
+    target_link_libraries(tvm_cutlass_objs PRIVATE tvm_ffi_header)
     target_compile_definitions(tvm_cutlass_objs PRIVATE 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+    # Note: enable this to get more detailed logs for cutlass kernels
+    # target_compile_definitions(tvm_cutlass_objs PRIVATE 
CUTLASS_DEBUG_TRACE_LEVEL=2)
     list(APPEND CUTLASS_RUNTIME_OBJS 
"$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
   endif()
 
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu 
b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
similarity index 56%
copy from src/runtime/contrib/cutlass/fp16_group_gemm.cu
copy to src/runtime/contrib/cutlass/fp16_group_gemm.cuh
index dffe7dc4ff..ebb8f58a6b 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
@@ -19,31 +19,24 @@
 
 #include <cuda_fp16.h>
 #include <float.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/ffi/function.h>
 #include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
 
-#include "group_gemm_runner.cuh"
-
-#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
-
-template <>
-struct KernelTraits<cutlass::half_t> {
-  using KernelSchedule = 
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
-  using TileShape = Shape<_128, _256, _64>;  // Threadblock-level tile size
-  using ClusterShape = Shape<_2, _2, _1>;    // Shape of the threadblocks in a 
cluster
-};
+#include "cutlass/bfloat16.h"
+#include "cutlass/half.h"
 
 namespace tvm {
 namespace runtime {
 
-template <typename ElementA, typename ElementB, typename ElementC>
-void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, 
NDArray workspace,
+template <int Arch, typename ElementA, typename ElementB, typename ElementC>
+struct CutlassGroupGemm;
+
+template <int Arch>
+void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, 
NDArray workspace,
                                  NDArray out) {
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommened size is 4MB.
   static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
   CHECK_EQ(x->ndim, 2);
   CHECK_EQ(weight->ndim, 3);
   CHECK_EQ(indptr->ndim, 1);
@@ -54,16 +47,26 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, 
NDArray indptr, NDAr
   int k = weight->shape[2];
   float alpha = 1.0f;
   float beta = 0.0f;
-  cutlass_group_gemm(static_cast<ElementA*>(x->data), 
static_cast<ElementB*>(weight->data),
-                     static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
-                     workspace->shape[0], n, k, num_groups, alpha, beta,
-                     static_cast<ElementC*>(out->data), stream);
-}
+  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
 
-TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
-    .set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t, 
cutlass::half_t, cutlass::half_t>);
+  if (DataType(x->dtype) == DataType::Float(16)) {
+    CHECK(DataType(weight->dtype) == DataType::Float(16));
+    CHECK(DataType(out->dtype) == DataType::Float(16));
+    using Dtype = cutlass::half_t;
+    CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
+        static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
+        static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0], n, k, num_groups, alpha, beta, 
static_cast<Dtype*>(out->data), stream);
+  } else if (DataType(x->dtype) == DataType::BFloat(16)) {
+    CHECK(DataType(weight->dtype) == DataType::BFloat(16));
+    CHECK(DataType(out->dtype) == DataType::BFloat(16));
+    using Dtype = cutlass::bfloat16_t;
+    CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
+        static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
+        static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0], n, k, num_groups, alpha, beta, 
static_cast<Dtype*>(out->data), stream);
+  }
+}
 
 }  // namespace runtime
 }  // namespace tvm
-
-#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
similarity index 73%
copy from src/runtime/contrib/cutlass/group_gemm_runner.cuh
copy to src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
index a3c52e27a9..f38664915d 100644
--- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
@@ -54,10 +54,25 @@ inline size_t aligned(size_t value, size_t alignment = 16) {
   return (value + alignment - 1) / alignment * alignment;
 }
 
-template <typename T>
-struct KernelTraits;
+template <typename ElementA>
+struct MMA1SMConfig {
+  using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(ElementA)>>;
+  using ClusterShape = Shape<_2, _2, _1>;
+  using KernelSchedule =
+      cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;                
// Kernel to launch
+  using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;  
// Epilogue to launch
+};
+
+template <typename ElementA>
+struct MMA2SMConfig {
+  using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>;
+  using ClusterShape = Shape<_2, _2, _1>;
+  using KernelSchedule =
+      cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;                
// Kernel to launch
+  using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;  
// Epilogue to launch
+};
 
-template <typename ElementA, typename ElementB, typename ElementC,
+template <typename ScheduleConfig, typename ElementA, typename ElementB, 
typename ElementC,
           typename LayoutA = cutlass::layout::RowMajor,
           typename LayoutB = cutlass::layout::ColumnMajor,
           typename LayoutC = cutlass::layout::RowMajor>
@@ -77,28 +92,25 @@ struct CutlassGroupGemmRunner {
   // Core kernel configurations
   using ElementAccumulator = float;  // Element type for internal accumulation
   using ScaleType = std::variant<ElementAccumulator, const 
ElementAccumulator*>;
-  using ArchTag =
-      cutlass::arch::Sm90;  // Tag indicating the minimum SM that supports the 
intended feature
   using OperatorClass = cutlass::arch::OpClassTensorOp;  // Operator class tag
-  using TileShape = typename KernelTraits<ElementA>::TileShape;
-  using ClusterShape = typename KernelTraits<ElementA>::ClusterShape;
   using StageCountType =
       cutlass::gemm::collective::StageCountAuto;  // Stage count maximized 
based on the tile size
-  using KernelSchedule = typename KernelTraits<ElementA>::KernelSchedule;     
// Kernel to launch
-  using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;  
// Epilogue to launch
 
+  // Different configs for 1SM and 2SM MMA kernel
   using CollectiveEpilogue = typename 
cutlass::epilogue::collective::CollectiveBuilder<
-      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, 
ClusterShape,
-      cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, 
ElementAccumulator,
-      ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
-      EpilogueSchedule>::CollectiveOp;
+      cutlass::arch::Sm100, OperatorClass, typename 
ScheduleConfig::MmaTileShape,
+      typename ScheduleConfig::ClusterShape, 
cutlass::epilogue::collective::EpilogueTileAuto,
+      ElementAccumulator, ElementAccumulator, ElementC, LayoutC*, AlignmentC, 
ElementC, LayoutC*,
+      AlignmentC, typename ScheduleConfig::EpilogueSchedule,
+      cutlass::epilogue::fusion::LinearCombination<ElementC, 
ElementAccumulator>>::CollectiveOp;
 
   using CollectiveMainloop = typename 
cutlass::gemm::collective::CollectiveBuilder<
-      ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, 
LayoutB*, AlignmentB,
-      ElementAccumulator, TileShape, ClusterShape,
+      cutlass::arch::Sm100, OperatorClass, ElementA, LayoutA*, AlignmentA, 
ElementB, LayoutB*,
+      AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
+      typename ScheduleConfig::ClusterShape,
       cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
           sizeof(typename CollectiveEpilogue::SharedStorage))>,
-      KernelSchedule>::CollectiveOp;
+      typename ScheduleConfig::KernelSchedule>::CollectiveOp;
 
   using GemmKernel =
       cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, 
CollectiveEpilogue>;
@@ -117,14 +129,16 @@ struct CutlassGroupGemmRunner {
                       StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, 
StrideD* stride_D,
                       uint8_t* workspace, int64_t workspace_size, int 
num_groups, ScaleType alpha,
                       ScaleType beta, cudaStream_t stream) {
-    typename Gemm::EpilogueOutputOp::Params epilogue_params = [&]() {
+    typename Gemm::Arguments arguments;
+    decltype(arguments.epilogue.thread) fusion_args;
+    [&]() {
       ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the 
same type";
       if (std::holds_alternative<ElementAccumulator>(alpha)) {
-        return typename 
Gemm::EpilogueOutputOp::Params{std::get<ElementAccumulator>(alpha),
-                                                       
std::get<ElementAccumulator>(beta)};
+        fusion_args.alpha = std::get<ElementAccumulator>(alpha);
+        fusion_args.beta = std::get<ElementAccumulator>(beta);
       } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
-        return typename Gemm::EpilogueOutputOp::Params{std::get<const 
ElementAccumulator*>(alpha),
-                                                       std::get<const 
ElementAccumulator*>(beta)};
+        fusion_args.alpha_ptr = std::get<const ElementAccumulator*>(alpha);
+        fusion_args.beta_ptr = std::get<const ElementAccumulator*>(beta);
       } else {
         LOG(FATAL) << "Unsupported alpha and beta type";
         throw;
@@ -135,11 +149,11 @@ struct CutlassGroupGemmRunner {
     hw_info.device_id = 0;
     hw_info.sm_count =
         
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
-    typename Gemm::Arguments 
arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
-                                       {num_groups, problem_sizes, 
problem_sizes_host},
-                                       {ptr_A, stride_A, ptr_B, stride_B},
-                                       {epilogue_params, ptr_C, stride_C, 
ptr_D, stride_D},
-                                       hw_info};
+    arguments = typename 
Gemm::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
+                                         {num_groups, problem_sizes, 
problem_sizes_host},
+                                         {ptr_A, stride_A, ptr_B, stride_B},
+                                         {fusion_args, ptr_C, stride_C, ptr_D, 
stride_D},
+                                         hw_info};
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
     CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
@@ -169,12 +183,13 @@ __global__ void prepare_group_gemm_arguments(
 }
 
 template <typename ElementA, typename ElementB, typename ElementC>
-void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, 
uint8_t* workspace,
-                        int64_t workspace_size, int64_t n, int64_t k, int64_t 
num_groups,
-                        std::variant<float, const float*> alpha,
-                        std::variant<float, const float*> beta, ElementC* out,
-                        cudaStream_t stream) {
-  using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
+void cutlass_group_gemm_sm100(ElementA* x, ElementB* weight, int64_t* indptr, 
uint8_t* workspace,
+                              int64_t workspace_size, int64_t n, int64_t k, 
int64_t num_groups,
+                              std::variant<float, const float*> alpha,
+                              std::variant<float, const float*> beta, 
ElementC* out,
+                              cudaStream_t stream) {
+  // Note: We use MMA2SMConfig for now. It can be changed to MMA1SMConfig if 
needed.
+  using Runner = CutlassGroupGemmRunner<MMA2SMConfig<ElementA>, ElementA, 
ElementB, ElementC>;
   using StrideA = typename Runner::StrideA;
   using StrideB = typename Runner::StrideB;
   using StrideC = typename Runner::StrideC;
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
similarity index 96%
rename from src/runtime/contrib/cutlass/group_gemm_runner.cuh
rename to src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
index a3c52e27a9..38e1beb2b8 100644
--- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
@@ -169,11 +169,11 @@ __global__ void prepare_group_gemm_arguments(
 }
 
 template <typename ElementA, typename ElementB, typename ElementC>
-void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, 
uint8_t* workspace,
-                        int64_t workspace_size, int64_t n, int64_t k, int64_t 
num_groups,
-                        std::variant<float, const float*> alpha,
-                        std::variant<float, const float*> beta, ElementC* out,
-                        cudaStream_t stream) {
+void cutlass_group_gemm_sm90(ElementA* x, ElementB* weight, int64_t* indptr, 
uint8_t* workspace,
+                             int64_t workspace_size, int64_t n, int64_t k, 
int64_t num_groups,
+                             std::variant<float, const float*> alpha,
+                             std::variant<float, const float*> beta, ElementC* 
out,
+                             cudaStream_t stream) {
   using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
   using StrideA = typename Runner::StrideA;
   using StrideB = typename Runner::StrideB;
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu 
b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu
new file mode 100644
index 0000000000..29efcbe088
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+
+#include "fp16_group_gemm.cuh"
+#include "fp16_group_gemm_runner_sm100.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
+
+namespace tvm {
+namespace runtime {
+
+template <typename ElementA, typename ElementB, typename ElementC>
+struct CutlassGroupGemm<100, ElementA, ElementB, ElementC> {
+  static void run(ElementA* A, ElementB* B, int64_t* indptr, uint8_t* 
workspace, int workspace_size,
+                  int N, int K, int num_groups, float alpha, float beta, 
ElementC* C,
+                  cudaStream_t stream) {
+    cutlass_group_gemm_sm100<ElementA, ElementB, ElementC>(
+        A, B, indptr, workspace, workspace_size, N, K, num_groups, alpha, 
beta, C, stream);
+  }
+};
+
+void tvm_cutlass_group_gemm_sm100(NDArray x, NDArray weight, NDArray indptr, 
NDArray workspace,
+                                  NDArray out) {
+  tvm_cutlass_group_gemm_impl<100>(x, weight, indptr, workspace, out);
+}
+
+TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm100);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_SM100_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu 
b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu
similarity index 60%
rename from src/runtime/contrib/cutlass/fp16_group_gemm.cu
rename to src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu
index dffe7dc4ff..93a03a0675 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu
@@ -19,14 +19,28 @@
 
 #include <cuda_fp16.h>
 #include <float.h>
-#include <tvm/runtime/ndarray.h>
 #include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
 #include <tvm/ffi/function.h>
 
-#include "group_gemm_runner.cuh"
+#include "fp16_group_gemm.cuh"
+#include "fp16_group_gemm_runner_sm90.cuh"
+
+namespace tvm {
+namespace runtime {
 
 #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
 
+template <typename ElementA, typename ElementB, typename ElementC>
+struct CutlassGroupGemm<90, ElementA, ElementB, ElementC> {
+  static void run(ElementA* A, ElementB* B, int64_t* indptr, uint8_t* 
workspace, int workspace_size,
+                  int N, int K, int num_groups, float alpha, float beta, 
ElementC* C,
+                  cudaStream_t stream) {
+    cutlass_group_gemm_sm90<ElementA, ElementB, ElementC>(A, B, indptr, 
workspace, workspace_size,
+                                                          N, K, num_groups, 
alpha, beta, C, stream);
+  }
+};
+
 template <>
 struct KernelTraits<cutlass::half_t> {
   using KernelSchedule = 
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
@@ -34,36 +48,21 @@ struct KernelTraits<cutlass::half_t> {
   using ClusterShape = Shape<_2, _2, _1>;    // Shape of the threadblocks in a 
cluster
 };
 
-namespace tvm {
-namespace runtime {
+template <>
+struct KernelTraits<cutlass::bfloat16_t> {
+  using KernelSchedule = 
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
+  using TileShape = Shape<_128, _256, _64>;  // Threadblock-level tile size
+  using ClusterShape = Shape<_2, _2, _1>;    // Shape of the threadblocks in a 
cluster
+};
 
-template <typename ElementA, typename ElementB, typename ElementC>
 void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, 
NDArray workspace,
                                  NDArray out) {
-  // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
-  // Recommened size is 4MB.
-  static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
-  CHECK_EQ(x->ndim, 2);
-  CHECK_EQ(weight->ndim, 3);
-  CHECK_EQ(indptr->ndim, 1);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, 2);
-  int num_groups = weight->shape[0];
-  int n = weight->shape[1];
-  int k = weight->shape[2];
-  float alpha = 1.0f;
-  float beta = 0.0f;
-  cutlass_group_gemm(static_cast<ElementA*>(x->data), 
static_cast<ElementB*>(weight->data),
-                     static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
-                     workspace->shape[0], n, k, num_groups, alpha, beta,
-                     static_cast<ElementC*>(out->data), stream);
+  tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out);
 }
 
-TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
-    .set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t, 
cutlass::half_t, cutlass::half_t>);
+TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm90);
+
+#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
 
 }  // namespace runtime
 }  // namespace tvm
-
-#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
deleted file mode 100644
index 5164958afe..0000000000
--- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <cuda_fp16.h>
-#include <float.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/function.h>
-
-#include "../cublas/cublas_utils.h"
-#include "blockwise_scaled_gemm_runner.cuh"
-
-#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
-
-namespace tvm {
-namespace runtime {
-
-void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray 
scales_a, NDArray scales_b,
-                                           NDArray workspace, int64_t 
block_size_0,
-                                           int64_t block_size_1, NDArray out) {
-  using TileShape = Shape<_128, _128, _128>;
-  using ClusterShape = Shape<_1, _1, _1>;
-
-  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
-  // Recommened size is 4MB.
-  const auto get_stream_func = 
tvm::ffi::Function::GetGlobal("runtime.get_cuda_stream");
-  ICHECK(get_stream_func.has_value());
-  cudaStream_t stream = 
static_cast<cudaStream_t>((*get_stream_func)().cast<void*>());
-
-  CHECK_GE(a->ndim, 2);
-  CHECK_EQ(scales_a->ndim, a->ndim);
-  CHECK_EQ(b->ndim, 2);
-  CHECK_EQ(scales_b->ndim, 2);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, a->ndim);
-  int64_t m = 1;
-  for (int64_t i = 0; i < a->ndim - 1; ++i) {
-    m *= a->shape[i];
-  }
-  int64_t n = b->shape[0];
-  CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is 
supported now.";
-  int64_t k = a->shape[a->ndim - 1];
-
-  // scales_a is col-major of (*a_shape[:-1], k / block_size)
-  CHECK_EQ(scales_a->shape[0] * block_size_1, k);
-  for (int64_t i = 1; i < scales_a->ndim; ++i) {
-    CHECK_EQ(scales_a->shape[i], a->shape[i - 1]);
-  }
-  // scales_b is col-major of (k / block_size, n / block_size)
-  CHECK_EQ(scales_b->shape[0] * block_size_0, n);
-  CHECK_EQ(scales_b->shape[1] * block_size_1, k);
-
-  using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
-
-  if (DataType(out->dtype) == DataType::Float(16)) {
-    cutlass_fp8_blockwise_scaled_gemm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
-                                      cutlass::float_e4m3_t, cutlass::half_t, 
float>(
-        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
-        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
-        static_cast<cutlass::half_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
-        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
stream);
-  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
-    cutlass_fp8_blockwise_scaled_gemm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
-                                      cutlass::float_e4m3_t, 
cutlass::bfloat16_t, float>(
-        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
-        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
-        static_cast<cutlass::bfloat16_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
-        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
stream);
-  } else {
-    LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
-  }
-}
-
-void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray 
scales_a, NDArray scales_b,
-                                          NDArray workspace, int64_t 
block_size_0,
-                                          int64_t block_size_1, NDArray out) {
-  using TileShape = Shape<_128, _128, _128>;
-  using ClusterShape = Shape<_1, _1, _1>;
-
-  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
-  // Recommened size is 4MB.
-  const auto get_stream_func = 
tvm::ffi::Function::GetGlobal("runtime.get_cuda_stream");
-  ICHECK(get_stream_func.has_value());
-  cudaStream_t stream = 
static_cast<cudaStream_t>((*get_stream_func)().cast<void*>());
-
-  CHECK_EQ(a->ndim, 3);
-  CHECK_EQ(scales_a->ndim, 3);
-  CHECK_EQ(b->ndim, 3);
-  CHECK_EQ(scales_b->ndim, 3);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, 3);
-  int64_t batch_size = a->shape[0];
-  int64_t m = a->shape[1];
-  int64_t n = b->shape[1];
-  CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now.";
-  int64_t k = a->shape[2];
-  CHECK_EQ(b->shape[0], batch_size);
-  CHECK_EQ(scales_a->shape[0], batch_size);
-  CHECK_EQ(scales_b->shape[0], batch_size);
-  CHECK_EQ(out->shape[0], batch_size);
-
-  // scales_a is col-major of (batch_size, m, k / block_size)
-  CHECK_EQ(scales_a->shape[1] * block_size_1, k);
-  CHECK_EQ(scales_a->shape[2], m);
-  // scales_b is col-major of (k / block_size, n / block_size)
-  CHECK_EQ(scales_b->shape[1] * block_size_0, n);
-  CHECK_EQ(scales_b->shape[2] * block_size_1, k);
-
-  using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
-
-  if (DataType(out->dtype) == DataType::Float(16)) {
-    cutlass_fp8_blockwise_scaled_bmm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
-                                     cutlass::float_e4m3_t, cutlass::half_t, 
float>(
-        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
-        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
-        static_cast<cutlass::half_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
-        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
batch_size, stream);
-  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
-    cutlass_fp8_blockwise_scaled_bmm<TileShape, ClusterShape, 
cutlass::float_e4m3_t,
-                                     cutlass::float_e4m3_t, 
cutlass::bfloat16_t, float>(
-        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
-        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
-        static_cast<cutlass::bfloat16_t*>(out->data), 
static_cast<uint8_t*>(workspace->data),
-        workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, 
batch_size, stream);
-  } else {
-    LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
-  }
-}
-
-TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn")
-    .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm);
-TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn")
-    .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm);
-
-}  // namespace runtime
-}  // namespace tvm
-
-#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
similarity index 98%
rename from src/runtime/contrib/cutlass/fp8_group_gemm.cu
rename to src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
index 62a91dec18..686a6ebcff 100644
--- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
@@ -23,7 +23,7 @@
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/function.h>
 
-#include "group_gemm_runner.cuh"
+#include "fp16_group_gemm_runner_sm90.cuh"
 
 #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
 
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
new file mode 100644
index 0000000000..4ecca5f1d8
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "cutlass/bfloat16.h"
+#include "cutlass/half.h"
+
+namespace tvm {
+namespace runtime {
+
+template <int Arch, typename TileShape, typename ClusterShape, typename 
ElementA, typename ElementB,
+          typename ElementC, typename ElementBlockScale>
+struct CutlassFP8GroupwiseGemm;
+
+template <int Arch, typename TileShape, typename ClusterShape>
+void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, NDArray b, NDArray 
scales_a,
+                                                NDArray scales_b, NDArray 
workspace,
+                                                int64_t block_size_0, int64_t 
block_size_1,
+                                                NDArray out) {
+  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
+  // Recommened size is 4MB.
+  static tvm::ffi::Function get_stream_func =
+      tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
+  cudaStream_t stream = 
static_cast<cudaStream_t>(get_stream_func().cast<void*>());
+
+  CHECK_GE(a->ndim, 2);
+  CHECK_EQ(scales_a->ndim, a->ndim);
+  CHECK_EQ(b->ndim, 2);
+  CHECK_EQ(scales_b->ndim, 2);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, a->ndim);
+  int64_t m = 1;
+  for (int64_t i = 0; i < a->ndim - 1; ++i) {
+    m *= a->shape[i];
+  }
+  int64_t n = b->shape[0];
+  CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is 
supported now.";
+  int64_t k = a->shape[a->ndim - 1];
+
+  // scales_a is col-major of (*a_shape[:-1], k / block_size)
+  CHECK_EQ(scales_a->shape[0] * block_size_1, k);
+  for (int64_t i = 1; i < scales_a->ndim; ++i) {
+    CHECK_EQ(scales_a->shape[i], a->shape[i - 1]);
+  }
+  // scales_b is col-major of (k / block_size, n / block_size)
+  CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]);
+  CHECK_EQ(scales_b->shape[1] * block_size_1, k);
+
+  using tvm::runtime::DataType;
+  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+  if (DataType(out->dtype) == DataType::Float(16)) {
+    CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                            cutlass::float_e4m3_t, cutlass::half_t,
+                            
float>::run(static_cast<cutlass::float_e4m3_t*>(a->data),
+                                        
static_cast<cutlass::float_e4m3_t*>(b->data),
+                                        static_cast<float*>(scales_a->data),
+                                        static_cast<float*>(scales_b->data),
+                                        
static_cast<cutlass::half_t*>(out->data),
+                                        static_cast<uint8_t*>(workspace->data),
+                                        workspace->shape[0] * 
DataType(workspace->dtype).bytes(), m,
+                                        n, k, 1, stream);
+  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+    CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                            cutlass::float_e4m3_t, cutlass::bfloat16_t,
+                            
float>::run(static_cast<cutlass::float_e4m3_t*>(a->data),
+                                        
static_cast<cutlass::float_e4m3_t*>(b->data),
+                                        static_cast<float*>(scales_a->data),
+                                        static_cast<float*>(scales_b->data),
+                                        
static_cast<cutlass::bfloat16_t*>(out->data),
+                                        static_cast<uint8_t*>(workspace->data),
+                                        workspace->shape[0] * 
DataType(workspace->dtype).bytes(), m,
+                                        n, k, 1, stream);
+  } else {
+    LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
+  }
+}
+
+template <int Arch, typename TileShape, typename ClusterShape>
+void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(NDArray a, NDArray b, NDArray 
scales_a,
+                                               NDArray scales_b, NDArray 
workspace,
+                                               int64_t block_size_0, int64_t 
block_size_1,
+                                               NDArray out) {
+  // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
+  // Recommened size is 4MB.
+  static tvm::ffi::Function get_stream_func =
+      tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
+  cudaStream_t stream = 
static_cast<cudaStream_t>(get_stream_func().cast<void*>());
+
+  CHECK_EQ(a->ndim, 3);
+  CHECK_EQ(scales_a->ndim, 3);
+  CHECK_EQ(b->ndim, 3);
+  CHECK_EQ(scales_b->ndim, 3);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, 3);
+  int64_t batch_size = a->shape[0];
+  int64_t m = a->shape[1];
+  int64_t n = b->shape[1];
+  CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now.";
+  int64_t k = a->shape[2];
+  CHECK_EQ(b->shape[0], batch_size);
+  CHECK_EQ(scales_a->shape[0], batch_size);
+  CHECK_EQ(scales_b->shape[0], batch_size);
+  CHECK_EQ(out->shape[0], batch_size);
+
+  // scales_a is col-major of (batch_size, m, k / block_size)
+  CHECK_EQ(scales_a->shape[1] * block_size_1, k);
+  CHECK_EQ(scales_a->shape[2], m);
+  // scales_b is col-major of (k / block_size, n / block_size)
+  CHECK_EQ(scales_b->shape[1] * block_size_0, n);
+  CHECK_EQ(scales_b->shape[2] * block_size_1, k);
+
+  using tvm::runtime::DataType;
+  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+  if (DataType(out->dtype) == DataType::Float(16)) {
+    CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                            cutlass::float_e4m3_t, cutlass::half_t,
+                            
float>::run(static_cast<cutlass::float_e4m3_t*>(a->data),
+                                        
static_cast<cutlass::float_e4m3_t*>(b->data),
+                                        static_cast<float*>(scales_a->data),
+                                        static_cast<float*>(scales_b->data),
+                                        
static_cast<cutlass::half_t*>(out->data),
+                                        static_cast<uint8_t*>(workspace->data),
+                                        workspace->shape[0] * 
DataType(workspace->dtype).bytes(), m,
+                                        n, k, batch_size, stream);
+  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+    CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape, 
cutlass::float_e4m3_t,
+                            cutlass::float_e4m3_t, cutlass::bfloat16_t,
+                            
float>::run(static_cast<cutlass::float_e4m3_t*>(a->data),
+                                        
static_cast<cutlass::float_e4m3_t*>(b->data),
+                                        static_cast<float*>(scales_a->data),
+                                        static_cast<float*>(scales_b->data),
+                                        
static_cast<cutlass::bfloat16_t*>(out->data),
+                                        static_cast<uint8_t*>(workspace->data),
+                                        workspace->shape[0] * 
DataType(workspace->dtype).bytes(), m,
+                                        n, k, batch_size, stream);
+  } else {
+    LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype);
+  }
+}
+
+}  // namespace runtime
+}  // namespace tvm
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
new file mode 100644
index 0000000000..95fc578fd4
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <type_traits>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_grouped.h"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/group_array_problem_shape.hpp"
+#include "cutlass/gemm/kernel/default_gemm_grouped.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+#include "cutlass/layout/matrix.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/tensor_ref.h"
+// clang-format on
+
+#define CUTLASS_CHECK(status)                                      \
+  {                                                                \
+    cutlass::Status error = status;                                \
+    CHECK(error == cutlass::Status::kSuccess)                      \
+        << "Got cutlass error: " << cutlassGetStatusString(error); \
+  }
+
+using namespace cute;
+using tvm::runtime::NDArray;
+
+template <typename TileShape, typename ClusterShape, typename ElementD>
+struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 {
+  using ElementA = cutlass::float_e4m3_t;
+  using LayoutA = cutlass::layout::RowMajor;
+  static constexpr int AlignmentA = 128 / 
cutlass::sizeof_bits<ElementA>::value;
+
+  using ElementB = cutlass::float_e4m3_t;
+  using LayoutB = cutlass::layout::ColumnMajor;
+  static constexpr int AlignmentB = 128 / 
cutlass::sizeof_bits<ElementB>::value;
+
+  using ElementC = void;
+  using LayoutC = cutlass::layout::RowMajor;
+  static constexpr int AlignmentC = 128 / 
cutlass::sizeof_bits<ElementD>::value;
+
+  using LayoutD = LayoutC;
+  static constexpr int AlignmentD = 128 / 
cutlass::sizeof_bits<ElementD>::value;
+
+  // MMA type
+  using ElementAccumulator = float;  // Element Accumulator will also be our 
scale factor type
+  using ElementCompute = float;
+  using ElementBlockScale = float;
+
+  static constexpr int ScaleGranularityM = 1;
+  static constexpr int ScaleGranularityN = 128;
+  static constexpr int ScaleGranularityK = 128;
+  using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
+      ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, 
UMMA::Major::MN, UMMA::Major::K>;
+
+  using LayoutSFA =
+      decltype(ScaleConfig::deduce_layoutSFA());  // Layout type for SFA 
matrix operand
+  using LayoutSFB =
+      decltype(ScaleConfig::deduce_layoutSFB());  // Layout type for SFB 
matrix operand
+  using CollectiveEpilogue = typename 
cutlass::epilogue::collective::CollectiveBuilder<
+      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, 
ClusterShape,
+      cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, 
ElementCompute, ElementC,
+      LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD,
+      cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
+
+  using CollectiveMainloop = typename 
cutlass::gemm::collective::CollectiveBuilder<
+      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA,
+      cute::tuple<LayoutA, LayoutSFA>, AlignmentA, ElementB, 
cute::tuple<LayoutB, LayoutSFB>,
+      AlignmentB, ElementAccumulator, TileShape, ClusterShape,
+      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+          sizeof(typename CollectiveEpilogue::SharedStorage))>,
+      cutlass::gemm::KernelScheduleSm100Blockwise>::CollectiveOp;
+
+  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
+      Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
+      void>;  // Default to ClusterLaunchControl (CLC) based tile scheduler
+
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+  using StrideA = typename Gemm::GemmKernel::StrideA;
+  using StrideB = typename Gemm::GemmKernel::StrideB;
+  using StrideD = typename Gemm::GemmKernel::StrideD;
+
+  void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const 
ElementBlockScale* scales_a_ptr,
+                const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, int m, 
int n, int k, int l,
+                uint8_t* workspace, int64_t workspace_size, cudaStream_t 
stream) {
+    cutlass::KernelHardwareInfo hw_info;
+    hw_info.device_id = 0;
+    hw_info.sm_count =
+        
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+
+    StrideA stride_a =
+        cute::make_stride(static_cast<int64_t>(k), Int<1>{}, 
static_cast<int64_t>(m * k));
+    StrideB stride_b =
+        cute::make_stride(static_cast<int64_t>(k), Int<1>{}, 
static_cast<int64_t>(n * k));
+    StrideD stride_d =
+        cute::make_stride(static_cast<int64_t>(n), Int<1>{}, 
static_cast<int64_t>(m * n));
+    auto layout_scales_a = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, 
n, k, l));
+    auto layout_scales_b = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, 
n, k, l));
+
+    typename Gemm::Arguments arguments = 
{cutlass::gemm::GemmUniversalMode::kGemm,
+                                          {m, n, k, l},
+                                          {a_ptr, stride_a, b_ptr, stride_b, 
scales_a_ptr,
+                                           layout_scales_a, scales_b_ptr, 
layout_scales_b},
+                                          {{}, o_ptr, stride_d, o_ptr, 
stride_d},
+                                          hw_info};
+
+    Gemm gemm_op;
+    CUTLASS_CHECK(gemm_op.can_implement(arguments));
+    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+    CUTLASS_CHECK(gemm_op.run(stream));
+  }
+};
+
+template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
+          typename ElementD, typename ElementBlockScale>
+void cutlass_fp8_groupwise_scaled_mm_sm100(ElementA* a, ElementB* b, 
ElementBlockScale* scales_a,
+                                           ElementBlockScale* scales_b, 
ElementD* out,
+                                           uint8_t* workspace, int64_t 
workspace_size, int64_t m,
+                                           int64_t n, int64_t k, int64_t l, 
cudaStream_t stream) {
+  using Runner = CutlassFP8ScaledGroupwiseGemmRunnerSM100<TileShape, 
ClusterShape, ElementD>;
+  Runner runner;
+  runner.run_gemm(a, b, scales_a, scales_b, out, m, n, k, l, workspace, 
workspace_size, stream);
+}
diff --git a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
similarity index 75%
rename from src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
rename to src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
index f520bf815a..5ec9ed0839 100644
--- a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
@@ -58,7 +58,7 @@ using tvm::runtime::NDArray;
 
 template <typename TileShape, typename ClusterShape, typename ElementD, 
typename SchedulerType,
           int ScaleGranularityM = 1>
-struct CutlassFP8ScaledBlockwiseGemmRunner {
+struct CutlassFP8GroupwiseScaledGemmRunner {
   using ElementAccumulator = float;
   using ElementCompute = float;
   using ElementBlockScale = float;
@@ -149,53 +149,14 @@ struct CutlassFP8ScaledBlockwiseGemmRunner {
 
 template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
           typename ElementD, typename ElementBlockScale>
-void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, 
ElementBlockScale* scales_a,
-                                       ElementBlockScale* scales_b, ElementD* 
out,
-                                       uint8_t* workspace, int64_t 
workspace_size, int64_t m,
-                                       int64_t n, int64_t k, cudaStream_t 
stream) {
+void cutlass_fp8_groupwise_scaled_mm_sm90(ElementA* a, ElementB* b, 
ElementBlockScale* scales_a,
+                                          ElementBlockScale* scales_b, 
ElementD* out,
+                                          uint8_t* workspace, int64_t 
workspace_size, int64_t m,
+                                          int64_t n, int64_t k, int64_t l, 
cudaStream_t stream) {
   if (k > 3 * n) {
     using SchedulerType = cutlass::gemm::StreamKScheduler;
     using Runner =
-        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
-    using StrideA = typename Runner::StrideA;
-    using StrideB = typename Runner::StrideB;
-    using StrideD = typename Runner::StrideD;
-
-    Runner runner;
-    StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
-    StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
-    StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
-    ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k), 1};
-    runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, 
&stride_b, &stride_d,
-                    workspace, workspace_size, stream);
-  } else {
-    using SchedulerType = cutlass::gemm::PersistentScheduler;
-    using Runner =
-        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
-    using StrideA = typename Runner::StrideA;
-    using StrideB = typename Runner::StrideB;
-    using StrideD = typename Runner::StrideD;
-
-    Runner runner;
-    StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
-    StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
-    StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
-    ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k), 1};
-    runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, 
&stride_b, &stride_d,
-                    workspace, workspace_size, stream);
-  }
-}
-
-template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
-          typename ElementD, typename ElementBlockScale>
-void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, 
ElementBlockScale* scales_a,
-                                      ElementBlockScale* scales_b, ElementD* 
out,
-                                      uint8_t* workspace, int64_t 
workspace_size, int64_t m,
-                                      int64_t n, int64_t k, int64_t l, 
cudaStream_t stream) {
-  if (k > 3 * n) {
-    using SchedulerType = cutlass::gemm::StreamKScheduler;
-    using Runner =
-        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
+        CutlassFP8GroupwiseScaledGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
     using StrideA = typename Runner::StrideA;
     using StrideB = typename Runner::StrideB;
     using StrideD = typename Runner::StrideD;
@@ -211,7 +172,7 @@ void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, 
ElementB* b, ElementBlockScal
   } else {
     using SchedulerType = cutlass::gemm::PersistentScheduler;
     using Runner =
-        CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
+        CutlassFP8GroupwiseScaledGemmRunner<TileShape, ClusterShape, ElementD, 
SchedulerType>;
     using StrideA = typename Runner::StrideA;
     using StrideB = typename Runner::StrideB;
     using StrideD = typename Runner::StrideD;
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu
new file mode 100644
index 0000000000..ffa3ae6653
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+
+#include "../cublas/cublas_utils.h"
+#include "fp8_groupwise_scaled_gemm.cuh"
+#include "fp8_groupwise_scaled_gemm_runner_sm100.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
+
+namespace tvm {
+namespace runtime {
+
+template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
+          typename ElementC, typename ElementBlockScale>
+struct CutlassFP8GroupwiseGemm<100, TileShape, ClusterShape, ElementA, 
ElementB, ElementC,
+                               ElementBlockScale> {
+  static void run(ElementA* a, ElementB* b, ElementBlockScale* scales_a,
+                  ElementBlockScale* scales_b, ElementC* out, uint8_t* 
workspace,
+                  int64_t workspace_size, int64_t m, int64_t n, int64_t k, 
int64_t l,
+                  cudaStream_t stream) {
+    cutlass_fp8_groupwise_scaled_mm_sm100<TileShape, ClusterShape, ElementA, 
ElementB, ElementC,
+                                          ElementBlockScale>(
+        a, b, scales_a, scales_b, out, workspace, workspace_size, m, n, k, l, 
stream);
+  }
+};
+
+void tvm_cutlass_fp8_groupwise_scaled_gemm_sm100(NDArray a, NDArray b, NDArray 
scales_a,
+                                                 NDArray scales_b, NDArray 
workspace,
+                                                 int64_t block_size_0, int64_t 
block_size_1,
+                                                 NDArray out) {
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_1, _1, _1>;
+  tvm_cutlass_fp8_groupwise_scaled_gemm_impl<100, TileShape, ClusterShape>(
+      a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out);
+}
+
+void tvm_cutlass_fp8_groupwise_scaled_bmm_sm100(NDArray a, NDArray b, NDArray 
scales_a,
+                                                NDArray scales_b, NDArray 
workspace,
+                                                int64_t block_size_0, int64_t 
block_size_1,
+                                                NDArray out) {
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_1, _1, _1>;
+  tvm_cutlass_fp8_groupwise_scaled_bmm_impl<100, TileShape, ClusterShape>(
+      a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out);
+}
+
+TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm100);
+TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm100);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_SM100_SUPPORTED
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu
new file mode 100644
index 0000000000..e445e97da3
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+
+#include "../cublas/cublas_utils.h"
+#include "fp8_groupwise_scaled_gemm.cuh"
+#include "fp8_groupwise_scaled_gemm_runner_sm90.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
+
+namespace tvm {
+namespace runtime {
+
+template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
+          typename ElementC, typename ElementBlockScale>
+struct CutlassFP8GroupwiseGemm<90, TileShape, ClusterShape, ElementA, 
ElementB, ElementC,
+                               ElementBlockScale> {
+  static void run(ElementA* a, ElementB* b, ElementBlockScale* scales_a,
+                  ElementBlockScale* scales_b, ElementC* out, uint8_t* 
workspace,
+                  int64_t workspace_size, int64_t m, int64_t n, int64_t k, 
int64_t l,
+                  cudaStream_t stream) {
+    cutlass_fp8_groupwise_scaled_mm_sm90<TileShape, ClusterShape, ElementA, 
ElementB, ElementC,
+                                         ElementBlockScale>(
+        a, b, scales_a, scales_b, out, workspace, workspace_size, m, n, k, l, 
stream);
+  }
+};
+
+void tvm_cutlass_fp8_groupwise_scaled_gemm_sm90(NDArray a, NDArray b, NDArray 
scales_a,
+                                                NDArray scales_b, NDArray 
workspace,
+                                                int64_t block_size_0, int64_t 
block_size_1,
+                                                NDArray out) {
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_1, _1, _1>;
+  tvm_cutlass_fp8_groupwise_scaled_gemm_impl<90, TileShape, ClusterShape>(
+      a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out);
+}
+
+void tvm_cutlass_fp8_groupwise_scaled_bmm_sm90(NDArray a, NDArray b, NDArray 
scales_a,
+                                               NDArray scales_b, NDArray 
workspace,
+                                               int64_t block_size_0, int64_t 
block_size_1,
+                                               NDArray out) {
+  using TileShape = Shape<_128, _128, _128>;
+  using ClusterShape = Shape<_1, _1, _1>;
+  tvm_cutlass_fp8_groupwise_scaled_bmm_impl<90, TileShape, ClusterShape>(
+      a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out);
+}
+
+TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm90);
+TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm90);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
new file mode 100644
index 0000000000..19c6b699aa
--- /dev/null
+++ 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <variant>
+#include <vector>
+
+#include "../../cuda/cuda_common.h"
+
+// clang-format off
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/group_array_problem_shape.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+// clang-format on
+
+#define CUTLASS_CHECK(status)                                      \
+  {                                                                \
+    cutlass::Status error = status;                                \
+    CHECK(error == cutlass::Status::kSuccess)                      \
+        << "Got cutlass error: " << cutlassGetStatusString(error); \
+  }
+
+using namespace cute;
+using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;  
// <M,N,K> per group
+
+inline size_t aligned(size_t value, size_t alignment = 16) {
+  return (value + alignment - 1) / alignment * alignment;
+}
+
+template <typename TileShape, typename ClusterShape, typename ElementA, 
typename ElementB,
+          typename ElementC, typename ElementBlockScale>
+struct CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100 {
+  using LayoutA = cutlass::layout::RowMajor;
+  static constexpr int AlignmentA = 128 / 
cutlass::sizeof_bits<ElementA>::value;
+
+  using LayoutB = cutlass::layout::ColumnMajor;
+  static constexpr int AlignmentB = 128 / 
cutlass::sizeof_bits<ElementB>::value;
+
+  using LayoutC = cutlass::layout::RowMajor;
+  static constexpr int AlignmentC = 128 / 
cutlass::sizeof_bits<ElementC>::value;
+
+  using ElementAccumulator = float;
+  using ElementCompute = float;
+
+  static constexpr int ScaleGranularityM = 1;
+  static constexpr int ScaleGranularityN = 128;
+  static constexpr int ScaleGranularityK = 128;
+  using ScaleConfig =
+      cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, 
ScaleGranularityN,
+                                                 ScaleGranularityK, 
UMMA::Major::K, UMMA::Major::K>;
+
+  using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
+  using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
+
+  using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
+  using KernelSchedule = 
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
+  using CollectiveEpilogue = typename 
cutlass::epilogue::collective::CollectiveBuilder<
+      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, 
ClusterShape,
+      cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, 
ElementCompute, ElementC,
+      LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, 
EpilogueSchedule>::CollectiveOp;
+  using CollectiveMainloop = typename 
cutlass::gemm::collective::CollectiveBuilder<
+      cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA,
+      cute::tuple<LayoutA*, LayoutSFA*>, AlignmentA, ElementB, 
cute::tuple<LayoutB*, LayoutSFB*>,
+      AlignmentB, ElementAccumulator, TileShape, ClusterShape,
+      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
+          sizeof(typename CollectiveEpilogue::SharedStorage))>,
+      KernelSchedule>::CollectiveOp;
+  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, 
CollectiveMainloop,
+                                                          CollectiveEpilogue, 
void>;
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
+
+  using StrideA = typename Gemm::GemmKernel::InternalStrideA;
+  using StrideB = typename Gemm::GemmKernel::InternalStrideB;
+  using StrideC = typename Gemm::GemmKernel::InternalStrideC;
+  using StrideD = typename Gemm::GemmKernel::InternalStrideD;
+
+  void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B,
+                      const ElementBlockScale** ptr_scales_a,
+                      const ElementBlockScale** ptr_scales_b, const ElementC** 
ptr_C,
+                      ElementC** ptr_D,
+                      typename ProblemShape::UnderlyingProblemShape* 
problem_sizes,
+                      typename ProblemShape::UnderlyingProblemShape* 
problem_sizes_host,
+                      StrideA* stride_A, StrideB* stride_B, LayoutSFA* 
layout_scales_a,
+                      LayoutSFB* layout_scales_b, StrideC* stride_C, StrideD* 
stride_D,
+                      uint8_t* workspace, int64_t workspace_size, int 
num_groups,
+                      cudaStream_t stream) {
+    cutlass::KernelHardwareInfo hw_info;
+    hw_info.device_id = 0;
+    hw_info.sm_count =
+        
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+    typename Gemm::Arguments 
arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
+                                       {num_groups, problem_sizes, 
problem_sizes_host},
+                                       {ptr_A, stride_A, ptr_B, stride_B, 
ptr_scales_a,
+                                        layout_scales_a, ptr_scales_b, 
layout_scales_b},
+                                       {{}, ptr_C, stride_C, ptr_D, stride_D},
+                                       hw_info};
+    auto& fusion_args = arguments.epilogue.thread;
+    fusion_args.alpha = 1.0f;
+    fusion_args.beta = 0.0f;
+
+    Gemm gemm_op;
+    CUTLASS_CHECK(gemm_op.can_implement(arguments));
+    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
+    CUTLASS_CHECK(gemm_op.run(stream));
+  }
+};
+
+template <typename ScaleConfig, typename ElementA, typename ElementB, typename 
ElementC,
+          typename ElementBlockScale, typename StrideA, typename StrideB, 
typename StrideC,
+          typename LayoutSFA, typename LayoutSFB>
+__global__ void prepare_group_gemm_arguments(
+    const ElementA** ptr_A, const ElementB** ptr_B, const ElementBlockScale** 
ptr_scales_a,
+    const ElementBlockScale** ptr_scales_b, ElementC** ptr_D,
+    typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* 
stride_A,
+    StrideB* stride_B, LayoutSFA* layout_scales_a, LayoutSFB* layout_scales_b, 
StrideC* stride_D,
+    const ElementA* a, const ElementB* b, const ElementBlockScale* scales_a,
+    const ElementBlockScale* scales_b, ElementC* out, int64_t* indptr, int64_t 
n, int64_t k,
+    int num_groups) {
+  int group_id = threadIdx.x;
+  if (group_id >= num_groups) return;
+  int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1];
+  ptr_A[group_id] = a + prev_rows * k;
+  ptr_B[group_id] = b + group_id * k * n;
+  ptr_D[group_id] = out + prev_rows * n;
+  ptr_scales_a[group_id] = scales_a + prev_rows * ((k + 127) / 128);
+  ptr_scales_b[group_id] = scales_b + group_id * ((k + 127) / 128) * ((n + 
127) / 128);
+  int64_t m = indptr[group_id] - prev_rows;
+  problem_sizes[group_id] = {static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k)};
+  stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
+  stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
+  stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{});
+  layout_scales_a[group_id] = ScaleConfig::tile_atom_to_shape_SFA(
+      make_shape(static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k), 1));
+  layout_scales_b[group_id] = ScaleConfig::tile_atom_to_shape_SFB(
+      make_shape(static_cast<int>(m), static_cast<int>(n), 
static_cast<int>(k), 1));
+}
+
+template <typename ElementA, typename ElementB, typename ElementC, typename 
ElementBlockScale>
+void cutlass_fp8_groupwise_scaled_group_gemm_sm100(
+    ElementA* a, ElementB* b, const ElementBlockScale* scales_a, const 
ElementBlockScale* scales_b,
+    int64_t* indptr, uint8_t* workspace, int64_t workspace_size, int64_t n, 
int64_t k,
+    int64_t num_groups, ElementC* out, cudaStream_t stream) {
+  using TileShape = Shape<_256, _128, _128>;
+  using ClusterShape = Shape<_2, _1, _1>;
+  using Runner =
+      CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100<TileShape, ClusterShape, 
ElementA, ElementB,
+                                                    ElementC, 
ElementBlockScale>;
+  using ScaleConfig = typename Runner::ScaleConfig;
+  using StrideA = typename Runner::StrideA;
+  using StrideB = typename Runner::StrideB;
+  using StrideC = typename Runner::StrideC;
+  using LayoutSFA = typename Runner::LayoutSFA;
+  using LayoutSFB = typename Runner::LayoutSFB;
+
+  Runner runner;
+  std::ptrdiff_t offset = 0;
+  const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + 
offset);
+  offset += aligned(sizeof(ElementA*) * num_groups);
+  const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + 
offset);
+  offset += aligned(sizeof(ElementB*) * num_groups);
+  const ElementBlockScale** ptr_scales_a =
+      reinterpret_cast<const ElementBlockScale**>(workspace + offset);
+  offset += aligned(sizeof(ElementBlockScale*) * num_groups);
+  const ElementBlockScale** ptr_scales_b =
+      reinterpret_cast<const ElementBlockScale**>(workspace + offset);
+  offset += aligned(sizeof(ElementBlockScale*) * num_groups);
+  ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
+  offset += aligned(sizeof(ElementC*) * num_groups);
+  typename ProblemShape::UnderlyingProblemShape* problem_sizes =
+      reinterpret_cast<typename 
ProblemShape::UnderlyingProblemShape*>(workspace + offset);
+  offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * 
num_groups);
+  StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
+  offset += aligned(sizeof(StrideA) * num_groups);
+  StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
+  offset += aligned(sizeof(StrideB) * num_groups);
+  StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
+  offset += aligned(sizeof(StrideC) * num_groups);
+  LayoutSFA* layout_scales_a = reinterpret_cast<LayoutSFA*>(workspace + 
offset);
+  offset += aligned(sizeof(LayoutSFA) * num_groups);
+  LayoutSFB* layout_scales_b = reinterpret_cast<LayoutSFB*>(workspace + 
offset);
+  offset += aligned(sizeof(LayoutSFB) * num_groups);
+  prepare_group_gemm_arguments<ScaleConfig, ElementA, ElementB, ElementC, 
ElementBlockScale,
+                               StrideA, StrideB, StrideC, LayoutSFA, LayoutSFB>
+      <<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_scales_a, ptr_scales_b, 
ptr_D, problem_sizes,
+                                     stride_A, stride_B, layout_scales_a, 
layout_scales_b, stride_D,
+                                     a, b, scales_a, scales_b, out, indptr, n, 
k, num_groups);
+  offset = aligned(offset, 256);
+  runner.run_group_gemm(ptr_A, ptr_B, ptr_scales_a, ptr_scales_b,
+                        const_cast<const ElementC**>(ptr_D), ptr_D, 
problem_sizes, nullptr,
+                        stride_A, stride_B, layout_scales_a, layout_scales_b, 
stride_D, stride_D,
+                        workspace + offset, workspace_size - offset, 
num_groups, stream);
+}
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
new file mode 100644
index 0000000000..d13481e9dd
--- /dev/null
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <float.h>
+#include <tvm/ffi/function.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+
+#include "fp8_groupwise_scaled_group_gemm_runner_sm100.cuh"
+
+#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
+
+namespace tvm {
+namespace runtime {
+
+void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray 
scales_a,
+                                               NDArray scales_b, NDArray 
indptr, NDArray workspace,
+                                               int64_t block_size_0, int64_t 
block_size_1,
+                                               NDArray out) {
+  // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
+  // Recommended size is 4MB.
+  static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
+  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  CHECK_EQ(a->ndim, 2);
+  CHECK_EQ(b->ndim, 3);
+  CHECK_EQ(indptr->ndim, 1);
+  CHECK_EQ(workspace->ndim, 1);
+  CHECK_EQ(out->ndim, 2);
+  int num_groups = b->shape[0];
+  int n = b->shape[1];
+  int k = b->shape[2];
+
+  CHECK_EQ(scales_a->ndim, a->ndim);
+  CHECK_EQ(scales_b->ndim, b->ndim);
+  // scales_a is row-major of (m, k / block_size)
+  CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1]);
+  CHECK_EQ(scales_a->shape[0], a->shape[0]);
+  // scales_b is col-major of (k / block_size, n / block_size)
+  CHECK_EQ(scales_b->shape[0], num_groups);
+  CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]);
+  CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]);
+
+  using tvm::runtime::DataType;
+  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
+  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
+  CHECK_EQ(DataType(indptr->dtype), DataType::Int(64));
+  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+
+  if (DataType(out->dtype) == DataType::Float(16)) {
+    using Dtype = cutlass::half_t;
+    cutlass_fp8_groupwise_scaled_group_gemm_sm100<cutlass::float_e4m3_t, 
cutlass::float_e4m3_t,
+                                                  Dtype, float>(
+        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
+        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
+        static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0], n, k, num_groups, static_cast<Dtype*>(out->data), 
stream);
+  } else if (DataType(out->dtype) == DataType::BFloat(16)) {
+    using Dtype = cutlass::bfloat16_t;
+    cutlass_fp8_groupwise_scaled_group_gemm_sm100<cutlass::float_e4m3_t, 
cutlass::float_e4m3_t,
+                                                  Dtype, float>(
+        static_cast<cutlass::float_e4m3_t*>(a->data), 
static_cast<cutlass::float_e4m3_t*>(b->data),
+        static_cast<float*>(scales_a->data), 
static_cast<float*>(scales_b->data),
+        static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
+        workspace->shape[0], n, k, num_groups, static_cast<Dtype*>(out->data), 
stream);
+  }
+}
+
+TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn")
+    .set_body_typed(tvm_fp8_groupwise_scaled_group_gemm_sm100);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // CUTLASS_ARCH_MMA_SM100_SUPPORTED
diff --git a/src/target/tag.cc b/src/target/tag.cc
index f6e2307b75..0df0d8d2c7 100644
--- a/src/target/tag.cc
+++ b/src/target/tag.cc
@@ -161,6 +161,8 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 
65536)
     .with_config("l2_cache_size_bytes", 41943040);
 TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536)
     .with_config("l2_cache_size_bytes", 52428800);
+TVM_REGISTER_CUDA_TAG("nvidia/nvidia-b100", "sm_100a", 49152, 65536)
+    .with_config("l2_cache_size_bytes", 52428800);
 TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536);
 TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536);
 TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536);
diff --git a/tests/python/contrib/test_cutlass_gemm.py 
b/tests/python/contrib/test_cutlass_gemm.py
index 7c259e6f7d..33f7ef1160 100644
--- a/tests/python/contrib/test_cutlass_gemm.py
+++ b/tests/python/contrib/test_cutlass_gemm.py
@@ -44,8 +44,8 @@ def verify_group_gemm(
     def get_ref_data():
         assert M % num_groups == 0
         M_per_group = M // num_groups
-        a_np = get_random_ndarray((M, K), "float16")
-        b_np = get_random_ndarray((num_groups, N, K), "float16")
+        a_np = get_random_ndarray((M, K), x_dtype)
+        b_np = get_random_ndarray((num_groups, N, K), weight_dtype)
         indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
         c_np = np.concatenate(
             [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i 
in range(num_groups)],
@@ -76,7 +76,7 @@ def verify_group_gemm(
 @tvm.testing.requires_cuda_compute_version(9)
 def test_group_gemm_sm90():
     verify_group_gemm(
-        "cutlass.group_gemm_fp16_sm90",
+        "cutlass.group_gemm",
         8,
         128,
         128,
@@ -116,6 +116,24 @@ def test_group_gemm_sm90():
     )
 
 
[email protected]_cutlass
[email protected]_cuda_compute_version(10)
+def test_group_gemm_sm100():
+    verify_group_gemm(
+        "cutlass.group_gemm",
+        8,
+        128,
+        128,
+        4,
+        "bfloat16",
+        "bfloat16",
+        "bfloat16",
+        False,
+        rtol=1e-2,
+        atol=1e-3,
+    )
+
+
 def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, 
int], dtype: str):
     x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
     x_scale_shape = (
@@ -283,14 +301,14 @@ def blockwise_bmm(
 
 @tvm.testing.requires_cutlass
 @tvm.testing.requires_cuda_compute_version(9)
-def test_fp8_e4m3_blockwise_scaled_gemm():
+def test_fp8_e4m3_groupwise_scaled_gemm():
     M = 16
     N = 4608
     K = 896
     block_size = (128, 128)
     assert N % 128 == 0 and K % 128 == 0  # Only support N/K are multiple of 
128
 
-    func_name = "cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn"
+    func_name = "cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn"
     gemm_func = tvm.get_global_func(func_name, allow_missing=True)
     if gemm_func is None:
         print(f"Skipped as {func_name} is not available")
@@ -316,7 +334,7 @@ def test_fp8_e4m3_blockwise_scaled_gemm():
 
 @tvm.testing.requires_cutlass
 @tvm.testing.requires_cuda_compute_version(9)
-def test_fp8_e4m3_blockwise_scaled_bmm():
+def test_fp8_e4m3_groupwise_scaled_bmm():
     B = 16
     M = 40
     N = 512
@@ -324,7 +342,7 @@ def test_fp8_e4m3_blockwise_scaled_bmm():
     block_size = (128, 128)
     assert N % 128 == 0 and K % 128 == 0  # Only support N/K are multiple of 
128
 
-    func_name = "cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn"
+    func_name = "cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn"
     gemm_func = tvm.get_global_func(func_name, allow_missing=True)
     if gemm_func is None:
         print(f"Skipped as {func_name} is not available")

Reply via email to