This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit e48ec28531ea795ff85546604c16a11b2b6acc90 Author: tqchen <[email protected]> AuthorDate: Sun Apr 13 16:36:47 2025 -0400 Update --- ffi/include/tvm/ffi/container/variant.h | 4 ++++ src/script/ir_builder/tir/ir.cc | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index a3e5f3ad44..120819f3d4 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -102,6 +102,10 @@ class Variant { return std::move(data_).operator T(); } + TVM_FFI_INLINE std::string GetTypeKey() const { + return data_.GetTypeKey(); + } + private: friend struct TypeTraits<Variant<V...>>; friend struct ObjectPtrHash; diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index bd24606b0c..383a31a112 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -18,6 +18,7 @@ */ #include <tvm/arith/analyzer.h> #include <tvm/script/ir_builder/tir/ir.h> +#include <tvm/ffi/container/variant.h> #include "./utils.h" @@ -710,14 +711,14 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") - .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) { + .set_body_typed([](ffi::Variant<tvm::tir::Var, String> thread_tag_or_var, PrimExpr extent) { if (auto var = thread_tag_or_var.as<tvm::tir::Var>()) { return LaunchThread(var.value(), extent); } else if (auto str = thread_tag_or_var.as<String>()) { return LaunchThread(str.value(), extent); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " - << thread_tag_or_var->GetTypeKey(); + << thread_tag_or_var.GetTypeKey(); throw; } });
