This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 68ac909713 [CodeGenC] Use PrimFuncNode::ret_type in function signature 
(#15073)
68ac909713 is described below

commit 68ac9097132b40cca88733267b628c0eb42bbbfb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Jun 13 10:58:19 2023 -0400

    [CodeGenC] Use PrimFuncNode::ret_type in function signature (#15073)
    
    Prior to this PR, the return type for `CodeGenC` was hard-coded as
    part of `virtual CodeGenC::PrintFuncPrefix`, regardless of the return
    type specified in the `PrimFunc`.  This PR updates `CodeGenC` to use
    `PrimFuncNode::ret_type` for the return type in the generated C code.
    
    This change should have no effect on observable behavior.  The
    majority of codegen classes specified a `void` return type, which
    matches the default `DataType::Void()` for a `PrimFunc`.  The one
    exception is `CodeGenCHost::PrintFuncPrefix`, which specified an
    `int32_t` return type, matching the `DataType::Int(32)` used for the
    functions generated by `MakePackedAPI` and `MakeUnpackedAPI`.
---
 src/target/source/codegen_c.cc      | 10 ++++++++--
 src/target/source/codegen_c.h       |  7 +++++--
 src/target/source/codegen_c_host.cc | 15 +++++++++------
 src/target/source/codegen_c_host.h  |  5 +++--
 src/target/source/codegen_cuda.cc   |  2 +-
 src/target/source/codegen_opencl.cc |  2 +-
 src/target/source/codegen_vhls.cc   |  2 +-
 7 files changed, 28 insertions(+), 15 deletions(-)

diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index f6792c1a4e..bcdd0bfea0 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -87,6 +87,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
   bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
 
   this->PrintFuncPrefix(stream);
+  PrintType(f->ret_type, stream);
   this->PrintExtraAttrs(f);
   this->stream << " " << static_cast<std::string>(global_symbol.value()) << 
"(";
 
@@ -128,7 +129,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
   this->stream << "}\n\n";
 }
 
-void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }
+void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
 
 void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
 
@@ -541,7 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT(*)
       ICHECK_GE(op->args.size(), 1U);
       auto func = Downcast<StringImm>(op->args[0]);
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, 
op->args, true, os);
-      this->GenerateForwardFunctionDeclarations(func->value, op->args);
+      Array<Type> arg_types;
+      for (size_t i = 1; i < op->args.size(); i++) {
+        arg_types.push_back(GetType(op->args[i]));
+      }
+      Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+      this->GenerateForwardFunctionDeclarations(func->value, arg_types, 
ret_type);
     } else if (op_attr_global_symbol_.count(call_op)) {
       // call extern if the op itself have a global symbol.
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), 
op_attr_global_symbol_[call_op],
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 4f0da5a9db..de9c2f1745 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -232,11 +232,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
   /*!
    * \brief Generate forward function declarations.
    * \param global_symbol The symbolc of the target function.
-   * \param args The arguments to the function.
+   * \param arg_types The argument types to the function.
+   * \param ret_type The return type of the function
    * \param os The output stream.
    */
   virtual void GenerateForwardFunctionDeclarations(String global_symbol,
-                                                   const Array<PrimExpr>& 
args) {}
+                                                   const Array<Type>& 
arg_types,
+                                                   const Type& ret_type) {}
+
   /*!
    * \brief Print external function call.
    * \param ret_type The return type.
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 1d8071774e..e98852c270 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -87,6 +87,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool 
emit_fwd_func_decl) {
     function_names_.push_back(runtime::symbol::tvm_module_main);
     stream << "// CodegenC: NOTE: Auto-generated entry function\n";
     PrintFuncPrefix(stream);
+    PrintType(f->ret_type, stream);
     stream << " " << tvm::runtime::symbol::tvm_module_main
            << "(void* args, int* arg_type_ids, int num_args, void* 
out_ret_value, "
            << "int* out_ret_tcode, void* resource_handle) {\n";
@@ -97,7 +98,9 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool 
emit_fwd_func_decl) {
 }
 
 void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
-                                                       const Array<PrimExpr>& 
args) {
+
+                                                       const Array<Type>& 
arg_types,
+                                                       const Type& ret_type) {
   if (!emit_fwd_func_decl_) {
     return;
   }
@@ -107,13 +110,13 @@ void 
CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
     }
   }
   this->PrintFuncPrefix(fwd_decl_stream);
+  this->PrintType(ret_type, fwd_decl_stream);
   fwd_decl_stream << " " << global_symbol << "(";
-  for (size_t i = 1; i < args.size(); ++i) {
-    CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
-    fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
-    if (i < args.size() - 1) {
+  for (size_t i = 0; i < arg_types.size(); ++i) {
+    if (i > 0) {
       fwd_decl_stream << ", ";
     }
+    CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
   }
   fwd_decl_stream << ");\n";
 }
@@ -122,7 +125,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) {  // 
NOLINT(*)
   os << "#ifdef __cplusplus\n"
      << "extern \"C\"\n"
      << "#endif\n"
-     << "TVM_DLL int32_t";
+     << "TVM_DLL ";
 }
 
 void CodeGenCHost::PrintFinalReturn() {  // NOLINT(*)
diff --git a/src/target/source/codegen_c_host.h 
b/src/target/source/codegen_c_host.h
index 6bae574627..9c71f197f0 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -55,6 +55,7 @@ class CodeGenCHost : public CodeGenC {
   void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, 
tvm::BaseFunc>> functions);
   void DefineModuleName();
 
+  using CodeGenC::PrintType;
   void PrintType(DataType t, std::ostream& os) final;  // NOLINT(*)
   void PrintFuncPrefix(std::ostream& os) final;        // NOLINT(*)
   void PrintFinalReturn() final;                       // NOLINT(*)
@@ -69,8 +70,8 @@ class CodeGenCHost : public CodeGenC {
 
   void VisitStmt_(const AssertStmtNode* op) final;  // NOLINT(*)
 
-  virtual void GenerateForwardFunctionDeclarations(String global_symbol,
-                                                   const Array<PrimExpr>& 
args);  // NOLINT(*)
+  void GenerateForwardFunctionDeclarations(String global_symbol, const 
Array<Type>& arg_types,
+                                           const Type& ret_type) override;
   Array<String> GetFunctionNames() { return function_names_; }
 
  private:
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index ec8695a2a0..cd0ec0e34f 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
   ICHECK_EQ(vid_global_barrier_state_, 
runtime::symbol::tvm_global_barrier_state);
 }
 
-void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" 
__global__ void"; }
+void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" 
__global__ "; }
 
 class ThreadIdxExtractor : public tir::StmtVisitor {
  private:
diff --git a/src/target/source/codegen_opencl.cc 
b/src/target/source/codegen_opencl.cc
index fa4ca7d34b..c15d2253d7 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -88,7 +88,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
   }
 }
 
-void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; 
}
+void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel "; }
 
 void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) {
   for (Var arg : f->params) {
diff --git a/src/target/source/codegen_vhls.cc 
b/src/target/source/codegen_vhls.cc
index 8463d6ac41..83046de107 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& 
os) {
   }
 }
 
-void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" 
void"; }
+void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" 
"; }
 
 void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
   for (size_t i = 0; i < f->params.size(); ++i) {

Reply via email to