This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 34d4bc75d8 [Fix] Codegen fix for relax cutlass (#18190)
34d4bc75d8 is described below

commit 34d4bc75d862c8571b3499a1e30046720770ed3a
Author: Chenfan <[email protected]>
AuthorDate: Sun Aug 10 14:24:24 2025 +0800

    [Fix] Codegen fix for relax cutlass (#18190)
    
    * Codegen fix
    
    ---------
    
    Co-authored-by: Tianqi Chen <[email protected]>
---
 python/tvm/contrib/cutlass/conv2d_operation.py  |  8 ++++----
 python/tvm/contrib/cutlass/gemm_operation.py    |  6 +++---
 src/relax/backend/contrib/codegen_c/codegen_c.h | 17 +++++++----------
 3 files changed, 14 insertions(+), 17 deletions(-)

diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index a37e46f404..361bcb54e5 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -418,17 +418,17 @@ def instantiate_conv2d_template(attrs):
   size_t workspace_size = conv2d_op.get_workspace_size(arguments);
   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
   cutlass::Status status = conv2d_op.can_implement(arguments);
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
   ${split_k_reset}
   status = conv2d_op.initialize(arguments, workspace.get());
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
   ${split_k_update}
 
   auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
   cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
 
   status = conv2d_op(stream);
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
   ${split_k_reduction}
 """
 
@@ -439,7 +439,7 @@ def instantiate_conv2d_template(attrs):
     split_k_update = """
   arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)};
   status = conv2d_op.update(arguments, workspace.get());
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
 """
 
     split_k_reduction = """
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 46b68c29ee..65dc5da772 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -341,15 +341,15 @@ def instantiate_gemm_template(attrs):
   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
   ${kernel} gemm_op;
   cutlass::Status status = gemm_op.can_implement(arguments);
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
   status = gemm_op.initialize(arguments, workspace.get());
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
 
   auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
   cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
 
   status = gemm_op(stream);
-  CHECK(status == cutlass::Status::kSuccess);
+  TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
 """
     op_type = attrs["op_type"]
     has_bias = "bias" in op_type
diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h 
b/src/relax/backend/contrib/codegen_c/codegen_c.h
index 795b691dec..7f04091fc1 100644
--- a/src/relax/backend/contrib/codegen_c/codegen_c.h
+++ b/src/relax/backend/contrib/codegen_c/codegen_c.h
@@ -83,31 +83,28 @@ class CodegenCBase {
     code_stream_ << "#ifdef __cplusplus\n";
     code_stream_ << "extern \"C\" {\n";
     code_stream_ << "#endif\n";
-    code_stream_ << "TVM_DLL int32_t ";
+    code_stream_ << "TVM_FFI_DLL_EXPORT int32_t ";
     code_stream_ << func_name << "(";
-    code_stream_ << "TVMValue* args, ";
-    code_stream_ << "int* type_code, ";
-    code_stream_ << "int num_args, ";
-    code_stream_ << "TVMValue* out_value, ";
-    code_stream_ << "int* out_type_code) {\n";
+    code_stream_ << "tvm::ffi::PackedArgs args, ";
+    code_stream_ << "tvm::ffi::AnyView* out_value) {\n";
   }
 
   /*!
-   * \brief Adds a line to convert TVMValue args to DLTensors
+   * \brief Adds a line to convert tvm::ffi::PackedArgs args to DLTensors
    */
   void PrintArgToData(int idx) {
     PrintIndents();
     code_stream_ << "DLTensor* arg" << idx << " = ";
-    code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << 
"].v_handle);\n";
+    code_stream_ << "(DLTensor*)(args[" << idx << "].cast<DLTensor*>());\n";
   }
 
   /*!
-   * \brief Adds a line to convert TVMValue rets to DLTensors
+   * \brief Adds a line to convert tvm::ffi::PackedArgs rets to DLTensors
    */
   void PrintRetToData(int idx) {
     PrintIndents();
     code_stream_ << "DLTensor* ret" << idx << " = ";
-    code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << 
"].v_handle);\n";
+    code_stream_ << "(DLTensor*)(args[" << idx << "].cast<DLTensor*>());\n";
   }
 
   /*!

Reply via email to