This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch small-str-v1
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 19bc811523c156f15c649fcd7278cb1087c30d45
Author: tqchen <[email protected]>
AuthorDate: Sat Aug 2 09:32:08 2025 -0400

    [FFI] Extra cleanup to move ObjectRef to Any
---
 include/tvm/relax/transform.h                 |  4 +--
 src/meta_schedule/mutator/mutate_tile_size.cc |  5 ++--
 src/node/serialization.cc                     | 36 +++++++++++++--------------
 src/relax/transform/bind_params.cc            |  8 +++---
 src/relax/transform/bind_symbolic_vars.cc     | 13 ++++++----
 src/runtime/profiling.cc                      |  4 +--
 6 files changed, 35 insertions(+), 35 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 4068f7c682..1567294a4b 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -196,7 +196,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
  *
  * \return The Pass.
  */
-TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);
+TVM_DLL Pass BindParams(String func_name, Map<Any, ObjectRef> params);
 
 /*!
  * \brief Bind symbolic vars to constant shape values.
@@ -213,7 +213,7 @@ TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, 
ObjectRef> params);
  *
  * \return The Pass.
  */
-TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
+TVM_DLL Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> 
binding_map,
                               Optional<String> func_name = std::nullopt);
 
 /*!
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc 
b/src/meta_schedule/mutator/mutate_tile_size.cc
index af5fb3ebab..36a38cac75 100644
--- a/src/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -94,9 +94,8 @@ void FindSamplePerfectTile(const Trace& trace, 
std::vector<Instruction>* inst,
   decisions.reserve(trace->decisions.size());
   for (const auto& kv : trace->decisions) {
     const Instruction& inst = kv.first;
-    const ObjectRef& decision = kv.second.cast<ObjectRef>();
     if (inst->kind.same_as(inst_sample_perfect_tile)) {
-      std::vector<int64_t> tiles = DowncastTilingDecision(decision);
+      std::vector<int64_t> tiles = 
DowncastTilingDecision(kv.second.cast<ObjectRef>());
       if (tiles.size() >= 2 && Product(tiles) >= 2) {
         instructions.push_back(inst);
         decisions.push_back(tiles);
@@ -130,7 +129,6 @@ void FindSampleVectorize(const Trace& trace, 
std::vector<Instruction>* inst,
   // Find sampling instruction that generates the annotation
   for (const auto& kv : trace->decisions) {
     const Instruction& inst = kv.first;
-    const ObjectRef& decision = kv.second.cast<ObjectRef>();
     if (inst->kind.same_as(inst_sample_categorical)) {
       ICHECK_EQ(inst->outputs.size(), 1);
       if (annotated.count(inst->outputs[0].as<Object>())) {
@@ -141,6 +139,7 @@ void FindSampleVectorize(const Trace& trace, 
std::vector<Instruction>* inst,
           // Skip mutating the sampling instructions who have only single 
candidate.
           continue;
         }
+        const ObjectRef& decision = kv.second.cast<ObjectRef>();
         const auto* d = TVM_TYPE_AS(decision, IntImmNode);
         instructions.push_back(inst);
         decisions.push_back(d->value);
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 65b9728317..853d6f2507 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -108,6 +108,7 @@ class NodeIndexer {
         }
       }
     } else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
+               node.type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
                node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
       // skip content index for string and bytes
     } else if (auto opt_object = node.as<const Object*>()) {
@@ -126,8 +127,8 @@ class NodeIndexer {
         << "` misses reflection registration and do not support serialization";
     ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* 
field_info) {
       Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
-      // only make index for ObjectRef
-      if (field_value.as<Object>()) {
+      // only make index for ObjectRef and String(which may not be object for 
small str)
+      if (field_value.as<Object>() || field_value.as<String>()) {
         this->MakeIndex(field_value);
       }
     });
@@ -234,9 +235,9 @@ class JSONAttrGetter {
     }
   }
 
-  void Visit(const char* key, ObjectRef* value) {
-    if (value->defined()) {
-      node_->attrs[key] = std::to_string(node_index_->at(Any(*value)));
+  void Visit(const char* key, Any* value) {
+    if (value != nullptr) {
+      node_->attrs[key] = std::to_string(node_index_->at(*value));
     } else {
       node_->attrs[key] = "null";
     }
@@ -249,6 +250,10 @@ class JSONAttrGetter {
       return;
     }
     node_->type_key = node.GetTypeKey();
+    // canonicalize type key for str
+    if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) {
+      node_->type_key = ffi::StaticTypeKey::kTVMFFIStr;
+    }
     // populates the fields.
     node_->attrs.clear();
     node_->data.clear();
@@ -344,19 +349,9 @@ class JSONAttrGetter {
           this->Visit(field_info->name.data, &value);
           break;
         }
-        case ffi::TypeIndex::kTVMFFINDArray: {
-          runtime::NDArray value = field_value.cast<runtime::NDArray>();
-          this->Visit(field_info->name.data, &value);
-          break;
-        }
         default: {
-          if (field_value.type_index() >= 
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
-            ObjectRef obj = field_value.cast<ObjectRef>();
-            this->Visit(field_info->name.data, &obj);
-            break;
-          } else {
-            LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey();
-          }
+          this->Visit(field_info->name.data, &field_value);
+          break;
         }
       }
     });
@@ -405,6 +400,7 @@ class FieldDependencyFinder {
       return;
     }
     if (node.type_index() == ffi::TypeIndex::kTVMFFIStr ||
+        node.type_index() == ffi::TypeIndex::kTVMFFISmallStr ||
         node.type_index() == ffi::TypeIndex::kTVMFFIBytes) {
       // skip indexing content of string and bytes
       return;
@@ -562,7 +558,8 @@ class JSONAttrSetter {
       setter.ParseValue("v_device_type", &device_type);
       setter.ParseValue("v_device_id", &device_id);
       return Any(DLDevice{static_cast<DLDeviceType>(device_type), device_id});
-    } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr) {
+    } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
+               jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) {
       return Any(String(jnode->repr_bytes));
     } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
       return Any(Bytes(jnode->repr_bytes));
@@ -596,6 +593,7 @@ class JSONAttrSetter {
       }
       *node = result;
     } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr ||
+               jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr ||
                jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) {
       // skip set attrs for string and bytes
     } else if (auto opt_object = node->as<const Object*>()) {
@@ -652,7 +650,7 @@ class JSONAttrSetter {
         ParseOptionalValue(field_info->name.data, &index,
                            [this](const char* key, int64_t* value) { 
ParseValue(key, value); });
         if (index.has_value()) {
-          Any value = node_list_->at(*index).cast<ObjectRef>();
+          Any value = node_list_->at(*index);
           setter(obj, value);
         } else {
           setter(obj, Any());
diff --git a/src/relax/transform/bind_params.cc 
b/src/relax/transform/bind_params.cc
index 49fe469e89..13b138ecce 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -83,7 +83,7 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant,
 }
 
 std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings(
-    const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) {
+    const Function& func, const Map<Any, ObjectRef>& untyped_params) {
   ICHECK(func.defined());
   ICHECK(untyped_params.defined());
 
@@ -158,7 +158,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> 
NormalizeBindings(
  * \param params params dict
  * \return Function
  */
-Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& 
untyped_params) {
+Function FunctionBindParams(Function func, const Map<Any, ObjectRef>& 
untyped_params) {
   auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);
 
   Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
@@ -172,7 +172,7 @@ Function FunctionBindParams(Function func, const 
Map<ObjectRef, ObjectRef>& unty
  * \param param The param dict
  * \return The module after binding params.
  */
-IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef> 
bind_params) {
+IRModule BindParam(IRModule m, String func_name, Map<Any, ObjectRef> 
bind_params) {
   IRModuleNode* new_module = m.CopyOnWrite();
   Map<GlobalVar, BaseFunc> functions = m->functions;
   for (const auto& func_pr : functions) {
@@ -203,7 +203,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
 
 namespace transform {
 
-Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) {
+Pass BindParams(String func_name, Map<Any, ObjectRef> params) {
   auto pass_func = [=](IRModule mod, PassContext pc) {
     return BindParam(std::move(mod), func_name, params);
   };
diff --git a/src/relax/transform/bind_symbolic_vars.cc 
b/src/relax/transform/bind_symbolic_vars.cc
index 22c557874c..5ba25b7e16 100644
--- a/src/relax/transform/bind_symbolic_vars.cc
+++ b/src/relax/transform/bind_symbolic_vars.cc
@@ -31,7 +31,8 @@
 namespace tvm {
 namespace relax {
 
-Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr> 
obj_remap) {
+Function FunctionBindSymbolicVars(Function func,
+                                  Map<Variant<tir::Var, String>, PrimExpr> 
obj_remap) {
   // Early bail-out if no updates need to be made.
   if (obj_remap.empty()) {
     return func;
@@ -90,7 +91,8 @@ Function FunctionBindSymbolicVars(Function func, 
Map<ffi::Any, PrimExpr> obj_rem
 }
 
 namespace {
-IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> 
binding_map) {
+IRModule ModuleBindSymbolicVars(IRModule mod,
+                                Map<Variant<tir::Var, String>, PrimExpr> 
binding_map) {
   std::unordered_set<ffi::Any, ffi::AnyHash, ffi::AnyEqual> used;
   IRModule updates;
   for (const auto& [gvar, base_func] : mod->functions) {
@@ -98,7 +100,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, 
PrimExpr> binding_ma
       auto func = opt.value();
 
       // Collect bindings that are used by this function.
-      auto func_binding_map = [&]() -> Map<ffi::Any, PrimExpr> {
+      auto func_binding_map = [&]() -> Map<Variant<tir::Var, String>, 
PrimExpr> {
         std::unordered_set<std::string> var_names;
         std::unordered_set<const tir::VarNode*> vars;
         for (const auto& var : DefinedSymbolicVars(func)) {
@@ -106,7 +108,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, 
PrimExpr> binding_ma
           vars.insert(var.get());
         }
 
-        Map<ffi::Any, PrimExpr> out;
+        Map<Variant<tir::Var, String>, PrimExpr> out;
         for (const auto& [key, replacement] : binding_map) {
           bool used_by_function = false;
           if (auto opt = key.as<String>()) {
@@ -156,7 +158,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
 
 namespace transform {
 
-Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, Optional<String> 
func_name) {
+Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map,
+                      Optional<String> func_name) {
   auto pass_func = [=](IRModule mod, PassContext context) -> IRModule {
     if (func_name) {
       auto gvar = mod->GetGlobalVar(func_name.value());
diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc
index ddd5462c68..e9652618e4 100644
--- a/src/runtime/profiling.cc
+++ b/src/runtime/profiling.cc
@@ -613,7 +613,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool 
compute_col_sums) con
         // fill empty data with empty strings
         cols[i].push_back("");
       } else {
-        cols[i].push_back(print_metric((*it).second.cast<ObjectRef>()));
+        cols[i].push_back(print_metric((*it).second));
       }
     }
   }
@@ -653,7 +653,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool 
compute_col_sums) con
   // Add configuration information. It will not be aligned with the columns.
   s << std::endl << "Configuration" << std::endl << "-------------" << 
std::endl;
   for (auto kv : configuration) {
-    s << kv.first << ": " << print_metric(kv.second.cast<ObjectRef>()) << 
std::endl;
+    s << kv.first << ": " << print_metric(kv.second) << std::endl;
   }
   return s.str();
 }

Reply via email to