This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f0bf057e42 [FFI][REFACTOR] Migrate StructuralEqual/Hash to new
reflection (#18166)
f0bf057e42 is described below
commit f0bf057e425363608b6a834140691be4d51e4417
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Jul 28 11:02:35 2025 -0400
[FFI][REFACTOR] Migrate StructuralEqual/Hash to new reflection (#18166)
This PR migrates the StructuralEqual/Hash to new reflection based approach.
The original mechanisms are still kept around and we will phase them out
in followup PRs.
The new mechanism unifies the structural equal/hash registration with
the normal reflection registeration and also brings cleaner implementation
for mismatch detection.
---
ffi/include/tvm/ffi/c_api.h | 21 ----
ffi/src/ffi/reflection/structural_equal.cc | 28 ++++--
ffi/src/ffi/reflection/structural_hash.cc | 51 ++++++----
ffi/tests/cpp/testing_object.h | 11 +-
include/tvm/arith/analyzer.h | 2 +
include/tvm/arith/int_solver.h | 3 +
include/tvm/arith/iter_affine_map.h | 3 +
include/tvm/ir/attrs.h | 3 +-
include/tvm/ir/diagnostic.h | 2 +
include/tvm/ir/env_func.h | 6 +-
include/tvm/ir/expr.h | 20 +++-
include/tvm/ir/global_info.h | 2 +
include/tvm/ir/module.h | 10 ++
include/tvm/ir/op.h | 13 +--
include/tvm/ir/source_map.h | 3 +
include/tvm/ir/type.h | 8 ++
include/tvm/relax/distributed/struct_info.h | 2 +
include/tvm/relax/expr.h | 66 +++++++++---
include/tvm/relax/struct_info.h | 2 +-
include/tvm/target/target.h | 2 +-
include/tvm/target/target_kind.h | 11 +-
include/tvm/te/tensor.h | 1 +
include/tvm/tir/buffer.h | 14 +--
include/tvm/tir/expr.h | 9 +-
include/tvm/tir/function.h | 11 +-
include/tvm/tir/index_map.h | 7 +-
include/tvm/tir/stmt.h | 37 +++----
include/tvm/tir/var.h | 6 +-
src/contrib/msc/core/ir/graph.h | 4 +
src/contrib/msc/core/ir/plugin.h | 4 +
src/ir/module.cc | 52 ++++++++++
src/ir/type.cc | 7 +-
src/meta_schedule/module_equality.cc | 33 +++---
src/node/structural_equal.cc | 111 +++++++++++++++++----
src/node/structural_hash.cc | 6 +-
src/relax/ir/expr.cc | 28 ++++++
src/relax/ir/struct_info.cc | 1 +
src/relax/transform/lift_transform_params.cc | 7 +-
tests/python/ir/test_node_reflection.py | 9 ++
.../test_tvmscript_printer_structural_equal.py | 10 ++
40 files changed, 462 insertions(+), 164 deletions(-)
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index e2de610a5d..60743b82c6 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -424,27 +424,6 @@ typedef enum {
* is only an unique copy of each value.
*/
kTVMFFISEqHashKindUniqueInstance = 5,
- /*!
- * \brief provide custom __s_equal__ and __s_hash__ functions through
TypeAttrColumn.
- *
- * The function signatures are(defined via ffi::Function)
- *
- * \code
- * bool __s_equal__(
- * ObjectRefType self, ObjectRefType other,
- * ffi::TypedFunction<bool(AnyView, AnyView, bool def_region, string
field_name)> cmp,
- * );
- *
- * uint64_t __s_hash__(
- * ObjectRefType self, uint64_t type_key_hash,
- * ffi::TypedFunction<uint64_t(AnyView, bool def_region)> hash
- * );
- * \endcode
- *
- * Where the extra string field in cmp is the name of the field that is
being compared.
- * The function should be registered through TVMFFITypeRegisterAttr via
reflection::TypeAttrDef.
- */
- kTVMFFISEqHashKindCustomTreeNode = 6,
#ifdef __cplusplus
};
#else
diff --git a/ffi/src/ffi/reflection/structural_equal.cc
b/ffi/src/ffi/reflection/structural_equal.cc
index 03cbdd95be..e44a0c3256 100644
--- a/ffi/src/ffi/reflection/structural_equal.cc
+++ b/ffi/src/ffi/reflection/structural_equal.cc
@@ -29,6 +29,7 @@
#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/ffi/string.h>
+#include <cmath>
#include <unordered_map>
namespace tvm {
@@ -49,7 +50,12 @@ class StructEqualHandler {
if (lhs_data->type_index != rhs_data->type_index) {
return false;
}
+
if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ // specially handle nan for float, as there can be multiple
representations of nan
+ if (lhs_data->type_index == TypeIndex::kTVMFFIFloat &&
std::isnan(lhs_data->v_float64)) {
+ return std::isnan(rhs_data->v_float64);
+ }
// this is POD data, we can just compare the value
return lhs_data->v_int64 == rhs_data->v_int64;
}
@@ -90,12 +96,18 @@ class StructEqualHandler {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
if (type_info->metadata == nullptr) {
- return lhs.same_as(rhs);
+ TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `"
+ << String(type_info->type_key)
+ << "`, so StructuralHash is not supported for
this type";
+ }
+ if (type_info->metadata->structural_eq_hash_kind ==
kTVMFFISEqHashKindUnsupported) {
+ TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `"
+ << String(type_info->type_key)
+ << "`, so StructuralHash is not supported for
this type";
}
- auto structural_eq_hash_kind =
type_info->metadata->structural_eq_hash_kind;
- if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported ||
- structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) {
+ auto structural_eq_hash_kind =
type_info->metadata->structural_eq_hash_kind;
+ if (structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) {
// use pointer comparison
return lhs.same_as(rhs);
}
@@ -118,8 +130,10 @@ class StructEqualHandler {
}
}
+ static reflection::TypeAttrColumn custom_s_equal =
reflection::TypeAttrColumn("__s_equal__");
+
bool success = true;
- if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
+ if (custom_s_equal[type_info->type_index] == nullptr) {
// We recursively compare the fields the object
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo*
field_info) {
// skip fields that are marked as structural eq hash ignore
@@ -153,7 +167,6 @@ class StructEqualHandler {
}
});
} else {
- static reflection::TypeAttrColumn custom_s_equal =
reflection::TypeAttrColumn("__s_equal__");
// run custom equal function defined via __s_equal__ type attribute
if (s_equal_callback_ == nullptr) {
s_equal_callback_ = ffi::Function::FromTyped(
@@ -179,9 +192,6 @@ class StructEqualHandler {
return success;
});
}
- TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr)
- << "TypeAttr `__s_equal__` is not registered for type `" <<
String(type_info->type_key)
- << "`";
success = custom_s_equal[type_info->type_index]
.cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
.cast<bool>();
diff --git a/ffi/src/ffi/reflection/structural_hash.cc
b/ffi/src/ffi/reflection/structural_hash.cc
index ba47de5146..e8ffcf6d2a 100644
--- a/ffi/src/ffi/reflection/structural_hash.cc
+++ b/ffi/src/ffi/reflection/structural_hash.cc
@@ -30,6 +30,8 @@
#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/ffi/string.h>
+#include <cmath>
+#include <limits>
#include <unordered_map>
#include <utility>
@@ -48,6 +50,13 @@ class StructuralHashHandler {
const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src);
if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
+ // specially handle nan for float, as there can be multiple
representations of nan
+ // make sure they map to the same hash value
+ if (src_data->type_index == TypeIndex::kTVMFFIFloat &&
std::isnan(src_data->v_float64)) {
+ TVMFFIAny temp = *src_data;
+ temp.v_float64 = std::numeric_limits<double>::quiet_NaN();
+ return details::StableHashCombine(temp.type_index, temp.v_uint64);
+ }
// this is POD data, we can just hash the value
return details::StableHashCombine(src_data->type_index,
src_data->v_uint64);
}
@@ -83,9 +92,16 @@ class StructuralHashHandler {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index());
if (type_info->metadata == nullptr) {
- // Fallback to pointer hash
- return std::hash<const Object*>()(obj.get());
+ TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `"
+ << String(type_info->type_key)
+ << "`, so StructuralHash is not supported for
this type";
}
+ if (type_info->metadata->structural_eq_hash_kind ==
kTVMFFISEqHashKindUnsupported) {
+ TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `"
+ << String(type_info->type_key)
+ << "`, so StructuralHash is not supported for
this type";
+ }
+
auto structural_eq_hash_kind =
type_info->metadata->structural_eq_hash_kind;
if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) {
// Fallback to pointer hash
@@ -97,9 +113,11 @@ class StructuralHashHandler {
return it->second;
}
+ static reflection::TypeAttrColumn custom_s_hash =
reflection::TypeAttrColumn("__s_hash__");
+
// compute the hash value
uint64_t hash_value = obj->GetTypeKeyHash();
- if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
+ if (custom_s_hash[type_info->type_index] == nullptr) {
// go over the content and hash the fields
ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
@@ -119,22 +137,19 @@ class StructuralHashHandler {
}
});
} else {
- static reflection::TypeAttrColumn custom_s_hash =
reflection::TypeAttrColumn("__s_hash__");
- TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr)
- << "TypeAttr `__s_hash__` is not registered for type `" <<
String(type_info->type_key)
- << "`";
if (s_hash_callback_ == nullptr) {
- s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool
def_region) {
- if (def_region) {
- bool allow_free_var = true;
- std::swap(allow_free_var, map_free_vars_);
- uint64_t hash_value = HashAny(val);
- std::swap(allow_free_var, map_free_vars_);
- return hash_value;
- } else {
- return HashAny(val);
- }
- });
+ s_hash_callback_ =
+ ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash,
bool def_region) {
+ if (def_region) {
+ bool allow_free_var = true;
+ std::swap(allow_free_var, map_free_vars_);
+ uint64_t hash_value = HashAny(val);
+ std::swap(allow_free_var, map_free_vars_);
+ return details::StableHashCombine(init_hash, hash_value);
+ } else {
+ return details::StableHashCombine(init_hash, HashAny(val));
+ }
+ });
}
hash_value = custom_s_hash[type_info->type_index]
.cast<ffi::Function>()(obj, hash_value,
s_hash_callback_)
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index 63c2b42d4f..3d8b4b23ed 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -227,10 +227,11 @@ class TCustomFuncObj : public Object {
return true;
}
- uint64_t SHash(uint64_t type_key_hash, ffi::TypedFunction<uint64_t(AnyView,
bool)> hash) const {
- uint64_t hash_value = type_key_hash;
- hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(params,
true));
- hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(body,
false));
+ uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash)
const {
+ uint64_t hash_value = init_hash;
+ hash_value = hash(params, hash_value, true);
+ hash_value = hash(body, hash_value, false);
return hash_value;
}
@@ -246,7 +247,7 @@ class TCustomFuncObj : public Object {
}
static constexpr const char* _type_key = "test.CustomFunc";
- static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindCustomTreeNode;
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object);
};
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 78eac07f45..54cbab2586 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -106,6 +106,7 @@ class ConstIntBoundNode : public Object {
*/
static const constexpr int64_t kNegInf = -kPosInf;
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
};
@@ -222,6 +223,7 @@ class ModularSetNode : public Object {
return equal(coeff, other->coeff) && equal(base, other->base);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h
index dd9259cf97..e2f384b696 100644
--- a/include/tvm/arith/int_solver.h
+++ b/include/tvm/arith/int_solver.h
@@ -83,6 +83,7 @@ class IntGroupBoundsNode : public Object {
hash_reduce(upper);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntGroupBounds";
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
@@ -173,6 +174,7 @@ class IntConstraintsNode : public Object {
hash_reduce(relations);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
@@ -238,6 +240,7 @@ class IntConstraintsTransformNode : public Object {
hash_reduce(dst_to_src);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
diff --git a/include/tvm/arith/iter_affine_map.h
b/include/tvm/arith/iter_affine_map.h
index b7f0e09e83..3c666b430f 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -116,6 +116,7 @@ class IterMarkNode : public Object {
hash_reduce(extent);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindDAGNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "arith.IterMark";
@@ -176,6 +177,7 @@ class IterSplitExprNode : public IterMapExprNode {
hash_reduce(scale);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.IterSplitExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode);
};
@@ -239,6 +241,7 @@ class IterSumExprNode : public IterMapExprNode {
hash_reduce(base);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "arith.IterSumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode);
};
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 952cea2a30..6a43274cae 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -83,7 +83,7 @@ class AttrFieldInfoNode : public Object {
}
static constexpr const char* _type_key = "ir.AttrFieldInfo";
-
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
@@ -122,6 +122,7 @@ class BaseAttrsNode : public Object {
TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs,
bool allow_unknown = false) = 0;
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "ir.Attrs";
diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h
index 8429ac1a62..e1d7abbead 100644
--- a/include/tvm/ir/diagnostic.h
+++ b/include/tvm/ir/diagnostic.h
@@ -79,6 +79,7 @@ class DiagnosticNode : public Object {
equal(this->message, other->message);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "Diagnostic";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object);
};
@@ -214,6 +215,7 @@ class DiagnosticContextNode : public Object {
return equal(module, other->module) && equal(diagnostics,
other->diagnostics);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "DiagnosticContext";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object);
};
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index 03cf3d625a..c1fdeb6d1c 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -51,7 +51,10 @@ class EnvFuncNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<EnvFuncNode>().def_ro("name", &EnvFuncNode::name);
+ // func do not participate in structural equal and hash.
+ refl::ObjectDef<EnvFuncNode>()
+ .def_ro("name", &EnvFuncNode::name)
+ .def_ro("func", &EnvFuncNode::func,
refl::AttachFieldFlag::SEqHashIgnore());
}
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
@@ -64,6 +67,7 @@ class EnvFuncNode : public Object {
hash_reduce(name);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 9a8e290cb9..cb62cbadf5 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -58,11 +58,14 @@ class BaseExprNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<BaseExprNode>().def_ro("span", &BaseExprNode::span,
refl::DefaultValue(Span()));
+ // span do not participate in structural equal and hash.
+ refl::ObjectDef<BaseExprNode>().def_ro("span", &BaseExprNode::span,
refl::DefaultValue(Span()),
+
refl::AttachFieldFlag::SEqHashIgnore());
}
static constexpr const char* _type_key = "ir.BaseExpr";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 64;
@@ -428,7 +431,8 @@ class RelaxExprNode : public BaseExprNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<RelaxExprNode>().def_ro("struct_info_",
&RelaxExprNode::struct_info_);
+ refl::ObjectDef<RelaxExprNode>().def_ro("struct_info_",
&RelaxExprNode::struct_info_,
+
refl::AttachFieldFlag::SEqHashIgnore());
}
static constexpr const char* _type_key = "ir.RelaxExpr";
@@ -474,6 +478,17 @@ class GlobalVarNode : public RelaxExprNode {
hash_reduce.FreeVarHashImpl(this);
}
+ bool SEqual(const GlobalVarNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal)
const {
+ return equal(name_hint, other->name_hint, false, "name_hint");
+ }
+
+ uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash)
const {
+ return hash(name_hint, init_hash, false);
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindFreeVar;
static constexpr const char* _type_key = "ir.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode);
};
@@ -711,6 +726,7 @@ class RangeNode : public Object {
}
static constexpr const char* _type_key = "ir.Range";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h
index 4a0b9ffdae..57eadf2b29 100644
--- a/include/tvm/ir/global_info.h
+++ b/include/tvm/ir/global_info.h
@@ -43,6 +43,8 @@ using MemoryScope = String;
class GlobalInfoNode : public Object {
public:
static constexpr const char* _type_key = "ir.GlobalInfo";
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object);
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index a5c2477b8d..66c26b0629 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -138,12 +138,21 @@ class IRModuleNode : public Object {
.def_ro("source_map", &IRModuleNode::source_map)
.def_ro("attrs", &IRModuleNode::attrs)
.def_ro("global_infos", &IRModuleNode::global_infos);
+ // register custom structural equal and hash.
+ refl::TypeAttrDef<IRModuleNode>()
+ .def("__s_equal__", &IRModuleNode::SEqual)
+ .def("__s_hash__", &IRModuleNode::SHash);
}
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal)
const;
TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
+ TVM_DLL bool SEqual(const IRModuleNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool,
AnyView)> equal) const;
+ TVM_DLL uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)>
hash) const;
+
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
@@ -237,6 +246,7 @@ class IRModuleNode : public Object {
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
static constexpr const char* _type_key = "ir.IRModule";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h
index a50f12b167..5903bed8d9 100644
--- a/include/tvm/ir/op.h
+++ b/include/tvm/ir/op.h
@@ -95,12 +95,12 @@ class OpNode : public RelaxExprNode {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<OpNode>()
.def_ro("name", &OpNode::name)
- .def_ro("op_type", &OpNode::op_type)
- .def_ro("description", &OpNode::description)
- .def_ro("arguments", &OpNode::arguments)
- .def_ro("attrs_type_key", &OpNode::attrs_type_key)
- .def_ro("num_inputs", &OpNode::num_inputs)
- .def_ro("support_level", &OpNode::support_level);
+ .def_ro("op_type", &OpNode::op_type,
refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("description", &OpNode::description,
refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("arguments", &OpNode::arguments,
refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("attrs_type_key", &OpNode::attrs_type_key,
refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("num_inputs", &OpNode::num_inputs,
refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("support_level", &OpNode::support_level,
refl::AttachFieldFlag::SEqHashIgnore());
}
bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
@@ -113,6 +113,7 @@ class OpNode : public RelaxExprNode {
hash_reduce(name);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindUniqueInstance;
static constexpr const char* _type_key = "ir.Op";
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode);
diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h
index 87b3532786..d53c234690 100644
--- a/include/tvm/ir/source_map.h
+++ b/include/tvm/ir/source_map.h
@@ -59,6 +59,7 @@ class SourceNameNode : public Object {
return equal(name, other->name);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
@@ -118,6 +119,7 @@ class SpanNode : public Object {
equal(end_column, other->end_column);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.Span";
TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object);
};
@@ -233,6 +235,7 @@ class SourceMapObj : public Object {
return equal(source_map, other->source_map);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.SourceMap";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object);
};
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index f879ab5911..a2ab74a3ae 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -80,6 +80,14 @@ class TypeNode : public Object {
*/
mutable Span span;
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ // span do not participate in structural equal and hash.
+ refl::ObjectDef<TypeNode>().def_ro("span", &TypeNode::span,
refl::DefaultValue(Span()),
+ refl::AttachFieldFlag::SEqHashIgnore());
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
diff --git a/include/tvm/relax/distributed/struct_info.h
b/include/tvm/relax/distributed/struct_info.h
index 3b4a4a0d1d..7f843a9f2c 100644
--- a/include/tvm/relax/distributed/struct_info.h
+++ b/include/tvm/relax/distributed/struct_info.h
@@ -61,6 +61,7 @@ class PlacementSpecNode : public Object {
}
static constexpr const char* _type_key = "relax.distributed.PlacementSpec";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindConstTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PlacementSpecNode, Object);
@@ -119,6 +120,7 @@ class PlacementNode : public Object {
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindConstTreeNode;
static constexpr const char* _type_key = "relax.distributed.Placement";
TVM_DECLARE_FINAL_OBJECT_INFO(PlacementNode, Object);
};
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 34aea7981d..06aba8618b 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -57,7 +57,8 @@ class IdNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<IdNode>().def_ro("name_hint", &IdNode::name_hint);
+ refl::ObjectDef<IdNode>().def_ro("name_hint", &IdNode::name_hint,
+ refl::AttachFieldFlag::SEqHashIgnore());
}
bool SEqualReduce(const IdNode* other, SEqualReducer equal) const {
@@ -66,6 +67,7 @@ class IdNode : public Object {
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.FreeVarHashImpl(this); }
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindFreeVar;
static constexpr const char* _type_key = "relax.Id";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -120,6 +122,13 @@ class StructInfoNode : public Object {
*/
mutable Span span;
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<StructInfoNode>().def_ro("span", &StructInfoNode::span,
+
refl::AttachFieldFlag::SEqHashIgnore());
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "ir.StructInfo";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -397,6 +406,10 @@ class VarNode : public LeafExprNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<VarNode>().def_ro("vid", &VarNode::vid);
+ // customize structural equal and hash to include struct_info_
+ refl::TypeAttrDef<VarNode>()
+ .def("__s_equal__", &VarNode::SEqual)
+ .def("__s_hash__", &VarNode::SHash);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
@@ -409,6 +422,21 @@ class VarNode : public LeafExprNode {
hash_reduce(struct_info_);
}
+ bool SEqual(const VarNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal)
const {
+ return equal(vid, other->vid, false, "vid") &&
+ equal(struct_info_, other->struct_info_, false, "struct_info_");
+ }
+
+ uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash)
const {
+ uint64_t hash_value = init_hash;
+ hash_value = hash(vid, hash_value, false);
+ hash_value = hash(struct_info_, hash_value, false);
+ return hash_value;
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindDAGNode;
static constexpr const char* _type_key = "relax.expr.Var";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -448,6 +476,7 @@ class DataflowVarNode : public VarNode {
hash_reduce(struct_info_);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindDAGNode;
static constexpr const char* _type_key = "relax.expr.DataflowVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -655,18 +684,19 @@ class DataTypeImm : public LeafExpr {
/*! \brief The base class of a variable binding in Relax. */
class BindingNode : public Object {
public:
+ mutable Span span;
/*! \brief The return variable to bound to. */
Var var;
- mutable Span span;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BindingNode>()
- .def_ro("var", &BindingNode::var)
- .def_ro("span", &BindingNode::span);
+ .def_ro("span", &BindingNode::span,
refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef());
}
static constexpr const char* _type_key = "relax.expr.Binding";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object);
@@ -701,9 +731,8 @@ class MatchCastNode : public BindingNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<MatchCastNode>()
- .def_ro("var", &MatchCastNode::var)
.def_ro("value", &MatchCastNode::value)
- .def_ro("struct_info", &MatchCastNode::struct_info);
+ .def_ro("struct_info", &MatchCastNode::struct_info,
refl::AttachFieldFlag::SEqHashDef());
}
bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const;
@@ -734,14 +763,21 @@ class VarBindingNode : public BindingNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<VarBindingNode>()
- .def_ro("var", &VarBindingNode::var)
- .def_ro("value", &VarBindingNode::value);
+ refl::ObjectDef<VarBindingNode>().def_ro("value", &VarBindingNode::value);
+ // customize the SEqual and SHash methods for better error messages
+ refl::TypeAttrDef<VarBindingNode>()
+ .def("__s_equal__", &VarBindingNode::SEqual)
+ .def("__s_hash__", &VarBindingNode::SHash);
}
bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const;
void SHashReduce(SHashReducer hash_reduce) const;
+ bool SEqual(const VarBindingNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal)
const;
+ uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash)
const;
+
static constexpr const char* _type_key = "relax.expr.VarBinding";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -757,12 +793,15 @@ class VarBinding : public Binding {
class BindingBlockNode : public Object {
public:
- mutable Span span;
Array<Binding> bindings;
+ mutable Span span;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<BindingBlockNode>().def_ro("bindings",
&BindingBlockNode::bindings);
+ refl::ObjectDef<BindingBlockNode>()
+ .def_ro("bindings", &BindingBlockNode::bindings)
+ .def_ro("span", &BindingBlockNode::span,
refl::AttachFieldFlag::SEqHashIgnore(),
+ refl::DefaultValue(Span()));
}
bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const {
@@ -771,6 +810,7 @@ class BindingBlockNode : public Object {
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "relax.expr.BindingBlock";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -906,6 +946,7 @@ class IfNode : public ExprNode {
hash_reduce(struct_info_);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindDAGNode;
static constexpr const char* _type_key = "relax.expr.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
@@ -960,7 +1001,7 @@ class FunctionNode : public BaseFuncNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FunctionNode>()
- .def_ro("params", &FunctionNode::params)
+ .def_ro("params", &FunctionNode::params,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("body", &FunctionNode::body)
.def_ro("ret_struct_info", &FunctionNode::ret_struct_info)
.def_ro("is_pure", &FunctionNode::is_pure);
@@ -983,6 +1024,7 @@ class FunctionNode : public BaseFuncNode {
hash_reduce(struct_info_);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindDAGNode;
static constexpr const char* _type_key = "relax.expr.Function";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index 25a6b1ef4a..cd9b05ab29 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -341,7 +341,7 @@ class FuncStructInfoNode : public StructInfoNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FuncStructInfoNode>()
- .def_ro("params", &FuncStructInfoNode::params)
+ .def_ro("params", &FuncStructInfoNode::params,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("ret", &FuncStructInfoNode::ret)
.def_ro("derive_func", &FuncStructInfoNode::derive_func)
.def_ro("purity", &FuncStructInfoNode::purity);
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 98c984c157..9929791f31 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -177,7 +177,7 @@ class TargetNode : public Object {
void SHashReduce(SHashReducer hash_reduce) const;
static constexpr const char* _type_key = "target.Target";
-
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index c9785820b4..15b5f62cd5 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -81,12 +81,14 @@ class TargetKindNode : public Object {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TargetKindNode>()
.def_ro("name", &TargetKindNode::name)
- .def_ro("default_device_type", &TargetKindNode::default_device_type)
- .def_ro("default_keys", &TargetKindNode::default_keys);
+ .def_ro("default_device_type", &TargetKindNode::default_device_type,
+ refl::AttachFieldFlag::SEqHashIgnore())
+ .def_ro("default_keys", &TargetKindNode::default_keys,
+ refl::AttachFieldFlag::SEqHashIgnore());
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindUniqueInstance;
static constexpr const char* _type_key = "target.TargetKind";
-
TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object);
private:
@@ -134,10 +136,11 @@ class TargetKind : public ObjectRef {
* \return The TargetKind requested
*/
TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef,
TargetKindNode);
/*! \brief Mutable access to the container class */
TargetKindNode* operator->() { return
static_cast<TargetKindNode*>(data_.get()); }
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef,
TargetKindNode);
+
private:
TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>&
GetAttrMapContainer(
const String& attr_name);
diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h
index 15be5bb069..f45a96df63 100644
--- a/include/tvm/te/tensor.h
+++ b/include/tvm/te/tensor.h
@@ -88,6 +88,7 @@ class TensorNode : public DataProducerNode {
TVM_DLL String GetNameHint() const final;
static constexpr const char* _type_key = "te.Tensor";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindConstTreeNode;
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
};
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index 6b49f619b5..cb16d2912a 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -115,13 +115,14 @@ class BufferNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BufferNode>()
- .def_ro("data", &BufferNode::data)
+ .def_ro("data", &BufferNode::data, refl::AttachFieldFlag::SEqHashDef())
.def_ro("dtype", &BufferNode::dtype)
- .def_ro("shape", &BufferNode::shape)
- .def_ro("strides", &BufferNode::strides)
- .def_ro("axis_separators", &BufferNode::axis_separators)
- .def_ro("elem_offset", &BufferNode::elem_offset)
- .def_ro("name", &BufferNode::name)
+ .def_ro("shape", &BufferNode::shape,
refl::AttachFieldFlag::SEqHashDef())
+ .def_ro("strides", &BufferNode::strides,
refl::AttachFieldFlag::SEqHashDef())
+ .def_ro("axis_separators", &BufferNode::axis_separators,
+ refl::AttachFieldFlag::SEqHashDef())
+ .def_ro("elem_offset", &BufferNode::elem_offset,
refl::AttachFieldFlag::SEqHashDef())
+ .def_ro("name", &BufferNode::name,
refl::AttachFieldFlag::SEqHashIgnore())
.def_ro("data_alignment", &BufferNode::data_alignment)
.def_ro("offset_factor", &BufferNode::offset_factor)
.def_ro("buffer_type", &BufferNode::buffer_type)
@@ -163,6 +164,7 @@ class BufferNode : public Object {
Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const;
static constexpr const char* _type_key = "tir.Buffer";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 9525f88784..3e6a07a6cd 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -833,7 +833,7 @@ class LetNode : public PrimExprNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LetNode>()
- .def_ro("var", &LetNode::var)
+ .def_ro("var", &LetNode::var, refl::AttachFieldFlag::SEqHashDef())
.def_ro("value", &LetNode::value)
.def_ro("body", &LetNode::body);
}
@@ -989,11 +989,11 @@ class CommReducerNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CommReducerNode>()
- .def_ro("lhs", &CommReducerNode::lhs)
- .def_ro("rhs", &CommReducerNode::rhs)
+ .def_ro("lhs", &CommReducerNode::lhs,
refl::AttachFieldFlag::SEqHashDef())
+ .def_ro("rhs", &CommReducerNode::rhs,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("result", &CommReducerNode::result)
.def_ro("identity_element", &CommReducerNode::identity_element)
- .def_ro("span", &CommReducerNode::span);
+ .def_ro("span", &CommReducerNode::span,
refl::AttachFieldFlag::SEqHashIgnore());
}
bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
@@ -1009,6 +1009,7 @@ class CommReducerNode : public Object {
}
static constexpr const char* _type_key = "tir.CommReducer";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index ead03e9676..2671f98791 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -49,8 +49,6 @@ class PrimFuncNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
Array<tir::Var> params;
- /*! \brief The body of the function */
- tir::Stmt body;
/*! \brief The return type of the function. */
Type ret_type;
/*!
@@ -99,14 +97,16 @@ class PrimFuncNode : public BaseFuncNode {
* flattened alias of the buffer.
*/
Map<tir::Var, Buffer> buffer_map;
+ /*! \brief The body of the function */
+ tir::Stmt body;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PrimFuncNode>()
- .def_ro("params", &PrimFuncNode::params)
- .def_ro("body", &PrimFuncNode::body)
+ .def_ro("params", &PrimFuncNode::params,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("ret_type", &PrimFuncNode::ret_type)
- .def_ro("buffer_map", &PrimFuncNode::buffer_map);
+ .def_ro("buffer_map", &PrimFuncNode::buffer_map)
+ .def_ro("body", &PrimFuncNode::body);
}
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
@@ -123,6 +123,7 @@ class PrimFuncNode : public BaseFuncNode {
hash_reduce(body);
hash_reduce(attrs);
}
+
/*!
* \brief Return the derived function annotation of this function.
*
diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index 1cc6fa950a..55d083834d 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -154,9 +154,11 @@ class IndexMapNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IndexMapNode>()
- .def_ro("initial_indices", &IndexMapNode::initial_indices)
+ .def_ro("initial_indices", &IndexMapNode::initial_indices,
+ refl::AttachFieldFlag::SEqHashDef())
.def_ro("final_indices", &IndexMapNode::final_indices)
- .def_ro("inverse_index_map", &IndexMapNode::inverse_index_map);
+ .def_ro("inverse_index_map", &IndexMapNode::inverse_index_map,
+ refl::AttachFieldFlag::SEqHashIgnore());
}
bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const {
@@ -169,6 +171,7 @@ class IndexMapNode : public Object {
hash_reduce(final_indices);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "tir.IndexMap";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index b89fff0032..9d31d25c39 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -54,6 +54,7 @@ class StmtNode : public Object {
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
static constexpr const char* _type_key = "tir.Stmt";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
@@ -81,7 +82,7 @@ class LetStmtNode : public StmtNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LetStmtNode>()
- .def_ro("var", &LetStmtNode::var)
+ .def_ro("var", &LetStmtNode::var, refl::AttachFieldFlag::SEqHashDef())
.def_ro("value", &LetStmtNode::value)
.def_ro("body", &LetStmtNode::body);
}
@@ -371,7 +372,7 @@ class AllocateNode : public StmtNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AllocateNode>()
- .def_ro("buffer_var", &AllocateNode::buffer_var)
+ .def_ro("buffer_var", &AllocateNode::buffer_var,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("dtype", &AllocateNode::dtype)
.def_ro("extents", &AllocateNode::extents)
.def_ro("condition", &AllocateNode::condition)
@@ -460,7 +461,7 @@ class AllocateConstNode : public StmtNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AllocateConstNode>()
- .def_ro("buffer_var", &AllocateConstNode::buffer_var)
+ .def_ro("buffer_var", &AllocateConstNode::buffer_var,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("data", &AllocateConstNode::data)
.def_ro("irmod_storage_idx", &AllocateConstNode::irmod_storage_idx)
.def_ro("dtype", &AllocateConstNode::dtype)
@@ -896,7 +897,7 @@ class ForNode : public StmtNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ForNode>()
- .def_ro("loop_var", &ForNode::loop_var)
+ .def_ro("loop_var", &ForNode::loop_var,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("min", &ForNode::min)
.def_ro("extent", &ForNode::extent)
.def_ro("kind", &ForNode::kind)
@@ -1017,6 +1018,7 @@ class BufferRegionNode : public PrimExprConvertibleNode {
TVM_DLL PrimExpr ToPrimExpr() const final;
static constexpr const char* _type_key = "tir.BufferRegion";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, PrimExprConvertibleNode);
@@ -1082,6 +1084,7 @@ class MatchBufferRegionNode : public Object {
}
static constexpr const char* _type_key = "tir.MatchBufferRegion";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object);
@@ -1130,8 +1133,12 @@ class BlockNode : public StmtNode {
Array<BufferRegion> writes;
/*! \brief The name_hint of the block. */
String name_hint;
- /*! \brief The body of the block. */
- Stmt body;
+ /*! \brief The buffer allocated in the block. */
+ Array<Buffer> alloc_buffers;
+ /*! \brief The match buffer regions. */
+ Array<MatchBufferRegion> match_buffers;
+ /*! \brief The annotation of the block. */
+ Map<String, ffi::Any> annotations;
/*!
* \brief The init statement is executed during the first iteration of
reduction loops in a
* reduction block. The optional init field allows us to represent
initialization and
@@ -1140,25 +1147,21 @@ class BlockNode : public StmtNode {
* Init field is `std::nullopt` if there is no reduction iter_vars
*/
Optional<Stmt> init;
- /*! \brief The buffer allocated in the block. */
- Array<Buffer> alloc_buffers;
- /*! \brief The match buffer regions. */
- Array<MatchBufferRegion> match_buffers;
- /*! \brief The annotation of the block. */
- Map<String, ffi::Any> annotations;
+ /*! \brief The body of the block. */
+ Stmt body;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BlockNode>()
- .def_ro("iter_vars", &BlockNode::iter_vars)
+ .def_ro("iter_vars", &BlockNode::iter_vars,
refl::AttachFieldFlag::SEqHashDef())
.def_ro("reads", &BlockNode::reads)
.def_ro("writes", &BlockNode::writes)
- .def_ro("name_hint", &BlockNode::name_hint)
- .def_ro("body", &BlockNode::body)
- .def_ro("init", &BlockNode::init)
+ .def_ro("name_hint", &BlockNode::name_hint,
refl::AttachFieldFlag::SEqHashIgnore())
.def_ro("alloc_buffers", &BlockNode::alloc_buffers)
.def_ro("match_buffers", &BlockNode::match_buffers)
- .def_ro("annotations", &BlockNode::annotations);
+ .def_ro("annotations", &BlockNode::annotations)
+ .def_ro("init", &BlockNode::init)
+ .def_ro("body", &BlockNode::body);
}
bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 24c7e6944d..021b6c301a 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -64,7 +64,7 @@ class VarNode : public PrimExprNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<VarNode>()
- .def_ro("name", &VarNode::name_hint)
+ .def_ro("name", &VarNode::name_hint,
refl::AttachFieldFlag::SEqHashIgnore())
.def_ro("type_annotation", &VarNode::type_annotation);
}
@@ -80,6 +80,7 @@ class VarNode : public PrimExprNode {
hash_reduce.FreeVarHashImpl(this);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindFreeVar;
static constexpr const char* _type_key = "tir.Var";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
@@ -290,7 +291,7 @@ class IterVarNode : public PrimExprConvertibleNode {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IterVarNode>()
.def_ro("dom", &IterVarNode::dom)
- .def_ro("var", &IterVarNode::var)
+ .def_ro("var", &IterVarNode::var, refl::AttachFieldFlag::SEqHashDef())
.def_ro("iter_type", &IterVarNode::iter_type)
.def_ro("thread_tag", &IterVarNode::thread_tag);
}
@@ -308,6 +309,7 @@ class IterVarNode : public PrimExprConvertibleNode {
}
static constexpr const char* _type_key = "tir.IterVar";
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, PrimExprConvertibleNode);
diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h
index 2ee4aeb05b..1d6b1046d9 100644
--- a/src/contrib/msc/core/ir/graph.h
+++ b/src/contrib/msc/core/ir/graph.h
@@ -400,6 +400,7 @@ class MSCTensorNode : public Object {
hash_reduce(prims);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.MSCTensor";
TVM_DECLARE_FINAL_OBJECT_INFO(MSCTensorNode, Object);
};
@@ -514,6 +515,7 @@ class BaseJointNode : public Object {
hash_reduce(children);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.BaseJoint";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -600,6 +602,7 @@ class MSCJointNode : public BaseJointNode {
hash_reduce(weights);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.MSCJoint";
TVM_DECLARE_FINAL_OBJECT_INFO(MSCJointNode, BaseJointNode);
};
@@ -833,6 +836,7 @@ class BaseGraphNode : public Object {
hash_reduce(node_names);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.BaseGraph";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h
index b926edf8e5..291a0e196a 100644
--- a/src/contrib/msc/core/ir/plugin.h
+++ b/src/contrib/msc/core/ir/plugin.h
@@ -290,6 +290,7 @@ class PluginAttrNode : public Object {
hash_reduce(describe);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.PluginAttr";
TVM_DECLARE_FINAL_OBJECT_INFO(PluginAttrNode, Object);
};
@@ -371,6 +372,7 @@ class PluginTensorNode : public Object {
hash_reduce(describe);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.PluginTensor";
TVM_DECLARE_FINAL_OBJECT_INFO(PluginTensorNode, Object);
};
@@ -454,6 +456,7 @@ class PluginExternNode : public Object {
hash_reduce(describe);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.PluginExtern";
TVM_DECLARE_FINAL_OBJECT_INFO(PluginExternNode, Object);
};
@@ -565,6 +568,7 @@ class PluginNode : public Object {
hash_reduce(options);
}
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const char* _type_key = "msc.core.Plugin";
TVM_DECLARE_FINAL_OBJECT_INFO(PluginNode, Object);
};
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 9eedbd5e30..f178747246 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -96,6 +96,30 @@ bool IRModuleNode::SEqualReduce(const IRModuleNode* other,
SEqualReducer equal)
return true;
}
+bool IRModuleNode::SEqual(const IRModuleNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool,
AnyView)> equal) const {
+ if (!equal(this->attrs, other->attrs, false, "attrs")) {
+ return false;
+ }
+ if (!equal(this->global_infos, other->global_infos, false, "global_infos")) {
+ return false;
+ }
+
+ // Define remaps for GlobalVar and GlobalTypeVar based on their string name.
+ for (const auto& gv : this->GetGlobalVars()) {
+ if (other->ContainGlobalVar(gv->name_hint)) {
+ if (!equal(gv, other->GetGlobalVar(gv->name_hint), true, "functions"))
return false;
+ }
+ }
+
+ // now check the functions with the GlobalVar remappped
+ if (!equal(this->functions, other->functions, false, "functions")) {
+ return false;
+ }
+
+ return true;
+}
+
void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
// hash the functions.
@@ -127,6 +151,34 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce)
const {
hash_reduce(this->global_infos);
}
+uint64_t IRModuleNode::SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t,
bool)> hash) const {
+ uint64_t hash_value = init_hash;
+ hash_value = hash(this->attrs, hash_value, false);
+ hash_value = hash(this->global_infos, hash_value, false);
+
+ // hash the functions.
+ using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
+ std::vector<KV> temp;
+ for (const auto& kv : this->functions) {
+ temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
+ }
+ // sort by the hash key of the keys.
+ std::sort(temp.begin(), temp.end(),
+ [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) <
std::get<0>(rhs); });
+ hash_value = hash(static_cast<uint64_t>(temp.size()), hash_value, false);
+ // first need to define the GlobalVar in the order of the keys
+ for (size_t i = 0; i < temp.size(); ++i) {
+ hash_value = hash(std::get<1>(temp[i]), hash_value, true);
+ }
+ // hash the name and content
+ for (size_t i = 0; i < temp.size(); ++i) {
+ hash_value = hash(std::get<0>(temp[i]), hash_value, false);
+ hash_value = hash(std::get<2>(temp[i]), hash_value, false);
+ }
+ return hash_value;
+}
+
bool IRModuleNode::ContainGlobalVar(const String& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
diff --git a/src/ir/type.cc b/src/ir/type.cc
index 4e580356ff..37b251f1f9 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -27,6 +27,7 @@
namespace tvm {
TVM_FFI_STATIC_INIT_BLOCK({
+ TypeNode::RegisterReflection();
PrimTypeNode::RegisterReflection();
PointerTypeNode::RegisterReflection();
TupleTypeNode::RegisterReflection();
@@ -50,8 +51,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
PointerType::PointerType(Type element_type, String storage_scope) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
+ if (storage_scope.empty()) {
+ n->storage_scope = "global";
+ } else {
+ n->storage_scope = std::move(storage_scope);
+ }
n->element_type = std::move(element_type);
- n->storage_scope = std::move(storage_scope);
data_ = std::move(n);
}
diff --git a/src/meta_schedule/module_equality.cc
b/src/meta_schedule/module_equality.cc
index 986233ca49..8901d5fd8d 100644
--- a/src/meta_schedule/module_equality.cc
+++ b/src/meta_schedule/module_equality.cc
@@ -18,6 +18,8 @@
*/
#include "module_equality.h"
+#include <tvm/ffi/reflection/structural_equal.h>
+#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
@@ -37,28 +39,15 @@ class ModuleEqualityStructural : public ModuleEquality {
String GetName() const { return "structural"; }
};
-class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
- public:
- SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {}
-
- protected:
- bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool
map_free_vars,
- const Optional<ObjectPathPair>& current_paths) {
- if (auto lhs_ptr = lhs.as<runtime::NDArray::Container>(),
- rhs_ptr = rhs.as<runtime::NDArray::Container>();
- lhs_ptr && rhs_ptr) {
- SEqualReducer reducer(this, nullptr, map_free_vars);
- return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, false);
- }
- return SEqualHandlerDefault::DispatchSEqualReduce(lhs, rhs, map_free_vars,
current_paths);
- }
-};
-
class ModuleEqualityIgnoreNDArray : public ModuleEquality {
public:
- size_t Hash(IRModule mod) const { return
SHashHandlerIgnoreNDArray().Hash(mod, false); }
+ size_t Hash(IRModule mod) const {
+ return tvm::ffi::reflection::StructuralHash::Hash(mod,
/*map_free_vars=*/false,
+
/*skip_ndarray_content=*/true);
+ }
bool Equal(IRModule lhs, IRModule rhs) const {
- return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false);
+ return tvm::ffi::reflection::StructuralEqual::Equal(lhs, rhs,
/*map_free_vars=*/false,
+
/*skip_ndarray_content=*/true);
}
String GetName() const { return "ignore-ndarray"; }
};
@@ -77,8 +66,10 @@ class ModuleEqualityAnchorBlock : public ModuleEquality {
auto anchor_block_lhs = tir::FindAnchorBlock(lhs);
auto anchor_block_rhs = tir::FindAnchorBlock(rhs);
if (anchor_block_lhs && anchor_block_rhs) {
- return
SEqualHandlerIgnoreNDArray().Equal(GetRef<tir::Block>(anchor_block_lhs),
-
GetRef<tir::Block>(anchor_block_rhs), false);
+ return
tvm::ffi::reflection::StructuralEqual::Equal(GetRef<tir::Block>(anchor_block_lhs),
+
GetRef<tir::Block>(anchor_block_rhs),
+
/*map_free_vars=*/false,
+
/*skip_ndarray_content=*/true);
}
return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs);
}
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 43dee2eb3b..5987692a0f 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -20,7 +20,9 @@
* \file src/node/structural_equal.cc
*/
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
@@ -599,34 +601,107 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const
ObjectRef& lhs, const Obje
return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths);
}
+Optional<ObjectPathPair> ObjectPathPairFromAccessPathPair(
+ Optional<ffi::reflection::AccessPathPair> src) {
+ if (!src.has_value()) return std::nullopt;
+ auto translate_path = [](ffi::reflection::AccessPath path) {
+ ObjectPath result = ObjectPath::Root();
+ for (const auto& step : path) {
+ switch (step->kind) {
+ case ffi::reflection::AccessKind::kObjectField: {
+ result = result->Attr(step->key.cast<String>());
+ break;
+ }
+ case ffi::reflection::AccessKind::kArrayIndex: {
+ result = result->ArrayIndex(step->key.cast<int64_t>());
+ break;
+ }
+ case ffi::reflection::AccessKind::kMapKey: {
+ result = result->MapValue(step->key);
+ break;
+ }
+ case ffi::reflection::AccessKind::kArrayIndexMissing: {
+ result = result->MissingArrayElement(step->key.cast<int64_t>());
+ break;
+ }
+ case ffi::reflection::AccessKind::kMapKeyMissing: {
+ result = result->MissingMapEntry();
+ break;
+ }
+ default: {
+ LOG(FATAL) << "Invalid access path kind: " <<
static_cast<int>(step->kind);
+ break;
+ }
+ }
+ }
+ return result;
+ };
+
+ return ObjectPathPair(translate_path((*src).get<0>()),
translate_path((*src).get<1>()));
+}
+
+bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool
assert_mode,
+ bool map_free_vars) {
+ if (assert_mode) {
+ auto first_mismatch = ObjectPathPairFromAccessPathPair(
+ ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs,
map_free_vars));
+ if (first_mismatch.has_value()) {
+ std::ostringstream oss;
+ oss << "StructuralEqual check failed, caused by lhs";
+ oss << " at " << (*first_mismatch)->lhs_path;
+ {
+ // print lhs
+ PrinterConfig cfg;
+ cfg->syntax_sugar = false;
+ cfg->path_to_underline.push_back((*first_mismatch)->lhs_path);
+ // The TVMScriptPrinter::Script will fallback to Repr printer,
+ // if the root node to print is not supported yet,
+ // e.g. Relax nodes, ArrayObj, MapObj, etc.
+ oss << ":" << std::endl <<
TVMScriptPrinter::Script(lhs.cast<ObjectRef>(), cfg);
+ }
+ oss << std::endl << "and rhs";
+ {
+ // print rhs
+ oss << " at " << (*first_mismatch)->rhs_path;
+ {
+ PrinterConfig cfg;
+ cfg->syntax_sugar = false;
+ cfg->path_to_underline.push_back((*first_mismatch)->rhs_path);
+ // The TVMScriptPrinter::Script will fallback to Repr printer,
+ // if the root node to print is not supported yet,
+ // e.g. Relax nodes, ArrayObj, MapObj, etc.
+ oss << ":" << std::endl <<
TVMScriptPrinter::Script(rhs.cast<ObjectRef>(), cfg);
+ }
+ }
+ TVM_FFI_THROW(ValueError) << oss.str();
+ }
+ return true;
+ } else {
+ return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_vars);
+ }
+}
+
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
- .def("node.StructuralEqual",
- [](const Any& lhs, const Any& rhs, bool assert_mode, bool
map_free_vars) {
- // If we are asserting on failure, then the `defer_fails` option
- // should be enabled, to provide better error messages. For
- // example, if the number of bindings in a `relax::BindingBlock`
- // differs, highlighting the first difference rather than the
- // entire block.
- bool defer_fails = assert_mode;
- Optional<ObjectPathPair> first_mismatch;
- return SEqualHandlerDefault(assert_mode, &first_mismatch,
defer_fails)
- .Equal(lhs, rhs, map_free_vars);
- })
+ .def("node.StructuralEqual", NodeStructuralEqualAdapter)
.def("node.GetFirstStructuralMismatch",
[](const Any& lhs, const Any& rhs, bool map_free_vars) {
- Optional<ObjectPathPair> first_mismatch;
- bool equal =
- SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs,
rhs, map_free_vars);
- ICHECK(equal == !first_mismatch.defined());
- return first_mismatch;
+ /*
+ Optional<ObjectPathPair> first_mismatch;
+ bool equal =
+ SEqualHandlerDefault(false, &first_mismatch,
true).Equal(lhs, rhs, map_free_vars);
+ ICHECK(equal == !first_mismatch.defined());
+ return first_mismatch;
+ */
+ return ObjectPathPairFromAccessPathPair(
+ ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs,
map_free_vars));
});
});
bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs,
bool map_free_params) const {
- return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs,
map_free_params);
+ return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_params);
}
bool NDArrayEqual(const runtime::NDArray::Container* lhs, const
runtime::NDArray::Container* rhs,
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 2a60754f7b..6fb8d36784 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -22,6 +22,7 @@
#include <dmlc/memory_io.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/object_path.h>
@@ -296,13 +297,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("node.StructuralHash",
[](const Any& object, bool map_free_vars) -> int64_t {
- uint64_t hashed_value =
SHashHandlerDefault().Hash(object, map_free_vars);
- return static_cast<int64_t>(hashed_value);
+ return ffi::reflection::StructuralHash::Hash(object,
map_free_vars);
});
});
uint64_t StructuralHash::operator()(const ObjectRef& object) const {
- return SHashHandlerDefault().Hash(object, false);
+ return ffi::reflection::StructuralHash::Hash(object, false);
}
void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool
map_free_vars) {
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 6005497a8e..c905b87305 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -517,6 +517,34 @@ void VarBindingNode::SHashReduce(SHashReducer hash_reduce)
const {
}
}
+bool VarBindingNode::SEqual(const VarBindingNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool,
AnyView)> equal) const {
+ if (value->IsInstance<FunctionNode>()) {
+ // Recursive function definitions may reference the bound variable
+ // within the value being bound. In these cases, the
+ // var comparison must occur first to define the var, to ensure it is
+ // defined at point of use.
+ return equal(var, other->var, true, "var") && equal(value, other->value,
false, "value");
+ } else {
+ // In all other cases, visit the bound value before the variable
+ // it is bound to, in order to provide better error messages.
+ return equal(value, other->value, false, "value") && equal(var,
other->var, true, "var");
+ }
+}
+
+uint64_t VarBindingNode::SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t,
bool)> hash) const {
+ uint64_t hash_value = init_hash;
+ if (value->IsInstance<FunctionNode>()) {
+ hash_value = hash(var, hash_value, true);
+ hash_value = hash(value, hash_value, false);
+ } else {
+ hash_value = hash(value, hash_value, false);
+ hash_value = hash(var, hash_value, true);
+ }
+ return hash_value;
+}
+
TVM_REGISTER_NODE_TYPE(BindingBlockNode);
BindingBlock::BindingBlock(Array<Binding> bindings, Span span) {
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
index 00414fe339..499dd47bb7 100644
--- a/src/relax/ir/struct_info.cc
+++ b/src/relax/ir/struct_info.cc
@@ -31,6 +31,7 @@ namespace tvm {
namespace relax {
TVM_FFI_STATIC_INIT_BLOCK({
+ StructInfoNode::RegisterReflection();
ObjectStructInfoNode::RegisterReflection();
PrimStructInfoNode::RegisterReflection();
ShapeStructInfoNode::RegisterReflection();
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index 1a426ec5da..83d978f27d 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
@@ -540,8 +541,8 @@ class ParamRemapper : private ExprFunctor<void(const Expr&,
const Expr&)> {
} else {
var_remap_.Set(GetRef<Var>(lhs_var), rhs_var);
}
- CHECK(structural_equal.Equal(lhs_var->struct_info_, rhs_var->struct_info_,
- /*map_free_vars=*/true))
+ CHECK(tvm::ffi::reflection::StructuralEqual::Equal(lhs_var->struct_info_,
rhs_var->struct_info_,
+ /*map_free_vars=*/true))
<< "The struct info of the parameters should be the same for all
target functions";
auto lhs_tir_vars =
DefinableTIRVarsInStructInfo(GetStructInfo(GetRef<Var>(lhs_var)));
auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr));
@@ -555,8 +556,6 @@ class ParamRemapper : private ExprFunctor<void(const Expr&,
const Expr&)> {
}
}
- SEqualHandlerDefault structural_equal{/*assert_mode=*/false,
/*first_mismatch=*/nullptr,
- /*defer_fail=*/false};
Map<Var, Expr> var_remap_;
Map<tir::Var, PrimExpr> tir_var_remap_;
};
diff --git a/tests/python/ir/test_node_reflection.py
b/tests/python/ir/test_node_reflection.py
index 741e61b2eb..be00bc3a47 100644
--- a/tests/python/ir/test_node_reflection.py
+++ b/tests/python/ir/test_node_reflection.py
@@ -181,6 +181,15 @@ def test_ndarray_dict():
tvm.ir.assert_structural_equal(m1, m2)
+def test_free_var_equal():
+ x = tvm.tir.Var("x", dtype="int32")
+ y = tvm.tir.Var("y", dtype="int32")
+ z = tvm.tir.Var("z", dtype="int32")
+ v1 = x + y
+ v1 = y + z
+ tvm.ir.assert_structural_equal(x, z, map_free_vars=True)
+
+
def test_alloc_const():
dev = tvm.cpu(0)
dtype = "float32"
diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
index 58d9402e6f..bbf95801ed 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
@@ -45,6 +45,9 @@ def test_prim_func_buffer_map():
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 256))
+ func1 = func1.with_attr("global_symbol", "main")
+ func2 = func2.with_attr("global_symbol", "main")
+
with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
assert _error_message(ve.value) == _expected_result(
@@ -109,8 +112,12 @@ def test_allocate():
a_data = T.allocate((256, 128), dtype="float32")
a = T.decl_buffer((256, 128), dtype="float32", data=a_data)
+ func1 = func1.with_attr("global_symbol", "main")
+ func2 = func2.with_attr("global_symbol", "main")
+
with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
+
assert _error_message(ve.value) == _expected_result(
func1,
func2,
@@ -132,6 +139,9 @@ def test_for():
with T.block():
pass
+ func1 = func1.with_attr("global_symbol", "main")
+ func2 = func2.with_attr("global_symbol", "main")
+
with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
assert _error_message(ve.value) == _expected_result(