This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch small-str-v1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 19bc811523c156f15c649fcd7278cb1087c30d45 Author: tqchen <[email protected]> AuthorDate: Sat Aug 2 09:32:08 2025 -0400 [FFI] Extra cleanup to move ObjectRef to Any --- include/tvm/relax/transform.h | 4 +-- src/meta_schedule/mutator/mutate_tile_size.cc | 5 ++-- src/node/serialization.cc | 36 +++++++++++++-------------- src/relax/transform/bind_params.cc | 8 +++--- src/relax/transform/bind_symbolic_vars.cc | 13 ++++++---- src/runtime/profiling.cc | 4 +-- 6 files changed, 35 insertions(+), 35 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 4068f7c682..1567294a4b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -196,7 +196,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false); * * \return The Pass. */ -TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params); +TVM_DLL Pass BindParams(String func_name, Map<Any, ObjectRef> params); /*! * \brief Bind symbolic vars to constant shape values. @@ -213,7 +213,7 @@ TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params); * * \return The Pass. */ -TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, +TVM_DLL Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map, Optional<String> func_name = std::nullopt); /*! diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index af5fb3ebab..36a38cac75 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -94,9 +94,8 @@ void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst, decisions.reserve(trace->decisions.size()); for (const auto& kv : trace->decisions) { const Instruction& inst = kv.first; - const ObjectRef& decision = kv.second.cast<ObjectRef>(); if (inst->kind.same_as(inst_sample_perfect_tile)) { - std::vector<int64_t> tiles = DowncastTilingDecision(decision); + std::vector<int64_t> tiles = DowncastTilingDecision(kv.second.cast<ObjectRef>()); if (tiles.size() >= 2 && Product(tiles) >= 2) { instructions.push_back(inst); decisions.push_back(tiles); @@ -130,7 +129,6 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, // Find sampling instruction that generates the annotation for (const auto& kv : trace->decisions) { const Instruction& inst = kv.first; - const ObjectRef& decision = kv.second.cast<ObjectRef>(); if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].as<Object>())) { @@ -141,6 +139,7 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, // Skip mutating the sampling instructions who have only single candidate. continue; } + const ObjectRef& decision = kv.second.cast<ObjectRef>(); const auto* d = TVM_TYPE_AS(decision, IntImmNode); instructions.push_back(inst); decisions.push_back(d->value); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 65b9728317..853d6f2507 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -108,6 +108,7 @@ class NodeIndexer { } } } else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || + node.type_index() == ffi::TypeIndex::kTVMFFISmallStr || node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { // skip content index for string and bytes } else if (auto opt_object = node.as<const Object*>()) { @@ -126,8 +127,8 @@ class NodeIndexer { << "` misses reflection registration and do not support serialization"; ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - // only make index for ObjectRef - if (field_value.as<Object>()) { + // only make index for ObjectRef and String(which may not be object for small str) + if (field_value.as<Object>() || field_value.as<String>()) { this->MakeIndex(field_value); } }); @@ -234,9 +235,9 @@ class JSONAttrGetter { } } - void Visit(const char* key, ObjectRef* value) { - if (value->defined()) { - node_->attrs[key] = std::to_string(node_index_->at(Any(*value))); + void Visit(const char* key, Any* value) { + if (value != nullptr) { + node_->attrs[key] = std::to_string(node_index_->at(*value)); } else { node_->attrs[key] = "null"; } @@ -249,6 +250,10 @@ class JSONAttrGetter { return; } node_->type_key = node.GetTypeKey(); + // canonicalize type key for str + if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) { + node_->type_key = ffi::StaticTypeKey::kTVMFFIStr; + } // populates the fields. node_->attrs.clear(); node_->data.clear(); @@ -344,19 +349,9 @@ class JSONAttrGetter { this->Visit(field_info->name.data, &value); break; } - case ffi::TypeIndex::kTVMFFINDArray: { - runtime::NDArray value = field_value.cast<runtime::NDArray>(); - this->Visit(field_info->name.data, &value); - break; - } default: { - if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ObjectRef obj = field_value.cast<ObjectRef>(); - this->Visit(field_info->name.data, &obj); - break; - } else { - LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); - } + this->Visit(field_info->name.data, &field_value); + break; } } }); @@ -405,6 +400,7 @@ class FieldDependencyFinder { return; } if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || + node.type_index() == ffi::TypeIndex::kTVMFFISmallStr || node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { // skip indexing content of string and bytes return; @@ -562,7 +558,8 @@ class JSONAttrSetter { setter.ParseValue("v_device_type", &device_type); setter.ParseValue("v_device_id", &device_id); return Any(DLDevice{static_cast<DLDeviceType>(device_type), device_id}); - } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr) { + } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || + jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) { return Any(String(jnode->repr_bytes)); } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { return Any(Bytes(jnode->repr_bytes)); @@ -596,6 +593,7 @@ class JSONAttrSetter { } *node = result; } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || + jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr || jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { // skip set attrs for string and bytes } else if (auto opt_object = node->as<const Object*>()) { @@ -652,7 +650,7 @@ class JSONAttrSetter { ParseOptionalValue(field_info->name.data, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); if (index.has_value()) { - Any value = node_list_->at(*index).cast<ObjectRef>(); + Any value = node_list_->at(*index); setter(obj, value); } else { setter(obj, Any()); diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 49fe469e89..13b138ecce 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -83,7 +83,7 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, } std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings( - const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) { + const Function& func, const Map<Any, ObjectRef>& untyped_params) { ICHECK(func.defined()); ICHECK(untyped_params.defined()); @@ -158,7 +158,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings( * \param params params dict * \return Function */ -Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& untyped_params) { +Function FunctionBindParams(Function func, const Map<Any, ObjectRef>& untyped_params) { auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params); Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); @@ -172,7 +172,7 @@ Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& unty * \param param The param dict * \return The module after binding params. */ -IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef> bind_params) { +IRModule BindParam(IRModule m, String func_name, Map<Any, ObjectRef> bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); Map<GlobalVar, BaseFunc> functions = m->functions; for (const auto& func_pr : functions) { @@ -203,7 +203,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) { +Pass BindParams(String func_name, Map<Any, ObjectRef> params) { auto pass_func = [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 22c557874c..5ba25b7e16 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -31,7 +31,8 @@ namespace tvm { namespace relax { -Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr> obj_remap) { +Function FunctionBindSymbolicVars(Function func, + Map<Variant<tir::Var, String>, PrimExpr> obj_remap) { // Early bail-out if no updates need to be made. if (obj_remap.empty()) { return func; @@ -90,7 +91,8 @@ Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr> obj_rem } namespace { -IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> binding_map) { +IRModule ModuleBindSymbolicVars(IRModule mod, + Map<Variant<tir::Var, String>, PrimExpr> binding_map) { std::unordered_set<ffi::Any, ffi::AnyHash, ffi::AnyEqual> used; IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { @@ -98,7 +100,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> binding_ma auto func = opt.value(); // Collect bindings that are used by this function. - auto func_binding_map = [&]() -> Map<ffi::Any, PrimExpr> { + auto func_binding_map = [&]() -> Map<Variant<tir::Var, String>, PrimExpr> { std::unordered_set<std::string> var_names; std::unordered_set<const tir::VarNode*> vars; for (const auto& var : DefinedSymbolicVars(func)) { @@ -106,7 +108,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> binding_ma vars.insert(var.get()); } - Map<ffi::Any, PrimExpr> out; + Map<Variant<tir::Var, String>, PrimExpr> out; for (const auto& [key, replacement] : binding_map) { bool used_by_function = false; if (auto opt = key.as<String>()) { @@ -156,7 +158,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, Optional<String> func_name) { +Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map, + Optional<String> func_name) { auto pass_func = [=](IRModule mod, PassContext context) -> IRModule { if (func_name) { auto gvar = mod->GetGlobalVar(func_name.value()); diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index ddd5462c68..e9652618e4 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -613,7 +613,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // fill empty data with empty strings cols[i].push_back(""); } else { - cols[i].push_back(print_metric((*it).second.cast<ObjectRef>())); + cols[i].push_back(print_metric((*it).second)); } } } @@ -653,7 +653,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // Add configuration information. It will not be aligned with the columns. s << std::endl << "Configuration" << std::endl << "-------------" << std::endl; for (auto kv : configuration) { - s << kv.first << ": " << print_metric(kv.second.cast<ObjectRef>()) << std::endl; + s << kv.first << ": " << print_metric(kv.second) << std::endl; } return s.str(); }
