This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89f9573dd8 [TIR] Decouple DeepEqual from StructuralEqual (#18151)
89f9573dd8 is described below
commit 89f9573dd84758f3abb608fcf666c3d34a1c3489
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jul 17 10:01:17 2025 -0400
[TIR] Decouple DeepEqual from StructuralEqual (#18151)
This PR decouples deep equal from structural equal implementation
by providing a more direct implementatio through functor.
DeepEqual is being used at heart of arith simplification as subroutine
and it performs more direct nested checking without doing var remapping
as structural equal for efficiency reasons. It also do not need to trace
the wrong comparison since the failed path is also expected to happen often.
This step likely will improve the deep equal efficiency
because of the more direct approach and gives us opportunity to
run simplify future refactor of structural equal to focus on struct path
tracing.
---
src/tir/analysis/deep_equal.cc | 179 +++++++++++++++++++++++++++++++++++------
1 file changed, 154 insertions(+), 25 deletions(-)
diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index bb105be02a..fe22d152cb 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -25,48 +25,177 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/object_path.h>
#include <tvm/node/reflection.h>
-#include <tvm/node/structural_equal.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr_functor.h>
namespace tvm {
namespace tir {
-class DeepCmpSEqualHandler : public SEqualReducer::Handler {
+#define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode) \
+ bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \
+ const auto* prhs = rhs.as<OpNode>(); \
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a) && \
+ VisitExpr(plhs->b, prhs->b); \
+ }
+
+#define DEFINE_DEEP_EQUAL_IMM_EXPR(OpNode) \
+ bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \
+ const auto* prhs = rhs.as<OpNode>(); \
+ return plhs->dtype == prhs->dtype && plhs->value == prhs->value; \
+ }
+
+class ExprDeepEqualChecker : private ExprFunctor<bool(const PrimExpr&, const
PrimExpr&)> {
public:
- // use direct recursion.
- bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool
map_free_vars,
- const Optional<ObjectPathPair>&) final {
+ static bool Check(const PrimExpr& lhs, const PrimExpr& rhs) {
+ // quick path without constructing the object
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
- return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this,
nullptr, false)) &&
- !fail_;
+ if (auto* plhs = lhs.as<IntImmNode>()) {
+ auto* prhs = rhs.as<IntImmNode>();
+ return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
+ }
+ return ExprDeepEqualChecker().VisitExpr(lhs, rhs);
}
- void DeferFail(const ObjectPathPair&) final { fail_ = true; }
- bool IsFailDeferralEnabled() final { return false; }
-
- ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return lhs; }
- void MarkGraphNode() final {}
+ bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final {
+ if (lhs.same_as(rhs)) return true;
+ if (!lhs.defined() && rhs.defined()) return false;
+ if (!rhs.defined() && lhs.defined()) return false;
+ if (lhs->type_index() != rhs->type_index()) return false;
+ return ExprFunctor::VisitExpr(lhs, rhs);
+ }
private:
- // reflection vtable
- ReflectionVTable* vtable_ = ReflectionVTable::Global();
- bool fail_ = false;
+ bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); i++) {
+ if (!VisitExpr(lhs[i], rhs[i])) return false;
+ }
+ return true;
+ }
+
+ bool ArrayDeepEqual(const Array<IterVar>& lhs, const Array<IterVar>& rhs) {
+ // for iter var, we require pointer equality
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); i++) {
+ if (!lhs[i].same_as(rhs[i])) return true;
+ }
+ return true;
+ }
+
+ bool OptionalDeepEqual(const Optional<PrimExpr>& lhs, const
Optional<PrimExpr>& rhs) {
+ if (lhs.same_as(rhs)) return true;
+ if (!lhs.defined() && rhs.defined()) return false;
+ if (lhs.defined() && !rhs.defined()) return false;
+ return VisitExpr(*lhs, *rhs);
+ }
+
+ bool VisitExpr_(const VarNode* plhs, const PrimExpr& rhs) final {
+ // for var, we require pointer equality
+ return plhs == rhs.get();
+ }
+
+ bool VisitExpr_(const SizeVarNode* plhs, const PrimExpr& rhs) final {
+ // for var, we require pointer equality
+ return plhs == rhs.get();
+ }
+
+ bool VisitExpr_(const BufferLoadNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<BufferLoadNode>();
+ // we run pointer comparison of the buffer
+ return plhs->dtype == prhs->dtype && plhs->buffer.same_as(prhs->buffer) &&
+ ArrayDeepEqual(plhs->indices, prhs->indices) &&
+ OptionalDeepEqual(plhs->predicate, prhs->predicate);
+ }
+
+ bool VisitExpr_(const ProducerLoadNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<ProducerLoadNode>();
+ // run shallow pointer comparison of the producer
+ return plhs->dtype == prhs->dtype &&
plhs->producer.same_as(prhs->producer) &&
+ ArrayDeepEqual(plhs->indices, prhs->indices);
+ }
+
+ bool VisitExpr_(const LetNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<LetNode>();
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->var, prhs->var) &&
+ VisitExpr(plhs->value, prhs->value) && VisitExpr(plhs->body,
prhs->body);
+ }
+
+ bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<CallNode>();
+ return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) &&
+ ArrayDeepEqual(plhs->args, prhs->args);
+ }
+
+ bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<ReduceNode>();
+ return plhs->dtype == prhs->dtype &&
plhs->combiner.same_as(prhs->combiner) &&
+ ArrayDeepEqual(plhs->source, prhs->source) &&
ArrayDeepEqual(plhs->init, prhs->init) &&
+ ArrayDeepEqual(plhs->axis, prhs->axis) &&
VisitExpr(plhs->condition, prhs->condition) &&
+ plhs->value_index == prhs->value_index;
+ }
+
+ bool VisitExpr_(const CastNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<CastNode>();
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value);
+ }
+
+ bool VisitExpr_(const NotNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<NotNode>();
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a);
+ }
+
+ bool VisitExpr_(const SelectNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<SelectNode>();
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->condition,
prhs->condition) &&
+ VisitExpr(plhs->true_value, prhs->true_value) &&
+ VisitExpr(plhs->false_value, prhs->false_value);
+ }
+
+ bool VisitExpr_(const RampNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<RampNode>();
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->base, prhs->base) &&
+ VisitExpr(plhs->stride, prhs->stride) && VisitExpr(plhs->lanes,
prhs->lanes);
+ }
+
+ bool VisitExpr_(const ShuffleNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<ShuffleNode>();
+ return plhs->dtype == prhs->dtype && ArrayDeepEqual(plhs->vectors,
prhs->vectors) &&
+ ArrayDeepEqual(plhs->indices, prhs->indices);
+ }
+
+ bool VisitExpr_(const BroadcastNode* plhs, const PrimExpr& rhs) final {
+ const auto* prhs = rhs.as<BroadcastNode>();
+ return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value) &&
+ VisitExpr(plhs->lanes, prhs->lanes);
+ }
+
+ DEFINE_DEEP_EQUAL_BIN_EXPR(AddNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(SubNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(MulNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(DivNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(ModNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(FloorDivNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(FloorModNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(MinNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(MaxNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(EQNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(NENode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(LTNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(LENode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(GTNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(GENode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(AndNode)
+ DEFINE_DEEP_EQUAL_BIN_EXPR(OrNode)
+ DEFINE_DEEP_EQUAL_IMM_EXPR(IntImmNode)
+ DEFINE_DEEP_EQUAL_IMM_EXPR(FloatImmNode)
+ DEFINE_DEEP_EQUAL_IMM_EXPR(StringImmNode)
};
bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const
{
- // quick path
- if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() && rhs.defined()) return false;
- if (!rhs.defined() && lhs.defined()) return false;
- if (lhs->type_index() != rhs->type_index()) return false;
- if (auto* plhs = lhs.as<IntImmNode>()) {
- auto* prhs = rhs.as<IntImmNode>();
- return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
- }
- return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt);
+ return ExprDeepEqualChecker::Check(lhs, rhs);
}
TVM_FFI_STATIC_INIT_BLOCK({