This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a40a1407b6 [FFI][REFACTOR] Refactor AccessPath to enable full tree
repr (#18191)
a40a1407b6 is described below
commit a40a1407b6998b1c8e3072586e697bc1afde6f2e
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Aug 6 22:00:06 2025 -0400
[FFI][REFACTOR] Refactor AccessPath to enable full tree repr (#18191)
This PR refactors AccessPath so it can be used to represent full tree with
compact memory.
Also fixes a bug in thec cython method export
---
ffi/CMakeLists.txt | 2 +-
ffi/include/tvm/ffi/c_api.h | 2 +-
ffi/include/tvm/ffi/reflection/access_path.h | 297 +++++++++++++++++++++-
ffi/include/tvm/ffi/reflection/registry.h | 24 +-
ffi/src/ffi/extra/reflection_extra.cc | 144 +++++++++++
ffi/src/ffi/extra/structural_equal.cc | 15 +-
ffi/src/ffi/object.cc | 80 ------
ffi/src/ffi/reflection/access_path.cc | 34 ---
ffi/tests/cpp/extra/test_structural_equal_hash.cc | 72 ++----
ffi/tests/cpp/test_reflection.cc | 104 ++++++++
python/tvm/ffi/__init__.py | 4 +
python/tvm/ffi/access_path.py | 179 +++++++++++++
python/tvm/ffi/cython/function.pxi | 13 +-
python/tvm/ffi/cython/object.pxi | 2 +-
src/node/structural_equal.cc | 4 +-
tests/python/ffi/test_access_path.py | 133 ++++++++++
16 files changed, 901 insertions(+), 208 deletions(-)
diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt
index 55fbd1c1bc..af9943476e 100644
--- a/ffi/CMakeLists.txt
+++ b/ffi/CMakeLists.txt
@@ -59,7 +59,6 @@ set(tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
- "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
)
if (TVM_FFI_USE_EXTRA_CXX_API)
@@ -69,6 +68,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc"
)
endif()
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 11080a21f0..c8d46d4552 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -896,7 +896,7 @@ TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t
type_index);
#endif
//---------------------------------------------------------------
-// The following API defines static object field accessors
+// The following API defines static object attribute accessors
// for language bindings.
//
// They are defined in C++ inline functions for cleaner code.
diff --git a/ffi/include/tvm/ffi/reflection/access_path.h
b/ffi/include/tvm/ffi/reflection/access_path.h
index a4f40f485e..267cb76fc1 100644
--- a/ffi/include/tvm/ffi/reflection/access_path.h
+++ b/ffi/include/tvm/ffi/reflection/access_path.h
@@ -25,24 +25,31 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/error.h>
#include <tvm/ffi/reflection/registry.h>
+#include <vector>
+
namespace tvm {
namespace ffi {
namespace reflection {
enum class AccessKind : int32_t {
- kObjectField = 0,
+ kAttr = 0,
kArrayItem = 1,
kMapItem = 2,
// the following two are used for error reporting when
// the supposed access field is not available
- kArrayItemMissing = 3,
- kMapItemMissing = 4,
+ kAttrMissing = 3,
+ kArrayItemMissing = 4,
+ kMapItemMissing = 5,
};
+class AccessStep;
+
/*!
* \brief Represent a single step in object field, map key, array index access.
*/
@@ -59,16 +66,18 @@ class AccessStepObj : public Object {
*/
Any key;
+ // default constructor to enable auto-serialization
+ AccessStepObj() = default;
AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {}
- static void RegisterReflection() {
- namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<AccessStepObj>()
- .def_ro("kind", &AccessStepObj::kind)
- .def_ro("key", &AccessStepObj::key);
- }
+ /*!
+ * \brief Deep check if two steps are equal.
+ * \param other The other step to compare with.
+ * \return True if the two steps are equal, false otherwise.
+ */
+ inline bool StepEqual(const AccessStep& other) const;
- static constexpr const char* _type_key = "tvm.ffi.reflection.AccessStep";
+ static constexpr const char* _type_key = "ffi.reflection.AccessStep";
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindConstTreeNode;
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object);
};
@@ -82,8 +91,10 @@ class AccessStep : public ObjectRef {
public:
AccessStep(AccessKind kind, Any key) :
ObjectRef(make_object<AccessStepObj>(kind, key)) {}
- static AccessStep ObjectField(String field_name) {
- return AccessStep(AccessKind::kObjectField, field_name);
+ static AccessStep Attr(String field_name) { return
AccessStep(AccessKind::kAttr, field_name); }
+
+ static AccessStep AttrMissing(String field_name) {
+ return AccessStep(AccessKind::kAttrMissing, field_name);
}
static AccessStep ArrayItem(int64_t index) { return
AccessStep(AccessKind::kArrayItem, index); }
@@ -94,15 +105,273 @@ class AccessStep : public ObjectRef {
static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem,
key); }
- static AccessStep MapItemMissing(Any key) { return
AccessStep(AccessKind::kMapItemMissing, key); }
+ static AccessStep MapItemMissing(Any key = nullptr) {
+ return AccessStep(AccessKind::kMapItemMissing, key);
+ }
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef,
AccessStepObj);
};
-using AccessPath = Array<AccessStep>;
+inline bool AccessStepObj::StepEqual(const AccessStep& other) const {
+ return this->kind == other->kind && AnyEqual()(this->key, other->key);
+}
+
+// forward declaration
+class AccessPath;
+
+/*!
+ * \brief ObjectRef class of AccessPathObj.
+ *
+ * \sa AccessPathObj
+ */
+class AccessPathObj : public Object {
+ public:
+ /*!
+ * \brief The parent of the access path.
+ *
+ * This parent-pointing tree structure is more space efficient when
+ * representing multiple paths that share a common prefix.
+ *
+ * \note Empty for root.
+ */
+ Optional<ObjectRef> parent;
+ /*!
+ * \brief The current of the access path.
+ * \note Empty for root.
+ */
+ Optional<AccessStep> step;
+ /*!
+ * \brief The current depth of the access path, 0 for root
+ */
+ int32_t depth;
+
+ // default constructor to enable auto-serialization
+ AccessPathObj() = default;
+ /*!
+ * \brief Constructor for the access path.
+ * \param parent The parent of the access path.
+ * \param step The current step of the access path.
+ * \param depth The current depth of the access path.
+ */
+ AccessPathObj(Optional<ObjectRef> parent, Optional<AccessStep> step, int32_t
depth)
+ : parent(parent), step(step), depth(depth) {}
+
+ /*!
+ * \brief Get the parent of the access path.
+ * \return The parent of the access path.
+ */
+ inline Optional<AccessPath> GetParent() const;
+
+ /*!
+ * \brief Extend the access path with a new step.
+ * \param step The step to extend the access path with.
+ * \return The extended access path.
+ */
+ inline AccessPath Extend(AccessStep step) const;
+
+ /*!
+ * \brief Extend the access path with an object attribute access.
+ * \param field_name The name of the field to access.
+ * \return The extended access path.
+ */
+ inline AccessPath Attr(String field_name) const;
+
+ /*!
+ * \brief Extend the access path with an object attribute missing access.
+ * \param field_name The name of the field to access.
+ * \return The extended access path.
+ */
+ inline AccessPath AttrMissing(String field_name) const;
+
+ /*!
+ * \brief Extend the access path with an array item access.
+ * \param index The index of the array item to access.
+ * \return The extended access path.
+ */
+ inline AccessPath ArrayItem(int64_t index) const;
+
+ /*!
+ * \brief Extend the access path with an array item missing access.
+ * \param index The index of the array item to access.
+ * \return The extended access path.
+ */
+ inline AccessPath ArrayItemMissing(int64_t index) const;
+
+ /*!
+ * \brief Extend the access path with a map item access.
+ * \param key The key of the map item to access.
+ * \return The extended access path.
+ */
+ inline AccessPath MapItem(Any key) const;
+
+ /*!
+ * \brief Extend the access path with a map item missing access.
+ * \param key The key of the map item to access.
+ * \return The extended access path.
+ */
+ inline AccessPath MapItemMissing(Any key) const;
+
+ /*!
+ * \brief Get the array of steps that corresponds to the access path.
+ * \return The array of steps that corresponds to the access path.
+ */
+ inline Array<AccessStep> ToSteps() const;
+
+ /*!
+ * \brief Check if two paths are equal by deep comparing the steps.
+ * \param other The other path to compare with.
+ * \return True if the two paths are equal, false otherwise.
+ */
+ inline bool PathEqual(const AccessPath& other) const;
+
+ /*!
+ * \brief Check if this path is a prefix of another path.
+ * \param other The other path to compare with.
+ * \return True if this path is a prefix of the other path, false otherwise.
+ */
+ inline bool IsPrefixOf(const AccessPath& other) const;
+
+ static constexpr const char* _type_key = "ffi.reflection.AccessPath";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindConstTreeNode;
+ TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessPathObj, Object);
+
+ private:
+ static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) {
+ // fast path for same pointer
+ if (lhs == rhs) return true;
+ if (lhs->depth != rhs->depth) return false;
+ // do deep equality checks
+ while (lhs->parent.has_value()) {
+ TVM_FFI_ICHECK(rhs->parent.has_value());
+ TVM_FFI_ICHECK(lhs->step.has_value());
+ TVM_FFI_ICHECK(rhs->step.has_value());
+ if (!(*lhs->step)->StepEqual(*(rhs->step))) {
+ return false;
+ }
+ lhs = static_cast<const AccessPathObj*>(lhs->parent.get());
+ rhs = static_cast<const AccessPathObj*>(rhs->parent.get());
+ // fast path for same pointer
+ if (lhs == rhs) return true;
+ TVM_FFI_ICHECK(lhs != nullptr);
+ TVM_FFI_ICHECK(rhs != nullptr);
+ }
+ return true;
+ }
+};
+
+/*!
+ * \brief ObjectRef class of AccessPath.
+ *
+ * \sa AccessPathObj
+ */
+class AccessPath : public ObjectRef {
+ public:
+ /*!
+ * \brief Create an access path from an iterator range of steps.
+ * \param begin The beginning of the iterator range.
+ * \param end The end of the iterator range.
+ * \return The access path.
+ */
+ template <typename Iter>
+ static AccessPath FromSteps(Iter begin, Iter end) {
+ AccessPath path = AccessPath::Root();
+ for (Iter it = begin; it != end; ++it) {
+ path = path->Extend(*it);
+ }
+ return path;
+ }
+ /*!
+ * \brief Create an access path from an array of steps.
+ * \param steps The array of steps.
+ * \return The access path.
+ */
+ static AccessPath FromSteps(Array<AccessStep> steps) {
+ AccessPath path = AccessPath::Root();
+ for (AccessStep step : steps) {
+ path = path->Extend(step);
+ }
+ return path;
+ }
+
+ /*!
+ * \brief Create a root access path.
+ * \return The root access path.
+ */
+ static AccessPath Root() {
+ return AccessPath(make_object<AccessPathObj>(std::nullopt, std::nullopt,
0));
+ }
+
+ TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef,
AccessPathObj);
+};
+
using AccessPathPair = Tuple<AccessPath, AccessPath>;
+inline Optional<AccessPath> AccessPathObj::GetParent() const {
+ if (auto opt_parent = this->parent.as<AccessPath>()) {
+ return opt_parent;
+ }
+ return std::nullopt;
+}
+
+inline AccessPath AccessPathObj::Extend(AccessStep step) const {
+ return AccessPath(make_object<AccessPathObj>(GetRef<AccessPath>(this), step,
this->depth + 1));
+}
+
+inline AccessPath AccessPathObj::Attr(String field_name) const {
+ return this->Extend(AccessStep::Attr(field_name));
+}
+
+inline AccessPath AccessPathObj::AttrMissing(String field_name) const {
+ return this->Extend(AccessStep::AttrMissing(field_name));
+}
+
+inline AccessPath AccessPathObj::ArrayItem(int64_t index) const {
+ return this->Extend(AccessStep::ArrayItem(index));
+}
+
+inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const {
+ return this->Extend(AccessStep::ArrayItemMissing(index));
+}
+
+inline AccessPath AccessPathObj::MapItem(Any key) const {
+ return this->Extend(AccessStep::MapItem(key));
+}
+
+inline AccessPath AccessPathObj::MapItemMissing(Any key) const {
+ return this->Extend(AccessStep::MapItemMissing(key));
+}
+
+inline Array<AccessStep> AccessPathObj::ToSteps() const {
+ std::vector<AccessStep> reverse_steps;
+ reverse_steps.reserve(this->depth);
+ const AccessPathObj* current = this;
+ while (current->parent.has_value()) {
+ TVM_FFI_ICHECK(current->step.has_value());
+ reverse_steps.push_back(*(current->step));
+ current = static_cast<const AccessPathObj*>(current->parent.get());
+ TVM_FFI_ICHECK(current != nullptr);
+ }
+ return Array<AccessStep>(reverse_steps.rbegin(), reverse_steps.rend());
+}
+
+inline bool AccessPathObj::PathEqual(const AccessPath& other) const {
+ return PathEqual(this, other.get());
+}
+
+inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const {
+ if (this->depth > other->depth) {
+ return false;
+ }
+ const AccessPathObj* rhs_path = other.get();
+ while (rhs_path->depth > this->depth) {
+ TVM_FFI_ICHECK(rhs_path->parent.has_value());
+ rhs_path = static_cast<const AccessPathObj*>(rhs_path->parent.get());
+ }
+ return PathEqual(this, rhs_path);
+}
+
} // namespace reflection
} // namespace ffi
} // namespace tvm
+
#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_
diff --git a/ffi/include/tvm/ffi/reflection/registry.h
b/ffi/include/tvm/ffi/reflection/registry.h
index 14b49395d7..107a6e7759 100644
--- a/ffi/include/tvm/ffi/reflection/registry.h
+++ b/ffi/include/tvm/ffi/reflection/registry.h
@@ -198,7 +198,7 @@ class ReflectionDefBase {
}
}
- template <typename Class, typename Func>
+ template <typename Func>
TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
return ffi::Function::FromTyped(std::forward<Func>(func), name);
}
@@ -258,27 +258,12 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
- RegisterFunc(name, GetMethod_(std::string(name), std::forward<Func>(func)),
+ RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
std::forward<Extra>(extra)...);
return *this;
}
private:
- template <typename Func>
- TVM_FFI_INLINE static Function GetMethod_(std::string name, Func&& func) {
- return ffi::Function::FromTyped(std::forward<Func>(func), name);
- }
-
- template <typename Class, typename R, typename... Args>
- TVM_FFI_INLINE static Function GetMethod_(std::string name, R
(Class::*func)(Args...) const) {
- return GetMethod<Class>(std::string(name), func);
- }
-
- template <typename Class, typename R, typename... Args>
- TVM_FFI_INLINE static Function GetMethod_(std::string name, R
(Class::*func)(Args...)) {
- return GetMethod<Class>(std::string(name), func);
- }
-
template <typename... Extra>
void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) {
TVMFFIMethodInfo info;
@@ -434,8 +419,7 @@ class ObjectDef : public ReflectionDefBase {
info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
}
// obtain the method function
- Function method =
- GetMethod<Class>(std::string(type_key_) + "." + name,
std::forward<Func>(func));
+ Function method = GetMethod(std::string(type_key_) + "." + name,
std::forward<Func>(func));
info.method = AnyView(method).CopyToTVMFFIAny();
// apply method info traits
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
@@ -467,7 +451,7 @@ class TypeAttrDef : public ReflectionDefBase {
TypeAttrDef& def(const char* name, Func&& func) {
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
ffi::Function ffi_func =
- GetMethod<Class>(std::string(type_key_) + "." + name,
std::forward<Func>(func));
+ GetMethod(std::string(type_key_) + "." + name,
std::forward<Func>(func));
TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array,
&value_any));
return *this;
diff --git a/ffi/src/ffi/extra/reflection_extra.cc
b/ffi/src/ffi/extra/reflection_extra.cc
new file mode 100644
index 0000000000..698be63376
--- /dev/null
+++ b/ffi/src/ffi/extra/reflection_extra.cc
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/ffi/extra/reflection_extra.cc
+ *
+ * \brief Extra reflection registrations. *
+ */
+#include <tvm/ffi/reflection/access_path.h>
+#include <tvm/ffi/reflection/registry.h>
+
+namespace tvm {
+namespace ffi {
+namespace reflection {
+
+void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
+ int32_t type_index;
+ if (auto opt_type_index = args[0].try_cast<int32_t>()) {
+ type_index = *opt_type_index;
+ } else {
+ String type_key = args[0].cast<String>();
+ TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(),
type_key.size()};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array,
&type_index));
+ }
+
+ TVM_FFI_ICHECK(args.size() % 2 == 1);
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
+
+ if (type_info->metadata == nullptr || type_info->metadata->creator ==
nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
+ << "` does not support reflection creation";
+ }
+ TVMFFIObjectHandle handle;
+ TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
+ ObjectPtr<Object> ptr =
+
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+
+ std::vector<String> keys;
+ std::vector<bool> keys_found;
+
+ for (int i = 1; i < args.size(); i += 2) {
+ keys.push_back(args[i].cast<String>());
+ }
+ keys_found.resize(keys.size(), false);
+
+ auto search_field = [&](const TVMFFIByteArray& field_name) {
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if (keys_found[i]) continue;
+ if (keys[i].compare(field_name) == 0) {
+ return i;
+ }
+ }
+ return keys.size();
+ };
+
+ auto update_fields = [&](const TVMFFITypeInfo* tinfo) {
+ for (int i = 0; i < tinfo->num_fields; ++i) {
+ const TVMFFIFieldInfo* field_info = tinfo->fields + i;
+ size_t arg_index = search_field(field_info->name);
+ void* field_addr = reinterpret_cast<char*>(ptr.get()) +
field_info->offset;
+ if (arg_index < keys.size()) {
+ AnyView field_value = args[arg_index * 2 + 2];
+ field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&field_value));
+ keys_found[arg_index] = true;
+ } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
+ field_info->setter(field_addr, &(field_info->default_value));
+ } else {
+ TVM_FFI_THROW(TypeError) << "Required field `"
+ << String(field_info->name.data,
field_info->name.size)
+ << "` not set in type `" <<
TypeIndexToTypeKey(type_index) << "`";
+ }
+ }
+ };
+
+ // iterate through acenstors in parent to child order
+ // skip the first one since it is always the root object
+ for (int i = 1; i < type_info->type_depth; ++i) {
+ update_fields(type_info->type_acenstors[i]);
+ }
+ update_fields(type_info);
+
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if (!keys_found[i]) {
+ TVM_FFI_THROW(TypeError) << "Type `" << TypeIndexToTypeKey(type_index)
+ << "` does not have field `" << keys[i] << "`";
+ }
+ }
+ *ret = ObjectRef(ptr);
+}
+
+inline void AccessStepRegisterReflection() {
+ // register access step reflection here since it is only needed for bindings
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<AccessStepObj>()
+ .def_ro("kind", &AccessStepObj::kind)
+ .def_ro("key", &AccessStepObj::key);
+}
+
+inline void AccessPathRegisterReflection() {
+ // register access path reflection here since it is only needed for bindings
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<AccessPathObj>()
+ .def_ro("parent", &AccessPathObj::parent)
+ .def_ro("step", &AccessPathObj::step)
+ .def_ro("depth", &AccessPathObj::depth)
+ .def_static("_root", &AccessPath::Root)
+ .def("_extend", &AccessPathObj::Extend)
+ .def("_attr", &AccessPathObj::Attr)
+ .def("_array_item", &AccessPathObj::ArrayItem)
+ .def("_map_item", &AccessPathObj::MapItem)
+ .def("_attr_missing", &AccessPathObj::AttrMissing)
+ .def("_array_item_missing", &AccessPathObj::ArrayItemMissing)
+ .def("_map_item_missing", &AccessPathObj::MapItemMissing)
+ .def("_is_prefix_of", &AccessPathObj::IsPrefixOf)
+ .def("_to_steps", &AccessPathObj::ToSteps)
+ .def("_path_equal",
+ [](const AccessPath& self, const AccessPath& other) { return
self->PathEqual(other); });
+}
+
+TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ AccessStepRegisterReflection();
+ AccessPathRegisterReflection();
+ refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs",
MakeObjectFromPackedArgs);
+});
+
+} // namespace reflection
+} // namespace ffi
+} // namespace tvm
diff --git a/ffi/src/ffi/extra/structural_equal.cc
b/ffi/src/ffi/extra/structural_equal.cc
index 97ebbf4072..171fa2f750 100644
--- a/ffi/src/ffi/extra/structural_equal.cc
+++ b/ffi/src/ffi/extra/structural_equal.cc
@@ -185,9 +185,9 @@ class StructEqualHandler {
// record the first mismatching field if we sub-rountine compare
failed
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(
- reflection::AccessStep::ObjectField(String(field_info->name)));
+ reflection::AccessStep::Attr(String(field_info->name)));
mismatch_rhs_reverse_path_->emplace_back(
- reflection::AccessStep::ObjectField(String(field_info->name)));
+ reflection::AccessStep::Attr(String(field_info->name)));
}
// return true to indicate early stop
return true;
@@ -216,9 +216,9 @@ class StructEqualHandler {
if (mismatch_lhs_reverse_path_ != nullptr) {
String field_name_str = field_name.cast<String>();
mismatch_lhs_reverse_path_->emplace_back(
- reflection::AccessStep::ObjectField(field_name_str));
+ reflection::AccessStep::Attr(field_name_str));
mismatch_rhs_reverse_path_->emplace_back(
- reflection::AccessStep::ObjectField(field_name_str));
+ reflection::AccessStep::Attr(field_name_str));
}
}
return success;
@@ -420,8 +420,11 @@ Optional<reflection::AccessPathPair>
StructuralEqual::GetFirstMismatch(const Any
if (handler.CompareAny(lhs, rhs)) {
return std::nullopt;
}
- reflection::AccessPath lhs_path(lhs_reverse_path.rbegin(),
lhs_reverse_path.rend());
- reflection::AccessPath rhs_path(rhs_reverse_path.rbegin(),
rhs_reverse_path.rend());
+ using reflection::AccessPath;
+ reflection::AccessPath lhs_path =
+ AccessPath::FromSteps(lhs_reverse_path.rbegin(),
lhs_reverse_path.rend());
+ reflection::AccessPath rhs_path =
+ AccessPath::FromSteps(rhs_reverse_path.rbegin(),
rhs_reverse_path.rend());
return reflection::AccessPathPair(lhs_path, rhs_path);
}
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 374c0c7c4e..61107cb63f 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -385,86 +385,6 @@ class TypeTable {
Map<String, int64_t> type_attr_name_to_column_index_;
};
-void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
- int32_t type_index;
- if (auto opt_type_index = args[0].try_cast<int32_t>()) {
- type_index = *opt_type_index;
- } else {
- String type_key = args[0].cast<String>();
- TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(),
type_key.size()};
- TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array,
&type_index));
- }
-
- TVM_FFI_ICHECK(args.size() % 2 == 1);
- const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
-
- if (type_info->metadata == nullptr || type_info->metadata->creator ==
nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
- << "` does not support reflection creation";
- }
- TVMFFIObjectHandle handle;
- TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
- ObjectPtr<Object> ptr =
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
-
- std::vector<String> keys;
- std::vector<bool> keys_found;
-
- for (int i = 1; i < args.size(); i += 2) {
- keys.push_back(args[i].cast<String>());
- }
- keys_found.resize(keys.size(), false);
-
- auto search_field = [&](const TVMFFIByteArray& field_name) {
- for (size_t i = 0; i < keys.size(); ++i) {
- if (keys_found[i]) continue;
- if (keys[i].compare(field_name) == 0) {
- return i;
- }
- }
- return keys.size();
- };
-
- auto update_fields = [&](const TVMFFITypeInfo* tinfo) {
- for (int i = 0; i < tinfo->num_fields; ++i) {
- const TVMFFIFieldInfo* field_info = tinfo->fields + i;
- size_t arg_index = search_field(field_info->name);
- void* field_addr = reinterpret_cast<char*>(ptr.get()) +
field_info->offset;
- if (arg_index < keys.size()) {
- AnyView field_value = args[arg_index * 2 + 2];
- field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&field_value));
- keys_found[arg_index] = true;
- } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
- field_info->setter(field_addr, &(field_info->default_value));
- } else {
- TVM_FFI_THROW(TypeError) << "Required field `"
- << String(field_info->name.data,
field_info->name.size)
- << "` not set in type `" <<
TypeIndexToTypeKey(type_index) << "`";
- }
- }
- };
-
- // iterate through acenstors in parent to child order
- // skip the first one since it is always the root object
- for (int i = 1; i < type_info->type_depth; ++i) {
- update_fields(type_info->type_acenstors[i]);
- }
- update_fields(type_info);
-
- for (size_t i = 0; i < keys.size(); ++i) {
- if (!keys_found[i]) {
- TVM_FFI_THROW(TypeError) << "Type `" << TypeIndexToTypeKey(type_index)
- << "` does not have field `" << keys[i] << "`";
- }
- }
- *ret = ObjectRef(ptr);
-}
-
-TVM_FFI_STATIC_INIT_BLOCK({
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs",
MakeObjectFromPackedArgs);
-});
-
} // namespace ffi
} // namespace tvm
diff --git a/ffi/src/ffi/reflection/access_path.cc
b/ffi/src/ffi/reflection/access_path.cc
deleted file mode 100644
index 17b8abb062..0000000000
--- a/ffi/src/ffi/reflection/access_path.cc
+++ /dev/null
@@ -1,34 +0,0 @@
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*
- * \file src/ffi/reflection/access_path.cc
- */
-
-#include <tvm/ffi/reflection/access_path.h>
-
-namespace tvm {
-namespace ffi {
-namespace reflection {
-
-TVM_FFI_STATIC_INIT_BLOCK({ AccessStepObj::RegisterReflection(); });
-
-} // namespace reflection
-} // namespace ffi
-} // namespace tvm
diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc
b/ffi/tests/cpp/extra/test_structural_equal_hash.cc
index 8a377f4837..a05c50cc26 100644
--- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc
@@ -47,21 +47,23 @@ TEST(StructuralEqualHash, Array) {
// first directly interepret diff,
EXPECT_TRUE(diff_a_c.has_value());
- EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayItem);
- EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayItem);
- EXPECT_EQ((*diff_a_c).get<0>()[0]->key.cast<int64_t>(), 1);
- EXPECT_EQ((*diff_a_c).get<1>()[0]->key.cast<int64_t>(), 1);
- EXPECT_EQ((*diff_a_c).get<0>().size(), 1);
- EXPECT_EQ((*diff_a_c).get<1>().size(), 1);
+ auto lhs_steps = (*diff_a_c).get<0>()->ToSteps();
+ auto rhs_steps = (*diff_a_c).get<1>()->ToSteps();
+ EXPECT_EQ(lhs_steps[0]->kind, refl::AccessKind::kArrayItem);
+ EXPECT_EQ(rhs_steps[0]->kind, refl::AccessKind::kArrayItem);
+ EXPECT_EQ(lhs_steps[0]->key.cast<int64_t>(), 1);
+ EXPECT_EQ(rhs_steps[0]->key.cast<int64_t>(), 1);
+ EXPECT_EQ(lhs_steps.size(), 1);
+ EXPECT_EQ(rhs_steps.size(), 1);
// use structural equal for checking in future parts
// given we have done some basic checks above by directly interepret diff,
Array<int> d = {1, 2};
auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d);
- auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({
+ auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::FromSteps({
refl::AccessStep::ArrayItem(2),
}),
- refl::AccessPath({
+ refl::AccessPath::FromSteps({
refl::AccessStep::ArrayItemMissing(2),
}));
// then use structural equal to check it
@@ -80,12 +82,8 @@ TEST(StructuralEqualHash, Map) {
EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
- auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapItem("c"),
- }),
- refl::AccessPath({
-
refl::AccessStep::MapItem("c"),
- }));
+ auto expected_diff_a_c =
refl::AccessPathPair(refl::AccessPath::Root()->MapItem("c"),
+
refl::AccessPath::Root()->MapItem("c"));
EXPECT_TRUE(diff_a_c.has_value());
EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c));
}
@@ -101,35 +99,22 @@ TEST(StructuralEqualHash, NestedMapArray) {
EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
- auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapItem("b"),
-
refl::AccessStep::ArrayItem(1),
- }),
- refl::AccessPath({
-
refl::AccessStep::MapItem("b"),
-
refl::AccessStep::ArrayItem(1),
- }));
+ auto expected_diff_a_c =
+
refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b")->ArrayItem(1),
+
refl::AccessPath::Root()->MapItem("b")->ArrayItem(1));
EXPECT_TRUE(diff_a_c.has_value());
EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c));
Map<String, Array<Any>> d = {{"a", {1, 2, 3}}};
auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d);
- auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapItem("b"),
- }),
- refl::AccessPath({
-
refl::AccessStep::MapItemMissing("b"),
- }));
+ auto expected_diff_a_d =
refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b"),
+
refl::AccessPath::Root()->MapItemMissing("b"));
EXPECT_TRUE(diff_a_d.has_value());
EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d));
auto diff_d_a = StructuralEqual::GetFirstMismatch(d, a);
- auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapItemMissing("b"),
- }),
- refl::AccessPath({
-
refl::AccessStep::MapItem("b"),
- }));
+ auto expected_diff_d_a =
refl::AccessPathPair(refl::AccessPath::Root()->MapItemMissing("b"),
+
refl::AccessPath::Root()->MapItem("b"));
}
TEST(StructuralEqualHash, FreeVar) {
@@ -157,12 +142,12 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) {
EXPECT_FALSE(StructuralEqual()(fa, fc));
auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc);
- auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::ObjectField("body"),
+ auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath::FromSteps({
+
refl::AccessStep::Attr("body"),
refl::AccessStep::ArrayItem(1),
}),
- refl::AccessPath({
-
refl::AccessStep::ObjectField("body"),
+ refl::AccessPath::FromSteps({
+
refl::AccessStep::Attr("body"),
refl::AccessStep::ArrayItem(1),
}));
EXPECT_TRUE(diff_fa_fc.has_value());
@@ -183,14 +168,9 @@ TEST(StructuralEqualHash, CustomTreeNode) {
EXPECT_FALSE(StructuralEqual()(fa, fc));
auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc);
- auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::ObjectField("body"),
-
refl::AccessStep::ArrayItem(1),
- }),
- refl::AccessPath({
-
refl::AccessStep::ObjectField("body"),
-
refl::AccessStep::ArrayItem(1),
- }));
+ auto expected_diff_fa_fc =
+
refl::AccessPathPair(refl::AccessPath::Root()->Attr("body")->ArrayItem(1),
+
refl::AccessPath::Root()->Attr("body")->ArrayItem(1));
EXPECT_TRUE(diff_fa_fc.has_value());
EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index 98915c54e1..85da00c132 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -20,6 +20,7 @@
#include <gtest/gtest.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/creator.h>
#include <tvm/ffi/reflection/registry.h>
@@ -165,4 +166,107 @@ TEST(Reflection, ObjectCreator) {
refl::ObjectCreator creator("test.Int");
EXPECT_EQ(creator(Map<String, Any>({{"value", 1}})).cast<TInt>()->value, 1);
}
+
+TEST(Reflection, AccessPath) {
+ namespace refl = tvm::ffi::reflection;
+
+ // Test basic path construction and ToSteps()
+ refl::AccessPath path = refl::AccessPath::Root()->Attr("body")->ArrayItem(1);
+ auto steps = path->ToSteps();
+ EXPECT_EQ(steps.size(), 2);
+ EXPECT_EQ(steps[0]->kind, refl::AccessKind::kAttr);
+ EXPECT_EQ(steps[1]->kind, refl::AccessKind::kArrayItem);
+ EXPECT_EQ(steps[0]->key.cast<String>(), "body");
+ EXPECT_EQ(steps[1]->key.cast<int64_t>(), 1);
+
+ // Test PathEqual with identical paths
+ refl::AccessPath path2 =
refl::AccessPath::Root()->Attr("body")->ArrayItem(1);
+ EXPECT_TRUE(path->PathEqual(path2));
+ EXPECT_TRUE(path->IsPrefixOf(path2));
+
+ // Test PathEqual with different paths
+ refl::AccessPath path3 =
refl::AccessPath::Root()->Attr("body")->ArrayItem(2);
+ EXPECT_FALSE(path->PathEqual(path3));
+ EXPECT_FALSE(path->IsPrefixOf(path3));
+
+ // Test prefix relationship - path4 extends path, so path should be prefix
of path4
+ refl::AccessPath path4 =
refl::AccessPath::Root()->Attr("body")->ArrayItem(1)->Attr("body");
+ EXPECT_FALSE(path->PathEqual(path4)); // Not equal (different lengths)
+ EXPECT_TRUE(path->IsPrefixOf(path4)); // But path is a prefix of path4
+
+ // Test completely different paths
+ refl::AccessPath path5 =
refl::AccessPath::Root()->ArrayItem(0)->ArrayItem(1)->Attr("body");
+ EXPECT_FALSE(path->PathEqual(path5));
+ EXPECT_FALSE(path->IsPrefixOf(path5));
+
+ // Test Root path
+ refl::AccessPath root = refl::AccessPath::Root();
+ auto root_steps = root->ToSteps();
+ EXPECT_EQ(root_steps.size(), 0);
+ EXPECT_EQ(root->depth, 0);
+ EXPECT_TRUE(root->IsPrefixOf(path));
+ EXPECT_TRUE(root->IsPrefixOf(root));
+ EXPECT_TRUE(root->PathEqual(refl::AccessPath::Root()));
+
+ // Test depth calculations
+ EXPECT_EQ(path->depth, 2);
+ EXPECT_EQ(path4->depth, 3);
+ EXPECT_EQ(root->depth, 0);
+
+ // Test MapItem access
+ refl::AccessPath map_path =
refl::AccessPath::Root()->Attr("data")->MapItem("key1");
+ auto map_steps = map_path->ToSteps();
+ EXPECT_EQ(map_steps.size(), 2);
+ EXPECT_EQ(map_steps[0]->kind, refl::AccessKind::kAttr);
+ EXPECT_EQ(map_steps[1]->kind, refl::AccessKind::kMapItem);
+ EXPECT_EQ(map_steps[0]->key.cast<String>(), "data");
+ EXPECT_EQ(map_steps[1]->key.cast<String>(), "key1");
+
+ // Test MapItemMissing access
+ refl::AccessPath map_missing_path =
refl::AccessPath::Root()->MapItemMissing(42);
+ auto map_missing_steps = map_missing_path->ToSteps();
+ EXPECT_EQ(map_missing_steps.size(), 1);
+ EXPECT_EQ(map_missing_steps[0]->kind, refl::AccessKind::kMapItemMissing);
+ EXPECT_EQ(map_missing_steps[0]->key.cast<int64_t>(), 42);
+
+ // Test ArrayItemMissing access
+ refl::AccessPath array_missing_path =
refl::AccessPath::Root()->ArrayItemMissing(5);
+ auto array_missing_steps = array_missing_path->ToSteps();
+ EXPECT_EQ(array_missing_steps.size(), 1);
+ EXPECT_EQ(array_missing_steps[0]->kind, refl::AccessKind::kArrayItemMissing);
+ EXPECT_EQ(array_missing_steps[0]->key.cast<int64_t>(), 5);
+
+ // Test FromSteps static method - round trip conversion
+ auto original_steps = path->ToSteps();
+ refl::AccessPath reconstructed = refl::AccessPath::FromSteps(original_steps);
+ EXPECT_TRUE(path->PathEqual(reconstructed));
+ EXPECT_EQ(path->depth, reconstructed->depth);
+
+ // Test complex prefix relationships
+ refl::AccessPath short_path = refl::AccessPath::Root()->Attr("x");
+ refl::AccessPath medium_path =
refl::AccessPath::Root()->Attr("x")->ArrayItem(0);
+ refl::AccessPath long_path =
refl::AccessPath::Root()->Attr("x")->ArrayItem(0)->MapItem("z");
+
+ EXPECT_TRUE(short_path->IsPrefixOf(medium_path));
+ EXPECT_TRUE(short_path->IsPrefixOf(long_path));
+ EXPECT_TRUE(medium_path->IsPrefixOf(long_path));
+ EXPECT_FALSE(medium_path->IsPrefixOf(short_path));
+ EXPECT_FALSE(long_path->IsPrefixOf(medium_path));
+ EXPECT_FALSE(long_path->IsPrefixOf(short_path));
+
+ // Test non-prefix relationships
+ refl::AccessPath branch1 = refl::AccessPath::Root()->Attr("x")->ArrayItem(0);
+ refl::AccessPath branch2 = refl::AccessPath::Root()->Attr("x")->ArrayItem(1);
+ EXPECT_FALSE(branch1->IsPrefixOf(branch2));
+ EXPECT_FALSE(branch2->IsPrefixOf(branch1));
+ EXPECT_FALSE(branch1->PathEqual(branch2));
+
+ // Test GetParent functionality
+ auto parent = path4->GetParent();
+ EXPECT_TRUE(parent.has_value());
+ EXPECT_TRUE(parent.value()->PathEqual(path));
+
+ auto root_parent = root->GetParent();
+ EXPECT_FALSE(root_parent.has_value());
+}
} // namespace
diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py
index 43a20e751c..e615e22a0c 100644
--- a/python/tvm/ffi/__init__.py
+++ b/python/tvm/ffi/__init__.py
@@ -31,6 +31,7 @@ from .ndarray import cpu, cuda, rocm, opencl, metal, vpi,
vulkan, ext_dev, hexag
from .ndarray import from_dlpack, NDArray, Shape
from .container import Array, Map
from . import serialization
+from . import access_path
from . import testing
@@ -67,4 +68,7 @@ __all__ = [
"Shape",
"Array",
"Map",
+ "testing",
+ "access_path",
+ "serialization",
]
diff --git a/python/tvm/ffi/access_path.py b/python/tvm/ffi/access_path.py
new file mode 100644
index 0000000000..c4822074eb
--- /dev/null
+++ b/python/tvm/ffi/access_path.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Access path classes."""
+
+from enum import IntEnum
+from typing import List, Any
+from . import core
+from .registry import register_object
+
+
+class AccessKind(IntEnum):
+ ATTR = 0
+ ARRAY_ITEM = 1
+ MAP_ITEM = 2
+ ATTR_MISSING = 3
+ ARRAY_ITEM_MISSING = 4
+ MAP_ITEM_MISSING = 5
+
+
+@register_object("ffi.reflection.AccessStep")
+class AccessStep(core.Object):
+ """Access step container"""
+
+
+@register_object("ffi.reflection.AccessPath")
+class AccessPath(core.Object):
+ """Access path container"""
+
+ def __init__(self) -> None:
+ super().__init__()
+ raise ValueError(
+ "AccessPath can't be initialized directly. "
+ "Use AccessPath.root() to create a path to the root object"
+ )
+
+ @staticmethod
+ def root() -> "AccessPath":
+ """Create a root access path"""
+ return AccessPath._root()
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, AccessPath):
+ return False
+ return self._path_equal(other)
+
+ def __ne__(self, other: Any) -> bool:
+ if not isinstance(other, AccessPath):
+ return True
+ return not self._path_equal(other)
+
+ def is_prefix_of(self, other: "AccessPath") -> bool:
+ """Check if this access path is a prefix of another access path
+
+ Parameters
+ ----------
+ other : AccessPath
+ The access path to check if it is a prefix of this access path
+
+ Returns
+ -------
+ bool
+ True if this access path is a prefix of the other access path,
False otherwise
+ """
+ return self._is_prefix_of(other)
+
+ def attr(self, attr_key: str) -> "AccessPath":
+ """Create an access path to the attribute of the current object
+
+ Parameters
+ ----------
+ attr_key : str
+ The key of the attribute to access
+
+ Returns
+ -------
+ AccessPath
+ The extended access path
+ """
+ return self._attr(attr_key)
+
+ def attr_missing(self, attr_key: str) -> "AccessPath":
+ """Create an access path that indicate an attribute is missing
+
+ Parameters
+ ----------
+ attr_key : str
+ The key of the attribute to access
+
+ Returns
+ -------
+ AccessPath
+ The extended access path
+ """
+ return self._attr_missing(attr_key)
+
+ def array_item(self, index: int) -> "AccessPath":
+ """Create an access path to the item of the current array
+
+ Parameters
+ ----------
+ index : int
+ The index of the item to access
+
+ Returns
+ -------
+ AccessPath
+ The extended access path
+ """
+ return self._array_item(index)
+
+ def array_item_missing(self, index: int) -> "AccessPath":
+ """Create an access path that indicate an array item is missing
+
+ Parameters
+ ----------
+ index : int
+ The index of the item to access
+
+ Returns
+ -------
+ AccessPath
+ The extended access path
+ """
+ return self._array_item_missing(index)
+
+ def map_item(self, key: Any) -> "AccessPath":
+ """Create an access path to the item of the current map
+
+ Parameters
+ ----------
+ key : Any
+ The key of the item to access
+
+ Returns
+ -------
+ AccessPath
+ The extended access path
+ """
+ return self._map_item(key)
+
+ def map_item_missing(self, key: Any) -> "AccessPath":
+ """Create an access path that indicate a map item is missing
+
+ Parameters
+ ----------
+ key : Any
+ The key of the item to access
+
+ Returns
+ -------
+ AccessPath
+ The extended access path
+ """
+ return self._map_item_missing(key)
+
+ def to_steps(self) -> List["AccessStep"]:
+ """Convert the access path to a list of access steps
+
+ Returns
+ -------
+ List[AccessStep]
+ The list of access steps
+ """
+ return self._to_steps()
diff --git a/python/tvm/ffi/cython/function.pxi
b/python/tvm/ffi/cython/function.pxi
index 8c9df19642..999c2e1338 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -291,6 +291,12 @@ cdef _get_method_from_method_info(const TVMFFIMethodInfo*
method):
return make_ret(result)
+def _member_method_wrapper(method_func):
+ def wrapper(self, *args):
+ return method_func(self, *args)
+ return wrapper
+
+
def _add_class_attrs_by_reflection(int type_index, object cls):
"""Decorate the class attrs by reflection"""
cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index)
@@ -335,8 +341,10 @@ def _add_class_attrs_by_reflection(int type_index, object
cls):
if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod:
method_pyfunc = staticmethod(method_func)
else:
- def method_pyfunc(self, *args):
- return method_func(self, *args)
+ # must call into another method instead of direct capture
+ # to avoid the same method_func variable being used
+ # across multiple loop iterations
+ method_pyfunc = _member_method_wrapper(method_func)
if doc is not None:
method_pyfunc.__doc__ = doc
@@ -345,7 +353,6 @@ def _add_class_attrs_by_reflection(int type_index, object
cls):
if hasattr(cls, name):
# skip already defined attributes
continue
-
setattr(cls, name, method_pyfunc)
return cls
diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi
index 7df5f7a19a..dad6bee51b 100644
--- a/python/tvm/ffi/cython/object.pxi
+++ b/python/tvm/ffi/cython/object.pxi
@@ -31,7 +31,7 @@ def _set_func_convert_to_object(func):
def __object_repr__(obj):
"""Object repr function that can be overridden by assigning to it"""
- return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"
+ return type(obj).__name__ + "(" + str(obj.__ctypes_handle__().value) + ")"
def _new_object(cls):
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index d1954413dc..c6875d3fca 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -55,9 +55,9 @@ Optional<ObjectPathPair> ObjectPathPairFromAccessPathPair(
if (!src.has_value()) return std::nullopt;
auto translate_path = [](ffi::reflection::AccessPath path) {
ObjectPath result = ObjectPath::Root();
- for (const auto& step : path) {
+ for (const auto& step : path->ToSteps()) {
switch (step->kind) {
- case ffi::reflection::AccessKind::kObjectField: {
+ case ffi::reflection::AccessKind::kAttr: {
result = result->Attr(step->key.cast<String>());
break;
}
diff --git a/tests/python/ffi/test_access_path.py
b/tests/python/ffi/test_access_path.py
new file mode 100644
index 0000000000..06fbb64ff2
--- /dev/null
+++ b/tests/python/ffi/test_access_path.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from tvm.ffi.access_path import AccessPath, AccessKind
+
+
+def test_root_path():
+ root = AccessPath.root()
+ assert isinstance(root, AccessPath)
+ steps = root.to_steps()
+ assert len(steps) == 0
+ assert root == AccessPath.root()
+
+
+def test_path_attr():
+ path = AccessPath.root().attr("foo")
+ assert isinstance(path, AccessPath)
+ steps = path.to_steps()
+ assert len(steps) == 1
+ assert steps[0].kind == AccessKind.ATTR
+ assert steps[0].key == "foo"
+ assert path.parent == AccessPath.root()
+
+
+def test_path_array_item():
+ path = AccessPath.root().array_item(2)
+ assert isinstance(path, AccessPath)
+ steps = path.to_steps()
+ assert len(steps) == 1
+ assert steps[0].kind == AccessKind.ARRAY_ITEM
+ assert steps[0].key == 2
+ assert path.parent == AccessPath.root()
+
+
+def test_path_missing_array_element():
+ path = AccessPath.root().array_item_missing(2)
+ assert isinstance(path, AccessPath)
+ steps = path.to_steps()
+ assert len(steps) == 1
+ assert steps[0].kind == AccessKind.ARRAY_ITEM_MISSING
+ assert steps[0].key == 2
+ assert path.parent == AccessPath.root()
+
+
+def test_path_map_item():
+ path = AccessPath.root().map_item("foo")
+ assert isinstance(path, AccessPath)
+ steps = path.to_steps()
+ assert len(steps) == 1
+ assert steps[0].kind == AccessKind.MAP_ITEM
+ assert steps[0].key == "foo"
+ assert path.parent == AccessPath.root()
+
+
+def test_path_missing_map_item():
+ path = AccessPath.root().map_item_missing("foo")
+ assert isinstance(path, AccessPath)
+ steps = path.to_steps()
+ assert len(steps) == 1
+ assert steps[0].kind == AccessKind.MAP_ITEM_MISSING
+ assert steps[0].key == "foo"
+ assert path.parent == AccessPath.root()
+
+
+def test_path_is_prefix_of():
+ # Root is prefix of root
+ assert AccessPath.root().is_prefix_of(AccessPath.root())
+
+ # Root is prefix of any path
+ assert AccessPath.root().is_prefix_of(AccessPath.root().attr("foo"))
+
+ # Non-root is not prefix of root
+ assert not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root())
+
+ # Path is prefix of itself
+ assert
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo"))
+
+ # Different attrs are not prefixes of each other
+ assert not
AccessPath.root().attr("bar").is_prefix_of(AccessPath.root().attr("foo"))
+
+ # Shorter path is prefix of longer path with same start
+ assert
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo").array_item(2))
+
+ # Longer path is not prefix of shorter path
+ assert (
+ not
AccessPath.root().attr("foo").array_item(2).is_prefix_of(AccessPath.root().attr("foo"))
+ )
+
+ # Different paths are not prefixes
+ assert (
+ not
AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("bar").array_item(2))
+ )
+
+
+def test_path_equal():
+ # Root equals root
+ assert AccessPath.root() == AccessPath.root()
+
+ # Root does not equal non-root paths
+ assert not (AccessPath.root() == AccessPath.root().attr("foo"))
+
+ # Non-root does not equal root
+ assert not (AccessPath.root().attr("foo") == AccessPath.root())
+
+ # Path equals itself
+ assert AccessPath.root().attr("foo") == AccessPath.root().attr("foo")
+
+ # Different attrs are not equal
+ assert not (AccessPath.root().attr("bar") == AccessPath.root().attr("foo"))
+
+ # Shorter path does not equal longer path
+ assert not (AccessPath.root().attr("foo") ==
AccessPath.root().attr("foo").array_item(2))
+
+ # Longer path does not equal shorter path
+ assert not (AccessPath.root().attr("foo").array_item(2) ==
AccessPath.root().attr("foo"))
+
+ # Different paths are not equal
+ assert not (AccessPath.root().attr("foo") ==
AccessPath.root().attr("bar").array_item(2))