This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 054de924e1366b2e19a2c0eafcb38daa67138813 Author: tqchen <[email protected]> AuthorDate: Thu Apr 24 09:43:59 2025 -0400 update code after array/map node to obj rename --- include/tvm/ir/source_map.h | 12 +++++----- include/tvm/relax/nested_msg.h | 2 +- include/tvm/runtime/container/array.h | 4 ++-- include/tvm/runtime/container/map.h | 2 +- src/contrib/msc/core/utils.cc | 2 +- .../msc/framework/tensorrt/transform_tensorrt.cc | 2 +- src/ir/attr_functor.h | 4 ++-- src/ir/module.cc | 2 +- src/ir/source_map.cc | 4 ++-- src/meta_schedule/arg_info.cc | 4 ++-- src/meta_schedule/database/database.cc | 6 ++--- src/meta_schedule/database/database_utils.cc | 4 ++-- src/meta_schedule/database/json_database.cc | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 2 +- src/meta_schedule/space_generator/schedule_fn.cc | 2 +- src/meta_schedule/utils.h | 4 ++-- src/node/container_printing.cc | 8 +++---- src/node/reflection.cc | 2 +- src/node/serialization.cc | 22 ++++++++--------- src/node/structural_equal.cc | 4 ++-- src/node/structural_hash.cc | 28 +++++++++++----------- .../backend/contrib/codegen_json/codegen_json.h | 2 +- src/relax/op/tensor/create.cc | 2 +- src/relax/op/tensor/manipulate.cc | 8 +++---- src/relax/transform/tuning_api/database.cc | 6 ++--- src/relax/transform/tuning_api/primitives.cc | 16 ++++++------- src/runtime/container.cc | 14 +++++------ src/runtime/relax_vm/builtin.cc | 6 ++--- src/runtime/relax_vm/vm.cc | 8 +++---- src/script/printer/ir_docsifier.cc | 8 +++---- src/script/printer/legacy_repr.cc | 8 +++---- src/target/target.cc | 20 ++++++++-------- src/te/operation/create_primfunc.cc | 2 +- src/tir/ir/data_type_rewriter.cc | 2 +- src/tir/ir/expr.cc | 6 ++--- src/tir/schedule/concrete_schedule.cc | 4 ++-- src/tir/schedule/instruction_traits.h | 4 ++-- .../schedule/primitive/layout_transformation.cc | 6 ++--- src/tir/schedule/primitive/loop_transformation.cc | 6 ++--- src/tir/schedule/state.cc | 2 +- src/tir/schedule/trace.cc | 18 +++++++------- web/emcc/wasm_runtime.cc | 2 +- 42 files changed, 136 insertions(+), 136 deletions(-) diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 9b3041f3c0..7b79a2c894 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -214,7 +214,7 @@ class SourceMap; /*! * \brief Stores locations in frontend source that generated a node. */ -class SourceMapNode : public Object { +class SourceMapObj : public Object { public: /*! \brief The source mapping. */ Map<SourceName, Source> source_map; @@ -222,12 +222,12 @@ class SourceMapNode : public Object { // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } - bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { + bool SEqualReduce(const SourceMapObj* other, SEqualReducer equal) const { return equal(source_map, other->source_map); } static constexpr const char* _type_key = "SourceMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object); }; class SourceMap : public ObjectRef { @@ -241,12 +241,12 @@ class SourceMap : public ObjectRef { void Add(const Source& source); - SourceMapNode* operator->() { + SourceMapObj* operator->() { ICHECK(get() != nullptr); - return static_cast<SourceMapNode*>(get_mutable()); + return static_cast<SourceMapObj*>(get_mutable()); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapObj); }; } // namespace tvm diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 40b3839a36..0ddcb271ab 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -176,7 +176,7 @@ class NestedMsg : public ObjectRef { bool IsNull() const { return data_ == nullptr; } /*! \return Whether the nested message is nested */ - bool IsNested() const { return data_ != nullptr && data_->IsInstance<ArrayNode>(); } + bool IsNested() const { return data_ != nullptr && data_->IsInstance<ArrayObj>(); } /*! * \return The underlying leaf value. diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 7545dc2c57..7d8de1c234 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -30,12 +30,12 @@ namespace tvm { namespace runtime { using tvm::ffi::Array; -using tvm::ffi::ArrayNode; +using tvm::ffi::ArrayObj; } // namespace runtime // expose class to root namespace using tvm::ffi::Array; -using tvm::ffi::ArrayNode; +using tvm::ffi::ArrayObj; } // namespace tvm #endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 7896276c2b..cd63cc94ad 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -35,6 +35,6 @@ using tvm::ffi::Map; // expose the functions to the root namespace. using tvm::ffi::Map; -using tvm::ffi::MapNode; +using tvm::ffi::MapObj; } // namespace tvm #endif // TVM_RUNTIME_CONTAINER_MAP_H_ diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index d18c721ec7..dcef830452 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -269,7 +269,7 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as<FloatImmNode>()) { obj_string = std::to_string(n->value); - } else if (const auto* n = obj.as<ArrayNode>()) { + } else if (const auto* n = obj.as<ArrayObj>()) { for (size_t i = 0; i < n->size(); i++) { obj_string = obj_string + ToString((*n)[i]); if (n->size() == 1 || i < n->size() - 1) { diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index f1866a9b90..5821950141 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -795,7 +795,7 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, split_begins.push_back(i * size); split_ends.push_back(i * size + size); } - } else if (src_attrs->indices_or_sections->IsInstance<ArrayNode>()) { + } else if (src_attrs->indices_or_sections->IsInstance<ArrayObj>()) { const auto& indices = Downcast<Array<Integer>>(src_attrs->indices_or_sections); int64_t last_index = 0; for (size_t i = 0; i < indices.size(); ++i) { diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 008e63fffc..ae26756431 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -75,7 +75,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> { } } virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; - virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ArrayObj* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -112,7 +112,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> { using namespace tir; FType vtable; // Set dispatch - ATTR_FUNCTOR_DISPATCH(ArrayNode); + ATTR_FUNCTOR_DISPATCH(ArrayObj); ATTR_FUNCTOR_DISPATCH(IntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); diff --git a/src/ir/module.cc b/src/ir/module.cc index 3ce6cfd900..336a90fa98 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -250,7 +250,7 @@ TVM_REGISTER_GLOBAL("ir.IRModule") return DictAttrs(); } else if (auto* as_dict_attrs = attrs.as<tvm::DictAttrsNode>()) { return GetRef<tvm::DictAttrs>(as_dict_attrs); - } else if (attrs.as<tvm::MapNode>()) { + } else if (attrs.as<tvm::MapObj>()) { return tvm::DictAttrs(Downcast<Map<String, Any>>(attrs)); } else { LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map<String,ObjectRef>"; diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 339b08d6ad..8e25b25a4c 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -208,10 +208,10 @@ tvm::String Source::GetLine(int line) { return line_text; } -TVM_REGISTER_NODE_TYPE(SourceMapNode); +TVM_REGISTER_NODE_TYPE(SourceMapObj); SourceMap::SourceMap(Map<SourceName, Source> source_map) { - auto n = make_object<SourceMapNode>(); + auto n = make_object<SourceMapObj>(); n->source_map = std::move(source_map); data_ = std::move(n); } diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index ff2b215e5f..40c8005d5b 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -69,7 +69,7 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { // Step 1. Extract the tag String tag{runtime::ObjectPtr<runtime::StringObj>(nullptr)}; try { - const ArrayNode* json_array = json_obj.as<ArrayNode>(); + const ArrayObj* json_array = json_obj.as<ArrayObj>(); CHECK(json_array && json_array->size() >= 1); tag = json_array->at(0); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error @@ -129,7 +129,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { DLDataType dtype; Array<Integer> shape; try { - const ArrayNode* json_array = json_obj.as<ArrayNode>(); + const ArrayObj* json_array = json_obj.as<ArrayObj>(); CHECK(json_array && json_array->size() == 3); // Load json[1] => dtype { diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 4146ff1b0a..153cb3e9ee 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -51,7 +51,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { IRModule mod{nullptr}; THashCode shash = 0; try { - const ArrayNode* json_array = json_obj.as<ArrayNode>(); + const ArrayObj* json_array = json_obj.as<ArrayObj>(); CHECK(json_array && json_array->size() == 2); // Load json[0] => shash String str_shash = json_array->at(0); @@ -134,7 +134,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w Optional<Target> target; Optional<Array<ArgInfo>> args_info; try { - const ArrayNode* json_array = json_obj.as<ArrayNode>(); + const ArrayObj* json_array = json_obj.as<ArrayObj>(); CHECK(json_array && json_array->size() == 4); // Load json[1] => run_secs if (json_array->at(1) != nullptr) { @@ -146,7 +146,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w } // Load json[3] => args_info if (json_array->at(3) != nullptr) { - const ArrayNode* json_args_info = json_array->at(3).operator const ArrayNode*(); + const ArrayObj* json_args_info = json_array->at(3).operator const ArrayObj*(); Array<ArgInfo> info; info.reserve(json_args_info->size()); for (Any json_arg_info : *json_args_info) { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index f79e91cefc..badba34bca 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -45,7 +45,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { os << std::setprecision(20) << float_imm->value; } else if (const auto* str = json_obj.as<runtime::StringObj>()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; - } else if (const auto* array = json_obj.as<ffi::ArrayNode>()) { + } else if (const auto* array = json_obj.as<ffi::ArrayObj>()) { os << "["; int n = array->size(); for (int i = 0; i < n; ++i) { @@ -55,7 +55,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { JSONDumps(array->at(i), os); } os << "]"; - } else if (const auto* dict = json_obj.as<ffi::MapNode>()) { + } else if (const auto* dict = json_obj.as<ffi::MapObj>()) { int n = dict->size(); std::vector<std::pair<String, ffi::Any>> key_values; key_values.reserve(n); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 8a514798ba..ff6b363bb7 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -189,7 +189,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, const ObjectRef& json_obj = json_objs[task_id]; Workload workload{nullptr}; try { - const ArrayNode* arr = json_obj.as<ArrayNode>(); + const ArrayObj* arr = json_obj.as<ArrayObj>(); ICHECK_EQ(arr->size(), 2); int64_t workload_index = arr->at(0).operator IntImm()->value; ICHECK(workload_index >= 0 && static_cast<size_t>(workload_index) < workloads.size()); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 03c438b0f5..d7b3e707fe 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -34,7 +34,7 @@ using tir::Trace; * \return The result of downcast */ std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) { - const auto* arr = TVM_TYPE_AS(decision, ffi::ArrayNode); + const auto* arr = TVM_TYPE_AS(decision, ffi::ArrayObj); return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr)); } diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index b1b9f550a5..4d583dcb5f 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -54,7 +54,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { if (auto sch = obj.as<tir::Schedule>()) { return {sch.value()}; } - if (const auto* arr = obj.as<ffi::ArrayNode>()) { + if (const auto* arr = obj.as<ffi::ArrayObj>()) { Array<tir::Schedule> result; result.reserve(arr->size()); for (Any val : *arr) { diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index bbd62f72a5..f8812dab37 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -419,7 +419,7 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) { * \return The array of floating point numbers */ inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) { - const ArrayNode* arr = obj.as<ArrayNode>(); + const ArrayObj* arr = obj.as<ArrayObj>(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); Array<FloatImm> results; results.reserve(arr->size()); @@ -446,7 +446,7 @@ inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) { * \return The array of integers */ inline Array<Integer> AsIntArray(const ObjectRef& obj) { - const ArrayNode* arr = obj.as<ArrayNode>(); + const ArrayObj* arr = obj.as<ArrayObj>(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); Array<Integer> results; results.reserve(arr->size()); diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc index c635cef3fe..261ae4825a 100644 --- a/src/node/container_printing.cc +++ b/src/node/container_printing.cc @@ -29,8 +29,8 @@ namespace tvm { // Container printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast<const ArrayNode*>(node.get()); + .set_dispatch<ArrayObj>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const ArrayObj*>(node.get()); p->stream << '['; for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { @@ -42,8 +42,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast<const MapNode*>(node.get()); + .set_dispatch<MapObj>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const MapObj*>(node.get()); p->stream << '{'; for (auto it = op->begin(); it != op->end(); ++it) { if (it != op->begin()) { diff --git a/src/node/reflection.cc b/src/node/reflection.cc index cd59ac5123..142e009194 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -253,7 +253,7 @@ ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, std::vector<AnyView> packed_args(kwargs.size() * 2); int index = 0; - for (const auto& kv : *static_cast<const MapNode*>(kwargs.get())) { + for (const auto& kv : *static_cast<const MapObj*>(kwargs.get())) { packed_args[index] = kv.first.operator String().c_str(); packed_args[index + 1] = kv.second; index += 2; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 51625aca36..86f1fe1703 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -111,13 +111,13 @@ class NodeIndexer : public AttrVisitor { return; } MakeNodeIndex(node); - if (auto opt_array = node.as<const ArrayNode*>()) { - const ArrayNode* n = opt_array.value(); + if (auto opt_array = node.as<const ArrayObj*>()) { + const ArrayObj* n = opt_array.value(); for (auto elem : *n) { MakeIndex(elem); } - } else if (auto opt_map = node.as<const MapNode*>()) { - const MapNode* n = opt_map.value(); + } else if (auto opt_map = node.as<const MapObj*>()) { + const MapObj* n = opt_map.value(); bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { return v.first.template as<const ffi::StringObj*>().has_value(); }); @@ -272,13 +272,13 @@ class JSONAttrGetter : public AttrVisitor { node_->attrs.clear(); node_->data.clear(); - if (auto opt_array = node.as<const ArrayNode*>()) { - const ArrayNode* n = opt_array.value(); + if (auto opt_array = node.as<const ArrayObj*>()) { + const ArrayObj* n = opt_array.value(); for (size_t i = 0; i < n->size(); ++i) { node_->data.push_back(node_index_->at(n->at(i))); } - } else if (auto opt_map = node.as<const MapNode*>()) { - const MapNode* n = opt_map.value(); + } else if (auto opt_map = node.as<const MapObj*>()) { + const MapObj* n = opt_map.value(); bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { return v.first.template as<const ffi::StringObj*>().has_value(); }); @@ -380,7 +380,7 @@ class FieldDependencyFinder : public AttrVisitor { return; } // Skip containers - if (jnode->type_key == ArrayNode::_type_key || jnode->type_key == MapNode::_type_key) { + if (jnode->type_key == ArrayObj::_type_key || jnode->type_key == MapObj::_type_key) { return; } jnode_ = jnode; @@ -517,13 +517,13 @@ class JSONAttrSetter : public AttrVisitor { void SetAttrs(Any* node, JSONNode* jnode) { jnode_ = jnode; // handling Array - if (jnode->type_key == ArrayNode::_type_key) { + if (jnode->type_key == ArrayObj::_type_key) { Array<Any> result; for (auto index : jnode->data) { result.push_back(node_list_->at(index)); } *node = result; - } else if (jnode->type_key == MapNode::_type_key) { + } else if (jnode->type_key == MapObj::_type_key) { Map<Any, Any> result; if (jnode->keys.empty()) { ICHECK_EQ(jnode->data.size() % 2, 0U); diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 260a02da8f..01331d2546 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -421,7 +421,7 @@ class SEqualHandlerDefault::Impl { cfg->path_to_underline.push_back(first_mismatch_->value()->lhs_path); // The TVMScriptPrinter::Script will fallback to Repr printer, // if the root node to print is not supported yet, - // e.g. Relax nodes, ArrayNode, MapNode, etc. + // e.g. Relax nodes, ArrayObj, MapObj, etc. oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg); } } else { @@ -436,7 +436,7 @@ class SEqualHandlerDefault::Impl { cfg->path_to_underline.push_back(first_mismatch_->value()->rhs_path); // The TVMScriptPrinter::Script will fallback to Repr printer, // if the root node to print is not supported yet, - // e.g. Relax nodes, ArrayNode, MapNode, etc. + // e.g. Relax nodes, ArrayObj, MapObj, etc. oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg); } } else { diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index bd9f39d617..4835518e10 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -406,17 +406,17 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrai return blob; }); -struct ArrayNodeTrait { +struct ArrayObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { + static void SHashReduce(const ArrayObj* key, SHashReducer hash_reduce) { hash_reduce(static_cast<uint64_t>(key->size())); for (uint32_t i = 0; i < key->size(); ++i) { hash_reduce(key->at(i)); } } - static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { + static bool SEqualReduce(const ArrayObj* lhs, const ArrayObj* rhs, SEqualReducer equal) { if (equal.IsPathTracingEnabled()) { return SEqualReduceTraced(lhs, rhs, equal); } @@ -429,7 +429,7 @@ struct ArrayNodeTrait { } private: - static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs, + static bool SEqualReduceTraced(const ArrayObj* lhs, const ArrayObj* rhs, const SEqualReducer& equal) { uint32_t min_size = std::min(lhs->size(), rhs->size()); const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths(); @@ -487,9 +487,9 @@ struct ArrayNodeTrait { return false; } }; -TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) +TVM_REGISTER_REFLECTION_VTABLE(ArrayObj, ArrayObjTrait) .set_creator([](const std::string&) -> ObjectPtr<Object> { - return ::tvm::runtime::make_object<ArrayNode>(); + return ::tvm::runtime::make_object<ArrayObj>(); }); struct ShapeTupleObjTrait { @@ -536,10 +536,10 @@ TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait) return blob; }); -struct MapNodeTrait { +struct MapObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) { + static void SHashReduceForOMap(const MapObj* key, SHashReducer hash_reduce) { // SHash's var handling depends on the determinism of traversal. // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store @@ -575,7 +575,7 @@ struct MapNodeTrait { } } - static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) { + static void SHashReduceForSMap(const MapObj* key, SHashReducer hash_reduce) { // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store // Map<Var, Value> where Var is defined in the function @@ -598,7 +598,7 @@ struct MapNodeTrait { } } - static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { + static void SHashReduce(const MapObj* key, SHashReducer hash_reduce) { bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { return v.first.template as<const ffi::StringObj*>(); }); @@ -609,7 +609,7 @@ struct MapNodeTrait { } } - static bool SEqualReduceTraced(const MapNode* lhs, const MapNode* rhs, + static bool SEqualReduceTraced(const MapObj* lhs, const MapObj* rhs, const SEqualReducer& equal) { const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); // First, check that every key from `lhs` is also in `rhs`, @@ -649,7 +649,7 @@ struct MapNodeTrait { TVM_FFI_UNREACHABLE(); } - static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + static bool SEqualReduce(const MapObj* lhs, const MapObj* rhs, SEqualReducer equal) { if (equal.IsPathTracingEnabled()) { return SEqualReduceTraced(lhs, rhs, equal); } @@ -670,8 +670,8 @@ struct MapNodeTrait { return true; } }; -TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) - .set_creator([](const std::string&) -> ObjectPtr<Object> { return MapNode::Empty(); }); +TVM_REGISTER_REFLECTION_VTABLE(MapObj, MapObjTrait) + .set_creator([](const std::string&) -> ObjectPtr<Object> { return MapObj::Empty(); }); struct ReportNodeTrait { static void VisitAttrs(runtime::profiling::ReportNode* report, AttrVisitor* attrs) { diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 8e6313d0a0..ce8ad2f8a7 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -111,7 +111,7 @@ class OpAttrExtractor : public AttrVisitor { } void Visit(const char* key, runtime::ObjectRef* value) final { - if (const auto* an = (*value).as<ArrayNode>()) { + if (const auto* an = (*value).as<ArrayObj>()) { std::vector<std::string> attr; for (size_t i = 0; i < an->size(); ++i) { if (const auto* im = (*an)[i].as<IntImmNode>()) { diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 5d556cc356..361f90c7b0 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -40,7 +40,7 @@ Expr full(Variant<Expr, Array<PrimExpr>> shape, Expr fill_value, Optional<DataTy Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as<ExprNode>()) { shape_in_expr = GetRef<Expr>(expr); - } else if (const auto* _array = shape.as<ArrayNode>()) { + } else if (const auto* _array = shape.as<ArrayObj>()) { shape_in_expr = ShapeExpr(GetRef<Array<PrimExpr>>(_array)); } else { LOG(FATAL) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 99eb9a46d0..fb84de79cf 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -812,16 +812,16 @@ TVM_REGISTER_OP("relax.permute_dims") /* relax.reshape */ Expr ConvertNewShapeToExpr(const Expr& data, const Variant<Expr, Array<PrimExpr>>& shape) { - const ArrayNode* array; + const ArrayObj* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as<ShapeExprNode>()) { - array = e->values.as<ArrayNode>(); + array = e->values.as<ArrayObj>(); // Other non-shape expressions are used directly. } else if (const auto* e = shape.as<ExprNode>()) { return GetRef<Expr>(e); // Process special values in constants and produce an expression. } else { - array = shape.as<ArrayNode>(); + array = shape.as<ArrayObj>(); } CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " @@ -973,7 +973,7 @@ Expr split(Expr x, Variant<IntImm, Array<IntImm>> indices_or_sections, int axis) ObjectPtr<SplitAttrs> attrs = make_object<SplitAttrs>(); ObjectRef indices_or_sections_obj; - if (const auto* indices = indices_or_sections.as<ArrayNode>()) { + if (const auto* indices = indices_or_sections.as<ArrayObj>()) { for (int i = 0; i < static_cast<int>(indices->size()); ++i) { const auto* idx = indices->at(i).as<IntImmNode>(); CHECK(idx != nullptr) << "Split op only accepts an array of integers as the indices. " diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index a52d9a9e1a..3bcb1b6da6 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -58,7 +58,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { Trace trace{nullptr}; Optional<Array<FloatImm>> run_secs{nullptr}; try { - const ArrayNode* json_array = json_obj.as<ArrayNode>(); + const ArrayObj* json_array = json_obj.as<ArrayObj>(); CHECK(json_array && json_array->size() == 2); // Load json[0] => trace { @@ -264,7 +264,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { const ObjectRef& json_obj = json_objs[task_id]; try { - const ArrayNode* arr = json_obj.as<ArrayNode>(); + const ArrayObj* arr = json_obj.as<ArrayObj>(); ICHECK_EQ(arr->size(), 3); workload_idxs[task_id] = Downcast<Integer>(arr->at(0)).IntValue(); targets[task_id] = Target(Downcast<Map<String, ffi::Any>>(arr->at(1))); @@ -296,7 +296,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { const ObjectRef& json_obj = json_objs[task_id]; try { - const ArrayNode* arr = json_obj.as<ArrayNode>(); + const ArrayObj* arr = json_obj.as<ArrayObj>(); ICHECK_EQ(arr->size(), 3); workload_idxs[task_id] = Downcast<Integer>(arr->at(0)).IntValue(); targets[task_id] = Target(Downcast<Map<String, ffi::Any>>(arr->at(1))); diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc index 8f2232db7e..a19022344b 100644 --- a/src/relax/transform/tuning_api/primitives.cc +++ b/src/relax/transform/tuning_api/primitives.cc @@ -64,12 +64,12 @@ Choice Choice::FromJSON(const ObjectRef& json) { String transform_func_key, constr_func_key; Array<Any> transform_func_args, constr_func_args; try { - const ArrayNode* arr = json.as<ArrayNode>(); + const ArrayObj* arr = json.as<ArrayObj>(); ICHECK(arr && arr->size() == 4); const auto* arr0 = arr->at(0).as<ffi::StringObj>(); - const auto* arr1 = arr->at(1).as<ArrayNode>(); + const auto* arr1 = arr->at(1).as<ArrayObj>(); const auto* arr2 = arr->at(2).as<ffi::StringObj>(); - const auto* arr3 = arr->at(3).as<ArrayNode>(); + const auto* arr3 = arr->at(3).as<ArrayObj>(); ICHECK(arr0 && arr1 && arr2 && arr3); transform_func_key = GetRef<String>(arr0); { @@ -123,10 +123,10 @@ Knob Knob::FromJSON(const ObjectRef& json) { String name; Map<String, Choice> choices; try { - const ArrayNode* arr = json.as<ArrayNode>(); + const ArrayObj* arr = json.as<ArrayObj>(); ICHECK(arr && arr->size() == 2); const auto* arr0 = arr->at(0).as<ffi::StringObj>(); - const auto* arr1 = arr->at(1).as<MapNode>(); + const auto* arr1 = arr->at(1).as<MapObj>(); ICHECK(arr0 && arr1); name = GetRef<String>(arr0); for (auto const& x : GetRef<Map<String, ffi::Any>>(arr1)) { @@ -198,12 +198,12 @@ Trace Trace::FromJSON(const ObjectRef& json) { Array<Knob> knobs; Array<String> decisions; try { - const ArrayNode* arr = json.as<ArrayNode>(); + const ArrayObj* arr = json.as<ArrayObj>(); // A trace will have 2 or 3 entries depending on `include_irmod` parameter. ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); - const auto* arr0 = arr->at(0).as<ArrayNode>(); - const auto* arr1 = arr->at(1).as<ArrayNode>(); + const auto* arr0 = arr->at(0).as<ArrayObj>(); + const auto* arr1 = arr->at(1).as<ArrayObj>(); ICHECK(arr0 && arr1); for (const Any& elem : *arr0) { diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 0a3a703a8e..838e1c0f03 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -33,7 +33,7 @@ namespace tvm { namespace runtime { // Array -TVM_REGISTER_OBJECT_TYPE(ArrayNode); +TVM_REGISTER_OBJECT_TYPE(ArrayObj); TVM_REGISTER_GLOBAL("runtime.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) { Array<Any> result; @@ -44,9 +44,9 @@ TVM_REGISTER_GLOBAL("runtime.Array").set_body_packed([](ffi::PackedArgs args, An }); TVM_REGISTER_GLOBAL("runtime.ArrayGetItem") - .set_body_typed([](const ffi::ArrayNode* n, int64_t i) -> Any { return n->at(i); }); + .set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }); -TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body_typed([](const ffi::ArrayNode* n) -> int64_t { +TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t { return static_cast<int64_t>(n->size()); }); @@ -69,17 +69,17 @@ TVM_REGISTER_GLOBAL("runtime.Map").set_body_packed([](ffi::PackedArgs args, Any* *ret = data; }); -TVM_REGISTER_GLOBAL("runtime.MapSize").set_body_typed([](const ffi::MapNode* n) -> int64_t { +TVM_REGISTER_GLOBAL("runtime.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t { return static_cast<int64_t>(n->size()); }); TVM_REGISTER_GLOBAL("runtime.MapGetItem") - .set_body_typed([](const ffi::MapNode* n, const Any& k) -> Any { return n->at(k); }); + .set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }); TVM_REGISTER_GLOBAL("runtime.MapCount") - .set_body_typed([](const ffi::MapNode* n, const Any& k) -> int64_t { return n->count(k); }); + .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }); -TVM_REGISTER_GLOBAL("runtime.MapItems").set_body_typed([](const ffi::MapNode* n) -> Array<Any> { +TVM_REGISTER_GLOBAL("runtime.MapItems").set_body_typed([](const ffi::MapObj* n) -> Array<Any> { Array<Any> rkvs; for (const auto& kv : *n) { rkvs.push_back(kv.first); diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index f7b08c097a..bfc94d1897 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -309,7 +309,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrim */ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional<String> err_ctx) { // a function that lazily get context for error reporting - auto* ptr = arg.as<ffi::ArrayNode>(); + auto* ptr = arg.as<ffi::ArrayObj>(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " << arg->GetTypeKey(); CHECK(static_cast<int64_t>(ptr->size()) == size) @@ -495,8 +495,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem") .set_body_typed([](runtime::Array<Any> arr, int64_t index) { return arr[index]; }); TVM_REGISTER_GLOBAL("vm.builtin.tuple_reset_item") - .set_body_typed([](const ffi::ArrayNode* arr, int64_t index) { - const_cast<ffi::ArrayNode*>(arr)->SetItem(index, nullptr); + .set_body_typed([](const ffi::ArrayObj* arr, int64_t index) { + const_cast<ffi::ArrayObj*>(arr)->SetItem(index, nullptr); }); TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body_packed([](ffi::PackedArgs args, Any* rv) { diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 4f518e4adb..461b64f50e 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -71,7 +71,7 @@ PackedFunc VMClosure::BindLastArgs(PackedFunc func, std::vector<Any> last_args) Any IndexIntoNestedObject(Any obj, TVMArgs args, int starting_arg_idx) { for (int i = starting_arg_idx; i < args.size(); i++) { // the object must be an Array to be able to index into it - if (!obj.as<ffi::ArrayNode>()) { + if (!obj.as<ffi::ArrayObj>()) { LOG(FATAL) << "ValueError: Attempted to index into an object that is not an Array."; } int index = args[i]; @@ -98,7 +98,7 @@ NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* allo Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { if (src.as<NDArray::ContainerType>()) { return ConvertNDArrayToDevice(Downcast<NDArray>(src), dev, alloc); - } else if (src.as<ffi::ArrayNode>()) { + } else if (src.as<ffi::ArrayObj>()) { std::vector<Any> ret; auto arr = Downcast<ffi::Array<Any>>(src); for (size_t i = 0; i < arr.size(); i++) { @@ -912,7 +912,7 @@ void VirtualMachineImpl::_GetOutputArity(TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0]; RegType out = LookupVMOutput(func_name); Any obj = IndexIntoNestedObject(out, args, 1); - if (const auto* arr = obj.as<ffi::ArrayNode>()) { + if (const auto* arr = obj.as<ffi::ArrayObj>()) { *rv = static_cast<int>(arr->size()); } else { *rv = -1; @@ -923,7 +923,7 @@ void VirtualMachineImpl::_GetOutput(TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0]; RegType out = LookupVMOutput(func_name); Any obj = IndexIntoNestedObject(out, args, 1); - if (obj.as<ffi::ArrayNode>()) { + if (obj.as<ffi::ArrayObj>()) { LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " "Please specify another index argument."; return; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 512ca54b42..1a99448711 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -138,13 +138,13 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, } visited_.insert(obj); stack_.push_back(obj); - if (obj->IsInstance<ArrayNode>()) { - const ArrayNode* array = static_cast<const ArrayNode*>(obj); + if (obj->IsInstance<ArrayObj>()) { + const ArrayObj* array = static_cast<const ArrayObj*>(obj); for (Any element : *array) { this->RecursiveVisitAny(&element); } - } else if (obj->IsInstance<MapNode>()) { - const MapNode* map = static_cast<const MapNode*>(obj); + } else if (obj->IsInstance<MapObj>()) { + const MapObj* map = static_cast<const MapObj*>(obj); for (std::pair<Any, Any> kv : *map) { this->RecursiveVisitAny(&kv.first); this->RecursiveVisitAny(&kv.second); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index f1be8ecce5..0316c3b27d 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -75,8 +75,8 @@ ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // N } TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch<ArrayNode>([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast<const ArrayNode*>(node.get()); + .set_dispatch<ArrayObj>([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast<const ArrayObj*>(node.get()); (*p) << '['; for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { @@ -88,8 +88,8 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch<MapNode>([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast<const MapNode*>(node.get()); + .set_dispatch<MapObj>([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast<const MapObj*>(node.get()); (*p) << '{'; for (auto it = op->begin(); it != op->end(); ++it) { if (it != op->begin()) { diff --git a/src/target/target.cc b/src/target/target.cc index be96abc8dd..3fbd7741a5 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -76,7 +76,7 @@ class TargetInternal { static std::string Interpret(const std::string& str); static std::string Uninterpret(const std::string& str); static std::string StringifyAtomicType(const Any& obj); - static std::string StringifyArray(const ArrayNode& array); + static std::string StringifyArray(const ArrayObj& array); static constexpr char quote = '\''; static constexpr char escape = '\\'; @@ -394,7 +394,7 @@ Any TargetInternal::ParseType(const std::string& str, const TargetKindNode::Valu } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target return Target(TargetInternal::FromString(interp_str)); - } else if (info.type_index == ArrayNode::RuntimeTypeIndex()) { + } else if (info.type_index == ArrayObj::RuntimeTypeIndex()) { // Parsing array std::vector<ObjectRef> result; for (const std::string& substr : SplitString(interp_str, ',')) { @@ -428,7 +428,7 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf return opt.value(); } else if (auto str = obj.as<String>()) { return Target(TargetInternal::FromString(str.value())); - } else if (const auto* ptr = obj.as<MapNode>()) { + } else if (const auto* ptr = obj.as<MapObj>()) { for (const auto& kv : *ptr) { if (!kv.first.as<StringObj>()) { TVM_FFI_THROW(TypeError) @@ -440,9 +440,9 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf } TVM_FFI_THROW(TypeError) << "Expect type 'dict' or 'str' to construct Target, but get: " + obj.GetTypeKey(); - } else if (info.type_index == ArrayNode::RuntimeTypeIndex()) { + } else if (info.type_index == ArrayObj::RuntimeTypeIndex()) { // Parsing array - const auto* array = ObjTypeCheck<const ArrayNode*>(obj, "Array"); + const auto* array = ObjTypeCheck<const ArrayObj*>(obj, "Array"); std::vector<ObjectRef> result; for (const Any& e : *array) { try { @@ -453,9 +453,9 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf } } return Array<ObjectRef>(result); - } else if (info.type_index == MapNode::RuntimeTypeIndex()) { + } else if (info.type_index == MapObj::RuntimeTypeIndex()) { // Parsing map - const auto* map = ObjTypeCheck<const MapNode*>(obj, "Map"); + const auto* map = ObjTypeCheck<const MapObj*>(obj, "Map"); std::unordered_map<Any, Any, ffi::AnyHash, ffi::AnyEqual> result; for (const auto& kv : *map) { Any key, val; @@ -502,7 +502,7 @@ std::string TargetInternal::StringifyAtomicType(const Any& obj) { TVM_FFI_UNREACHABLE(); } -std::string TargetInternal::StringifyArray(const ArrayNode& array) { +std::string TargetInternal::StringifyArray(const ArrayObj& array) { std::vector<std::string> elements; for (const Any& item : array) { @@ -531,7 +531,7 @@ Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ffi::Any> std::string value; // skip undefined attrs if (obj == nullptr) continue; - if (const auto* array = obj.as<ArrayNode>()) { + if (const auto* array = obj.as<ArrayObj>()) { value = String(StringifyArray(*array)); } else { value = StringifyAtomicType(obj); @@ -889,7 +889,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ffi::Any> config) { bool has_user_keys = config.count(kKeys); if (has_user_keys) { // user provided keys - if (const auto* cfg_keys = config[kKeys].as<ArrayNode>()) { + if (const auto* cfg_keys = config[kKeys].as<ArrayObj>()) { for (const Any& e : *cfg_keys) { if (auto key = e.as<String>()) { keys.push_back(key.value()); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 64e8f46852..b5da43d670 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -309,7 +309,7 @@ Map<String, ffi::Any> GenerateBlockAnnotations(const te::ComputeOp& compute_op, const String& key = pair.first; const Any& value = pair.second; // TensorIR will not allow Tensor data structure - if (value.as<ArrayNode>()) { + if (value.as<ArrayObj>()) { const auto array_value = Downcast<Array<ffi::Any>>(value); annotations.Set(key, array_value.Map(mutate_attr)); } else { diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 5aa0db9535..3f1576ab92 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -381,7 +381,7 @@ Map<String, ffi::Any> IndexDataTypeRewriter::VisitBlockAnnotations( if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) { return new_buffer; } - } else if (obj->IsInstance<ArrayNode>()) { + } else if (obj->IsInstance<ArrayObj>()) { return Downcast<Array<ObjectRef>>(obj).Map(f_mutate_obj); } return obj; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 81566c60a4..1590c51f38 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -649,8 +649,8 @@ CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result, << "ValueError: The number of identities must equal to the number of elements in `results`"; // Change the dtype of input vars to adapt to the dtype of identities - ArrayNode* p_lhs = lhs.CopyOnWrite(); - ArrayNode* p_rhs = rhs.CopyOnWrite(); + ArrayObj* p_lhs = lhs.CopyOnWrite(); + ArrayObj* p_rhs = rhs.CopyOnWrite(); std::unordered_map<const VarNode*, PrimExpr> var_map; var_map.reserve(n_group * 2); for (int i = 0; i < static_cast<int>(n_group); ++i) { @@ -664,7 +664,7 @@ CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result, p_rhs->SetItem(i, r); } - ArrayNode* p_result = result.CopyOnWrite(); + ArrayObj* p_result = result.CopyOnWrite(); for (int i = 0; i < static_cast<int>(n_group); ++i) { p_result->SetItem(i, Substitute(result[i], var_map)); } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 80796a243e..48fac7d5ef 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -941,7 +941,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } return res_expr; } - if (const auto* arr = ann_val.as<ArrayNode>()) { + if (const auto* arr = ann_val.as<ArrayObj>()) { Array<Any> result; result.reserve(arr->size()); for (size_t i = 0; i < arr->size(); i++) { @@ -949,7 +949,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } return std::move(result); } - if (const auto* dict = ann_val.as<MapNode>()) { + if (const auto* dict = ann_val.as<MapObj>()) { Map<String, ffi::Any> result; for (auto it = dict->begin(); it != dict->end(); ++it) { const auto& key = it->first; diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 53deb7457c..924b2ddb00 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -425,7 +425,7 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { } else if (const auto opt_float_imm = obj.as<FloatImm>()) { os.precision(17); os << (*opt_float_imm)->value; - } else if (const auto* array = obj.as<ArrayNode>()) { + } else if (const auto* array = obj.as<ArrayObj>()) { os << '['; bool is_first = true; for (Any e : *array) { @@ -437,7 +437,7 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { AsPythonString(e, os); } os << ']'; - } else if (const auto* dict = obj.as<MapNode>()) { + } else if (const auto* dict = obj.as<MapObj>()) { os << '{'; bool is_first = true; std::vector<std::pair<std::string, std::string>> dict_items; diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 3e665a47e5..6a0bba5c81 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1209,7 +1209,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ GlobalVar g_var; const auto* old_func = GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); - MapNode* new_map = new_mod->functions.CopyOnWrite(); + MapObj* new_map = new_mod->functions.CopyOnWrite(); Map<Var, Buffer> new_buffer_map; for (auto [var, buffer] : old_func->buffer_map) { @@ -1533,10 +1533,10 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer GlobalVar g_var; GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); - MapNode* new_map = new_mod->functions.CopyOnWrite(); + MapObj* new_map = new_mod->functions.CopyOnWrite(); PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var))); PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); - MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite(); + MapObj* new_buffer_map = new_func->buffer_map.CopyOnWrite(); for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { if ((*it).second.same_as(old_buffer)) { (*it).second = new_buffer; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7d562712db..b8f4dfd58c 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -77,13 +77,13 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent, + explicit IterMapSimplifyBlockBinding(MapObj* opaque_blocks, Map<Var, Range> loop_var2extent, bool preserve_unit_iters) : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent), preserve_unit_iters_(preserve_unit_iters) {} - static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks, + static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapObj* opaque_blocks, bool preserve_unit_iters) { Map<Var, Range> loop_var2extent; for (const StmtSRef& sref : loop_srefs) { @@ -132,7 +132,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { } /*! \brief The reuse mapping */ - MapNode* opaque_blocks_; + MapObj* opaque_blocks_; /*! \brief The range of loops */ Map<Var, Range> loop_var2extent_; /*! \brief Internal analyzer */ diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ecb857a4c3..f2301a1a77 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -945,7 +945,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } // Ensure the uniqueness of `this->mod` and `this->mod->functions` IRModuleNode* new_mod = this->mod.CopyOnWrite(); - MapNode* new_map = new_mod->functions.CopyOnWrite(); + MapObj* new_map = new_mod->functions.CopyOnWrite(); // Move out the PrimFunc where the sref belong while ensuring uniqueness PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var))); ICHECK(ref_new_func.get() == g_func); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 4c72ce4d1b..8eb95be8fb 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -126,10 +126,10 @@ Array<Any> TranslateInputRVs( } else if (input.as<IntImmNode>() || input.as<FloatImmNode>()) { // Case 3. integer or floating-point number results.push_back(input); - } else if (input.as<ArrayNode>()) { + } else if (input.as<ArrayObj>()) { // Case 4: array results.push_back(TranslateInputRVs(Downcast<Array<Any>>(Any(input)), rv_names)); - } else if (input.as<MapNode>()) { + } else if (input.as<MapObj>()) { // Case 5: dict results.push_back(input); } else if (input.as<IndexMapNode>()) { @@ -166,12 +166,12 @@ Array<Any> TranslateInputRVs(const Array<Any>& inputs, continue; } // Case 4. array - if (input.as<ArrayNode>()) { + if (input.as<ArrayObj>()) { results.push_back(TranslateInputRVs(Downcast<Array<Any>>(input), named_rvs)); continue; } // Case 5. dict - if (input.as<MapNode>()) { + if (input.as<MapObj>()) { results.push_back(input); continue; } @@ -378,10 +378,10 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Array<Any> json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` try { - const ArrayNode* arr = json.as<ArrayNode>(); + const ArrayObj* arr = json.as<ArrayObj>(); ICHECK(arr && arr->size() == 2); - const auto* arr0 = arr->at(0).as<ArrayNode>(); - const auto* arr1 = arr->at(1).as<ArrayNode>(); + const auto* arr0 = arr->at(0).as<ArrayObj>(); + const auto* arr1 = arr->at(1).as<ArrayObj>(); ICHECK(arr0 && arr1); json_insts = GetRef<Array<Any>>(arr0); json_decisions = GetRef<Array<Any>>(arr1); @@ -397,7 +397,7 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { int index = -1; Any decision{nullptr}; try { - const ArrayNode* arr = decision_entry.as<ArrayNode>(); + const ArrayObj* arr = decision_entry.as<ArrayObj>(); ICHECK(arr && arr->size() == 2); auto arr0 = arr->at(0).as<IntImm>(); ICHECK(arr0); @@ -421,7 +421,7 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Array<String> outputs{ObjectPtr<Object>{nullptr}}; // Parse the entry try { - const auto* arr = inst_entry.as<ArrayNode>(); + const auto* arr = inst_entry.as<ArrayObj>(); ICHECK(arr && arr->size() == 4); const auto* arr0 = arr->at(0).as<StringObj>(); kind = InstructionKind::Get(arr0->data); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index c41dd01af5..6e2dfb366a 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -165,7 +165,7 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRe std::vector<Any> data; for (int i = 0; i < args.size(); ++i) { // Get i-th TVMArray - auto* arr_i = args[i].as<ArrayNode>(); + auto* arr_i = args[i].as<ArrayObj>(); ICHECK(arr_i != nullptr); for (size_t j = 0; j < arr_i->size(); ++j) { // Push back each j-th element of the i-th array
