This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch deepequal
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f55232d2674cee8f20d36c1ccb5d9b2c73dc88c6
Author: tqchen <[email protected]>
AuthorDate: Wed Jul 16 10:00:08 2025 -0400

    11;rgb:1414/1414/1414# This is the 1st commit message:
    
    [TIR] Decouple DeepEqual from StructuralEqual
    
    This PR decouples deep equal from structural equal implementation
    by providing a more direct implementatio through functor.
    
    DeepEqual is being used at heart of arith simplification as subroutine
    and it performs more direct nested checking without doing var remapping
    as structural equal for efficiency reasons. It also do not need to trace
    the wrong comparison since the failed path is also expected to happen often.
    
    This step likely will improve the deep equal efficiency
    because of the more direct approach and gives us opportunity to
    run simplify future refactor of structural equal to focus on struct path 
tracing.
---
 src/tir/analysis/deep_equal.cc | 179 +++++++++++++++++++++++++++++++++++------
 1 file changed, 154 insertions(+), 25 deletions(-)

diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index bb105be02a..368e0de3e4 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -25,48 +25,177 @@
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/node/object_path.h>
 #include <tvm/node/reflection.h>
-#include <tvm/node/structural_equal.h>
 #include <tvm/tir/analysis.h>
+#include <tvm/tir/expr_functor.h>
 
 namespace tvm {
 namespace tir {
 
-class DeepCmpSEqualHandler : public SEqualReducer::Handler {
+#define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode)                              \
+  bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final {      \
+    const auto* prhs = rhs.as<OpNode>();                                \
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a) && \
+           VisitExpr(plhs->b, prhs->b);                                 \
+  }
+
+#define DEFINE_DEEP_EQUAL_IMM_EXPR(OpNode)                           \
+  bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final {   \
+    const auto* prhs = rhs.as<OpNode>();                             \
+    return plhs->dtype == prhs->dtype && plhs->value == prhs->value; \
+  }
+
+class ExprDeepEqualChecker : private ExprFunctor<bool(const PrimExpr&, const 
PrimExpr&)> {
  public:
-  // use direct recursion.
-  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool 
map_free_vars,
-                    const Optional<ObjectPathPair>&) final {
+  static bool Check(const PrimExpr& lhs, const PrimExpr& rhs) {
+    // quick path without constructing the object
     if (lhs.same_as(rhs)) return true;
     if (!lhs.defined() && rhs.defined()) return false;
     if (!rhs.defined() && lhs.defined()) return false;
     if (lhs->type_index() != rhs->type_index()) return false;
-    return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, 
nullptr, false)) &&
-           !fail_;
+    if (auto* plhs = lhs.as<IntImmNode>()) {
+      auto* prhs = rhs.as<IntImmNode>();
+      return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
+    }
+    return ExprDeepEqualChecker().VisitExpr(lhs, rhs);
   }
 
-  void DeferFail(const ObjectPathPair&) final { fail_ = true; }
-  bool IsFailDeferralEnabled() final { return false; }
+ private:
+  bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
+    if (lhs.size() != rhs.size()) return false;
+    for (size_t i = 0; i < lhs.size(); i++) {
+      if (!VisitExpr(lhs[i], rhs[i])) return false;
+    }
+    return true;
+  }
 
-  ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return lhs; }
-  void MarkGraphNode() final {}
+  bool ArrayDeepEqual(const Array<IterVar>& lhs, const Array<IterVar>& rhs) {
+    // for iter var, we require pointer equality
+    if (lhs.size() != rhs.size()) return false;
+    for (size_t i = 0; i < lhs.size(); i++) {
+      if (!lhs[i].same_as(rhs[i])) return true;
+    }
+    return true;
+  }
 
