This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bcfd0afb5f [FFI] Provide Field Visit bridge so we can do gradual
transition (#18091)
bcfd0afb5f is described below
commit bcfd0afb5f62edf094d0d8e1e2122fb8fe906038
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Jun 25 09:51:46 2025 -0400
[FFI] Provide Field Visit bridge so we can do gradual transition (#18091)
This PR provides functions that adapts old VisitAttrs reflection utilities
to use new reflection mechanism when available.
These adapter would allow us to gradually transition the object
def from old VisitAttrs based mechanism to new mechanism.
- For all objects
- Replace VisitAttrs with static void RegisterReflection() that registers
the fields
- Call T::ReflectionDef() in TVM_STATIC_INIT_BLOCK in cc file
- For subclass of AttrsNode<T>: subclass AttrsNodeReflAdapter<T> instead
- Do the same steps as above and replace TVM_ATTRS
- Provide explicit declaration of _type_key and
TVM_FFI_DEFINE_FINAL_OBJECT_INFO
We will send followup PRs to do the gradual transition. Once all transition
is completed, we will remove AttrsVisitor and only go through the new
mechanism.
---
ffi/include/tvm/ffi/reflection/reflection.h | 30 ++++
include/tvm/ir/attrs.h | 61 ++++++++
include/tvm/relax/attrs/ccl.h | 64 +++++----
src/contrib/msc/core/ir/graph_builder.cc | 4 +-
src/contrib/msc/core/ir/graph_builder.h | 54 ++++++-
src/node/reflection.cc | 74 ++++++++--
src/node/serialization.cc | 155 +++++++++++++++++++--
src/node/structural_equal.cc | 6 +-
.../backend/contrib/codegen_json/codegen_json.h | 52 ++++++-
src/relax/op/ccl/ccl.cc | 6 +
src/script/printer/ir_docsifier.cc | 19 ++-
src/script/printer/relax/call.cc | 27 +++-
src/support/ffi_testing.cc | 25 ++--
tests/python/ir/test_ir_attrs.py | 12 +-
14 files changed, 521 insertions(+), 68 deletions(-)
diff --git a/ffi/include/tvm/ffi/reflection/reflection.h
b/ffi/include/tvm/ffi/reflection/reflection.h
index d53a4817ad..04d96857cb 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -404,6 +404,8 @@ inline Function GetMethod(std::string_view type_key, const
char* method_name) {
*/
template <typename Callback>
inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
+ using ResultType = decltype(callback(type_info->fields));
+ static_assert(std::is_same_v<ResultType, void>, "Callback must return void");
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
@@ -417,6 +419,34 @@ inline void ForEachFieldInfo(const TypeInfo* type_info,
Callback callback) {
}
}
+/*!
+ * \brief Visit each field info of the type info and run callback which
returns bool for early stop.
+ *
+ * \tparam Callback The callback function type, which returns bool for early
stop.
+ *
+ * \param type_info The type info.
+ * \param callback_with_early_stop The callback function.
+ * \return true if any of early stop is triggered.
+ *
+ * \note This function calls both the child and parent type info and can be
used for searching.
+ */
+template <typename Callback>
+inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info,
+ Callback callback_with_early_stop) {
+ // iterate through acenstors in parent to child order
+ // skip the first one since it is always the root object
+ for (int i = 1; i < type_info->type_depth; ++i) {
+ const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
+ for (int j = 0; j < parent_info->num_fields; ++j) {
+ if (callback_with_early_stop(parent_info->fields + j)) return true;
+ }
+ }
+ for (int i = 0; i < type_info->num_fields; ++i) {
+ if (callback_with_early_stop(type_info->fields + i)) return true;
+ }
+ return false;
+}
+
} // namespace reflection
} // namespace ffi
} // namespace tvm
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 6378d6f74a..a409822512 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -47,6 +47,7 @@
#include <dmlc/common.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
@@ -970,5 +971,65 @@ inline void BaseAttrsNode::PrintDocString(std::ostream&
os) const { // NOLINT(*
}
}
+/*!
+ * \brief Adapter for AttrsNode with the new reflection API.
+ *
+ * We will phaseout the old AttrsNode in future in favor of the new reflection
API.
+ * This adapter allows us to gradually migrate to the new reflection API.
+ *
+ * \tparam DerivedType The final attribute type.
+ */
+template <typename DerivedType>
+class AttrsNodeReflAdapter : public BaseAttrsNode {
+ public:
+ void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final
{
+ LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection
mechanism for init";
+ }
+ void VisitNonDefaultAttrs(AttrVisitor* v) final {
+ LOG(FATAL) << "`" << DerivedType::_type_key
+ << "` uses new reflection mechanism for visit non default
attrs";
+ }
+ void VisitAttrs(AttrVisitor* v) final {
+ LOG(FATAL) << "`" << DerivedType::_type_key
+ << "` uses new reflection mechanism for visit attrs";
+ }
+
+ bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
+ const TVMFFITypeInfo* type_info =
TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex());
+ bool success = true;
+ ffi::reflection::ForEachFieldInfoWithEarlyStop(
+ type_info, [&](const TVMFFIFieldInfo* field_info) {
+ ffi::reflection::FieldGetter field_getter(field_info);
+ ffi::Any field_value = field_getter(self());
+ ffi::Any other_field_value = field_getter(other);
+ if (!equal.AnyEqual(field_value, other_field_value)) {
+ success = false;
+ return true;
+ }
+ return false;
+ });
+ return success;
+ }
+
+ void SHashReduce(SHashReducer hash_reducer) const {
+ const TVMFFITypeInfo* type_info =
TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex());
+ ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
+ ffi::reflection::FieldGetter field_getter(field_info);
+ ffi::Any field_value = field_getter(self());
+ hash_reducer(field_value);
+ });
+ }
+
+ Array<AttrFieldInfo> ListFieldInfo() const final {
+ // use the new reflection to list field info
+ return Array<AttrFieldInfo>();
+ }
+
+ private:
+ DerivedType* self() const {
+ return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
+ }
+};
+
} // namespace tvm
#endif // TVM_IR_ATTRS_H_
diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index de043f92be..ffc8301e9f 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -24,54 +24,70 @@
#ifndef TVM_RELAX_ATTRS_CCL_H_
#define TVM_RELAX_ATTRS_CCL_H_
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/relax/expr.h>
namespace tvm {
namespace relax {
/*! \brief Attributes used in allreduce operators */
-struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
+struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
String op_type;
bool in_group;
- TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
- TVM_ATTR_FIELD(op_type).describe(
- "The type of reduction operation to be applied to the input data. Now
only sum is "
- "supported.");
- TVM_ATTR_FIELD(in_group).describe(
- "Whether the reduction operation performs in group or globally or in
group as default.");
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<AllReduceAttrs>()
+ .def_ro("op_type", &AllReduceAttrs::op_type,
+ "The type of reduction operation to be applied to the input
data. Now only sum is "
+ "supported.")
+ .def_ro("in_group", &AllReduceAttrs::in_group,
+ "Whether the reduction operation performs in group or globally
or in group as "
+ "default.");
}
+
+ static constexpr const char* _type_key = "relax.attrs.AllReduceAttrs";
+ TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllReduceAttrs, BaseAttrsNode);
}; // struct AllReduceAttrs
/*! \brief Attributes used in allgather operators */
-struct AllGatherAttrs : public tvm::AttrsNode<AllGatherAttrs> {
+struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
int num_workers;
bool in_group;
- TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") {
- TVM_ATTR_FIELD(num_workers)
- .describe(
- "The number of workers, also the number of parts the given buffer
should be chunked "
- "into.");
- TVM_ATTR_FIELD(in_group).describe(
- "Whether the allgather operation performs in group or globally or in
group as default.");
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<AllGatherAttrs>()
+ .def_ro("num_workers", &AllGatherAttrs::num_workers,
+ "The number of workers, also the number of parts the given
buffer should be "
+ "chunked into.")
+ .def_ro("in_group", &AllGatherAttrs::in_group,
+ "Whether the allgather operation performs in group or globally
or in group as "
+ "default.");
}
+
+ static constexpr const char* _type_key = "relax.attrs.AllGatherAttrs";
+ TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllGatherAttrs, BaseAttrsNode);
}; // struct AllGatherAttrs
/*! \brief Attributes used in scatter operators */
-struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
+struct ScatterCollectiveAttrs : public
tvm::AttrsNodeReflAdapter<ScatterCollectiveAttrs> {
int num_workers;
int axis;
- TVM_DECLARE_ATTRS(ScatterCollectiveAttrs,
"relax.attrs.ScatterCollectiveAttrs") {
- TVM_ATTR_FIELD(num_workers)
- .describe(
- "The number of workers, also the number of parts the given buffer
should be chunked "
- "into.");
- TVM_ATTR_FIELD(axis).describe(
- "The axis of the tensor to be scattered. The tensor will be chunked
along "
- "this axis.");
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<ScatterCollectiveAttrs>()
+ .def_ro("num_workers", &ScatterCollectiveAttrs::num_workers,
+ "The number of workers, also the number of parts the given
buffer should be "
+ "chunked into.")
+ .def_ro("axis", &ScatterCollectiveAttrs::axis,
+ "The axis of the tensor to be scattered. The tensor will be
chunked along "
+ "this axis.");
}
+
+ static constexpr const char* _type_key =
"relax.attrs.ScatterCollectiveAttrs";
+ TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterCollectiveAttrs, BaseAttrsNode);
}; // struct ScatterCollectiveAttrs
} // namespace relax
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index 2550f5652f..69903b26a9 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -50,7 +50,7 @@ void FuncAttrGetter::VisitExpr_(const CallNode* op) {
if (op->attrs.defined()) {
Map<String, String> attrs;
AttrGetter getter(&attrs);
- const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
+ getter(op->attrs);
for (const auto& pair : attrs) {
if (attrs_.count(pair.first)) {
int cnt = 1;
@@ -350,7 +350,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr,
const Optional<Expr>& bin
attrs = FuncAttrGetter().GetAttrs(call_node->op);
} else if (call_node->attrs.defined()) {
AttrGetter getter(&attrs);
- const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
+ getter(call_node->attrs);
}
} else if (const auto* const_node = expr.as<ConstantNode>()) {
if (const_node->is_scalar()) {
diff --git a/src/contrib/msc/core/ir/graph_builder.h
b/src/contrib/msc/core/ir/graph_builder.h
index 4eac043497..b2689ee7b7 100644
--- a/src/contrib/msc/core/ir/graph_builder.h
+++ b/src/contrib/msc/core/ir/graph_builder.h
@@ -25,6 +25,7 @@
#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_
#include <dmlc/json.h>
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/runtime/ndarray.h>
@@ -106,7 +107,7 @@ struct MSCRBuildConfig {
}
};
-class AttrGetter : public AttrVisitor {
+class AttrGetter : private AttrVisitor {
public:
/*!
* \brief Get the attributes as Map<String, String>
@@ -114,6 +115,57 @@ class AttrGetter : public AttrVisitor {
*/
explicit AttrGetter(Map<String, String>* attrs) : attrs_(attrs) {}
+ void operator()(const Attrs& attrs) {
+ // dispatch between new reflection and old reflection
+ const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index());
+ if (attrs_tinfo->extra_info != nullptr) {
+ tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const
TVMFFIFieldInfo* field_info) {
+ Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs);
+ this->VisitAny(String(field_info->name), field_value);
+ });
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ const_cast<BaseAttrsNode*>(attrs.get())->VisitAttrs(this);
+ }
+ }
+
+ private:
+ void VisitAny(String key, Any value) {
+ switch (value.type_index()) {
+ case kTVMFFINone: {
+ attrs_->Set(key, "");
+ break;
+ }
+ case kTVMFFIBool: {
+ attrs_->Set(key, std::to_string(value.cast<bool>()));
+ break;
+ }
+ case kTVMFFIInt: {
+ attrs_->Set(key, std::to_string(value.cast<int64_t>()));
+ break;
+ }
+ case kTVMFFIFloat: {
+ attrs_->Set(key, std::to_string(value.cast<double>()));
+ break;
+ }
+ case kTVMFFIDataType: {
+ attrs_->Set(key,
runtime::DLDataTypeToString(value.cast<DLDataType>()));
+ }
+ case kTVMFFIStr: {
+ attrs_->Set(key, value.cast<String>());
+ break;
+ }
+ default: {
+ if (value.type_index() >= kTVMFFIStaticObjectBegin) {
+ attrs_->Set(key, StringUtils::ToString(value.cast<ObjectRef>()));
+ } else {
+ LOG(FATAL) << "Unsupported type: " << value.type_index();
+ }
+ break;
+ }
+ }
+ }
+
void Visit(const char* key, double* value) final { attrs_->Set(key,
std::to_string(*value)); }
void Visit(const char* key, int64_t* value) final { attrs_->Set(key,
std::to_string(*value)); }
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index 2290403d37..70e94b0440 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -22,6 +22,7 @@
* \file node/reflection.cc
*/
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/attrs.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
@@ -104,8 +105,22 @@ ffi::Any ReflectionVTable::GetAttr(Object* self, const
String& field_name) const
ret = self->GetTypeKey();
success = true;
} else if (!self->IsInstance<DictAttrsNode>()) {
- VisitAttrs(self, &getter);
- success = getter.found_ref_object || ret != nullptr;
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index());
+ success = false;
+ // use new reflection mechanism
+ if (type_info->extra_info != nullptr) {
+ ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
+ if (field_name.compare(field_info->name) == 0) {
+ ffi::reflection::FieldGetter field_getter(field_info);
+ ret = field_getter(self);
+ success = true;
+ }
+ });
+ } else {
+ // legacy reflection mechanism, will be phased out in the future
+ VisitAttrs(self, &getter);
+ success = getter.found_ref_object || ret != nullptr;
+ }
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
@@ -149,7 +164,16 @@ std::vector<std::string>
ReflectionVTable::ListAttrNames(Object* self) const {
dir.names = &names;
if (!self->IsInstance<DictAttrsNode>()) {
- VisitAttrs(self, &dir);
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index());
+ if (type_info->extra_info != nullptr) {
+ // use new reflection mechanism
+ ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
+ names.push_back(std::string(field_info->name.data,
field_info->name.size));
+ });
+ } else {
+ // legacy reflection mechanism, will be phased out in the future
+ VisitAttrs(self, &dir);
+ }
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
@@ -288,8 +312,20 @@ void NodeListAttrNames(ffi::PackedArgs args, ffi::Any*
ret) {
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) {
+ // dispatch between new reflection and old reflection
auto type_key = args[0].cast<std::string>();
- *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1));
+ int32_t type_index;
+ TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(),
type_key.size()};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
+ if (type_info->extra_info != nullptr) {
+ auto fcreate_object =
ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
+ fcreate_object.CallPacked(args, rv);
+ return;
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1));
+ }
}
TVM_FFI_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr);
@@ -332,13 +368,31 @@ class GetAttrKeyByAddressVisitor : public AttrVisitor {
} // anonymous namespace
Optional<String> GetAttrKeyByAddress(const Object* object, const void*
attr_address) {
- GetAttrKeyByAddressVisitor visitor(attr_address);
- ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object),
&visitor);
- const char* key = visitor.GetKey();
- if (key == nullptr) {
- return std::nullopt;
+ // NOTE: reflection dispatch for both new and legacy reflection mechanism
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(object->type_index());
+ if (tinfo->extra_info != nullptr) {
+ Optional<String> result;
+ // visit fields with the new reflection
+ ffi::reflection::ForEachFieldInfoWithEarlyStop(tinfo, [&](const
TVMFFIFieldInfo* field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(object);
+ const void* field_addr = reinterpret_cast<const char*>(object) +
field_info->offset;
+ if (field_addr == attr_address) {
+ result = String(field_info->name);
+ return true;
+ }
+ return false;
+ });
+ return result;
} else {
- return String(key);
+ // TODO(tvm-team): remove this path once all objects are transitioned to
the new reflection
+ GetAttrKeyByAddressVisitor visitor(attr_address);
+ ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object),
&visitor);
+ const char* key = visitor.GetKey();
+ if (key == nullptr) {
+ return std::nullopt;
+ } else {
+ return String(key);
+ }
}
}
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 986a2d0445..08fc32ad3a 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -24,6 +24,7 @@
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/attrs.h>
#include <tvm/node/reflection.h>
#include <tvm/node/serialization.h>
@@ -62,7 +63,7 @@ inline std::string Base64Encode(std::string s) {
}
// indexer to index all the nodes
-class NodeIndexer : public AttrVisitor {
+class NodeIndexer : private AttrVisitor {
public:
std::unordered_map<Any, size_t, ffi::AnyHash, ffi::AnyEqual>
node_index_{{Any(nullptr), 0}};
std::vector<Any> node_list_{Any(nullptr)};
@@ -133,10 +134,26 @@ class NodeIndexer : public AttrVisitor {
Object* n = const_cast<Object*>(opt_object.value());
// if the node already have repr bytes, no need to visit Attrs.
if (!reflection_->GetReprBytes(n, nullptr)) {
- reflection_->VisitAttrs(n, this);
+ this->VisitObjectFields(n);
}
}
}
+
+ void VisitObjectFields(Object* obj) {
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
+ if (tinfo->extra_info != nullptr) {
+ ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
+ // only make index for ObjectRef
+ if (field_value.as<Object>()) {
+ this->MakeIndex(field_value);
+ }
+ });
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ reflection_->VisitAttrs(obj, this);
+ }
+ }
};
// use map so attributes are ordered.
@@ -211,7 +228,7 @@ struct JSONNode {
// Helper class to populate the json node
// using the existing index.
-class JSONAttrGetter : public AttrVisitor {
+class JSONAttrGetter : private AttrVisitor {
public:
const std::unordered_map<Any, size_t, ffi::AnyHash, ffi::AnyEqual>*
node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
@@ -296,7 +313,7 @@ class JSONAttrGetter : public AttrVisitor {
// do not need to print additional things once we have repr bytes.
if (!reflection_->GetReprBytes(n, &(node_->repr_bytes))) {
// recursively index normal object.
- reflection_->VisitAttrs(n, this);
+ this->VisitObjectFields(n);
}
} else {
// handling primitive types
@@ -327,9 +344,59 @@ class JSONAttrGetter : public AttrVisitor {
}
}
}
+
+ void VisitObjectFields(Object* obj) {
+ // dispatch between new reflection and old reflection
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
+ if (tinfo->extra_info != nullptr) {
+ ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
+ String field_name(field_info->name);
+ switch (field_value.type_index()) {
+ case ffi::TypeIndex::kTVMFFINone: {
+ node_->attrs[field_name] = "null";
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIBool:
+ case ffi::TypeIndex::kTVMFFIInt: {
+ int64_t value = field_value.cast<int64_t>();
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIFloat: {
+ double value = field_value.cast<double>();
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIDataType: {
+ DataType value(field_value.cast<DLDataType>());
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFINDArray: {
+ runtime::NDArray value = field_value.cast<runtime::NDArray>();
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ default: {
+ if (field_value.type_index() >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ ObjectRef obj = field_value.cast<ObjectRef>();
+ this->Visit(field_info->name.data, &obj);
+ break;
+ } else {
+ LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey();
+ }
+ }
+ }
+ });
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ reflection_->VisitAttrs(obj, this);
+ }
+ }
};
-class FieldDependencyFinder : public AttrVisitor {
+class FieldDependencyFinder : private AttrVisitor {
public:
JSONNode* jnode_;
ReflectionVTable* reflection_ = ReflectionVTable::Global();
@@ -385,14 +452,31 @@ class FieldDependencyFinder : public AttrVisitor {
jnode_ = jnode;
if (auto opt_object = node.as<const Object*>()) {
Object* n = const_cast<Object*>(opt_object.value());
- reflection_->VisitAttrs(n, this);
+ this->VisitObjectFields(n);
+ }
+ }
+
+ void VisitObjectFields(Object* obj) {
+ // dispatch between new reflection and old reflection
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
+ if (tinfo->extra_info != nullptr) {
+ ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
+ if (auto opt_object = field_value.as<ObjectRef>()) {
+ ObjectRef obj = *std::move(opt_object);
+ this->Visit(field_info->name.data, &obj);
+ }
+ });
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ reflection_->VisitAttrs(obj, this);
}
}
};
// Helper class to set the attributes of a node
// from given json node.
-class JSONAttrSetter : public AttrVisitor {
+class JSONAttrSetter : private AttrVisitor {
public:
const std::vector<Any>* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
@@ -543,7 +627,62 @@ class JSONAttrSetter : public AttrVisitor {
if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(n,
nullptr)) {
return;
}
- reflection_->VisitAttrs(n, this);
+ this->SetObjectFields(n);
+ }
+ }
+
+ void SetObjectFields(Object* obj) {
+ // dispatch between new reflection and old reflection
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
+ if (tinfo->extra_info != nullptr) {
+ ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
+ this->SetObjectField(obj, field_info);
+ });
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ reflection_->VisitAttrs(obj, this);
+ }
+ }
+
+ void SetObjectField(Object* obj, const TVMFFIFieldInfo* field_info) {
+ ffi::reflection::FieldSetter setter(field_info);
+ switch (field_info->field_static_type_index) {
+ case ffi::TypeIndex::kTVMFFIBool:
+ case ffi::TypeIndex::kTVMFFIInt: {
+ Optional<int64_t> value;
+ this->Visit(field_info->name.data, &value);
+ setter(obj, value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIFloat: {
+ Optional<double> value;
+ this->Visit(field_info->name.data, &value);
+ setter(obj, value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIDataType: {
+ DataType value;
+ this->Visit(field_info->name.data, &value);
+ setter(obj, value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFINDArray: {
+ runtime::NDArray value;
+ this->Visit(field_info->name.data, &value);
+ setter(obj, value);
+ break;
+ }
+ default: {
+ if (field_info->field_static_type_index >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ ObjectRef value;
+ this->Visit(field_info->name.data, &value);
+ setter(obj, value);
+ break;
+ } else {
+ LOG(FATAL) << "Unsupported type: " <<
field_info->field_static_type_index;
+ }
+ }
}
}
};
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 6b19fb5355..d1163269a8 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -193,9 +193,13 @@ bool SEqualReducer::AnyEqual(const ffi::Any& lhs, const
ffi::Any& rhs,
if (paths) {
return operator()(lhs.cast<ObjectRef>(), rhs.cast<ObjectRef>(),
paths.value());
} else {
- return operator()(lhs.cast<ObjectRef>(), rhs.cast<ObjectRef>());
+ ObjectRef lhs_obj = lhs.cast<ObjectRef>();
+ ObjectRef rhs_obj = rhs.cast<ObjectRef>();
+ bool result = operator()(lhs_obj, rhs_obj);
+ return result;
}
}
+
if (ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs)->v_uint64 ==
ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs)->v_uint64) {
return true;
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h
b/src/relax/backend/contrib/codegen_json/codegen_json.h
index f7df28bf71..9aa693f58b 100644
--- a/src/relax/backend/contrib/codegen_json/codegen_json.h
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -26,6 +26,7 @@
#include <dmlc/any.h>
#include <dmlc/json.h>
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/node/reflection.h>
#include <tvm/relax/struct_info.h>
#include <tvm/tir/op.h>
@@ -56,7 +57,7 @@ using JSONGraphObjectPtr = std::shared_ptr<JSONGraphNode>;
* \brief Helper class to extract all attributes of a certain op and save them
* into text format.
*/
-class OpAttrExtractor : public AttrVisitor {
+class OpAttrExtractor : private AttrVisitor {
public:
explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {}
@@ -150,11 +151,58 @@ class OpAttrExtractor : public AttrVisitor {
void Extract(Object* node) {
if (node) {
- reflection_->VisitAttrs(node, this);
+ this->VisitObjectFields(node);
}
}
private:
+ void VisitObjectFields(Object* obj) {
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
+ if (tinfo->extra_info != nullptr) {
+ ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
+ switch (field_value.type_index()) {
+ case ffi::TypeIndex::kTVMFFINone: {
+ SetNodeAttr(field_info->name.data, {""});
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIBool:
+ case ffi::TypeIndex::kTVMFFIInt: {
+ int64_t value = field_value.cast<int64_t>();
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIFloat: {
+ double value = field_value.cast<double>();
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIDataType: {
+ DataType value(field_value.cast<DLDataType>());
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFINDArray: {
+ runtime::NDArray value = field_value.cast<runtime::NDArray>();
+ this->Visit(field_info->name.data, &value);
+ break;
+ }
+ default: {
+ if (field_value.type_index() >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ ObjectRef obj = field_value.cast<ObjectRef>();
+ this->Visit(field_info->name.data, &obj);
+ break;
+ }
+ LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey();
+ }
+ }
+ });
+ } else {
+ // TODO(tvm-team): remove this once all objects are transitioned to the
new reflection
+ reflection_->VisitAttrs(obj, this);
+ }
+ }
+
JSONGraphObjectPtr node_;
ReflectionVTable* reflection_ = ReflectionVTable::Global();
};
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 2f6314221e..c73cf672ab 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -27,6 +27,12 @@ namespace relax {
/* relax.ccl.allreduce */
TVM_REGISTER_NODE_TYPE(AllReduceAttrs);
+TVM_FFI_STATIC_INIT_BLOCK({
+ AllReduceAttrs::RegisterReflection();
+ AllGatherAttrs::RegisterReflection();
+ ScatterCollectiveAttrs::RegisterReflection();
+});
+
Expr allreduce(Expr x, String op_type, bool in_group) {
ObjectPtr<AllReduceAttrs> attrs = make_object<AllReduceAttrs>();
attrs->op_type = std::move(op_type);
diff --git a/src/script/printer/ir_docsifier.cc
b/src/script/printer/ir_docsifier.cc
index 8c72eb4ef3..d906c1baf5 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -17,6 +17,8 @@
* under the License.
*/
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
+#include <tvm/node/reflection.h>
#include <tvm/runtime/logging.h>
#include <tvm/script/printer/ir_docsifier.h>
@@ -104,9 +106,22 @@ void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root,
ffi::TypedFunction<bool(ObjectRef)>
is_var) {
- class Visitor : public AttrVisitor {
+ class Visitor : private AttrVisitor {
public:
- inline void operator()(ObjectRef obj) { Visit("", &obj); }
+ void operator()(ObjectRef obj) {
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
+ if (tinfo->extra_info != nullptr) {
+ // visit fields with the new reflection
+ ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo*
field_info) {
+ Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
+ this->RecursiveVisitAny(&field_value);
+ });
+ } else {
+ // NOTE: legacy VisitAttrs mechanism
+ // TODO(tvm-team): remove this once all objects are transitioned to
the new reflection
+ this->Visit("", &obj);
+ }
+ }
private:
void RecursiveVisitAny(ffi::Any* value) {
diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc
index 82c2083044..3dd6cab052 100644
--- a/src/script/printer/relax/call.cc
+++ b/src/script/printer/relax/call.cc
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/distributed/struct_info.h>
@@ -25,12 +26,30 @@ namespace tvm {
namespace script {
namespace printer {
-class AttrPrinter : public tvm::AttrVisitor {
+class AttrPrinter : private AttrVisitor {
public:
explicit AttrPrinter(ObjectPath p, const IRDocsifier& d, Array<String>* keys,
Array<ExprDoc>* values)
: p(std::move(p)), d(d), keys(keys), values(values) {}
+ void operator()(const tvm::Attrs& attrs) {
+ // NOTE: reflection dispatch for both new and legacy reflection mechanism
+ const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index());
+ if (attrs_tinfo->extra_info != nullptr) {
+ LOG(INFO) << "Using new reflection to print attrs" <<
String(attrs_tinfo->type_key);
+ // new printing mechanism using the new reflection
+ ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const
TVMFFIFieldInfo* field_info) {
+ String field_name = String(field_info->name);
+ Any field_value = ffi::reflection::FieldGetter(field_info)(attrs);
+ keys->push_back(field_name);
+ values->push_back(d->AsDoc<ExprDoc>(field_value, p->Attr(field_name)));
+ });
+ } else {
+ const_cast<BaseAttrsNode*>(attrs.get())->VisitAttrs(this);
+ }
+ }
+
+ private:
void Visit(const char* key, double* value) final {
keys->push_back(key);
values->push_back(LiteralDoc::Float(*value, p->Attr(key)));
@@ -235,8 +254,7 @@ Optional<ExprDoc> PrintHintOnDevice(const relax::Call& n,
const ObjectPath& n_p,
Array<ExprDoc> kwargs_values;
ICHECK(n->attrs.defined());
if (n->attrs.as<relax::HintOnDeviceAttrs>()) {
- AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values);
- const_cast<BaseAttrsNode*>(n->attrs.get())->VisitAttrs(&printer);
+ AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs);
args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values));
}
return Relax(d, "hint_on_device")->Call(args);
@@ -355,8 +373,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
d->AsDoc<ExprDoc>(kv.second,
n_p->Attr("attrs")->Attr(kv.first)));
}
} else {
- AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys,
&kwargs_values);
- const_cast<BaseAttrsNode*>(n->attrs.get())->VisitAttrs(&printer);
+ AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys,
&kwargs_values)(n->attrs);
}
}
// Step 4. Print type_args
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index d0d9a35db8..106e7f985b 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/runtime/module.h>
@@ -34,23 +35,29 @@
namespace tvm {
// Attrs used to python API
-struct TestAttrs : public AttrsNode<TestAttrs> {
+struct TestAttrs : public AttrsNodeReflAdapter<TestAttrs> {
int axis;
String name;
Array<PrimExpr> padding;
TypedEnvFunc<int(int)> func;
- TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
-
TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe(
- "axis field");
- TVM_ATTR_FIELD(name).describe("name");
- TVM_ATTR_FIELD(padding).describe("padding of
input").set_default(Array<PrimExpr>({0, 0}));
- TVM_ATTR_FIELD(func)
- .describe("some random env function")
- .set_default(TypedEnvFunc<int(int)>(nullptr));
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TestAttrs>()
+ .def_ro("axis", &TestAttrs::axis, "axis field", refl::DefaultValue(10))
+ .def_ro("name", &TestAttrs::name, "name")
+ .def_ro("padding", &TestAttrs::padding, "padding of input",
+ refl::DefaultValue(Array<PrimExpr>({0, 0})))
+ .def_ro("func", &TestAttrs::func, "some random env function",
+ refl::DefaultValue(TypedEnvFunc<int(int)>(nullptr)));
}
+
+ static constexpr const char* _type_key = "attrs.TestAttrs";
+ TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestAttrs, BaseAttrsNode);
};
+TVM_FFI_STATIC_INIT_BLOCK({ TestAttrs::RegisterReflection(); });
+
TVM_REGISTER_NODE_TYPE(TestAttrs);
TVM_FFI_REGISTER_GLOBAL("testing.GetShapeSize").set_body_typed([](ffi::Shape
shape) {
diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py
index ce8ac3e4ba..d61538ac25 100644
--- a/tests/python/ir/test_ir_attrs.py
+++ b/tests/python/ir/test_ir_attrs.py
@@ -20,18 +20,22 @@ import tvm.ir._ffi_api
def test_make_attrs():
- with pytest.raises(AttributeError):
+ with pytest.raises(TypeError):
x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
- with pytest.raises(AttributeError):
- x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
-
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
+ x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
+ y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 5))
+ z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 5))
+ assert not tvm.ir.structural_equal(x, y)
+ assert tvm.ir.structural_equal(x, x)
+ assert tvm.ir.structural_equal(y, z)
+
def test_dict_attrs():
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,
0))