This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 68ac909713 [CodeGenC] Use PrimFuncNode::ret_type in function signature
(#15073)
68ac909713 is described below
commit 68ac9097132b40cca88733267b628c0eb42bbbfb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Jun 13 10:58:19 2023 -0400
[CodeGenC] Use PrimFuncNode::ret_type in function signature (#15073)
Prior to this PR, the return type for `CodeGenC` was hard-coded as
part of `virtual CodeGenC::PrintFuncPrefix`, regardless of the return
type specified in the `PrimFunc`. This PR updates `CodeGenC` to use
`PrimFuncNode::ret_type` for the return type in the generated C code.
This change should have no effect on observable behavior. The
majority of codegen classes specified a `void` return type, which
matches the default `DataType::Void()` for a `PrimFunc`. The one
exception is `CodeGenCHost::PrintFuncPrefix`, which specified an
`int32_t` return type, matching the `DataType::Int(32)` used for the
functions generated by `MakePackedAPI` and `MakeUnpackedAPI`.
---
src/target/source/codegen_c.cc | 10 ++++++++--
src/target/source/codegen_c.h | 7 +++++--
src/target/source/codegen_c_host.cc | 15 +++++++++------
src/target/source/codegen_c_host.h | 5 +++--
src/target/source/codegen_cuda.cc | 2 +-
src/target/source/codegen_opencl.cc | 2 +-
src/target/source/codegen_vhls.cc | 2 +-
7 files changed, 28 insertions(+), 15 deletions(-)
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index f6792c1a4e..bcdd0bfea0 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -87,6 +87,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix(stream);
+ PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) <<
"(";
@@ -128,7 +129,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n";
}
-void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }
+void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
@@ -541,7 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value,
op->args, true, os);
- this->GenerateForwardFunctionDeclarations(func->value, op->args);
+ Array<Type> arg_types;
+ for (size_t i = 1; i < op->args.size(); i++) {
+ arg_types.push_back(GetType(op->args[i]));
+ }
+ Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+ this->GenerateForwardFunctionDeclarations(func->value, arg_types,
ret_type);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)),
op_attr_global_symbol_[call_op],
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 4f0da5a9db..de9c2f1745 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -232,11 +232,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
/*!
* \brief Generate forward function declarations.
* \param global_symbol The symbolc of the target function.
- * \param args The arguments to the function.
+ * \param arg_types The argument types to the function.
+ * \param ret_type The return type of the function
* \param os The output stream.
*/
virtual void GenerateForwardFunctionDeclarations(String global_symbol,
- const Array<PrimExpr>&
args) {}
+ const Array<Type>&
arg_types,
+ const Type& ret_type) {}
+
/*!
* \brief Print external function call.
* \param ret_type The return type.
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 1d8071774e..e98852c270 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -87,6 +87,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool
emit_fwd_func_decl) {
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
+ PrintType(f->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void*
out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
@@ -97,7 +98,9 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool
emit_fwd_func_decl) {
}
void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
- const Array<PrimExpr>&
args) {
+
+ const Array<Type>&
arg_types,
+ const Type& ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
@@ -107,13 +110,13 @@ void
CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
}
}
this->PrintFuncPrefix(fwd_decl_stream);
+ this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
- for (size_t i = 1; i < args.size(); ++i) {
- CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
- fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
- if (i < args.size() - 1) {
+ for (size_t i = 0; i < arg_types.size(); ++i) {
+ if (i > 0) {
fwd_decl_stream << ", ";
}
+ CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
@@ -122,7 +125,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { //
NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
- << "TVM_DLL int32_t";
+ << "TVM_DLL ";
}
void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
diff --git a/src/target/source/codegen_c_host.h
b/src/target/source/codegen_c_host.h
index 6bae574627..9c71f197f0 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -55,6 +55,7 @@ class CodeGenCHost : public CodeGenC {
void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar,
tvm::BaseFunc>> functions);
void DefineModuleName();
+ using CodeGenC::PrintType;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)
@@ -69,8 +70,8 @@ class CodeGenCHost : public CodeGenC {
void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)
- virtual void GenerateForwardFunctionDeclarations(String global_symbol,
- const Array<PrimExpr>&
args); // NOLINT(*)
+ void GenerateForwardFunctionDeclarations(String global_symbol, const
Array<Type>& arg_types,
+ const Type& ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }
private:
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index ec8695a2a0..cd0ec0e34f 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
ICHECK_EQ(vid_global_barrier_state_,
runtime::symbol::tvm_global_barrier_state);
}
-void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\"
__global__ void"; }
+void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\"
__global__ "; }
class ThreadIdxExtractor : public tir::StmtVisitor {
private:
diff --git a/src/target/source/codegen_opencl.cc
b/src/target/source/codegen_opencl.cc
index fa4ca7d34b..c15d2253d7 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -88,7 +88,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
}
}
-void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void";
}
+void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel "; }
void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) {
for (Var arg : f->params) {
diff --git a/src/target/source/codegen_vhls.cc
b/src/target/source/codegen_vhls.cc
index 8463d6ac41..83046de107 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream&
os) {
}
}
-void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\"
void"; }
+void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\"
"; }
void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {