junrushao1994 commented on a change in pull request #9871:
URL: https://github.com/apache/tvm/pull/9871#discussion_r780719575
##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -1376,5 +1376,333 @@ void CheckStorageScope(const ScheduleState& self,
String storage_scope) {
}
}
+/******** Tensorize Comparator ********/
+
+bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
+ if (n.same_as(other)) return true;
+ if (n->type_index() != other->type_index()) return false;
+ bool equal = StmtComparator::VisitStmt(n, other);
+ if (!equal && assert_mode_)
+ LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" <<
other;
+ return equal;
+}
+
+bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
+ const auto* rhs = other.as<ForNode>();
+ if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
+ if (!VisitExpr(op->min, rhs->min)) return false;
+ if (!VisitExpr(op->extent, rhs->extent)) return false;
+ if (!VisitStmt(op->body, rhs->body)) return false;
+ if (op->kind != rhs->kind) return false;
+ if (op->thread_binding.defined() ^ rhs->thread_binding.defined()) return
false;
+ if (op->thread_binding.defined() &&
+ !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value()))
+ return false;
+ return CompareAnnotationMap(op->annotations, rhs->annotations);
+}
+
+bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other)
{
+ const auto* rhs = other.as<SeqStmtNode>();
+ return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt);
+}
+
+bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BufferStoreNode>();
+ return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BlockRealizeNode>();
+ // Skip Compare binding values if the block is scope block (the outermost
one).
+ if (!is_scope_block) {
+ size_t offset = op->iter_values.size() - rhs->iter_values.size();
+ if (rhs->iter_values.size() > op->iter_values.size()) return false;
+ if (is_inner_block) {
+ // weak pattern matching for the inner block (the son of the scope block)
+ // where the pattern is v + iter <=> expr + iter
+ for (size_t i = 0; i < rhs->iter_values.size(); ++i) {
+ PrimExpr lhs_expr, rhs_expr;
+ Optional<Var> lhs_iter, rhs_iter;
+ auto detect = [](const PrimExpr& binding) -> std::pair<PrimExpr,
Optional<Var>> {
+ arith::PVar<PrimExpr> expr;
+ arith::PVar<Var> iter;
+ if (iter.Match(binding)) {
+ return std::make_pair(0, iter.Eval());
+ } else if ((expr + iter).Match(binding)) {
+ return std::make_pair(expr.Eval(), iter.Eval());
+ } else if ((iter + expr).Match(binding)) {
+ return std::make_pair(expr.Eval(), iter.Eval());
+ } else {
+ return std::make_pair(expr.Eval(), NullOpt);
+ }
+ };
+ std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]);
+ std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]);
+ CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) <<
"Incompatible binding";
+ if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value());
+ if (is_zero(rhs_expr)) {
+ CHECK(is_zero(lhs_expr)) << "Incompatible binding";
+ } else {
+ const auto* bv = rhs_expr.as<VarNode>();
+ if (!bv) {
+ VisitExpr(lhs_expr, rhs_expr);
+ } else {
+ auto it = equal_map_.find(GetRef<Var>(bv));
+ if (it == equal_map_.end()) {
+ equal_map_[GetRef<Var>(bv)] = lhs_expr;
+ } else {
+ CHECK(it->second->IsInstance<PrimExprNode>());
+ VisitExpr(lhs_expr, Downcast<PrimExpr>(it->second));
+ }
+ }
+ }
+ }
+ } else {
+ for (size_t i = 0; i < rhs->iter_values.size(); ++i) {
+ if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i]))
return false;
+ }
+ const Block& block = op->block;
+ for (size_t i = 0; i < offset; ++i) {
+ Var block_var = Downcast<Var>(op->iter_values[i]);
+ auto it = equal_map_.find(block_var);
+ equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ?
block_var : it->second);
+ }
+ }
+ }
+
+ return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block,
rhs->block);
+}
+
+bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
+ const auto* rhs = other.as<BlockNode>();
+ // Check block equality.
+ // All iter vars and buffer regions including the order shoudl match.
+ // When checking iter vars, DefEqual is used to remap variables.
+ // Only the inner most several axis are compared. Other iter vars are added
to extra_block_vars.
+ if (op->iter_vars.size() < rhs->iter_vars.size()) return false;
+
+ size_t offset = op->iter_vars.size() - rhs->iter_vars.size();
+ for (size_t i = 0; i < rhs->iter_vars.size(); ++i) {
+ auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i];
+ // Skip iter dom
+ if (!DefEqual(lhs_var->var, rhs_var->var)) {
+ return false;
+ }
+ if (lhs_var->iter_type != rhs_var->iter_type) {
+ return false;
+ }
+ }
+
+ if (is_scope_block) {
+ for (size_t i = 0; i < offset; ++i) {
+ extra_block_vars_.push_back(op->iter_vars[i]);
+ }
+ }
+
+ if (!is_scope_block) {
+ if (!CompareArray(op->writes, rhs->writes,
&TensorizeComparator::CompareBufferRegion)) {
+ return false;
+ }
+ if (!CompareArray(op->reads, rhs->reads,
&TensorizeComparator::CompareBufferRegion)) {
+ return false;
+ }
+ if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+ return false;
+ }
+ if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
&TensorizeComparator::CompareBuffer)) {
+ return false;
+ }
+ }
+ is_scope_block = false;
+ return VisitStmt(op->body, rhs->body);
+}
+
+// Exprs
+#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName)
\
+ bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr&
other) { \
+ const auto* rhs = other.as<OpName>();
\
+ return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b);
\
+ }
+
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode);
+
+bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<IntImmNode>();
+ return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value;
+}
+
+bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<FloatImmNode>();
+ return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value;
+}
+
+bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<CastNode>();
+ return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value,
rhs->value);
+}
+
+bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other)
{
+ const auto* rhs = other.as<VarNode>();
+ auto lhs = GetRef<Var>(op);
+ if (lhs.same_as(other)) return true;
+ if (!CompareType(op->dtype, rhs->dtype)) return false;
+ auto it = equal_map_.find(lhs);
+ return it != equal_map_.end() && it->second.same_as(other);
+}
+
+bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<BufferLoadNode>();
+ return CompareBufferAccess(op, rhs);
+}
+
+bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs)
{
+ if (lhs.same_as(rhs)) return true;
+ if (lhs->type_index() != rhs->type_index()) return false;
+ auto it = equal_map_.find(lhs);
+ // If there is already a mapping
+ if (it != equal_map_.end()) return it->second.same_as(rhs);
+ equal_map_[lhs] = rhs;
+ return true;
+}
+
+bool TensorizeComparator::CompareAnnotation(const std::pair<String,
ObjectRef>& lhs,
+ const std::pair<String,
ObjectRef>& rhs) {
+ if (lhs.first != rhs.first) return false;
+ if (!lhs.second.same_as(rhs.second)) return false;
+ return VisitExpr(Downcast<PrimExpr>(lhs.second),
Downcast<PrimExpr>(rhs.second));
+}
+
+bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>&
lhs,
+ const Map<String, ObjectRef>&
rhs) {
+ if (lhs.same_as(rhs)) return true;
+ if (lhs.size() != rhs.size()) return false;
+
+ auto sort_map =
+ [](const Map<String, ObjectRef>& map) -> std::vector<std::pair<String,
ObjectRef>> {
+ std::vector<std::pair<String, ObjectRef>> ret;
+ ret.reserve(map.size());
+ for (const auto& pair : map) {
+ ret.emplace_back(pair);
+ }
+ sort(ret.begin(), ret.end());
+ return ret;
+ };
+
+ auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs);
Review comment:
nit: no need to abuse auto
```suggestion
std::vector<std::pair<String, ObjectRef>> lhs_array = sort_map(lhs);
std::vector<std::pair<String, ObjectRef>> rhs_array = sort_map(rhs);
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]