This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 698531e6d7 [CodeGenC][Redo] Handle GlobalVar callee as internal
function call (#15835)
698531e6d7 is described below
commit 698531e6d7e1493b7a73c71137132de87de8aad0
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Oct 18 14:07:43 2023 -0500
[CodeGenC][Redo] Handle GlobalVar callee as internal function call (#15835)
* [CodeGenC][Redo] Handle GlobalVar callee as internal function call
This reverts commit
[`e88d0d`](https://github.com/apache/tvm/pull/15725), which
itself reverted
[`9ff71f`](https://github.com/apache/tvm/pull/15103) for
breakages on the metal backend. Now that the CI contains
compile-time testing of the metal codegen, the original
breakage should be identifiable.
* Added codegen metal CI debug print
* Print function decl to the argument stream
* Remove the codegen metal CI debug print-outs
---
.../arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py | 8 +-
.../topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py | 87 +++++++++---
.../arm_cpu/mprofile/dsp/micro_kernel/max_pool.py | 13 +-
.../arm_cpu/mprofile/dsp/micro_kernel/tensordot.py | 7 +-
.../backend/contrib/cmsisnn/tir_to_runtime.cc | 28 ++--
.../contrib/example_target_hooks/tir_to_runtime.cc | 26 +++-
src/relay/backend/contrib/uma/tir_to_runtime.cc | 34 +++--
src/target/opt/build_cuda_on.cc | 18 ++-
src/target/source/codegen_aocl.cc | 19 ++-
src/target/source/codegen_c.cc | 153 ++++++++++++++-------
src/target/source/codegen_c.h | 59 +++++++-
src/target/source/codegen_c_host.cc | 93 ++++++-------
src/target/source/codegen_c_host.h | 3 +-
src/target/source/codegen_cuda.cc | 4 +-
src/target/source/codegen_cuda.h | 2 +-
src/target/source/codegen_metal.cc | 89 ++++++------
src/target/source/codegen_metal.h | 3 +-
src/target/source/codegen_opencl.cc | 24 ++--
src/target/source/codegen_vhls.cc | 34 +++--
src/target/source/codegen_webgpu.cc | 79 +++++------
src/target/source/codegen_webgpu.h | 4 +-
src/target/source/source_module.cc | 6 +-
src/tir/op/op.cc | 26 ++++
.../relay/aot/test_crt_forward_declarations.py | 6 +-
.../topi/python/test_topi_conv2d_tensordot_opts.py | 28 +++-
.../python/unittest/test_target_codegen_c_host.py | 48 +++++--
.../test_tir_transform_inject_ptx_async_copy.py | 1 +
27 files changed, 597 insertions(+), 305 deletions(-)
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
index e8e45152aa..3eb32d8fdb 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
@@ -55,7 +55,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
- cc.dtype,
+ "int32",
f"{func_prefix}_{width}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
- tvm.tir.call_extern(cc.dtype,
f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
+ tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}",
cc.access_ptr("w"))
)
return ib.get()
@@ -113,8 +113,8 @@ extern "C"
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
int16_t *arr,
int16_t *res16,
- long arr_offset,
- int reset) {{
+ int32_t arr_offset,
+ int32_t reset) {{
int n;
int32_t *p32;
int32_t res = reset ? 0 : *res16;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
index 929dcc6557..e26e818fbd 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
@@ -156,9 +156,14 @@ __attribute__((always_inline)) static inline const int8_t
*read_and_pad(const in
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm_{M}x{N}_body_rest_{uniq_id}(
- int K,
+ int32_t K_arg,
int8_t *aa, int8_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int K = K_arg;
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
@@ -200,7 +205,12 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
+
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -221,7 +231,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm_{M}x{K}x{N}_body_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
@@ -265,9 +279,14 @@ out:
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm_{M}x{N}_update_rest_{uniq_id}(
- int K,
+ int32_t K_arg,
int8_t *aa, int8_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int K = K_arg;
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
@@ -309,7 +328,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -327,7 +350,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm_{M}x{K}x{N}_update_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
@@ -368,9 +395,14 @@ out:
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm16_{M}x{N}_body_rest_{uniq_id}(
- int K,
+ int32_t K_arg,
int16_t *aa, int16_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int K = K_arg;
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
@@ -387,7 +419,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -408,7 +444,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm16_{M}x{K}x{N}_body_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
@@ -450,9 +490,14 @@ out:
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm16_{M}x{N}_update_rest_{uniq_id}(
- int K,
+ int32_t K_arg,
int16_t *aa, int16_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int K = K_arg;
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
@@ -469,7 +514,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -487,7 +536,11 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t
gemm16_{M}x{K}x{N}_update_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int A_stride, int B_stride, int C_stride) {{
+ int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+ int A_stride = A_stride_arg;
+ int B_stride = B_stride_arg;
+ int C_stride = C_stride_arg;
+
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
@@ -520,7 +573,7 @@ out:
#ifdef __cplusplus
extern "C"
#endif
-__attribute__((always_inline)) static inline int32_t
gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
+__attribute__((always_inline)) static inline int32_t
gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
cc[i*C_stride + j] = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
index 66d712a4a0..cfed417c9f 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
@@ -46,7 +46,7 @@ def intrin_max(shape, in_dtype, out_dtype):
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
- cc.dtype,
+ "int32",
f"{func_prefix}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def intrin_max(shape, in_dtype, out_dtype):
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
- cc.dtype, f"{func_prefix}_reset_{uniq_id}",
cc.access_ptr("w"), cc.strides[0]
+ "int32", f"{func_prefix}_reset_{uniq_id}",
cc.access_ptr("w"), cc.strides[0]
)
)
return ib.get()
@@ -96,7 +96,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
int8_t *res,
- int N) {{
+ int32_t N) {{
memset(res, (int8_t)-128, N * sizeof(*res));
return 0;
}}
@@ -107,7 +107,9 @@ extern "C"
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
int8_t *arg,
int8_t *res,
- int N) {{
+ int32_t N_arg) {{
+ int N = N_arg;
+
for ( int i = 0; i < N; ++ i )
if ( arg[i] > res[i] )
res[i] = arg[i];
@@ -120,7 +122,8 @@ extern "C"
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
int8_t *arg,
int8_t *res,
- int N) {{
+ int32_t N_arg) {{
+ int N = N_arg;
int32_t *parg32, *pres32;
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
int32_t retcode = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
index d2a8f1ef69..af3b23e01d 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
@@ -390,8 +390,13 @@ def tensordot_int16_impl(
#define {function_name.upper()}_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t {function_name}(
- int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias,
int32_t *scale
+ int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+ int32_t *bias, int32_t *scale
) {{
+ int32_t *output = output_arg;
+ int32_t *tensor = tensor_arg;
+ int32_t *kernel = kernel_arg;
+
{_init_biased_accumulators(num_outputs)}
{insert_lines(load_tensor_lines)}
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index 186fa30f20..6febfe3486 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl,
target_str, devices);
}
- /*!
- * \brief Emit code that offloads a subgraph to the Cortex-M
- *
- * \return string of code that offloads a subgraph to the Cortex-M
- */
- void AddFunction(const PrimFunc& prim_func) {
CodeGenC::AddFunction(prim_func); }
-
private:
/*! * \brief Enable storing the last error */
bool debug_last_error;
@@ -575,11 +568,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target)
{
bool emit_fwd_func_decl = false;
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
CodeGenCMSISNN codegen;
- Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(),
debug_last_error);
- std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
- for (auto kv : mod->functions) {
- funcs.push_back(kv);
+
+ std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
+ for (auto [gvar, base_func] : mod->functions) {
+ funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
}
std::sort(funcs.begin(), funcs.end(),
@@ -594,13 +587,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target)
{
return name_hint_a < name_hint_b;
});
- for (auto kv : funcs) {
- auto prim_func = Downcast<PrimFunc>(kv.second);
- auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- function_names.push_back(global_symbol.value());
- codegen.AddFunction(prim_func);
+ for (auto [gvar, prim_func] : funcs) {
+ codegen.AddFunction(gvar, prim_func);
}
std::string code = codegen.Finish();
+
+ Array<String> function_names;
+ for (auto [gvar, prim_func] : funcs) {
+ function_names.push_back(codegen.GetFunctionName(gvar));
+ }
+
return codegen::CSourceModuleCreate(code, "c", function_names);
}
diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
index 0db8d06c31..6f09e0a0c3 100644
--- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
@@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenExampleTargetHook codegen;
- Array<String> function_names;
+
std::unordered_set<std::string> devices;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(),
devices);
- for (auto kv : mod->functions) {
- auto prim_func = Downcast<PrimFunc>(kv.second);
- auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- function_names.push_back(global_symbol.value());
- codegen.AddFunction(prim_func);
+
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ functions.Set(gvar, prim_func);
+ }
+
+ for (auto [gvar, prim_func] : functions) {
+ codegen.DeclareFunction(gvar, prim_func);
+ }
+ for (auto [gvar, prim_func] : functions) {
+ codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
+
std::string code = codegen.Finish();
+
+ Array<String> function_names;
+ for (auto [gvar, prim_func] : functions) {
+ function_names.push_back(codegen.GetFunctionName(gvar));
+ }
+
return codegen::CSourceModuleCreate(code, "c", function_names);
}
diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc
b/src/relay/backend/contrib/uma/tir_to_runtime.cc
index 3b58fda54b..487e247f5d 100644
--- a/src/relay/backend/contrib/uma/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc
@@ -49,13 +49,6 @@ class UMACodegen : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl,
target_str_, devices);
}
- /*!
- * \brief Emit code that offloads a subgraph to the UMA target
- *
- * \return string of code that offloads a subgraph to the UMA target
- */
- void AddFunction(const PrimFunc& prim_func) {
CodeGenC::AddFunction(prim_func); }
-
private:
String target_str_;
};
@@ -63,17 +56,30 @@ class UMACodegen : public codegen::CodeGenCHost {
runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
- bool emit_fwd_func_decl = false;
+ bool emit_fwd_func_decl = true;
UMACodegen codegen(target->kind->name);
- Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl);
- for (auto kv : mod->functions) {
- auto prim_func = Downcast<PrimFunc>(kv.second);
- auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- function_names.push_back(global_symbol.value());
- codegen.AddFunction(prim_func);
+
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ functions.Set(gvar, prim_func);
+ }
+
+ for (auto [gvar, prim_func] : functions) {
+ codegen.DeclareFunction(gvar, prim_func);
+ }
+ for (auto [gvar, prim_func] : functions) {
+ codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
+
std::string code = codegen.Finish();
+
+ Array<String> function_names;
+ for (auto [gvar, prim_func] : functions) {
+ function_names.push_back(codegen.GetFunctionName(gvar));
+ }
+
return codegen::CSourceModuleCreate(code, "c", function_names);
}
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index 1c0b5094ef..e0f53e3509 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -131,13 +131,21 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
CodeGenCUDA cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only
take PrimFunc";
- auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only
take PrimFunc";
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
- cg.AddFunction(f);
+ functions.Set(gvar, prim_func);
+ }
+
+ for (auto [gvar, prim_func] : functions) {
+ cg.DeclareFunction(gvar, prim_func);
+ }
+ for (auto [gvar, prim_func] : functions) {
+ cg.AddFunction(gvar, prim_func);
}
std::string code = cg.Finish();
diff --git a/src/target/source/codegen_aocl.cc
b/src/target/source/codegen_aocl.cc
index 700d85b4cc..dc3ba08751 100644
--- a/src/target/source/codegen_aocl.cc
+++ b/src/target/source/codegen_aocl.cc
@@ -40,13 +40,22 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool
emulation) {
CodeGenOpenCL cg;
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only
take PrimFunc";
- auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only
take PrimFunc";
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodegenOpenCL: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
- cg.AddFunction(f);
+ functions.Set(gvar, prim_func);
+ }
+
+ for (auto [gvar, prim_func] : functions) {
+ cg.DeclareFunction(gvar, prim_func);
+ }
+
+ for (auto [gvar, prim_func] : functions) {
+ cg.AddFunction(gvar, prim_func);
}
std::string code = cg.Finish();
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index a7cc320562..187bdc74fe 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -42,6 +42,7 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_.clear();
handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState();
+ ReserveKeywordsAsUnique();
}
void CodeGenC::ReserveKeywordsAsUnique() {
@@ -75,51 +76,92 @@ void CodeGenC::ReserveKeywordsAsUnique() {
name_supply_->ReserveName("return");
}
-void CodeGenC::AddFunction(const PrimFunc& f) {
- // clear previous generated state.
- this->InitFuncState(f);
- // reserve keywords
- ReserveKeywordsAsUnique();
+void CodeGenC::PrintFunctionSignature(const String& function_name, const
PrimFunc& func,
+ std::ostream& os) {
+ PrintFuncPrefix(os);
+ PrintType(func->ret_type, os);
+ PrintExtraAttrs(func, os);
+ os << " " << function_name << "(";
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ tir::Var v = func->params[i];
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
- << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
- bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
-
- this->PrintFuncPrefix(stream);
- PrintType(f->ret_type, stream);
- this->PrintExtraAttrs(f);
- this->stream << " " << static_cast<std::string>(global_symbol.value()) <<
"(";
-
- for (size_t i = 0; i < f->params.size(); ++i) {
- tir::Var v = f->params[i];
- std::string vid = AllocVarID(v.get());
- if (i != 0) stream << ", ";
- if (v.dtype().is_handle()) {
- auto it = alloc_storage_scope_.find(v.get());
- if (it != alloc_storage_scope_.end()) {
- PrintStorageScope(it->second, stream);
- }
+ if (i > 0) {
+ os << ", ";
+ }
- PrintType(GetType(v), stream);
- // Register handle data type
- // TODO(tvm-team): consider simply keep type info in the
- // type annotation(via a normalizing rewriting).
- if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
- if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
- RegisterHandleType(v.get(), prim->dtype);
- }
- }
+ if (auto it = alloc_storage_scope_.find(v.get()); it !=
alloc_storage_scope_.end()) {
+ PrintStorageScope(it->second, os);
+ }
- if (no_alias) {
- PrintRestrict(v, stream);
+ PrintType(GetType(v), os);
+
+ bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
+ bool is_handle = v.dtype().is_handle();
+ if (no_alias && is_handle) {
+ PrintRestrict(v, os);
+ }
+
+ os << " " << AllocVarID(v.get());
+ }
+ os << ")";
+
+ // Register handle data type
+ // TODO(tvm-team): consider simply keep type info in the
+ // type annotation(via a normalizing rewriting).
+ for (const auto& param : func->params) {
+ if (auto* ptr = param->type_annotation.as<PointerTypeNode>()) {
+ if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
+ RegisterHandleType(param.get(), prim->dtype);
}
- } else {
- PrintType(GetType(v), stream);
}
- stream << ' ' << vid;
}
- stream << ") {\n";
+}
+
+void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) {
+ if (internal_functions_.count(gvar)) {
+ return;
+ }
+
+ auto function_name = [&]() -> String {
+ if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+ auto name = global_symbol.value();
+ ICHECK(!func_name_supply_->ContainsName(name))
+ << "Function " << gvar << " must use global symbol " << name
+ << ", but this name has already been used.";
+ func_name_supply_->ReserveName(name);
+ return name;
+ } else {
+ func_name_supply_->ReserveName(gvar->name_hint);
+ return gvar->name_hint;
+ }
+ }();
+
+ internal_functions_.insert({gvar, function_name});
+
+ InitFuncState(func);
+ PrintFunctionSignature(function_name, func, fwd_decl_stream);
+ fwd_decl_stream << ";\n";
+}
+
+String CodeGenC::GetFunctionName(const GlobalVar& gvar) {
+ auto it = internal_functions_.find(gvar);
+ ICHECK(it != internal_functions_.end())
+ << "Attempted to find name of " << gvar
+ << ", but no function with this GlobalVar has been declared";
+ return it->second;
+}
+
+void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
+ // If the function has already been forward-declared, this is a
+ // no-op.
+ DeclareFunction(gvar, f);
+ auto function_name = GetFunctionName(gvar);
+
+ // clear previous generated state.
+ InitFuncState(f);
+
+ PrintFunctionSignature(function_name, f, stream);
+ stream << " {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
@@ -130,9 +172,15 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
-void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
+void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {}
-std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
+std::string CodeGenC::Finish() {
+ std::ostringstream code;
+ code << decl_stream.str();
+ code << fwd_decl_stream.str();
+ code << stream.str();
+ return code.str();
+}
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
@@ -542,12 +590,17 @@ void CodeGenC::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value,
op->args, true, os);
- Array<Type> arg_types;
- for (size_t i = 1; i < op->args.size(); i++) {
- arg_types.push_back(GetType(op->args[i]));
+
+ // If the call_extern refers to an function within the IRModule, then
+ // the forward declaration is already provided from DeclareFunction.
+ if (!func_name_supply_->ContainsName(func->value)) {
+ Array<Type> arg_types;
+ for (size_t i = 1; i < op->args.size(); i++) {
+ arg_types.push_back(GetType(op->args[i]));
+ }
+ Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+ this->GenerateForwardFunctionDeclarations(func->value, arg_types,
ret_type);
}
- Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
- this->GenerateForwardFunctionDeclarations(func->value, arg_types,
ret_type);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)),
op_attr_global_symbol_[call_op],
@@ -615,9 +668,13 @@ void CodeGenC::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLINT(*)
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
+ } else if (auto opt = op->op.as<GlobalVar>()) {
+ auto gvar = opt.value();
+ auto callee_name = GetFunctionName(gvar);
+ PrintCallExtern(GetType(GetRef<PrimExpr>(op)), callee_name, op->args,
false, os);
} else {
- ICHECK(op->op.as<GlobalVarNode>());
- LOG(FATAL) << "Do not yet support cross function call";
+ LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a
recognized built-in, "
+ << "nor a GlobalVar reference to another function in the
IRModule";
}
}
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 93f9ea519c..2921a56ef3 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -65,12 +65,33 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
* \param output_ssa Whether output SSA.
*/
void Init(bool output_ssa);
+
/*!
- * \brief Add the function to the generated module.
- * \param f The function to be compiled.
+ * \brief Add the function declaration to the generated module,
+ * without defining it.
+ *
+ * \param gvar The GlobalVar representing the function.
+ * \param func The function to be compiled.
* \param whether to append return 0 in the end.
*/
- void AddFunction(const PrimFunc& f);
+ virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func);
+
+ /*!
+ * \brief Add the function to the generated module, including its
+ * declaration and definition.
+ *
+ * \param gvar The GlobalVar representing the function.
+ * \param func The function to be compiled.
+ */
+ virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func);
+
+ /*!
+ * \brief Get the name of a declared function
+ * \param gvar The GlobalVar of the function
+ * \returns The string name of the function
+ */
+ String GetFunctionName(const GlobalVar& gvar);
+
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
@@ -96,7 +117,23 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
PrintExpr(n, os);
return os.str();
}
+
// The following parts are overloadable print operations.
+
+ /*! \brief Print the function signature before the argument list
+ *
+ * The default implementation delegates out to PrintFuncPrefix and
+ * PrintExtraAttrs.
+ *
+ * \param function_name The name of the function
+ *
+ * \param func The function whose signature should be printed
+ *
+ * \param os The output stream
+ */
+ virtual void PrintFunctionSignature(const String& function_name, const
PrimFunc& func,
+ std::ostream& os);
+
/*!
* \brief Print the function header before the argument list
* \param os The output stream
@@ -109,7 +146,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
*
* Example: __launch_bounds__(256) for CUDA functions
*/
- virtual void PrintExtraAttrs(const PrimFunc& f);
+ virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os); //
NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
@@ -284,10 +321,24 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
private:
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
+
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
+
// binding of let variables. Enables duplicate var defs that map to same
value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual>
let_binding_;
+
+ /* \brief Map of GlobalVar to their symbol.
+ *
+ * For externally-exposed functions, this is given by the
+ * tvm::attr::kTarget attribute of the PrimFunc. For internal
+ * functions, this is the name of the function's GlobalVar, possibly
+ * altered to prevent duplicate names.
+ */
+ std::unordered_map<GlobalVar, String, ObjectPtrHash, ObjectPtrEqual>
internal_functions_;
+
+ /* \brief Name supply to generate unique function names */
+ NameSupply func_name_supply_{""};
};
} // namespace codegen
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 3255e11c5d..caef43e8af 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -75,19 +75,24 @@ void CodeGenCHost::InitGlobalContext() {
void CodeGenCHost::DefineModuleName() { decl_stream << "void* " <<
module_name_ << " = NULL;\n"; }
-void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
- << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
- function_names_.push_back(global_symbol.value());
+void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
+ bool emit_fwd_func_decl) {
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ if (global_symbol) {
+ function_names_.push_back(global_symbol.value());
+ }
emit_fwd_func_decl_ = emit_fwd_func_decl;
- CodeGenC::AddFunction(f);
- if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+ CodeGenC::AddFunction(gvar, func);
+ if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+ ICHECK(global_symbol.defined())
+ << "CodeGenCHost: The entry func must have the global_symbol
attribute, "
+ << "but function " << gvar << " only has attributes " << func->attrs;
+
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
- PrintType(f->ret_type, stream);
+ PrintType(func->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void*
out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
@@ -128,15 +133,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { //
NOLINT(*)
<< "TVM_DLL ";
}
-std::string CodeGenCHost::Finish() { // NOLINT(*)
- std::string ret = decl_stream.str();
- if (emit_fwd_func_decl_) {
- ret += fwd_decl_stream.str();
- }
- ret += stream.str();
- return ret;
-}
-
void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
@@ -437,42 +433,38 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(),
devices);
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
- PrimFunc aot_executor_fn;
-
- std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
- for (auto kv : mod->functions) {
- // Make sure that the executor function is the last one to be code
generated so that all the
- // symbols are available to __tvm_main__
- auto fun_name = std::string(kv.first->name_hint);
- bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function",
Bool(false)).value();
-
- if (is_aot_executor_fn) {
- aot_executor_fn = Downcast<PrimFunc>(kv.second);
- continue;
- }
- funcs.push_back(kv);
+
+ auto is_aot_executor_fn = [](const PrimFunc& func) -> bool {
+ return func->GetAttr<Bool>("runner_function", Bool(false)).value();
+ };
+
+ std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only
take PrimFunc";
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ funcs.push_back({gvar, prim_func});
}
// Sort functions
- std::sort(funcs.begin(), funcs.end(),
- [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
- std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
- std::string name_hint_a = kv_a.first->name_hint;
- std::string name_hint_b = kv_b.first->name_hint;
- return name_hint_a < name_hint_b;
- });
-
- // Add all functions except __tvm_main__
- for (auto& kv : funcs) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only
take PrimFunc";
- auto f = Downcast<PrimFunc>(kv.second);
- cg.AddFunction(f);
+ auto sort_key = [&is_aot_executor_fn](const auto& kv) {
+ return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
+ };
+ std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const
auto& kv_b) {
+ return sort_key(kv_a) < sort_key(kv_b);
+ });
+
+ // Declare all functions first. This ensures that all functions,
+ // including the __tvm_main__ used in AOT, have access to forward
+ // declarations of other functions in the IRModule.
+ for (const auto& [gvar, prim_func] : funcs) {
+ cg.DeclareFunction(gvar, prim_func);
}
- // Add __tvm_main__
- if (aot_executor_fn.defined()) {
- emit_fwd_func_decl = true;
- cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
+ // Codegen all functions. Passing emit_fwd_func_decl=true adds a
+ // forward declaration for any `builtin::call_extern`, based on the
+ // arguments provided to it.
+ for (const auto& [gvar, prim_func] : funcs) {
+ cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
// NOTE: it's possible that kRuntime attr is not attached when the mod was
built with tvm.build().
@@ -484,7 +476,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
} else {
runtime = relay::Runtime::Create("cpp", {});
}
- if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) {
+
+ bool has_aot_executor_fn = std::any_of(
+ funcs.begin(), funcs.end(), [&](const auto& kv) { return
is_aot_executor_fn(kv.second); });
+ if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) {
cg.InitGlobalContext();
}
diff --git a/src/target/source/codegen_c_host.h
b/src/target/source/codegen_c_host.h
index 694104afc0..aeba685f74 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -44,8 +44,7 @@ class CodeGenCHost : public CodeGenC {
const std::unordered_set<std::string>& devices);
void InitGlobalContext();
- void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false);
- std::string Finish() final;
+ void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool
emit_fwd_func_decl = false);
/*!
* \brief Add functions from the (unordered) range to the current module in
a deterministic
* order. This helps with debugging.
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index a91f8b0164..7639ce6065 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor {
PrimExpr threadIdx_z_ext = Integer(1);
};
-void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
+void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
ThreadIdxExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
@@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
// unable to extract the number of threads per block, hence directly
return
return;
}
- stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
+ os << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
}
}
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 3ec0c3bc2d..bc7b34b500 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC {
}
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
- void PrintExtraAttrs(const PrimFunc& f) final;
+ void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; //
NOLINT(*)
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final;
// NOLINT(*)
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index b8c30691e2..ebb7566489 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -36,6 +36,8 @@ namespace codegen {
void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
+ // skip the first underscore, so SSA variable starts from _1
+ name_supply_->FreshName("v_");
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
@@ -52,37 +54,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target)
{
<< "};\n\n";
}
-void CodeGenMetal::AddFunction(const PrimFunc& f) {
- // clear previous generated state.
- this->InitFuncState(f);
- // skip the first underscore, so SSA variable starts from _1
- name_supply_->FreshName("v_");
-
+void CodeGenMetal::PrintFunctionSignature(const String& function_name, const
PrimFunc& func,
+ std::ostream& os) {
// add to alloc buffer type.
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header.
- this->stream << "kernel void " <<
static_cast<std::string>(global_symbol.value()) << "(";
+ os << "kernel void " << static_cast<std::string>(global_symbol.value()) <<
"(";
// Buffer arguments
size_t num_buffer = 0;
size_t limit =
target_->GetAttr<Integer>("max_function_args").value().IntValue();
- if (f->params.size() > limit) {
+ if (func->params.size() > limit) {
LOG(WARNING) << "Probably you won't be able to execute your kernel due to
high number of "
"buffers in the kernel";
}
- for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
- Var v = f->params[i];
+ for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
+ Var v = func->params[i];
if (!v.dtype().is_handle()) break;
- stream << " ";
+ os << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
- PrintStorageScope(it->second, stream);
+ PrintStorageScope(it->second, os);
}
- PrintType(GetType(v), stream);
+ PrintType(GetType(v), os);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
@@ -91,19 +89,18 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
RegisterHandleType(v.get(), prim->dtype);
}
}
- stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
+ os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
- size_t nargs = f->params.size() - num_buffer;
+ size_t nargs = func->params.size() - num_buffer;
std::string varg = name_supply_->FreshName("arg");
if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol.value())
+ "_args_t";
- stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer("
<< num_buffer
- << ") ]],\n";
+ os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" <<
num_buffer << ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
- for (size_t i = num_buffer; i < f->params.size(); ++i) {
- Var v = f->params[i];
+ for (size_t i = num_buffer; i < func->params.size(); ++i) {
+ Var v = func->params[i];
ICHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
@@ -131,7 +128,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
int work_dim = 0;
- auto launch_params =
f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
+ auto launch_params =
func->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
for (const auto& tag : launch_params) {
if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
runtime::ThreadScope scope = runtime::ThreadScope::Create(tag);
@@ -141,22 +138,16 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
if (work_dim != 0) {
// use ushort by default for now
- stream << " ";
- PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
- stream << " blockIdx [[threadgroup_position_in_grid]],\n";
- stream << " ";
- PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
- stream << " threadIdx [[thread_position_in_threadgroup]]\n";
+ os << " ";
+ PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
+ os << " blockIdx [[threadgroup_position_in_grid]],\n";
+ os << " ";
+ PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
+ os << " threadIdx [[thread_position_in_threadgroup]]\n";
}
thread_work_dim_ = work_dim;
- // the function scope.
- stream << ") {\n";
- int func_scope = this->BeginScope();
- this->PrintStmt(f->body);
- this->EndScope(func_scope);
- this->PrintIndent();
- this->stream << "}\n\n";
+ os << ")";
}
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
@@ -342,27 +333,33 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";
- for (auto kv : mod->functions) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only
take PrimFunc";
- auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined());
- std::string func_name = global_symbol.value();
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only
take PrimFunc";
+ auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+ ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenMetal: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ functions.Set(gvar, prim_func);
+ }
- source_maker << "// Function: " << func_name << "\n";
+ for (auto [gvar, prim_func] : functions) {
+ source_maker << "// Function: " << gvar->name_hint << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
- auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
- ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
- << "CodeGenMetal: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
- cg.AddFunction(f);
+ for (auto [other_gvar, other_prim_func] : functions) {
+ cg.DeclareFunction(other_gvar, other_prim_func);
+ }
+ cg.AddFunction(gvar, prim_func);
+
std::string fsource = cg.Finish();
source_maker << fsource << "\n";
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource, target).operator std::string();
}
- smap[func_name] = fsource;
+ smap[cg.GetFunctionName(gvar)] = fsource;
}
return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt,
source_maker.str());
diff --git a/src/target/source/codegen_metal.h
b/src/target/source/codegen_metal.h
index 36be10d163..26c991e60d 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -38,7 +38,8 @@ class CodeGenMetal final : public CodeGenC {
explicit CodeGenMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
- void AddFunction(const PrimFunc& f); // NOLINT(*)
+ void PrintFunctionSignature(const String& function_name, const PrimFunc&
func,
+ std::ostream& os) override;
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final;
// NOLINT(*)
void PrintStorageSync(const CallNode* op) final;
// NOLINT(*)
diff --git a/src/target/source/codegen_opencl.cc
b/src/target/source/codegen_opencl.cc
index c15d2253d7..da6a4de619 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -595,18 +595,26 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only
take PrimFunc";
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+ ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenOpenCL: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+ functions.Set(gvar, prim_func);
+ }
+
std::stringstream code;
const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
- for (auto kv : mod->functions) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only
take PrimFunc";
- code << "// Function: " << kv.first->name_hint << std::endl;
+ for (auto [gvar, prim_func] : functions) {
+ code << "// Function: " << gvar->name_hint << std::endl;
CodeGenOpenCL cg;
cg.Init(output_ssa);
- auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
- ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
- << "CodeGenOpenCL: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
- cg.AddFunction(f);
+ for (auto [other_gvar, other_prim_func] : functions) {
+ cg.DeclareFunction(other_gvar, other_prim_func);
+ }
+ cg.AddFunction(gvar, prim_func);
std::string fsource = cg.Finish();
if (fpostproc) {
fsource = (*fpostproc)(fsource, target).operator std::string();
diff --git a/src/target/source/codegen_vhls.cc
b/src/target/source/codegen_vhls.cc
index 83046de107..aa7a32320c 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -145,13 +145,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target)
{
// Generate source code for get_source().
cg.Init(output_ssa);
- for (auto kv : mod->functions) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only
take PrimFunc";
- auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only
take PrimFunc";
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenVLHS: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
- cg.AddFunction(f);
+ functions.Set(gvar, prim_func);
+ }
+
+ for (auto [gvar, prim_func] : functions) {
+ cg.DeclareFunction(gvar, prim_func);
+ }
+ for (auto [gvar, prim_func] : functions) {
+ cg.AddFunction(gvar, prim_func);
}
std::string whole_code = cg.Finish();
@@ -159,21 +167,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target)
{
// Generate source code for compilation.
Array<Array<runtime::String>> kernel_info;
- for (auto kv : mod->functions) {
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only
take PrimFunc";
- auto f = Downcast<PrimFunc>(kv.second);
+ for (auto [gvar, prim_func] : functions) {
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
- cg.AddFunction(f);
+
+ for (auto [other_gvar, other_prim_func] : functions) {
+ cg.DeclareFunction(other_gvar, other_prim_func);
+ }
+ cg.AddFunction(gvar, prim_func);
std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code, target).operator std::string();
}
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
- << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
- kernel_info.push_back({global_symbol.value(), code});
+ auto function_name = cg.GetFunctionName(gvar);
+ kernel_info.push_back({function_name, code});
}
std::string xclbin;
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 4d1d834c7f..6a6712a4ce 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -45,6 +45,12 @@ std::string CodeGenWebGPU::Finish() {
void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
+ // skip the first underscore, so SSA variable starts from
+ name_supply_->FreshName("v_");
+ // Setup the thread group info.
+ ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+ ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
+
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
@@ -56,28 +62,12 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
-void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
- // clear previous generated state.
- this->InitFuncState(f);
- // skip the first underscore, so SSA variable starts from
- name_supply_->FreshName("v_");
- // Setup the thread group info.
- ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
- ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
-
- // add to alloc buffer type.
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
- << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
-
- decl_stream << "//----------------------------------------\n"
- << "// function: " << global_symbol.value() << "\n"
- << "//----------------------------------------\n";
-
+void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const
PrimFunc& func,
+ std::ostream& os) {
std::vector<Var> pod_args;
int num_buffer = 0;
// setup buffer argumemts
- for (Var arg : f->params) {
+ for (Var arg : func->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
@@ -111,16 +101,18 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
}
// add to alloc buffer type.
// Function header.
- this->stream << "fn main(\n"
- << " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
- << " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
- << ") {\n";
- // the function scope.
- int func_scope = this->BeginScope();
- this->PrintStmt(f->body);
- this->EndScope(func_scope);
- this->PrintIndent();
- this->stream << "}\n\n";
+ os << "fn main(\n"
+ << " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
+ << " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
+ << ")";
+}
+
+void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
+ CodeGenC::AddFunction(gvar, func);
+ decl_stream << "//----------------------------------------\n"
+ << "// function: " << GetFunctionName(gvar) << "\n"
+ << "//----------------------------------------\n";
+
// anotate workgroup
this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0]
<< ", "
<< workgroup_size_[1] << ", " << workgroup_size_[2] <<
")\n";
@@ -524,22 +516,31 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
bool output_ssa = false;
- std::unordered_map<std::string, std::string> smap;
- for (auto kv : mod->functions) {
- CodeGenWebGPU cg(target);
- ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only
take PrimFunc";
- auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+ Map<GlobalVar, PrimFunc> functions;
+ for (auto [gvar, base_func] : mod->functions) {
+ ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only
take PrimFunc";
+ auto prim_func = Downcast<PrimFunc>(base_func);
+ auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenWebGPU: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
- auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol
attribute";
- std::string f_name = global_symbol.value();
+ functions.Set(gvar, prim_func);
+ }
+
+ std::unordered_map<std::string, std::string> smap;
+ for (auto [gvar, prim_func] : functions) {
+ CodeGenWebGPU cg(target);
cg.Init(output_ssa);
- cg.AddFunction(f);
+
+ for (auto [other_gvar, other_prim_func] : functions) {
+ cg.DeclareFunction(other_gvar, other_prim_func);
+ }
+ cg.AddFunction(gvar, prim_func);
+
std::string code = cg.Finish();
- smap[f_name] = code;
+ smap[cg.GetFunctionName(gvar)] = code;
}
auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod));
return runtime::Module(n);
diff --git a/src/target/source/codegen_webgpu.h
b/src/target/source/codegen_webgpu.h
index 57f226ba8a..6ae942a3ad 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,7 +48,9 @@ class CodeGenWebGPU final : public CodeGenC {
explicit CodeGenWebGPU(Target target);
// overrides
std::string Finish() final;
- void AddFunction(const PrimFunc& f); // NOLINT(*)
+ void PrintFunctionSignature(const String& function_name, const PrimFunc&
func,
+ std::ostream& os) final;
+ void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final;
void InitFuncState(const PrimFunc& f) final;
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
diff --git a/src/target/source/source_module.cc
b/src/target/source/source_module.cc
index a6f4b5bb3e..90640a6db6 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -613,12 +613,14 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
}
for (const tir::Var& pool_var : metadata_->pools) {
+ call_args_ss << "((uint8_t*)";
String pool_name =
metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name;
if (IsInternalWorkspaceBuffer(pool_var)) {
- call_args_ss << "&" << pool_name << ",";
+ call_args_ss << "&" << pool_name;
} else {
- call_args_ss << "workspace_pools->" <<
tvm::runtime::SanitizeName(pool_name) << ",";
+ call_args_ss << "workspace_pools->" <<
tvm::runtime::SanitizeName(pool_name);
}
+ call_args_ss << "),";
}
for (const String& device : metadata_->devices) {
call_args_ss << "devices->" << device << ",";
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 39214c4546..fd14f48921 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) {
return ptr->type_annotation;
}
}
+
+ if (auto* access = expr.as<tir::CallNode>()) {
+ if (access->op.same_as(builtin::tvm_access_ptr())) {
+ ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have
empty arguments";
+ auto type_annotation = Downcast<Call>(access->args[0]);
+ static auto builtin_op = Op::Get("tir.type_annotation");
+ ICHECK(type_annotation->op.same_as(builtin_op))
+ << "Expected the first argument of builtin tvm_access_ptr() "
+ << "to be a type annotation, but found " << type_annotation->op;
+ return PointerType(PrimType(type_annotation->dtype));
+ }
+ }
+
+ if (auto* address_of = expr.as<tir::CallNode>()) {
+ if (address_of->op.same_as(builtin::address_of())) {
+ ICHECK_EQ(address_of->args.size(), 1)
+ << "Builtin address_of() expects a single argument, but received
arguments "
+ << address_of->args;
+ auto* address = address_of->args[0].as<BufferLoadNode>();
+ ICHECK(address)
+ << "Builtin address_of() expects the argument to be a BufferLoad,
but received argument "
+ << address_of->args[0];
+
+ return PointerType(PrimType(address->dtype));
+ }
+ }
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
return GetTypeFromRuntimeDataType(dtype);
diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py
b/tests/python/relay/aot/test_crt_forward_declarations.py
index e001a62ab9..99e2f0c923 100644
--- a/tests/python/relay/aot/test_crt_forward_declarations.py
+++ b/tests/python/relay/aot/test_crt_forward_declarations.py
@@ -33,8 +33,6 @@ from tvm.micro.testing.aot_test_utils import (
AOTTestRunner,
)
-pytestmark = pytest.mark.skip(reason="regression introduced in #15725")
-
def _change_ndarray_layout(arr, src_layout, dst_layout):
"""Makes a copy of an ndarray, reshaping it to a new data layout.
@@ -162,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api,
test_runner):
lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0]
main_source = lib_mod.get_source()
- assert main_source.count("int32_t
tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1
- assert main_source.count("int32_t tvmgen_default_fused_layout_transform")
== 3
+ assert main_source.count("int32_t
tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2
+ assert main_source.count("int32_t tvmgen_default_fused_layout_transform")
== 6
@tvm.testing.requires_corstone300
diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
index 7bea7577b6..f6145cd1c5 100644
--- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
+++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
@@ -135,8 +135,13 @@ def test_write_3x3_depthwise_code():
#define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t
tensordot_opt_x1_int16_w48_3x3_000(
- int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias,
int32_t *scale
+ int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+ int32_t *bias, int32_t *scale
) {
+ int32_t *output = output_arg;
+ int32_t *tensor = tensor_arg;
+ int32_t *kernel = kernel_arg;
+
int32_t sum_0 = *bias;
int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -188,8 +193,13 @@ def test_odd_width_3x3_depthwise_strides_code():
#define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t
tensordot_opt_x2_int16_w49_3x3_000_2_4(
- int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias,
int32_t *scale
+ int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+ int32_t *bias, int32_t *scale
) {
+ int32_t *output = output_arg;
+ int32_t *tensor = tensor_arg;
+ int32_t *kernel = kernel_arg;
+
int32_t sum_0 = *bias, sum_1 = *bias;
int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -251,8 +261,13 @@ def test_1x1x8_convolution_code():
#define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t
tensordot_opt_x4_int16_w384_1x8_000_8_1(
- int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias,
int32_t *scale
+ int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+ int32_t *bias, int32_t *scale
) {
+ int32_t *output = output_arg;
+ int32_t *tensor = tensor_arg;
+ int32_t *kernel = kernel_arg;
+
int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias;
int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -349,8 +364,13 @@ def test_3x3x3_offset_convolution_code():
#define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t
tensordot_opt_x1_int16_w288_3x9_111(
- int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias,
int32_t *scale
+ int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+ int32_t *bias, int32_t *scale
) {
+ int32_t *output = output_arg;
+ int32_t *tensor = tensor_arg;
+ int32_t *kernel = kernel_arg;
+
int32_t sum_0 = *bias;
int32_t tensor__unknown__y00_x00 = tensor[0];
diff --git a/tests/python/unittest/test_target_codegen_c_host.py
b/tests/python/unittest/test_target_codegen_c_host.py
index d02f8744f1..3aca0fc8c7 100644
--- a/tests/python/unittest/test_target_codegen_c_host.py
+++ b/tests/python/unittest/test_target_codegen_c_host.py
@@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import tvm
import tvm.testing
+
from tvm import te
-import numpy as np
from tvm.contrib import utils
+from tvm.script import tir as T, ir as I
+
+import numpy as np
def test_add():
@@ -228,11 +232,39 @@ def test_call_packed():
check_global_packed_func()
+def test_subroutine_call():
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, dtype="float32")):
+ mod.subroutine(A.data)
+
+ @T.prim_func(private=True)
+ def subroutine(A_data: T.handle("float32")):
+ A = T.decl_buffer(1, dtype="float32", data=A_data)
+ A[0] = 42.0
+
+ built = tvm.build(mod, target="c")
+
+ func_names = list(built["get_func_names"]())
+ assert (
+ "main" in func_names
+ ), "Externally exposed functions should be listed in available functions."
+ assert (
+ "subroutine" not in func_names
+ ), "Internal function should not be listed in available functions."
+
+ source = built.get_source()
+ assert (
+ source.count("main(void*") == 2
+ ), "Expected two occurrences, for forward-declaration and definition"
+ assert (
+ source.count("subroutine(float*") == 2
+ ), "Expected two occurrences, for forward-declaration and definition"
+ assert (
+ source.count("subroutine(") == 3
+ ), "Expected three occurrences, for forward-declaration, definition, and
call from main."
+
+
if __name__ == "__main__":
- test_add()
- test_add_pipeline()
- test_reinterpret()
- test_ceil()
- test_floor()
- test_round()
- test_call_packed()
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 588a92d87c..61f0892a9c 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -268,6 +268,7 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
#define int64_t long long
#define uint64_t unsigned long long
#endif
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C);
extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
__shared__ float A_shared[64];
__shared__ float B_shared[64];