wrongtest commented on code in PR #11287:
URL: https://github.com/apache/tvm/pull/11287#discussion_r881811402


##########
src/arith/iter_affine_map.cc:
##########
@@ -1263,140 +1285,135 @@ IterSumExpr 
IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
   }
 }
 
+/*! \brief Find approximate least common multiplier. */
+PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, 
Analyzer* analyzer) {
+  auto fsplit = [](const PrimExpr& e) -> std::pair<PrimExpr, int64_t> {
+    if (const IntImmNode* imm = e.as<IntImmNode>()) {
+      return {1, imm->value};
+    }
+    PVar<PrimExpr> pv;
+    PVar<IntImm> pc;
+    if ((pv * pc).Match(e) || (pc * pv).Match(e)) {
+      return {pv.Eval(), pc.Eval()->value};
+    } else {
+      return {e, 1};
+    }
+  };
+  auto p1 = fsplit(a);
+  auto p2 = fsplit(b);
+  auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second));
+  if (analyzer->CanProveEqual(p1.first, p2.first)) {
+    return p1.first * const_lcm;
+  } else {
+    return (p1.first * p2.first) * const_lcm;
+  }
+}
+
 std::pair<IterSplitExpr, PrimExpr> 
IterMapRewriter::PadDividendToDivisor(IterSplitExpr split,
                                                                          
PrimExpr base,
                                                                          
PrimExpr divisor) {
   // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor
   // If FloorMod: (((source//lower_factor) % extent) + base) % divisor
 
-  PrimExpr lookup_key = split;
-
-  auto modified_divisor = [&]() {
-    if (update_iterator_padding_) {
-      return divisor;
-    }
-
-    auto it = padded_iter_map_.find(lookup_key);
-    if (it == padded_iter_map_.end()) {
-      return divisor;
-    }
-
-    const std::vector<PrimExpr>& divisors = it->second.divisors;
-    PrimExpr largest_divisor = divisor;
-    for (const auto& other : divisors) {
-      if (CanProveDivisible(other, largest_divisor)) {
-        // New one is bigger, use it
-        largest_divisor = other;
-      } else if (CanProveDivisible(largest_divisor, other)) {
-        // Current is bigger, keep it
-      } else {
-        ErrorLogger(this) << "Iterator appears in multiple terms with 
incompatible divisors "
-                          << tvm::PrettyPrint(largest_divisor) << " and "
-                          << tvm::PrettyPrint(other);
-      }
-    }
-    return largest_divisor;
-  }();
-
-  divisor = modified_divisor;
-
   // First, adding any padding that is on the lower side of a
-  // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0
-  // when iter==0.
-
-  PrimExpr left_pad;
-
-  if (is_zero(base)) {
-    // Padding on the left is unnecessary if base is known to be zero.
-    left_pad = make_zero(base->dtype);
-  } else {
-    left_pad = analyzer_->Simplify(floormod(base, divisor));
-  }
+  // FloorDiv/FloorMod, such that floormod(split - left_pad, divisor) == 0
+  // when iter == 0.
+  PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor));
 
   // Next, adding any padding that is on the upper side of a
-  // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, 
divisor) == 0
-  // when iter==extent.
-
+  // FloorDiv/FloorMod, such that floormod(left_pad + split + right_pad, 
divisor) == 0
+  // when iter == extent.
   PrimExpr right_edge = left_pad + split->extent;
   PrimExpr right_pad;
-
   if (CanProveDivisible(right_edge, divisor)) {
-    // Padding on the right is unnecessary if the extent is a multiple of
-    // the divisor.
     right_pad = 0;
   } else {
     right_pad = analyzer_->Simplify(floormod(-right_edge, divisor));
   }
 
-  if (is_zero(left_pad) && is_zero(right_pad)) {
-    return {split, left_pad};
-  }
-
+  const IterMark& mark = split->source;
   if (update_iterator_padding_) {
     // In the first pass, the primary goal is to collect all the divisors
-    // that may be used for padding.  These will impact the divisor used
-    // to determine padding in the second pass.
-    IterPaddingInfo& info = padded_iter_map_[lookup_key];
-
-    info.divisors.push_back(divisor);
-
-    PrimExpr padded_extent = left_pad + split->extent + right_pad;
-
-    IterSumExpr as_sum({split}, left_pad);
-    IterMark mark(as_sum, padded_extent);
-    IterSplitExpr new_split(mark);
-
-    return {new_split, left_pad};
+    // that may be used for padding. These will impact the divisor used
+    // to determine padding in the second pass. We try add padding to
+    // split's source iteraton mark thus all splits under the same mark will
+    // share the same padded source iteration.
+    auto& info = padded_iter_map_[mark];
+    info.padding_factor =
+        ApproxLeastCommonMultiple(info.padding_factor, divisor * 
split->lower_factor, analyzer_);
+
+    // If the split itself require no padding, return directly.
+    if (is_zero(left_pad) && is_zero(right_pad)) {
+      return {split, 0};
+    }
+
+    // Update padding requirement on the lower side of the source iter mark.
+    PrimExpr mark_left_pad = left_pad * split->lower_factor;
+    info.left_pad = max(info.left_pad, mark_left_pad);

Review Comment:
   need check what happened if different splits take incompatible left pads.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to