This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ab1351c9d [FFI][REFACTOR] Move Downcast out of ffi for now (#18198)
0ab1351c9d is described below

commit 0ab1351c9da0d5918d58789d359ee1b5bf470a15
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Aug 8 14:57:00 2025 -0400

    [FFI][REFACTOR] Move Downcast out of ffi for now (#18198)
    
    Downcast was added for backward compact reasons and it have
    duplicated features as Any.cast. This PR moves it out of ffi
    to node for now so the ffi part contains minimal set of implementations.
---
 ffi/include/tvm/ffi/any.h                        |  1 -
 ffi/include/tvm/ffi/cast.h                       | 94 +-----------------------
 ffi/tests/cpp/test_string.cc                     |  3 +-
 {ffi/include/tvm/ffi => include/tvm/node}/cast.h | 79 ++++----------------
 include/tvm/node/node.h                          |  3 +-
 include/tvm/runtime/disco/session.h              | 15 +++-
 include/tvm/runtime/object.h                     |  1 -
 include/tvm/runtime/vm/vm.h                      |  8 +-
 src/node/container_printing.cc                   |  3 +-
 src/node/repr_printer.cc                         |  1 +
 src/runtime/profiling.cc                         | 10 +--
 src/runtime/vm/attn_backend.cc                   | 28 +++----
 src/runtime/vm/cuda/cuda_graph_builtin.cc        |  6 +-
 src/runtime/vm/vm.cc                             |  6 +-
 src/target/source/codegen_metal.cc               |  2 +-
 src/tir/transforms/memhammer_coalesce.cc         |  4 +-
 16 files changed, 69 insertions(+), 195 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 55eff8802a..ed34328d1e 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -635,7 +635,6 @@ struct AnyEqual {
     }
   }
 };
-
 }  // namespace ffi
 
 // Expose to the tvm namespace for usability
diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h
index 997c0bb178..c75d4a075f 100644
--- a/ffi/include/tvm/ffi/cast.h
+++ b/ffi/include/tvm/ffi/cast.h
@@ -18,21 +18,18 @@
  */
 /*!
  * \file tvm/ffi/cast.h
- * \brief Value casting support
+ * \brief Extra value casting helpers
  */
 #ifndef TVM_FFI_CAST_H_
 #define TVM_FFI_CAST_H_
 
 #include <tvm/ffi/any.h>
-#include <tvm/ffi/dtype.h>
-#include <tvm/ffi/error.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/optional.h>
 
-#include <utility>
-
 namespace tvm {
 namespace ffi {
+
 /*!
  * \brief Get a reference type from a raw object ptr type
  *
@@ -46,7 +43,7 @@ namespace ffi {
  * \return The corresponding RefType
  */
 template <typename RefType, typename ObjectType>
-TVM_FFI_INLINE RefType GetRef(const ObjectType* ptr) {
+inline RefType GetRef(const ObjectType* ptr) {
   static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>,
                 "Can only cast to the ref of same container type");
 
@@ -75,92 +72,9 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr) {
                 "Can only cast to the ref of same container type");
   return details::ObjectUnsafe::ObjectPtrFromUnowned<BaseType>(ptr);
 }
-
-/*!
- * \brief Downcast a base reference type to a more specific type.
- *
- * \param ref The input reference
- * \return The corresponding SubRef.
- * \tparam SubRef The target specific reference type.
- * \tparam BaseRef the current reference type.
- */
-template <typename SubRef, typename BaseRef,
-          typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>>
-inline SubRef Downcast(BaseRef ref) {
-  if (ref.defined()) {
-    if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
-      TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " 
to "
-                               << SubRef::ContainerType::_type_key << " 
failed.";
-    }
-    return 
SubRef(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(ref)));
-  } else {
-    if constexpr (is_optional_type_v<SubRef> || SubRef::_type_is_nullable) {
-      return SubRef(ObjectPtr<Object>(nullptr));
-    }
-    TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `"
-                             << SubRef::ContainerType::_type_key
-                             << "` is not allowed. Use Downcast<Optional<T>> 
instead.";
-    TVM_FFI_UNREACHABLE();
-  }
-}
-
-/*!
- * \brief Downcast any to a specific type
- *
- * \param ref The input reference
- * \return The corresponding SubRef.
- * \tparam T The target specific reference type.
- */
-template <typename T>
-inline T Downcast(const Any& ref) {
-  if constexpr (std::is_same_v<T, Any>) {
-    return ref;
-  } else {
-    return ref.cast<T>();
-  }
-}
-
-/*!
- * \brief Downcast any to a specific type
- *
- * \param ref The input reference
- * \return The corresponding SubRef.
- * \tparam T The target specific reference type.
- */
-template <typename T>
-inline T Downcast(Any&& ref) {
-  if constexpr (std::is_same_v<T, Any>) {
-    return std::move(ref);
-  } else {
-    return std::move(ref).cast<T>();
-  }
-}
-
-/*!
- * \brief Downcast std::optional<Any> to std::optional<T>
- *
- * \param ref The input reference
- * \return The corresponding SubRef.
- * \tparam OptionalType The target optional type
- */
-template <typename OptionalType, typename = 
std::enable_if_t<is_optional_type_v<OptionalType>>>
-inline OptionalType Downcast(const std::optional<Any>& ref) {
-  if (ref.has_value()) {
-    if constexpr (std::is_same_v<OptionalType, Any>) {
-      return *ref;
-    } else {
-      return (*ref).cast<OptionalType>();
-    }
-  } else {
-    return OptionalType(std::nullopt);
-  }
-}
-
 }  // namespace ffi
 
