This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d2c7167913 [Cutlass] Fix usage of cuda stream for group gemm (#16818)
d2c7167913 is described below
commit d2c7167913fabe0ac46c5bd50b0a9984d5b174c5
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Mar 29 09:17:25 2024 -0700
[Cutlass] Fix usage of cuda stream for group gemm (#16818)
---
src/runtime/contrib/cutlass/group_gemm_runner.cuh | 15 ++++++---------
1 file changed, 6 insertions(+), 9 deletions(-)
diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
index 50bdcf7bec..71979672b9 100644
--- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh
@@ -40,14 +40,11 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
-#define CUTLASS_CHECK(status)
\
- {
\
- cutlass::Status error = status;
\
- if (error != cutlass::Status::kSuccess) {
\
- std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << "
at: " << __LINE__ \
- << std::endl;
\
- exit(EXIT_FAILURE);
\
- }
\
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ CHECK(error == cutlass::Status::kSuccess) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -147,7 +144,7 @@ struct CutlassGroupGemmRunner {
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
- CUTLASS_CHECK(gemm_op.run());
+ CUTLASS_CHECK(gemm_op.run(stream));
}
};