masahi commented on code in PR #11740:
URL: https://github.com/apache/tvm/pull/11740#discussion_r898969882


##########
src/tir/schedule/ir_comparator.cc:
##########
@@ -355,5 +355,135 @@ void TensorizeComparator::EmitError(const std::string& 
error_message) {
   error_messages_.push_back(error_message);
 }
 
+/******** AutoTensorize Extractor ********/
+
+bool AutoTensorizeExtractor::VisitExprDefault_(const Object* op, const 
PrimExpr& other) {
+  return false;
+}
+
+bool AutoTensorizeExtractor::VisitStmtDefault_(const Object* op, const Stmt& 
other) {
+  return false;
+}
+
+template <typename T, typename F>
+bool AutoTensorizeExtractor::CompareArray(const Array<T>& lhs, const Array<T>& 
rhs, F cmp) {
+  if (lhs.same_as(rhs)) return true;
+  if (lhs.size() != rhs.size()) return false;
+  for (size_t i = 0; i < lhs.size(); ++i) {
+    if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+  }
+  return true;
+}
+
+bool AutoTensorizeExtractor::VisitStmt_(const BlockNode* op, const Stmt& 
other) {
+  const auto* rhs = other.as<BlockNode>();
+  // Check block equality.
+  // All iter vars and buffer regions including the order should match.
+  // When checking iter vars, DefEqual is used to remap variables.
+  if (!is_scope_block) {
+    if (!CompareArray(op->iter_vars, rhs->iter_vars, 
&AutoTensorizeExtractor::CompareIterVar)) {
+      return false;
+    }
+    if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+      return false;
+    }
+    if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
+                      &AutoTensorizeExtractor::CompareBuffer)) {
+      return false;
+    }
+    for (const IterVar& block_iter : op->iter_vars) {
+      inner_iter_dom_map_.Set(block_iter->var, 
arith::IntSet::FromRange(block_iter->dom));
+    }
+  } else {
+    auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters) 
-> bool {
+      for (const auto& iter : op->iter_vars) {
+        analyzer_.Bind(iter->var, iter->dom);
+        if (iter->iter_type == IterVarType::kDataPar ||
+            iter->iter_type == IterVarType::kCommReduce) {
+          iters.push_back(iter);
+        } else {
+          return false;
+        }
+      }
+      return true;
+    };
+    if (!collect_iter(op, lhs_iters_)) {
+      return false;
+    }
+    if (!collect_iter(rhs, rhs_iters_)) {
+      return false;
+    }
+  }
+  is_scope_block = false;
+  return VisitStmt(op->body, rhs->body);
+}
+
+bool AutoTensorizeExtractor::CompareBuffer(const Buffer& lhs, const Buffer& 
rhs) {
+  if (lhs.same_as(rhs)) return true;
+  auto it = rhs_buffer_map_.find(rhs);
+  bool equal;
+  if (it != rhs_buffer_map_.end()) {
+    equal = (*it).second.same_as(lhs);
+  } else {
+    // Remap both buffer itself and buffer data, skip buffer shape and scope
+    equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype;
+    if (equal) {
+      rhs_buffer_map_[rhs] = lhs;
+      lhs_buffer_map_[lhs] = rhs;
+    }
+  }
+  return equal;
+}
+
+bool AutoTensorizeExtractor::VisitStmt_(const BufferStoreNode* op, const Stmt& 
other) {
+  const auto* rhs = other.as<BufferStoreNode>();
+  return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool AutoTensorizeExtractor::VisitExpr_(const BufferLoadNode* op, const 
PrimExpr& other) {
+  const auto* rhs = other.as<BufferLoadNode>();
+  return CompareBufferAccess(op, rhs);
+}

Review Comment:
   Above two overrides look identical to the base class impls.



##########
src/tir/schedule/ir_comparator.cc:
##########
@@ -355,5 +355,135 @@ void TensorizeComparator::EmitError(const std::string& 
error_message) {
   error_messages_.push_back(error_message);
 }
 
+/******** AutoTensorize Extractor ********/
+
+bool AutoTensorizeExtractor::VisitExprDefault_(const Object* op, const 
PrimExpr& other) {
+  return false;
+}
+
+bool AutoTensorizeExtractor::VisitStmtDefault_(const Object* op, const Stmt& 
other) {
+  return false;
+}
+
+template <typename T, typename F>
+bool AutoTensorizeExtractor::CompareArray(const Array<T>& lhs, const Array<T>& 
rhs, F cmp) {

Review Comment:
   Looks identical to the base class impl.



##########
src/tir/schedule/ir_comparator.h:
##########
@@ -110,6 +110,48 @@ class TensorizeComparator : public ExprComparator, public 
StmtComparator {
   std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_;
 };
 
+/*! \brief IR comparator for auto tensorization. Extract correspondence 
between the IR of the
+ *         workload and the tensor intrin.
+ */
+class AutoTensorizeExtractor : public TensorizeComparator {

Review Comment:
   The name of this class is a bit confusing: It is derived from "Comparator", 
and it's main job does seem to be comparing things. But it is named "Extractor".
   
   To me, the "extract" aspect of this class seems like a by-product. How about 
just call it "AutoTensorizeComparator"?
   
   Also, it's better to document how this class is different from 
`TensorizeComparator`, in terms of IR comparison. Some of the overrides look 
identical to the base class ones. 



##########
src/tir/schedule/analysis.h:
##########
@@ -707,6 +707,55 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const 
tir::ScheduleState& self,
                                                 const tir::StmtSRef& 
block_sref,
                                                 const tir::PrimFunc& 
desc_func);
 
+/*!\brief Necessary information used to perform transformations for 
tensorization */
+class AutoTensorizeMappingInfoNode : public Object {
+ public:
+  /*! \brief Possible mappings to apply to block iters */
+  Array<IndexMap> mappings;
+
+  /* Additional information from AutoTensorizeExtractor */
+
+  /*! \brief Mapping from LHS buffer to RHS buffer */
+  Map<Buffer, Buffer> lhs_buffer_map;
+  /*! \brief Buffer indices on RHS */
+  Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
+  /*! \brief Block iters on LHS */
+  Array<IterVar> lhs_iters;
+  /*! \brief Block iters on RHS */
+  Array<IterVar> rhs_iters;
+

Review Comment:
   I get what you mean, but "LHS" and "RHS" are undefined.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to