- private:
-  // reflection vtable
-  ReflectionVTable* vtable_ = ReflectionVTable::Global();
-  bool fail_ = false;
+  bool OptionalDeepEqual(const Optional<PrimExpr>& lhs, const 
Optional<PrimExpr>& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() && rhs.defined()) return false;
+    if (lhs.defined() && !rhs.defined()) return false;
+    return VisitExpr(*lhs, *rhs);
+  }
+
+  bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() && rhs.defined()) return false;
+    if (!rhs.defined() && lhs.defined()) return false;
+    if (lhs->type_index() != rhs->type_index()) return false;
+    return false;
+  }
+
+  bool VisitExpr_(const VarNode* plhs, const PrimExpr& rhs) final {
+    // for var, we require pointer equality
+    return plhs == rhs.get();
+  }
+
+  bool VisitExpr_(const SizeVarNode* plhs, const PrimExpr& rhs) final {
+    // for var, we require pointer equality
+    return plhs == rhs.get();
+  }
+
+  bool VisitExpr_(const BufferLoadNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<BufferLoadNode>();
+    // we run pointer comparison of the buffer
+    return plhs->dtype == prhs->dtype && plhs->buffer.same_as(prhs->buffer) &&
+           ArrayDeepEqual(plhs->indices, prhs->indices) &&
+           OptionalDeepEqual(plhs->predicate, prhs->predicate);
+  }
+
+  bool VisitExpr_(const ProducerLoadNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<ProducerLoadNode>();
+    // run shallow pointer comparison of the producer
+    return plhs->dtype == prhs->dtype && 
plhs->producer.same_as(prhs->producer) &&
+           ArrayDeepEqual(plhs->indices, prhs->indices);
+  }
+
+  bool VisitExpr_(const LetNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<LetNode>();
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->var, prhs->var) &&
+           VisitExpr(plhs->value, prhs->value) && VisitExpr(plhs->body, 
prhs->body);
+  }
+
+  bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<CallNode>();
+    return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) &&
+           ArrayDeepEqual(plhs->args, prhs->args);
+  }
+
+  bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<ReduceNode>();
+    return plhs->dtype == prhs->dtype && 
plhs->combiner.same_as(prhs->combiner) &&
+           ArrayDeepEqual(plhs->source, prhs->source) && 
ArrayDeepEqual(plhs->init, prhs->init) &&
+           ArrayDeepEqual(plhs->axis, prhs->axis) && 
VisitExpr(plhs->condition, prhs->condition) &&
+           plhs->value_index == prhs->value_index;
+  }
+
+  bool VisitExpr_(const CastNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<CastNode>();
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value);
+  }
+
+  bool VisitExpr_(const NotNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<NotNode>();
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a);
+  }
+
+  bool VisitExpr_(const SelectNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<SelectNode>();
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->condition, 
prhs->condition) &&
+           VisitExpr(plhs->true_value, prhs->true_value) &&
+           VisitExpr(plhs->false_value, prhs->false_value);
+  }
+
+  bool VisitExpr_(const RampNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<RampNode>();
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->base, prhs->base) &&
+           VisitExpr(plhs->stride, prhs->stride) && VisitExpr(plhs->lanes, 
prhs->lanes);
+  }
+
+  bool VisitExpr_(const ShuffleNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<ShuffleNode>();
+    return plhs->dtype == prhs->dtype && ArrayDeepEqual(plhs->vectors, 
prhs->vectors) &&
+           ArrayDeepEqual(plhs->indices, prhs->indices);
+  }
+
+  bool VisitExpr_(const BroadcastNode* plhs, const PrimExpr& rhs) final {
+    const auto* prhs = rhs.as<BroadcastNode>();
+    return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value) &&
+           VisitExpr(plhs->lanes, prhs->lanes);
+  }
+
+  DEFINE_DEEP_EQUAL_BIN_EXPR(AddNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(SubNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(MulNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(DivNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(ModNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(FloorDivNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(FloorModNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(MinNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(MaxNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(EQNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(NENode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(LTNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(LENode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(GTNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(GENode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(AndNode)
+  DEFINE_DEEP_EQUAL_BIN_EXPR(OrNode)
+  DEFINE_DEEP_EQUAL_IMM_EXPR(IntImmNode)
+  DEFINE_DEEP_EQUAL_IMM_EXPR(FloatImmNode)
+  DEFINE_DEEP_EQUAL_IMM_EXPR(StringImmNode)
 };
 
 bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const 
{
-  // quick path
-  if (lhs.same_as(rhs)) return true;
-  if (!lhs.defined() && rhs.defined()) return false;
-  if (!rhs.defined() && lhs.defined()) return false;
-  if (lhs->type_index() != rhs->type_index()) return false;
-  if (auto* plhs = lhs.as<IntImmNode>()) {
-    auto* prhs = rhs.as<IntImmNode>();
-    return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
-  }
-  return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt);
+  return ExprDeepEqualChecker::Check(lhs, rhs);
 }
 
 TVM_FFI_STATIC_INIT_BLOCK({

Reply via email to