This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a81aef40b [Fix] Use proper target in VerifyGPUCode (#13548)
3a81aef40b is described below

commit 3a81aef40bca9479d4a691b3a80e42b01f3f8a0d
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Dec 4 18:22:38 2022 -0500

    [Fix] Use proper target in VerifyGPUCode (#13548)
    
    Previously, the VerifyGPUCode post-processor uses hardcoded target 
`Target("cuda")` for applying pass LowerIntrin. This is a bit problematic since 
the actual target can be other GPU target (e.g., Metal). Therefore, this PR 
changes the hardcoded target to be the actual target.
---
 src/meta_schedule/postproc/verify_gpu_code.cc | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc 
b/src/meta_schedule/postproc/verify_gpu_code.cc
index ae6f3474bb..99ffc1bfcd 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -113,19 +113,20 @@ Integer Extract(const Target& target, const char* name) {
 /*! \brief Verify the correctness of the generated GPU code. */
 class VerifyGPUCodeNode : public PostprocNode {
  public:
+  Target target_{nullptr};
   Map<String, PrimExpr> target_constraints_{nullptr};
   int thread_warp_size_ = -1;
 
   void InitializeWithTuneContext(const TuneContext& context) final {
     ICHECK(context->target.defined());
-    Target target = context->target.value();
+    this->target_ = context->target.value();
     this->target_constraints_ = Map<String, PrimExpr>{
-        {"max_shared_memory_per_block", Extract(target, 
"max_shared_memory_per_block")},
-        {"max_threads_per_block", Extract(target, "max_threads_per_block")},
+        {"max_shared_memory_per_block", Extract(this->target_, 
"max_shared_memory_per_block")},
+        {"max_threads_per_block", Extract(this->target_, 
"max_threads_per_block")},
         {"max_vthread", Integer(8)},
         {"max_vector_bytes", Integer(16)},
     };
-    thread_warp_size_ = Extract(target, "thread_warp_size").IntValue();
+    thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue();
   }
 
   bool Verify(const IRModule& mod) const {
@@ -180,7 +181,7 @@ class VerifyGPUCodeNode : public PostprocNode {
           transform::PassContext pass_ctx = transform::PassContext::Current();
           tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), 
"global_symbol",
                                      runtime::String(g_var->name_hint));
-          f = WithAttr(f, tvm::attr::kTarget, Target("cuda"));  // Required 
for LowerIntrin
+          f = WithAttr(f, tvm::attr::kTarget, this->target_);  // Required for 
LowerIntrin
           bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", 
Bool(true)).value();
           if (noalias) {
             f = WithAttr(std::move(f), "tir.noalias", Bool(true));

Reply via email to