This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch small-str-v1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 240ccd721ccea3693279a7e5c8c4882aadbe4b4c Author: tqchen <[email protected]> AuthorDate: Sat Aug 2 09:32:08 2025 -0400 [FFI] Extra fixes to bring up SmallStr --- include/tvm/relax/transform.h | 4 +-- include/tvm/script/ir_builder/tir/frame.h | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 5 ++- src/node/serialization.cc | 42 +++++++++++----------- src/relax/backend/contrib/cutlass/codegen.cc | 2 +- src/relax/transform/bind_params.cc | 8 ++--- src/relax/transform/bind_symbolic_vars.cc | 13 ++++--- src/runtime/profiling.cc | 4 +-- src/script/ir_builder/tir/ir.cc | 5 +-- src/tir/ir/stmt.cc | 5 +-- src/tir/schedule/concrete_schedule.cc | 4 +-- src/tir/schedule/instruction.cc | 3 +- src/tir/schedule/trace.cc | 17 +++++---- .../test_tir_transform_lower_tvm_builtin.py | 12 ++++--- 14 files changed, 68 insertions(+), 58 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 4068f7c682..1567294a4b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -196,7 +196,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false); * * \return The Pass. */ -TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params); +TVM_DLL Pass BindParams(String func_name, Map<Any, ObjectRef> params); /*! * \brief Bind symbolic vars to constant shape values. @@ -213,7 +213,7 @@ TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params); * * \return The Pass. */ -TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, +TVM_DLL Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map, Optional<String> func_name = std::nullopt); /*! diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index e9087588ff..1e205edc43 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -555,7 +555,7 @@ class AllocateConstFrame : public TIRFrame { class AttrFrameNode : public TIRFrameNode { public: /*! \brief The node to annotate the attribute. */ - ObjectRef node; + Any node; /*! \brief Attribute type key. */ String attr_key; /*! \brief The value of the attribute. */ diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index af5fb3ebab..36a38cac75 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -94,9 +94,8 @@ void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst, decisions.reserve(trace->decisions.size()); for (const auto& kv : trace->decisions) { const Instruction& inst = kv.first; - const ObjectRef& decision = kv.second.cast<ObjectRef>(); if (inst->kind.same_as(inst_sample_perfect_tile)) { - std::vector<int64_t> tiles = DowncastTilingDecision(decision); + std::vector<int64_t> tiles = DowncastTilingDecision(kv.second.cast<ObjectRef>()); if (tiles.size() >= 2 && Product(tiles) >= 2) { instructions.push_back(inst); decisions.push_back(tiles); @@ -130,7 +129,6 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, // Find sampling instruction that generates the annotation for (const auto& kv : trace->decisions) { const Instruction& inst = kv.first; - const ObjectRef& decision = kv.second.cast<ObjectRef>(); if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].as<Object>())) { @@ -141,6 +139,7 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, // Skip mutating the sampling instructions who have only single candidate. continue; } + const ObjectRef& decision = kv.second.cast<ObjectRef>(); const auto* d = TVM_TYPE_AS(decision, IntImmNode); instructions.push_back(inst); decisions.push_back(d->value); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 65b9728317..3d0175bcfa 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -108,6 +108,7 @@ class NodeIndexer { } } } else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || + node.type_index() == ffi::TypeIndex::kTVMFFISmallStr || node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { // skip content index for string and bytes } else if (auto opt_object = node.as<const Object*>()) { @@ -126,8 +127,8 @@ class NodeIndexer { << "` misses reflection registration and do not support serialization"; ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - // only make index for ObjectRef - if (field_value.as<Object>()) { + // only make index for ObjectRef and String(which may not be object for small str) + if (field_value.as<Object>() || field_value.as<String>()) { this->MakeIndex(field_value); } }); @@ -234,9 +235,9 @@ class JSONAttrGetter { } } - void Visit(const char* key, ObjectRef* value) { - if (value->defined()) { - node_->attrs[key] = std::to_string(node_index_->at(Any(*value))); + void Visit(const char* key, Any* value) { + if (value != nullptr) { + node_->attrs[key] = std::to_string(node_index_->at(*value)); } else { node_->attrs[key] = "null"; } @@ -249,6 +250,10 @@ class JSONAttrGetter { return; } node_->type_key = node.GetTypeKey(); + // canonicalize type key for str + if (node_->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) { + node_->type_key = ffi::StaticTypeKey::kTVMFFIStr; + } // populates the fields. node_->attrs.clear(); node_->data.clear(); @@ -344,19 +349,9 @@ class JSONAttrGetter { this->Visit(field_info->name.data, &value); break; } - case ffi::TypeIndex::kTVMFFINDArray: { - runtime::NDArray value = field_value.cast<runtime::NDArray>(); - this->Visit(field_info->name.data, &value); - break; - } default: { - if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ObjectRef obj = field_value.cast<ObjectRef>(); - this->Visit(field_info->name.data, &obj); - break; - } else { - LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); - } + this->Visit(field_info->name.data, &field_value); + break; } } }); @@ -401,14 +396,15 @@ class FieldDependencyFinder { if (node == nullptr) { return; } - if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return; - } if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || + node.type_index() == ffi::TypeIndex::kTVMFFISmallStr || node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { // skip indexing content of string and bytes return; } + if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + return; + } // Skip the objects that have their own string repr if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node.cast<const Object*>(), nullptr)) { @@ -562,7 +558,8 @@ class JSONAttrSetter { setter.ParseValue("v_device_type", &device_type); setter.ParseValue("v_device_id", &device_id); return Any(DLDevice{static_cast<DLDeviceType>(device_type), device_id}); - } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr) { + } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || + jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr) { return Any(String(jnode->repr_bytes)); } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { return Any(Bytes(jnode->repr_bytes)); @@ -596,6 +593,7 @@ class JSONAttrSetter { } *node = result; } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || + jnode->type_key == ffi::StaticTypeKey::kTVMFFISmallStr || jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { // skip set attrs for string and bytes } else if (auto opt_object = node->as<const Object*>()) { @@ -652,7 +650,7 @@ class JSONAttrSetter { ParseOptionalValue(field_info->name.data, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); if (index.has_value()) { - Any value = node_list_->at(*index).cast<ObjectRef>(); + Any value = node_list_->at(*index); setter(obj, value); } else { setter(obj, Any()); diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 874dced500..932fdadddf 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -221,7 +221,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator<OutputType>, } OutputType VisitExpr_(const FunctionNode* fn) final { - ICHECK(fn->GetAttr<String>(attr::kComposite).defined()) + ICHECK(fn->GetAttr<String>(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. return {}; diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 49fe469e89..13b138ecce 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -83,7 +83,7 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, } std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings( - const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) { + const Function& func, const Map<Any, ObjectRef>& untyped_params) { ICHECK(func.defined()); ICHECK(untyped_params.defined()); @@ -158,7 +158,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings( * \param params params dict * \return Function */ -Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& untyped_params) { +Function FunctionBindParams(Function func, const Map<Any, ObjectRef>& untyped_params) { auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params); Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); @@ -172,7 +172,7 @@ Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& unty * \param param The param dict * \return The module after binding params. */ -IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef> bind_params) { +IRModule BindParam(IRModule m, String func_name, Map<Any, ObjectRef> bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); Map<GlobalVar, BaseFunc> functions = m->functions; for (const auto& func_pr : functions) { @@ -203,7 +203,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) { +Pass BindParams(String func_name, Map<Any, ObjectRef> params) { auto pass_func = [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 22c557874c..5ba25b7e16 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -31,7 +31,8 @@ namespace tvm { namespace relax { -Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr> obj_remap) { +Function FunctionBindSymbolicVars(Function func, + Map<Variant<tir::Var, String>, PrimExpr> obj_remap) { // Early bail-out if no updates need to be made. if (obj_remap.empty()) { return func; @@ -90,7 +91,8 @@ Function FunctionBindSymbolicVars(Function func, Map<ffi::Any, PrimExpr> obj_rem } namespace { -IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> binding_map) { +IRModule ModuleBindSymbolicVars(IRModule mod, + Map<Variant<tir::Var, String>, PrimExpr> binding_map) { std::unordered_set<ffi::Any, ffi::AnyHash, ffi::AnyEqual> used; IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { @@ -98,7 +100,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> binding_ma auto func = opt.value(); // Collect bindings that are used by this function. - auto func_binding_map = [&]() -> Map<ffi::Any, PrimExpr> { + auto func_binding_map = [&]() -> Map<Variant<tir::Var, String>, PrimExpr> { std::unordered_set<std::string> var_names; std::unordered_set<const tir::VarNode*> vars; for (const auto& var : DefinedSymbolicVars(func)) { @@ -106,7 +108,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map<ffi::Any, PrimExpr> binding_ma vars.insert(var.get()); } - Map<ffi::Any, PrimExpr> out; + Map<Variant<tir::Var, String>, PrimExpr> out; for (const auto& [key, replacement] : binding_map) { bool used_by_function = false; if (auto opt = key.as<String>()) { @@ -156,7 +158,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, Optional<String> func_name) { +Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map, + Optional<String> func_name) { auto pass_func = [=](IRModule mod, PassContext context) -> IRModule { if (func_name) { auto gvar = mod->GetGlobalVar(func_name.value()); diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index ddd5462c68..e9652618e4 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -613,7 +613,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // fill empty data with empty strings cols[i].push_back(""); } else { - cols[i].push_back(print_metric((*it).second.cast<ObjectRef>())); + cols[i].push_back(print_metric((*it).second)); } } } @@ -653,7 +653,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // Add configuration information. It will not be aligned with the columns. s << std::endl << "Configuration" << std::endl << "-------------" << std::endl; for (auto kv : configuration) { - s << kv.first << ": " << print_metric(kv.second.cast<ObjectRef>()) << std::endl; + s << kv.first << ": " << print_metric(kv.second) << std::endl; } return s.str(); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 5ff42db74d..78bccb829c 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -520,11 +520,12 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { // convert POD value to PrimExpr - if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + node.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { node = node.cast<PrimExpr>(); } ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>(); - n->node = node.cast<ObjectRef>(); + n->node = std::move(node); n->attr_key = attr_key; n->value = value; return AttrFrame(n); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6803e01f50..5a2b95844b 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -100,10 +100,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to // primexpr. - if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + node.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { return AttrStmt(node.cast<PrimExpr>(), attr_key, value, body, span); } - return AttrStmt(node.cast<ObjectRef>(), attr_key, value, body, span); + return AttrStmt(node, attr_key, value, body, span); }); }); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index c00c946852..db175c77f2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -916,8 +916,8 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { if (auto opt_str = ann_val.try_cast<ffi::String>()) { return *std::move(opt_str); } - - if (ann_val.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (ann_val.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + ann_val.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { return ann_val; } // prefer to return int/float literals for annotations diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 3ee43c698a..fdc0dd41c4 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -74,7 +74,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) inputs.push_back(String('"' + (*opt_str).operator std::string() + '"')); } else if (obj.as<BlockRVNode>() || obj.as<LoopRVNode>()) { inputs.push_back(String("_")); - } else if (obj.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + } else if (obj.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + obj.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { inputs.push_back(obj); } else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) { inputs.push_back(obj); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 43c2ce0a7b..b1fb7881a6 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -71,7 +71,8 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs, }; for (const Any& input : inputs) { - if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type result.push_back(input); } else if (auto expr = input.as<ffi::String>()) { @@ -110,8 +111,12 @@ Array<Any> TranslateInputRVs( results.push_back(String("None")); continue; } - if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - // directly put back POD type + // string => "content" + if (auto opt_str = input.as<ffi::String>()) { + results.push_back(String('"' + (*opt_str).operator std::string() + '"')); + } else if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { + // directly put back POD type and not string results.push_back(input); } else if (input.as<BlockRVNode>() || // RV: block input.as<LoopRVNode>() || // RV: loop @@ -124,9 +129,6 @@ Array<Any> TranslateInputRVs( LOG(FATAL) << "IndexError: Random variable is not defined " << input; throw; } - } else if (auto opt_str = input.as<ffi::String>()) { - // Case 2. string => "content" - results.push_back(String('"' + (*opt_str).operator std::string() + '"')); } else if (input.as<IntImmNode>() || input.as<FloatImmNode>()) { // Case 3. integer or floating-point number results.push_back(input); @@ -159,7 +161,8 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs, Array<Any> results; results.reserve(inputs.size()); for (const Any& input : inputs) { - if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (input.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin && + input.type_index() != ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type results.push_back(input); continue; diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index 299c193146..08f377829f 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -97,13 +97,17 @@ def test_lower_call_packed(): T.tvm_struct_set(stack_array, 2, 9, 0) T.tvm_struct_set(stack_array, 2, 10, 1) T.tvm_struct_set(stack_ffi_any, 0, 13, 7) - T.tvm_struct_set(stack_ffi_any, 0, 14, T.tvm_struct_get(stack_array, 0, 0, "handle")) + T.tvm_struct_set(stack_ffi_any, 0, 14, 0) + T.tvm_struct_set(stack_ffi_any, 0, 15, T.tvm_struct_get(stack_array, 0, 0, "handle")) T.tvm_struct_set(stack_ffi_any, 1, 13, 7) - T.tvm_struct_set(stack_ffi_any, 1, 14, T.tvm_struct_get(stack_array, 1, 0, "handle")) + T.tvm_struct_set(stack_ffi_any, 1, 14, 0) + T.tvm_struct_set(stack_ffi_any, 1, 15, T.tvm_struct_get(stack_array, 1, 0, "handle")) T.tvm_struct_set(stack_ffi_any, 2, 13, 7) - T.tvm_struct_set(stack_ffi_any, 2, 14, T.tvm_struct_get(stack_array, 2, 0, "handle")) + T.tvm_struct_set(stack_ffi_any, 2, 14, 0) + T.tvm_struct_set(stack_ffi_any, 2, 15, T.tvm_struct_get(stack_array, 2, 0, "handle")) T.tvm_struct_set(stack_ffi_any, 3, 13, 0) - T.tvm_struct_set(stack_ffi_any, 3, 14, T.int64(0)) + T.tvm_struct_set(stack_ffi_any, 3, 14, 0) + T.tvm_struct_set(stack_ffi_any, 3, 15, T.int64(0)) T.call_packed_lowered("tvm.test_matmul", stack_ffi_any, 0, 3) After = tvm.tir.transform.LowerTVMBuiltin()(Before)