-// Expose to the tvm namespace
-// Rationale: convinience and no ambiguity
-using ffi::Downcast;
+using ffi::GetObjectPtr;
 using ffi::GetRef;
 }  // namespace tvm
 #endif  // TVM_FFI_CAST_H_
diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc
index 364f2f6540..8522aa93a3 100644
--- a/ffi/tests/cpp/test_string.cc
+++ b/ffi/tests/cpp/test_string.cc
@@ -18,7 +18,6 @@
  */
 #include <gtest/gtest.h>
 #include <tvm/ffi/any.h>
-#include <tvm/ffi/cast.h>
 #include <tvm/ffi/string.h>
 
 namespace {
@@ -266,7 +265,7 @@ TEST(String, Cast) {
   string source = "this is a string";
   String s{source};
   Any r = s;
-  String s2 = Downcast<String>(r);
+  String s2 = r.cast<String>();
 }
 
 TEST(String, Concat) {
diff --git a/ffi/include/tvm/ffi/cast.h b/include/tvm/node/cast.h
similarity index 54%
copy from ffi/include/tvm/ffi/cast.h
copy to include/tvm/node/cast.h
index 997c0bb178..ae23c9e9aa 100644
--- a/ffi/include/tvm/ffi/cast.h
+++ b/include/tvm/node/cast.h
@@ -17,13 +17,14 @@
  * under the License.
  */
 /*!
- * \file tvm/ffi/cast.h
- * \brief Value casting support
+ * \file tvm/node/cast.h
+ * \brief Value casting helpers
  */
-#ifndef TVM_FFI_CAST_H_
-#define TVM_FFI_CAST_H_
+#ifndef TVM_NODE_CAST_H_
+#define TVM_NODE_CAST_H_
 
 #include <tvm/ffi/any.h>
+#include <tvm/ffi/cast.h>
 #include <tvm/ffi/dtype.h>
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/object.h>
@@ -32,49 +33,6 @@
 #include <utility>
 
 namespace tvm {
-namespace ffi {
-/*!
- * \brief Get a reference type from a raw object ptr type
- *
- *  It is always important to get a reference type
- *  if we want to return a value as reference or keep
- *  the object alive beyond the scope of the function.
- *
- * \param ptr The object pointer
- * \tparam RefType The reference type
- * \tparam ObjectType The object type
- * \return The corresponding RefType
- */
-template <typename RefType, typename ObjectType>
-TVM_FFI_INLINE RefType GetRef(const ObjectType* ptr) {
-  static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>,
-                "Can only cast to the ref of same container type");
-
-  if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
-    if (ptr == nullptr) {
-      return RefType(ObjectPtr<Object>(nullptr));
-    }
-  } else {
-    TVM_FFI_ICHECK_NOTNULL(ptr);
-  }
-  return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
-      const_cast<Object*>(static_cast<const Object*>(ptr))));
-}
-
-/*!
- * \brief Get an object ptr type from a raw object ptr.
- *
- * \param ptr The object pointer
- * \tparam BaseType The reference type
- * \tparam ObjectType The object type
- * \return The corresponding RefType
- */
-template <typename BaseType, typename ObjectType>
-inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr) {
-  static_assert(std::is_base_of<BaseType, ObjectType>::value,
-                "Can only cast to the ref of same container type");
-  return details::ObjectUnsafe::ObjectPtrFromUnowned<BaseType>(ptr);
-}
 
 /*!
  * \brief Downcast a base reference type to a more specific type.
@@ -85,17 +43,17 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr) {
  * \tparam BaseRef the current reference type.
  */
 template <typename SubRef, typename BaseRef,
