This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3a81aef40b [Fix] Use proper target in VerifyGPUCode (#13548)
3a81aef40b is described below
commit 3a81aef40bca9479d4a691b3a80e42b01f3f8a0d
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Dec 4 18:22:38 2022 -0500
[Fix] Use proper target in VerifyGPUCode (#13548)
Previously, the VerifyGPUCode post-processor uses hardcoded target
`Target("cuda")` for applying pass LowerIntrin. This is a bit problematic since
the actual target can be other GPU target (e.g., Metal). Therefore, this PR
changes the hardcoded target to be the actual target.
---
src/meta_schedule/postproc/verify_gpu_code.cc | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc
b/src/meta_schedule/postproc/verify_gpu_code.cc
index ae6f3474bb..99ffc1bfcd 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -113,19 +113,20 @@ Integer Extract(const Target& target, const char* name) {
/*! \brief Verify the correctness of the generated GPU code. */
class VerifyGPUCodeNode : public PostprocNode {
public:
+ Target target_{nullptr};
Map<String, PrimExpr> target_constraints_{nullptr};
int thread_warp_size_ = -1;
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
- Target target = context->target.value();
+ this->target_ = context->target.value();
this->target_constraints_ = Map<String, PrimExpr>{
- {"max_shared_memory_per_block", Extract(target,
"max_shared_memory_per_block")},
- {"max_threads_per_block", Extract(target, "max_threads_per_block")},
+ {"max_shared_memory_per_block", Extract(this->target_,
"max_shared_memory_per_block")},
+ {"max_threads_per_block", Extract(this->target_,
"max_threads_per_block")},
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)},
};
- thread_warp_size_ = Extract(target, "thread_warp_size").IntValue();
+ thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue();
}
bool Verify(const IRModule& mod) const {
@@ -180,7 +181,7 @@ class VerifyGPUCodeNode : public PostprocNode {
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func),
"global_symbol",
runtime::String(g_var->name_hint));
- f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required
for LowerIntrin
+ f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for
LowerIntrin
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias",
Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));