This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 34d4bc75d8 [Fix] Codegen fix for relax cutlass (#18190)
34d4bc75d8 is described below
commit 34d4bc75d862c8571b3499a1e30046720770ed3a
Author: Chenfan <[email protected]>
AuthorDate: Sun Aug 10 14:24:24 2025 +0800
[Fix] Codegen fix for relax cutlass (#18190)
* Codegen fix
---------
Co-authored-by: Tianqi Chen <[email protected]>
---
python/tvm/contrib/cutlass/conv2d_operation.py | 8 ++++----
python/tvm/contrib/cutlass/gemm_operation.py | 6 +++---
src/relax/backend/contrib/codegen_c/codegen_c.h | 17 +++++++----------
3 files changed, 14 insertions(+), 17 deletions(-)
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index a37e46f404..361bcb54e5 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -418,17 +418,17 @@ def instantiate_conv2d_template(attrs):
size_t workspace_size = conv2d_op.get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = conv2d_op.can_implement(arguments);
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_reset}
status = conv2d_op.initialize(arguments, workspace.get());
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_update}
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
status = conv2d_op(stream);
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_reduction}
"""
@@ -439,7 +439,7 @@ def instantiate_conv2d_template(attrs):
split_k_update = """
arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)};
status = conv2d_op.update(arguments, workspace.get());
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
"""
split_k_reduction = """
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index 46b68c29ee..65dc5da772 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -341,15 +341,15 @@ def instantiate_gemm_template(attrs):
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
${kernel} gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
status = gemm_op.initialize(arguments, workspace.get());
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
status = gemm_op(stream);
- CHECK(status == cutlass::Status::kSuccess);
+ TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
"""
op_type = attrs["op_type"]
has_bias = "bias" in op_type
diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h
b/src/relax/backend/contrib/codegen_c/codegen_c.h
index 795b691dec..7f04091fc1 100644
--- a/src/relax/backend/contrib/codegen_c/codegen_c.h
+++ b/src/relax/backend/contrib/codegen_c/codegen_c.h
@@ -83,31 +83,28 @@ class CodegenCBase {
code_stream_ << "#ifdef __cplusplus\n";
code_stream_ << "extern \"C\" {\n";
code_stream_ << "#endif\n";
- code_stream_ << "TVM_DLL int32_t ";
+ code_stream_ << "TVM_FFI_DLL_EXPORT int32_t ";
code_stream_ << func_name << "(";
- code_stream_ << "TVMValue* args, ";
- code_stream_ << "int* type_code, ";
- code_stream_ << "int num_args, ";
- code_stream_ << "TVMValue* out_value, ";
- code_stream_ << "int* out_type_code) {\n";
+ code_stream_ << "tvm::ffi::PackedArgs args, ";
+ code_stream_ << "tvm::ffi::AnyView* out_value) {\n";
}
/*!
- * \brief Adds a line to convert TVMValue args to DLTensors
+ * \brief Adds a line to convert tvm::ffi::PackedArgs args to DLTensors
*/
void PrintArgToData(int idx) {
PrintIndents();
code_stream_ << "DLTensor* arg" << idx << " = ";
- code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx <<
"].v_handle);\n";
+ code_stream_ << "(DLTensor*)(args[" << idx << "].cast<DLTensor*>());\n";
}
/*!
- * \brief Adds a line to convert TVMValue rets to DLTensors
+ * \brief Adds a line to convert tvm::ffi::PackedArgs rets to DLTensors
*/
void PrintRetToData(int idx) {
PrintIndents();
code_stream_ << "DLTensor* ret" << idx << " = ";
- code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx <<
"].v_handle);\n";
+ code_stream_ << "(DLTensor*)(args[" << idx << "].cast<DLTensor*>());\n";
}
/*!