gemini-code-assist[bot] commented on code in PR #18869:
URL: https://github.com/apache/tvm/pull/18869#discussion_r2881702073


##########
include/tvm/s_tir/data_layout.h:
##########
@@ -158,6 +160,22 @@ class Layout : public ObjectRef {
     return undef;
   }
 
+  /*!
+   * \brief Packs the Given Array of IterVars into a Single IterVar. Each 
IterVar in the Array
+   *        should represent either a single primal axis or one or more 
subordinate axis
+   * \param iters Array of iter vars to be packed
+   * \return A packed iter var
+   */

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The documentation for `PackIterVar` states that it can pack a "single primal 
axis or one or more subordinate axis". However, the implementation in 
`src/s_tir/data_layout.cc` only supports packing subordinate axes and includes 
a check `Packed Axis can contain only Subordinate Axes`. This discrepancy 
should be resolved. Please update the documentation to reflect the 
implementation's behavior.



##########
src/s_tir/data_layout.cc:
##########
@@ -218,63 +306,120 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>* 
index_rule, ffi::Array<PrimExpr>*
     return false;
   }
 
-  for (size_t i = 0; i < dst_layout.ndim(); ++i) {
-    const auto& store_axis = dst_layout[i];
-    const IterVar& store_axis_impl = dst_layout->axes[i];
-    PrimExpr index_store(0);
-
-    for (size_t j = 0; j < src_layout.ndim(); ++j) {
-      const auto& orig_axis = src_layout[j];
-      const IterVar& orig_axis_impl = src_layout->axes[j];
-      if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
-        if (orig_axis.IsPrimal()) {
-          PrimExpr orig_var = orig_axis_impl->var;
-          const int32_t factor = src_layout.FactorOf(orig_axis);
-          if (factor > 0) {
-            orig_var = orig_var * factor;
-          }
-          index_store = index_store + orig_var;
-        } else {
-          PrimExpr factor(1);
-          for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
-            if (LayoutAxis::Get(orig_axis_impl) == 
LayoutAxis::Get(src_layout->axes[k])) {
-              factor = factor * src_layout->axes[k]->dom->extent;
+  std::vector<bool> exists(128, false);
+  PrimExpr norm_indexes[128];
+  for (auto& it : norm_indexes) it = PrimExpr(0);
+
+  for (size_t i = 0; i < src_layout.ndim(); i++) {
+    auto factor = src_layout.PackedAxisAt(i)->dom->extent;
+    auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i));
+
+    if (src_unpacked_axes.size() == 1 && 
LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) {
+      const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]);
+      int64_t offset = src_layout.FactorOf(prim_axis);
+      if (offset == -1)
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] + 
src_layout.PackedAxisAt(i);
+      else
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] +
+            src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis);
+      exists[prim_axis.name()[0]] = true;
+    } else {
+      int64_t value = 1;
+      std::vector<int> index_divs(src_unpacked_axes.size());
+      for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+        index_divs[j] = value;
+        const auto* extent = 
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+        TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset 
Calculation";
+        index_divs.push_back(value);
+        value = value * extent->value;
+      }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   There appears to be a bug in this loop. The vector `index_divs` is 
initialized with a specific size at line 330. Inside the loop, both 
`index_divs[j] = value;` and `index_divs.push_back(value);` are called. This 
will result in a vector of double the intended size with incorrect contents, 
which will likely cause incorrect behavior in layout transformations. The 
`push_back` call should probably be removed.
   
   ```c
         for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
           index_divs[j] = value;
           const auto* extent = 
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
           TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset 
Calculation";
           value = value * extent->value;
         }
   ```



##########
python/tvm/s_tir/data_layout.py:
##########
@@ -41,7 +41,8 @@ def __len__(self):
         return _ffi_api.LayoutNdim(self)  # type: ignore
 
     def __contains__(self, axis):
-        return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
+        # Note: We do a weaker check for packed axis assuming layout is valid
+        return not any(bkt in axis for bkt in "[]") and axis in self.name

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The current implementation of `__contains__` using a substring check (`axis 
in self.name`) can lead to incorrect results and is inconsistent with 
`index_of`. For example, for a layout `NCHW16c`, `__contains__("16")` would 
return `True`, which is incorrect as "16" is not an axis. Similarly, for 
`OIHW[4o4i]`, `"i" in layout` is `True` while `layout.index_of("i")` is `-1`.
   
   A more robust implementation should check against the actual axes of the 
layout. The current substring check is too broad and gives false positives.



##########
src/s_tir/data_layout.cc:
##########
@@ -218,63 +306,120 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>* 
index_rule, ffi::Array<PrimExpr>*
     return false;
   }
 
-  for (size_t i = 0; i < dst_layout.ndim(); ++i) {
-    const auto& store_axis = dst_layout[i];
-    const IterVar& store_axis_impl = dst_layout->axes[i];
-    PrimExpr index_store(0);
-
-    for (size_t j = 0; j < src_layout.ndim(); ++j) {
-      const auto& orig_axis = src_layout[j];
-      const IterVar& orig_axis_impl = src_layout->axes[j];
-      if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
-        if (orig_axis.IsPrimal()) {
-          PrimExpr orig_var = orig_axis_impl->var;
-          const int32_t factor = src_layout.FactorOf(orig_axis);
-          if (factor > 0) {
-            orig_var = orig_var * factor;
-          }
-          index_store = index_store + orig_var;
-        } else {
-          PrimExpr factor(1);
-          for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
-            if (LayoutAxis::Get(orig_axis_impl) == 
LayoutAxis::Get(src_layout->axes[k])) {
-              factor = factor * src_layout->axes[k]->dom->extent;
+  std::vector<bool> exists(128, false);
+  PrimExpr norm_indexes[128];
+  for (auto& it : norm_indexes) it = PrimExpr(0);
+
+  for (size_t i = 0; i < src_layout.ndim(); i++) {
+    auto factor = src_layout.PackedAxisAt(i)->dom->extent;
+    auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i));
+
+    if (src_unpacked_axes.size() == 1 && 
LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) {
+      const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]);
+      int64_t offset = src_layout.FactorOf(prim_axis);
+      if (offset == -1)
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] + 
src_layout.PackedAxisAt(i);
+      else
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] +
+            src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis);
+      exists[prim_axis.name()[0]] = true;
+    } else {
+      int64_t value = 1;
+      std::vector<int> index_divs(src_unpacked_axes.size());
+      for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+        index_divs[j] = value;
+        const auto* extent = 
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+        TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset 
Calculation";
+        index_divs.push_back(value);
+        value = value * extent->value;
+      }
+      std::reverse(index_divs.begin(), index_divs.end());
+
+      for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+        const int extent = 
src_unpacked_axes[j]->dom->extent.as<IntImmNode>()->value;
+        const LayoutAxis& store_axis_impl = 
LayoutAxis::Get(src_unpacked_axes[j]);
+        const LayoutAxis& sub_axis = store_axis_impl.ToSubordinate(); /* Not 
Needed */

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The comment `/* Not Needed */` suggests this variable `sub_axis` is unused. 
If it's indeed not needed, it should be removed to improve code clarity and 
avoid confusion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to