This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 148737b1ff [IR] Compact Functor vtable (#17731)
148737b1ff is described below
commit 148737b1ffe194645cb25fb810296c2edc8ef345
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 11 10:40:29 2025 -0400
[IR] Compact Functor vtable (#17731)
This PR add a finalize routine to optionally compact functor vtable
dynamically.
Also updates child_slots for key types to make sure the IR node type
index stay within range and such compact happens.
---
include/tvm/arith/iter_affine_map.h | 2 +-
include/tvm/ir/expr.h | 4 ++--
include/tvm/ir/type_functor.h | 1 +
include/tvm/node/functor.h | 28 +++++++++++++++++++++++++++-
include/tvm/relax/dataflow_pattern.h | 2 ++
include/tvm/relax/dataflow_pattern_functor.h | 2 +-
include/tvm/relax/expr.h | 4 ++--
include/tvm/relax/expr_functor.h | 1 +
include/tvm/relax/struct_info_functor.h | 1 +
include/tvm/tir/expr_functor.h | 1 +
include/tvm/tir/stmt_functor.h | 1 +
src/ir/attr_functor.h | 1 +
src/relax/ir/py_expr_functor.cc | 3 +++
src/runtime/object.cc | 10 +++++++++-
14 files changed, 53 insertions(+), 8 deletions(-)
diff --git a/include/tvm/arith/iter_affine_map.h
b/include/tvm/arith/iter_affine_map.h
index 53c5b32dd2..d2a6f9a745 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -69,7 +69,7 @@ class IterMapExprNode : public PrimExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "arith.IterMapExpr";
- static constexpr const uint32_t _type_child_slots = 3;
+ static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
};
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index b3b4e8ab32..53af269756 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -58,7 +58,7 @@ class BaseExprNode : public Object {
static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
- static constexpr const uint32_t _type_child_slots = 62;
+ static constexpr const uint32_t _type_child_slots = 64;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
@@ -104,7 +104,7 @@ class PrimExprNode : public BaseExprNode {
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
static constexpr const char* _type_key = "PrimExpr";
- static constexpr const uint32_t _type_child_slots = 38;
+ static constexpr const uint32_t _type_child_slots = 40;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};
diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h
index 2c145e480b..858226354c 100644
--- a/include/tvm/ir/type_functor.h
+++ b/include/tvm/ir/type_functor.h
@@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h
index 58d59c81cb..82ea37566e 100644
--- a/include/tvm/node/functor.h
+++ b/include/tvm/node/functor.h
@@ -26,6 +26,7 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/object.h>
+#include <cstring>
#include <type_traits>
#include <utility>
#include <vector>
@@ -72,6 +73,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
/*! \brief internal function table */
std::vector<FPointer> func_;
+ /*! \brief start range of func index */
+ uint32_t begin_type_index_{0};
public:
/*! \brief the result type of this functor */
@@ -83,6 +86,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
*/
bool can_dispatch(const ObjectRef& n) const {
uint32_t type_index = n->type_index();
+ if (type_index < begin_type_index_) return false;
+ type_index -= begin_type_index_;
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
@@ -94,7 +99,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
R operator()(const ObjectRef& n, Args... args) const {
ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on
type "
<< n->GetTypeKey();
- return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
+ return (*func_[n->type_index() - begin_type_index_])(n,
std::forward<Args>(args)...);
}
/*!
* \brief set the dispatcher for type TNode
@@ -109,6 +114,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
func_.resize(tindex + 1, nullptr);
}
ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key <<
" is already set";
+ ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after
calling Finalize";
func_[tindex] = f;
return *this;
}
@@ -122,9 +128,29 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
+ ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after
calling Finalize";
func_[tindex] = nullptr;
return *this;
}
+ /*!
+ * \brief Finalize the functor after calling sequence of set_dispatch
+ * This function will attempt to find the min type index that is not null
+ * and optimize the space of the func table so it is more compact
+ */
+ void Finalize() {
+ ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once";
+ while (begin_type_index_ < func_.size() && func_[begin_type_index_] ==
nullptr) {
+ ++begin_type_index_;
+ }
+ // shift up the function value
+ size_t new_ftable_size = func_.size() - begin_type_index_;
+ if (begin_type_index_ != 0) {
+ std::memmove(func_.data(), func_.data() + begin_type_index_,
+ new_ftable_size * sizeof(FPointer));
+ }
+ func_.resize(new_ftable_size);
+ func_.shrink_to_fit();
+ }
};
#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto&
__make_functor##_##ClsName
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index df9fdcad97..b3bbebd0e0 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -91,6 +91,7 @@ TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const
PatternSeq& rhs);
class DFPatternNode : public Object {
public:
static constexpr const char* _type_key = "DFPatternNode";
+ static constexpr const uint32_t _type_child_slots = 21;
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
};
@@ -373,6 +374,7 @@ class VarPatternNode : public DFPatternNode {
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relax.dpl.VarPattern";
+ static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode);
};
diff --git a/include/tvm/relax/dataflow_pattern_functor.h
b/include/tvm/relax/dataflow_pattern_functor.h
index bbdda44213..fb67f3cc4a 100644
--- a/include/tvm/relax/dataflow_pattern_functor.h
+++ b/include/tvm/relax/dataflow_pattern_functor.h
@@ -135,12 +135,12 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
-
RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index fb6f0e40b1..330ff7e8da 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -119,7 +119,7 @@ class StructInfoNode : public Object {
static constexpr const char* _type_key = "StructInfo";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
- static constexpr const uint32_t _type_child_slots = 5;
+ static constexpr const uint32_t _type_child_slots = 7;
TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object);
};
@@ -416,7 +416,7 @@ class VarNode : public LeafExprNode {
static constexpr const char* _type_key = "relax.expr.Var";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
- static constexpr const uint32_t _type_child_slots = 2;
+ static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode);
};
diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h
index cdc09c4431..4904b02960 100644
--- a/include/tvm/relax/expr_functor.h
+++ b/include/tvm/relax/expr_functor.h
@@ -176,6 +176,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode);
RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode);
RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/include/tvm/relax/struct_info_functor.h
b/include/tvm/relax/struct_info_functor.h
index 8418b48dc1..2ce5627547 100644
--- a/include/tvm/relax/struct_info_functor.h
+++ b/include/tvm/relax/struct_info_functor.h
@@ -108,6 +108,7 @@ class StructInfoFunctor<R(const StructInfo& n, Args...)> {
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(distributed::DTensorStructInfoNode);
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode);
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h
index 3f66164b42..7a9cf91a65 100644
--- a/include/tvm/tir/expr_functor.h
+++ b/include/tvm/tir/expr_functor.h
@@ -193,6 +193,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index c5b20f8ec0..e9a41468d3 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -126,6 +126,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h
index 12b4f6f65b..008e63fffc 100644
--- a/src/ir/attr_functor.h
+++ b/src/ir/attr_functor.h
@@ -139,6 +139,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(CastNode);
ATTR_FUNCTOR_DISPATCH(CallNode);
ATTR_FUNCTOR_DISPATCH(SelectNode);
+ vtable.Finalize();
return vtable;
}
};
diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc
index a7ac245610..eb286b4ef6 100644
--- a/src/relax/ir/py_expr_functor.cc
+++ b/src/relax/ir/py_expr_functor.cc
@@ -161,6 +161,7 @@ class PyExprVisitorNode : public Object, public ExprVisitor
{
PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_);
PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_);
PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_);
+ vtable.Finalize();
return vtable;
}
};
@@ -414,6 +415,7 @@ class PyExprMutatorNode : public Object, public ExprMutator
{
PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_);
PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_);
PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_);
+ vtable.Finalize();
return vtable;
}
@@ -437,6 +439,7 @@ class PyExprMutatorNode : public Object, public ExprMutator
{
PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode);
PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode);
PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode);
+ post_order_vtable.Finalize();
return post_order_vtable;
}
};
diff --git a/src/runtime/object.cc b/src/runtime/object.cc
index 05bfd6d1cf..85ec4f0360 100644
--- a/src/runtime/object.cc
+++ b/src/runtime/object.cc
@@ -170,10 +170,17 @@ class TypeContext {
void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
+ // expected child slots compute the expected slots
+ // based on the current child slot setting
+ std::vector<int> expected_child_slots(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
+ if (static_cast<uint32_t>(expected_child_slots[it->index] + 1) <
it->num_slots) {
+ expected_child_slots[it->index] = it->num_slots - 1;
+ }
+ expected_child_slots[it->parent_index] +=
expected_child_slots[it->index] + 1;
}
}
@@ -182,7 +189,8 @@ class TypeContext {
std::cerr << '[' << info.index << "] " << info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
- << "\tnum_children=" << num_children[info.index] <<
std::endl;
+ << "\tnum_children=" << num_children[info.index]
+ << "\texpected_child_slots=" <<
expected_child_slots[info.index] << std::endl;
}
}
}