-          typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>>
+          typename = std::enable_if_t<std::is_base_of_v<ffi::ObjectRef, 
BaseRef>>>
 inline SubRef Downcast(BaseRef ref) {
   if (ref.defined()) {
     if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
       TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " 
to "
                                << SubRef::ContainerType::_type_key << " 
failed.";
     }
-    return 
SubRef(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(ref)));
+    return 
SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(std::move(ref)));
   } else {
-    if constexpr (is_optional_type_v<SubRef> || SubRef::_type_is_nullable) {
-      return SubRef(ObjectPtr<Object>(nullptr));
+    if constexpr (ffi::is_optional_type_v<SubRef> || 
SubRef::_type_is_nullable) {
+      return SubRef(ffi::ObjectPtr<ffi::Object>(nullptr));
     }
     TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `"
                              << SubRef::ContainerType::_type_key
@@ -112,7 +70,7 @@ inline SubRef Downcast(BaseRef ref) {
  * \tparam T The target specific reference type.
  */
 template <typename T>
-inline T Downcast(const Any& ref) {
+inline T Downcast(const ffi::Any& ref) {
   if constexpr (std::is_same_v<T, Any>) {
     return ref;
   } else {
@@ -128,7 +86,7 @@ inline T Downcast(const Any& ref) {
  * \tparam T The target specific reference type.
  */
 template <typename T>
-inline T Downcast(Any&& ref) {
+inline T Downcast(ffi::Any&& ref) {
   if constexpr (std::is_same_v<T, Any>) {
     return std::move(ref);
   } else {
@@ -143,10 +101,10 @@ inline T Downcast(Any&& ref) {
  * \return The corresponding SubRef.
  * \tparam OptionalType The target optional type
  */
-template <typename OptionalType, typename = 
std::enable_if_t<is_optional_type_v<OptionalType>>>
-inline OptionalType Downcast(const std::optional<Any>& ref) {
+template <typename OptionalType, typename = 
std::enable_if_t<ffi::is_optional_type_v<OptionalType>>>
+inline OptionalType Downcast(const std::optional<ffi::Any>& ref) {
   if (ref.has_value()) {
-    if constexpr (std::is_same_v<OptionalType, Any>) {
+    if constexpr (std::is_same_v<OptionalType, ffi::Any>) {
       return *ref;
     } else {
       return (*ref).cast<OptionalType>();
@@ -155,12 +113,5 @@ inline OptionalType Downcast(const std::optional<Any>& 
ref) {
     return OptionalType(std::nullopt);
   }
 }
-
-}  // namespace ffi
-
-// Expose to the tvm namespace
-// Rationale: convinience and no ambiguity
-using ffi::Downcast;
-using ffi::GetRef;
 }  // namespace tvm
-#endif  // TVM_FFI_CAST_H_
+#endif  // TVM_NODE_CAST_H_
diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
index 4398f5881d..734a28c133 100644
--- a/include/tvm/node/node.h
+++ b/include/tvm/node/node.h
@@ -35,6 +35,7 @@
 #define TVM_NODE_NODE_H_
 
 #include <tvm/ffi/memory.h>
+#include <tvm/node/cast.h>
 #include <tvm/node/repr_printer.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/node/structural_hash.h>
@@ -57,8 +58,6 @@ using ffi::ObjectPtrHash;
 using ffi::ObjectRef;
 using ffi::PackedArgs;
 using ffi::TypeIndex;
-using runtime::Downcast;
-using runtime::GetRef;
 
 }  // namespace tvm
 #endif  // TVM_NODE_NODE_H_
diff --git a/include/tvm/runtime/disco/session.h 
b/include/tvm/runtime/disco/session.h
index 0c1ed7ca0a..4fe0e72e79 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -124,6 +124,8 @@ inline std::string DiscoAction2String(DiscoAction action) {
   LOG(FATAL) << "ValueError: Unknown DiscoAction: " << 
static_cast<int>(action);
 }
 
+class SessionObj;
+
 /*!
  * \brief An object that exists on all workers.
  *
@@ -156,6 +158,9 @@ class DRefObj : public Object {
   int64_t reg_id;
   /*! \brief Back-pointer to the host controler session */
   ObjectRef session{nullptr};
+
+ private:
+  inline SessionObj* GetSession();
 };
 
 /*!
@@ -321,18 +326,22 @@ class WorkerZeroData {
 
 // Implementation details
 
+inline SessionObj* DRefObj::GetSession() {
+  return const_cast<SessionObj*>(static_cast<const 
SessionObj*>(session.get()));
+}
+
 DRefObj::~DRefObj() {
   if (this->session.defined()) {
-    Downcast<Session>(this->session)->DeallocReg(reg_id);
+    GetSession()->DeallocReg(reg_id);
   }
 }
 
 ffi::Any DRefObj::DebugGetFromRemote(int worker_id) {
-  return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id, 
worker_id);
+  return GetSession()->DebugGetFromRemote(this->reg_id, worker_id);
 }
 
 void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) {
-  return Downcast<Session>(this->session)->DebugSetRegister(this->reg_id, 
value, worker_id);
+  return GetSession()->DebugSetRegister(this->reg_id, value, worker_id);
 }
 
 template <typename... Args>
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 1fa9f248e8..302b161b6f 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -39,7 +39,6 @@ using tvm::ffi::ObjectPtrEqual;
 using tvm::ffi::ObjectPtrHash;
 using tvm::ffi::ObjectRef;
 
-using tvm::ffi::Downcast;
 using tvm::ffi::GetObjectPtr;
 using tvm::ffi::GetRef;
 
diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h
index 9aa34c9b44..ed74ba7b7b 100644
--- a/include/tvm/runtime/vm/vm.h
+++ b/include/tvm/runtime/vm/vm.h
@@ -189,10 +189,12 @@ class VirtualMachine : public runtime::ModuleNode {
     using ContainerType = typename T::ContainerType;
     uint32_t key = ContainerType::RuntimeTypeIndex();
     if (auto it = extensions.find(key); it != extensions.end()) {
-      return Downcast<T>((*it).second);
+      ffi::Any value = (*it).second;
+      return value.cast<T>();
     }
     auto [it, _] = extensions.emplace(key, T::Create());
-    return Downcast<T>((*it).second);
+    ffi::Any value = (*it).second;
+    return value.cast<T>();
   }
 
   /*!
@@ -224,7 +226,7 @@ class VirtualMachine : public runtime::ModuleNode {
   std::vector<Device> devices;
   /*! \brief The VM extensions. Mapping from the type index of the extension 
to the extension
    * instance. */
-  std::unordered_map<uint32_t, VMExtension> extensions;
+  std::unordered_map<uint32_t, Any> extensions;
 };
 
 }  // namespace vm
diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc
index 7441db7832..b4773b2a81 100644
--- a/src/node/container_printing.cc
+++ b/src/node/container_printing.cc
@@ -22,6 +22,7 @@
  * \file node/container_printint.cc
  */
 #include <tvm/ffi/function.h>
+#include <tvm/node/cast.h>
 #include <tvm/node/functor.h>
 #include <tvm/node/repr_printer.h>
 
@@ -62,6 +63,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<ffi::ShapeObj>([](const ObjectRef& node, ReprPrinter* p) {
-      p->stream << ffi::Downcast<ffi::Shape>(node);
+      p->stream << Downcast<ffi::Shape>(node);
     });
 }  // namespace tvm
diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc
index 6a60b9723d..04a6f7533a 100644
--- a/src/node/repr_printer.cc
+++ b/src/node/repr_printer.cc
@@ -23,6 +23,7 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/node/cast.h>
 #include <tvm/node/repr_printer.h>
 #include <tvm/runtime/device_api.h>
 
diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc
index d60030729f..4cce0d40d1 100644
--- a/src/runtime/profiling.cc
+++ b/src/runtime/profiling.cc
@@ -479,15 +479,15 @@ String ReportNode::AsTable(bool sort, bool aggregate, 
bool compute_col_sums) con
     for (size_t i = 0; i < calls.size(); i++) {
       auto& frame = calls[i];
       auto it = frame.find("Hash");
-      std::string name = Downcast<String>(frame["Name"]);
+      std::string name = frame["Name"].cast<String>();
       if (it != frame.end()) {
-        name = Downcast<String>((*it).second);
+        name = (*it).second.cast<String>();
       }
       if (frame.find("Argument Shapes") != frame.end()) {
-        name += Downcast<String>(frame["Argument Shapes"]);
+        name += frame["Argument Shapes"].cast<String>();
       }
       if (frame.find("Device") != frame.end()) {
-        name += Downcast<String>(frame["Device"]);
+        name += frame["Device"].cast<String>();
       }
 
       if (aggregates.find(name) == aggregates.end()) {
@@ -680,7 +680,7 @@ Report Profiler::Report() {
   for (size_t i = 0; i < devs_.size(); i++) {
     auto row = rows[rows.size() - 1];
     rows.pop_back();
-    device_metrics[Downcast<String>(row["Device"])] = row;
+    device_metrics[row["Device"].cast<String>()] = row;
     overall_time_us =
         std::max(overall_time_us, row["Duration 
(us)"].as<DurationNode>()->microseconds);
   }
diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc
index 04e5094d8e..c8fbd90821 100644
--- a/src/runtime/vm/attn_backend.cc
+++ b/src/runtime/vm/attn_backend.cc
@@ -33,13 +33,13 @@ std::unique_ptr<PagedPrefillFunc> 
ConvertPagedPrefillFunc(Array<ffi::Any> args,
   String backend_name = args[0].cast<String>();
   if (backend_name == "tir") {
     CHECK_EQ(args.size(), 2);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
     return std::make_unique<TIRPagedPrefillFunc>(std::move(attn_func), 
attn_kind);
   }
   if (backend_name == "flashinfer") {
     CHECK_EQ(args.size(), 3);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
-    ffi::Function plan_func = Downcast<ffi::Function>(args[2]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
+    ffi::Function plan_func = args[2].cast<ffi::Function>();
     return std::make_unique<FlashInferPagedPrefillFunc>(std::move(attn_func), 
std::move(plan_func),
                                                         attn_kind);
   }
@@ -55,13 +55,13 @@ std::unique_ptr<RaggedPrefillFunc> 
ConvertRaggedPrefillFunc(Array<ffi::Any> args
   String backend_name = args[0].cast<String>();
   if (backend_name == "tir") {
     CHECK_EQ(args.size(), 2);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
     return std::make_unique<TIRRaggedPrefillFunc>(std::move(attn_func), 
attn_kind);
   }
   if (backend_name == "flashinfer") {
     CHECK_EQ(args.size(), 3);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
-    ffi::Function plan_func = Downcast<ffi::Function>(args[2]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
+    ffi::Function plan_func = args[2].cast<ffi::Function>();
     return std::make_unique<FlashInferRaggedPrefillFunc>(std::move(attn_func), 
std::move(plan_func),
                                                          attn_kind);
   }
@@ -73,16 +73,16 @@ std::unique_ptr<PagedDecodeFunc> 
ConvertPagedDecodeFunc(Array<ffi::Any> args, At
   if (args.empty()) {
     return nullptr;
   }
-  String backend_name = Downcast<String>(args[0]);
+  String backend_name = args[0].cast<String>();
   if (backend_name == "tir") {
     CHECK_EQ(args.size(), 2);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
     return std::make_unique<TIRPagedDecodeFunc>(std::move(attn_func), 
attn_kind);
   }
   if (backend_name == "flashinfer") {
     CHECK_EQ(args.size(), 3);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
-    ffi::Function plan_func = Downcast<ffi::Function>(args[2]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
+    ffi::Function plan_func = args[2].cast<ffi::Function>();
     return std::make_unique<FlashInferPagedDecodeFunc>(std::move(attn_func), 
std::move(plan_func),
                                                        attn_kind);
   }
@@ -95,10 +95,10 @@ std::unique_ptr<PagedPrefillTreeMaskFunc> 
ConvertPagedPrefillTreeMaskFunc(Array<
   if (args.empty()) {
     return nullptr;
   }
-  String backend_name = Downcast<String>(args[0]);
+  String backend_name = args[0].cast<String>();
   if (backend_name == "tir") {
     CHECK_EQ(args.size(), 2);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
     return std::make_unique<TIRPagedPrefillTreeMaskFunc>(std::move(attn_func), 
attn_kind);
   }
   LOG(FATAL) << "Cannot reach here";
@@ -110,10 +110,10 @@ std::unique_ptr<RaggedPrefillTreeMaskFunc> 
ConvertRaggedPrefillTreeMaskFunc(Arra
   if (args.empty()) {
     return nullptr;
   }
-  String backend_name = Downcast<String>(args[0]);
+  String backend_name = args[0].cast<String>();
   if (backend_name == "tir") {
     CHECK_EQ(args.size(), 2);
-    ffi::Function attn_func = Downcast<ffi::Function>(args[1]);
+    ffi::Function attn_func = args[1].cast<ffi::Function>();
     return 
std::make_unique<TIRRaggedPrefillTreeMaskFunc>(std::move(attn_func), attn_kind);
   }
   LOG(FATAL) << "Cannot reach here";
diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/vm/cuda/cuda_graph_builtin.cc
index 8844517973..691246c3bf 100644
--- a/src/runtime/vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc
@@ -149,7 +149,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
    * \param entry_index The unique index of the capture function used for 
lookup.
    * \return The return value of the capture function.
    */
-  ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, 
ObjectRef args,
+  ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, 
Any args,
                          int64_t entry_index, Optional<ffi::Shape> shape_expr) 
{
     CUDAGraphCaptureKey entry_key{entry_index, shape_expr};
     if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) {
@@ -160,7 +160,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
     }
 
     // Set up arguments for the graph execution
-    Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
+    Array<Any> tuple_args = args.cast<ffi::Array<Any>>();
     int nargs = static_cast<int>(tuple_args.size());
 
     std::vector<AnyView> packed_args(nargs);
@@ -250,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
                     VirtualMachine* vm = 
VirtualMachine::GetContextPtr(args[0]);
                     auto extension = 
vm->GetOrCreateExtension<CUDAGraphExtension>();
                     auto capture_func = args[1].cast<ObjectRef>();
-                    auto func_args = args[2].cast<ObjectRef>();
+                    Any func_args = args[2];
                     int64_t entry_index = args[3].cast<int64_t>();
                     Optional<ffi::Shape> shape_expr = std::nullopt;
                     if (args.size() == 5) {
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 17b99bbc8c..c28e30084f 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -74,7 +74,7 @@ ffi::Any IndexIntoNestedObject(ffi::Any obj, ffi::PackedArgs 
args, int starting_
       LOG(FATAL) << "ValueError: Attempted to index into an object that is not 
an Array.";
     }
     int index = args[i].cast<int>();
-    auto arr = Downcast<ffi::Array<ffi::Any>>(obj);
+    auto arr = obj.cast<ffi::Array<ffi::Any>>();
     // make sure the index is in bounds
     if (index >= static_cast<int>(arr.size())) {
       LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << 
arr.size() << ").";
@@ -96,10 +96,10 @@ NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& 
dev, Allocator* allo
 
 Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) {
   if (src.as<NDArray::ContainerType>()) {
-    return ConvertNDArrayToDevice(Downcast<NDArray>(src), dev, alloc);
+    return ConvertNDArrayToDevice(src.cast<NDArray>(), dev, alloc);
   } else if (src.as<ffi::ArrayObj>()) {
     std::vector<Any> ret;
-    auto arr = Downcast<ffi::Array<Any>>(src);
+    auto arr = src.cast<ffi::Array<Any>>();
     for (size_t i = 0; i < arr.size(); i++) {
       ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc));
     }
diff --git a/src/target/source/codegen_metal.cc 
b/src/target/source/codegen_metal.cc
index 3cd4a6ed0d..ffb1737a70 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -370,7 +370,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT
   };
   if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) {
     ICHECK_EQ(op->args.size(), 5);
-    Var var = runtime::Downcast<Var>(op->args[0]);
+    Var var = Downcast<Var>(op->args[0]);
     // Get the data type of the simdgroup matrix
     auto it = simdgroup_dtype_.find(var.get());
     ICHECK(it != simdgroup_dtype_.end())
diff --git a/src/tir/transforms/memhammer_coalesce.cc 
b/src/tir/transforms/memhammer_coalesce.cc
index 2be5e148fb..43a976fa89 100644
--- a/src/tir/transforms/memhammer_coalesce.cc
+++ b/src/tir/transforms/memhammer_coalesce.cc
@@ -204,9 +204,9 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const 
ConstraintSet& constraints,
     if (is_one(write_region->region[i]->extent)) {
       write_index.push_back(write_region->region[i]->min);
     } else {
-      Var var = 
runtime::Downcast<Var>(loop_vars[j]).copy_with_suffix("_inverse");
+      Var var = Downcast<Var>(loop_vars[j]).copy_with_suffix("_inverse");
       new_loop_vars.push_back(var);
-      substitute_map.Set(runtime::Downcast<Var>(loop_vars[j++]), var);
+      substitute_map.Set(Downcast<Var>(loop_vars[j++]), var);
       write_index.push_back(write_region->region[i]->min + var);
     }
   }

Reply via email to