This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s1
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit e48ec28531ea795ff85546604c16a11b2b6acc90
Author: tqchen <[email protected]>
AuthorDate: Sun Apr 13 16:36:47 2025 -0400

    Update
---
 ffi/include/tvm/ffi/container/variant.h | 4 ++++
 src/script/ir_builder/tir/ir.cc         | 5 +++--
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/ffi/include/tvm/ffi/container/variant.h 
b/ffi/include/tvm/ffi/container/variant.h
index a3e5f3ad44..120819f3d4 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -102,6 +102,10 @@ class Variant {
     return std::move(data_).operator T();
   }
 
+  TVM_FFI_INLINE std::string GetTypeKey() const {
+    return data_.GetTypeKey();
+  }
+
  private:
   friend struct TypeTraits<Variant<V...>>;
   friend struct ObjectPtrHash;
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index bd24606b0c..383a31a112 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -18,6 +18,7 @@
  */
 #include <tvm/arith/analyzer.h>
 #include <tvm/script/ir_builder/tir/ir.h>
+#include <tvm/ffi/container/variant.h>
 
 #include "./utils.h"
 
@@ -710,14 +711,14 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread")
-    .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) {
+    .set_body_typed([](ffi::Variant<tvm::tir::Var, String> thread_tag_or_var, 
PrimExpr extent) {
       if (auto var = thread_tag_or_var.as<tvm::tir::Var>()) {
         return LaunchThread(var.value(), extent);
       } else if (auto str = thread_tag_or_var.as<String>()) {
         return LaunchThread(str.value(), extent);
       } else {
         LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: "
-                   << thread_tag_or_var->GetTypeKey();
+                   << thread_tag_or_var.GetTypeKey();
         throw;
       }
     });

Reply via email to