This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d2c7167913 [Cutlass] Fix usage of cuda stream for group gemm (#16818)
d2c7167913 is described below

commit d2c7167913fabe0ac46c5bd50b0a9984d5b174c5
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Mar 29 09:17:25 2024 -0700

    [Cutlass] Fix usage of cuda stream for group gemm (#16818)
---
 src/runtime/contrib/cutlass/group_gemm_runner.cuh | 15 ++++++---------
 1 file changed, 6 insertions(+), 9 deletions(-)

diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh 
b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
index 50bdcf7bec..71979672b9 100644
--- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
@@ -40,14 +40,11 @@
 #include "cutlass/gemm/kernel/gemm_universal.hpp"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                                  
                  \
-  {                                                                            
                  \
-    cutlass::Status error = status;                                            
                  \
-    if (error != cutlass::Status::kSuccess) {                                  
                  \
-      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " 
at: " << __LINE__ \
-                << std::endl;                                                  
                  \
-      exit(EXIT_FAILURE);                                                      
                  \
-    }                                                                          
                  \
+#define CUTLASS_CHECK(status)                                      \
+  {                                                                \
+    cutlass::Status error = status;                                \
+    CHECK(error == cutlass::Status::kSuccess)                      \
+        << "Got cutlass error: " << cutlassGetStatusString(error); \
   }
 
 using namespace cute;
@@ -147,7 +144,7 @@ struct CutlassGroupGemmRunner {
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
     CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
-    CUTLASS_CHECK(gemm_op.run());
+    CUTLASS_CHECK(gemm_op.run(stream));
   }
 };
 

Reply via email to