vinx13 commented on code in PR #11740:
URL: https://github.com/apache/tvm/pull/11740#discussion_r899516179
##########
src/tir/schedule/ir_comparator.cc:
##########
@@ -355,5 +355,135 @@ void TensorizeComparator::EmitError(const std::string&
error_message) {
error_messages_.push_back(error_message);
}
+/******** AutoTensorize Extractor ********/
+
+bool AutoTensorizeExtractor::VisitExprDefault_(const Object* op, const
PrimExpr& other) {
+ return false;
+}
+
+bool AutoTensorizeExtractor::VisitStmtDefault_(const Object* op, const Stmt&
other) {
+ return false;
+}
+
+template <typename T, typename F>
+bool AutoTensorizeExtractor::CompareArray(const Array<T>& lhs, const Array<T>&
rhs, F cmp) {
+ if (lhs.same_as(rhs)) return true;
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+ }
+ return true;
+}
+
+bool AutoTensorizeExtractor::VisitStmt_(const BlockNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BlockNode>();
+ // Check block equality.
+ // All iter vars and buffer regions including the order should match.
+ // When checking iter vars, DefEqual is used to remap variables.
+ if (!is_scope_block) {
+ if (!CompareArray(op->iter_vars, rhs->iter_vars,
&AutoTensorizeExtractor::CompareIterVar)) {
+ return false;
+ }
+ if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+ return false;
+ }
+ if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
+ &AutoTensorizeExtractor::CompareBuffer)) {
+ return false;
+ }
+ for (const IterVar& block_iter : op->iter_vars) {
+ inner_iter_dom_map_.Set(block_iter->var,
arith::IntSet::FromRange(block_iter->dom));
+ }
+ } else {
+ auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters)
-> bool {
+ for (const auto& iter : op->iter_vars) {
+ analyzer_.Bind(iter->var, iter->dom);
+ if (iter->iter_type == IterVarType::kDataPar ||
+ iter->iter_type == IterVarType::kCommReduce) {
+ iters.push_back(iter);
+ } else {
+ return false;
+ }
+ }
+ return true;
+ };
+ if (!collect_iter(op, lhs_iters_)) {
+ return false;
+ }
+ if (!collect_iter(rhs, rhs_iters_)) {
+ return false;
+ }
+ }
+ is_scope_block = false;
+ return VisitStmt(op->body, rhs->body);
+}
+
+bool AutoTensorizeExtractor::CompareBuffer(const Buffer& lhs, const Buffer&
rhs) {
+ if (lhs.same_as(rhs)) return true;
+ auto it = rhs_buffer_map_.find(rhs);
+ bool equal;
+ if (it != rhs_buffer_map_.end()) {
+ equal = (*it).second.same_as(lhs);
+ } else {
+ // Remap both buffer itself and buffer data, skip buffer shape and scope
+ equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype;
+ if (equal) {
+ rhs_buffer_map_[rhs] = lhs;
+ lhs_buffer_map_[lhs] = rhs;
+ }
+ }
+ return equal;
+}
+
+bool AutoTensorizeExtractor::VisitStmt_(const BufferStoreNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BufferStoreNode>();
+ return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool AutoTensorizeExtractor::VisitExpr_(const BufferLoadNode* op, const
PrimExpr& other) {
+ const auto* rhs = other.as<BufferLoadNode>();
+ return CompareBufferAccess(op, rhs);
+}
Review Comment:
Yes. They call the template function `CompareBufferAccess`, which can't be
virtual. So I have to duplicate these two overrides
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]