vinx13 commented on a change in pull request #9871:
URL: https://github.com/apache/tvm/pull/9871#discussion_r788220637



##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const 
StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not 
match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : 
assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+  // Buffer indices mapping
+  std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, 
ObjectPtrEqual> buffer_indices_;
+  std::vector<IterVar> extra_block_vars_;
+  // variable remap if any
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_;
+
+  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
+  bool VisitStmt(const Stmt& n, const Stmt& other) override;
+
+  bool VisitStmt_(const ForNode* op, const Stmt& other) override;
+  bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
+
+  bool VisitExpr_(const AddNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const SubNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MulNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const DivNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const ModNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const EQNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const NENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const LTNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const LENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const GTNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const GENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const AndNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const OrNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MinNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const CastNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const VarNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
+
+  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
+  virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs);
+  bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs);
+  bool CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
+                         const std::pair<String, ObjectRef>& rhs);
+  bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const 
Map<String, ObjectRef>& rhs);
+  template <typename T>
+  bool CompareBufferAccess(const T* lhs, const T* rhs);
+  template <typename T, typename F>
+  bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp);
+  bool CompareRange(const Range& lhs, const Range& rhs);
+  bool CompareType(const DataType& lhs, const DataType& rhs);
+
+ protected:

Review comment:
       there might be some use cases to subclass this class, like auto 
tensorization




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to