This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 64a24c4104 [RELAX][LAYOUT] Support multiple axis paching (#18869)
64a24c4104 is described below

commit 64a24c4104110c83c15ac507effbc18845b2d530
Author: Siva <[email protected]>
AuthorDate: Sun Mar 15 11:49:43 2026 +0530

    [RELAX][LAYOUT] Support multiple axis paching (#18869)
    
    Like OIHW[4o4i] where we can pack multiple axis.
    Helpful while handling complex target layouts.
    This PR covers layout representation and transforms for these.
---
 include/tvm/s_tir/data_layout.h                 |  79 +++++-
 python/tvm/s_tir/data_layout.py                 |   5 +-
 src/s_tir/data_layout.cc                        | 357 +++++++++++++++++-------
 tests/python/s_tir/base/test_tir_data_layout.py |  75 ++++-
 4 files changed, 386 insertions(+), 130 deletions(-)

diff --git a/include/tvm/s_tir/data_layout.h b/include/tvm/s_tir/data_layout.h
index 8d1ad0ca4c..e57ed9bc96 100644
--- a/include/tvm/s_tir/data_layout.h
+++ b/include/tvm/s_tir/data_layout.h
@@ -35,6 +35,8 @@
 #include <utility>
 #include <vector>
 
+#include "tvm/tir/var.h"
+
 namespace tvm {
 namespace tir {
 
@@ -158,6 +160,22 @@ class Layout : public ObjectRef {
     return undef;
   }
 
+  /*!
+   * \brief Packs the Given Array of IterVars into a Single IterVar. Each 
IterVar in the Array
+   *        should represent either a single primal axis or one or more 
subordinate axis
+   * \param iters Array of iter vars to be packed
+   * \return A packed iter var
+   */
+  static IterVar PackIterVar(ffi::Array<IterVar> iters);
+
+  /*!
+   * \brief Unpacks a Packed IterVar into its constituents
+   * \param packed_iter A Packed IterVar containing a single primal axis or 
one or more subordinate
+   *                    axis
+   * \return Constituent IterVars
+   */
+  static ffi::Array<IterVar> UnpackIterVar(IterVar packed_iter);
+
   /*!
    * \brief Returns a sub-layout which is the portion of the object
    *        that starts at dimension \p pos and spans \p len dimensions
@@ -187,9 +205,12 @@ class Layout : public ObjectRef {
   inline size_t ndim_primal() const {
     if (!defined()) return 0;
     size_t ct = 0;
-    for (auto x : operator->()->axes) {
-      if (LayoutAxis::Get(x).IsPrimal()) {
-        ct++;
+    for (auto px : operator->()->axes) {
+      auto iter_vars = UnpackIterVar(px);
+      for (auto x : iter_vars) {
+        if (LayoutAxis::Get(x).IsPrimal()) {
+          ct++;
+        }
       }
     }
     return ct;
@@ -204,10 +225,13 @@ class Layout : public ObjectRef {
     Layout new_src_layout;
     // 1) Find the axis which are missing in the current layout. Make them the 
prefix.
     std::string new_src_layout_str = "";
-    for (auto dst_axis : dst_layout->axes) {
-      if (LayoutAxis::Get(dst_axis).IsPrimal()) {
-        if (!this->Contains(LayoutAxis::Get(dst_axis))) {
-          new_src_layout_str += dst_axis->var->name_hint;
+    for (auto packed_axis : dst_layout->axes) {
+      auto iter_vars = UnpackIterVar(packed_axis);
+      for (auto dst_axis : iter_vars) {
+        if (LayoutAxis::Get(dst_axis).IsPrimal()) {
+          if (!this->Contains(LayoutAxis::Get(dst_axis))) {
+            new_src_layout_str += dst_axis->var->name_hint;
+          }
         }
       }
     }
@@ -221,18 +245,36 @@ class Layout : public ObjectRef {
    * \brief return the index of the input axis.
    *        If it is not found in the layout or the layout is undefined,
    *        return -1.
-   * \param axis the input axis.
+   * \param axis The input axis either a layout axis, or a packed axis
    * \return the index or -1 if not found.
    */
-  inline int32_t IndexOf(const LayoutAxis& axis) const {
+  inline int32_t IndexOf(const std::string& axis) const {
     if (!this->defined()) return -1;
     const auto axes = operator->()->axes;
     for (size_t i = 0; i < axes.size(); ++i) {
-      if (axes[i]->var->name_hint == axis.name()) return 
static_cast<int32_t>(i);
+      if (axes[i]->var->name_hint == axis) return static_cast<int32_t>(i);
     }
     return -1;
   }
 
+  /*!
+   * \brief return the index of the input axis.
+   *        If it is not found in the layout or the layout is undefined,
+   *        return -1.
+   * \param axis the input layout axis.
+   * \return the index or -1 if not found.
+   */
+  inline int32_t IndexOf(const LayoutAxis& axis) const { return 
IndexOf(axis.name()); }
+
+  /*!
+   * \brief return the index of the input axis.
+   *        If it is not found in the layout or the layout is undefined,
+   *        return -1.
+   * \param iter the input iter var.
+   * \return the index or -1 if not found.
+   */
+  inline int32_t IndexOf(const tir::IterVar& iter) const { return 
IndexOf(iter->var->name_hint); }
+
   /*!
    * \brief Get the factor size of the subordinate axis.
    * \param axis the input primal-axis or subordinate-axis.
@@ -249,9 +291,12 @@ class Layout : public ObjectRef {
    */
   bool Contains(const LayoutAxis& axis) const {
     if (!defined()) return false;
-    for (const tir::IterVar var : operator->()->axes) {
-      if (var->var->name_hint == axis.name()) {
-        return true;
+    for (const tir::IterVar packed_var : operator->()->axes) {
+      auto iter_vars = UnpackIterVar(packed_var);
+      for (auto var : iter_vars) {
+        if (var->var->name_hint == axis.name()) {
+          return true;
+        }
       }
     }
     return false;
@@ -265,6 +310,14 @@ class Layout : public ObjectRef {
     return LayoutAxis::Get(axis);
   }
 
+  IterVar PackedAxisAt(int32_t i) const {
+    TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined 
layout.";
+    int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
+    TVM_FFI_ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << 
"Invalid index " << i;
+    const tir::IterVar axis = operator->()->axes[index];
+    return axis;
+  }
+
   /*! \return the string description of the layout */
   inline std::string name() const {
     if (!defined()) return "__undef__";
diff --git a/python/tvm/s_tir/data_layout.py b/python/tvm/s_tir/data_layout.py
index 01dd88e3b0..00d6f0ebb0 100644
--- a/python/tvm/s_tir/data_layout.py
+++ b/python/tvm/s_tir/data_layout.py
@@ -41,7 +41,8 @@ class Layout(Object):
         return _ffi_api.LayoutNdim(self)  # type: ignore
 
     def __contains__(self, axis):
-        return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
+        # Note: We do a weaker check for packed axis assuming layout is valid
+        return not any(bkt in axis for bkt in "[]") and axis in self.name
 
     def __getitem__(self, index):
         if index >= len(self):
@@ -54,7 +55,7 @@ class Layout(Object):
         Parameters
         ----------
         axis : str
-            The axis name, need to be [a-z,A-Z]
+            The axis name, needs to be [a-z,A-Z] or a packed axis
 
         Returns
         -------
diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc
index 267368975d..fb64ee00ff 100644
--- a/src/s_tir/data_layout.cc
+++ b/src/s_tir/data_layout.cc
@@ -24,9 +24,17 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/object.h>
 #include <tvm/s_tir/data_layout.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/var.h>
 
+#include <algorithm>
 #include <cctype>
 
 namespace tvm {
@@ -78,17 +86,30 @@ Layout::Layout(const ffi::Array<IterVar>& axes) {
   auto node = ffi::make_object<LayoutNode>();
   node->axes = axes;
   std::ostringstream repr;
-  for (const IterVar& axis : axes) {
-    if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
-      TVM_FFI_ICHECK_GT(factor->value, 0);
-      repr << factor->value;
+
+  for (const IterVar& packed_axis : axes) {
+    auto unpacked_axes = UnpackIterVar(packed_axis);
+    bool is_grouped = unpacked_axes.size() > 1;
+
+    if (is_grouped) repr << "[";
+    for (const IterVar& axis : unpacked_axes) {
+      if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
+        TVM_FFI_ICHECK_GT(factor->value, 0);
+        repr << factor->value;
+      } else {
+        TVM_FFI_ICHECK(!is_grouped)
+            << "Only Subordinate Axes with extent is allowed within a packed 
dim";
+      }
+      TVM_FFI_ICHECK_EQ(axis->var.get()->name_hint.size(), 1)
+          << "Invalid layout axis " << axis->var.get()->name_hint;
+      char c = axis->var.get()->name_hint.operator std::string()[0];
+      TVM_FFI_ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
+          << "Invalid layout axis " << c;
+      repr << axis->var.get()->name_hint;
     }
-    TVM_FFI_ICHECK_EQ(axis->var.get()->name_hint.size(), 1)
-        << "Invalid layout axis " << axis->var.get()->name_hint;
-    char c = axis->var.get()->name_hint.operator std::string()[0];
-    TVM_FFI_ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << 
"Invalid layout axis " << c;
-    repr << axis->var.get()->name_hint;
+    if (is_grouped) repr << "]";
   }
+
   node->name = repr.str();
   data_ = std::move(node);
 }
@@ -104,46 +125,93 @@ Layout::Layout(const std::string& name, DataType dtype) { 
 // NOLINT(*)
 
   // parse layout string
   int32_t factor = 0;
+  bool in_packing = false;
+  std::vector<IterVar> unpacked_axes;
+
   for (char c : name) {
     if (c >= 'A' && c <= 'Z') {
       TVM_FFI_ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid 
factor size "
                                    << factor << " before dimension " << c;
-      std::string shape_name("_shape");
-      shape_name.insert(0, 1, c);
-      IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), 
Var(std::string(1, c), dtype),
-                   tir::kDataPar);
-      node->axes.push_back(axis);
+      IterVar axis(Range(IntImm(dtype, 0), Var(std::string(1, c), dtype)),
+                   Var(std::string(1, c), dtype), tir::kDataPar);
+      if (!in_packing) {
+        node->axes.push_back(axis);
+      } else {
+        unpacked_axes.push_back(axis);
+      }
     } else if (c >= 'a' && c <= 'z') {
       TVM_FFI_ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid 
factor size "
                                    << factor << " for dimension " << c;
-      IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), 
Var(std::string(1, c), dtype),
+      std::stringstream name;
+      name << factor << c;
+      IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), 
Var(name.str(), dtype),
                    tir::kDataPar);
-      node->axes.push_back(axis);
+      if (!in_packing) {
+        node->axes.push_back(axis);
+      } else {
+        unpacked_axes.push_back(axis);
+      }
       factor = 0;
     } else if (c >= '0' && c <= '9') {
       TVM_FFI_ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is 
adjacent to a number.";
       factor = factor * 10 + c - '0';
+    } else if (c == '[') {
+      TVM_FFI_ICHECK(!in_packing) << "Invalid layout " << name << ": can't do 
nested packing";
+      in_packing = true;
+    } else if (c == ']') {
+      TVM_FFI_ICHECK(in_packing) << "Invalid layout " << name
+                                 << ": encountered ] without matching bracket";
+      TVM_FFI_ICHECK(unpacked_axes.size() > 1)
+          << "Invalid layout " << name << ": found empty/single packed axis";
+      std::stringstream ss;
+      int64_t extent = 1;
+      for (auto& axis : unpacked_axes) {
+        TVM_FFI_ICHECK(axis->dom->extent.as<IntImmNode>())
+            << "Invalid Layout " << name << ": can't have variable sized node("
+            << axis->var->name_hint << ") within a packed axis";
+        auto axis_name = axis->var->name_hint.operator std::string();
+        auto factor = axis->dom->extent.as<IntImm>().value();
+        ss << axis_name;
+        extent = extent * factor->value;
+      }
+      std::string grouped_name = ss.str();
+      IterVar grouped_axis(Range(IntImm(dtype, 0), IntImm(dtype, extent)), 
Var(grouped_name, dtype),
+                           tir::kDataPar);
+      node->axes.push_back(grouped_axis);
+
+      in_packing = false;
+      unpacked_axes.clear();
     } else {
       TVM_FFI_THROW(InternalError) << "Invalid layout " << name;
     }
   }
+  TVM_FFI_ICHECK(in_packing == false)
+      << "Invalid Layout " << name << ": haven't terminated the packing 
sequence";
 
   // validate layout
-  std::vector<bool> exist_axis(256, false);
-  for (const IterVar& v : node->axes) {
-    auto axis_str = v->var.get()->name_hint.operator std::string();
-    TVM_FFI_ICHECK_EQ(axis_str.size(), 1);
-    char axis = axis_str[0];
-    TVM_FFI_ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 
'Z'));
-    exist_axis[axis] = true;
+  std::vector<int> axis_cnt(256, 0);
+  for (const IterVar& pv : node->axes) {
+    for (const IterVar& v : UnpackIterVar(pv)) {
+      auto axis_str = v->var.get()->name_hint.operator std::string();
+      TVM_FFI_ICHECK_EQ(axis_str.size(), 1);
+      char axis = axis_str[0];
+      TVM_FFI_ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 
'Z'));
+      axis_cnt[axis] += 1;
+    }
   }
-  for (const IterVar& v : node->axes) {
-    char axis = v->var.get()->name_hint.operator std::string()[0];
-    if (axis >= 'a' && axis <= 'z') {
-      TVM_FFI_ICHECK(exist_axis[axis - 'a' + 'A'])
-          << "Invalid layout " << name << ": missing axis " << 
std::toupper(axis);
+  for (const IterVar& pv : node->axes) {
+    for (const IterVar& v : UnpackIterVar(pv)) {
+      char axis = v->var.get()->name_hint.operator std::string()[0];
+      if (axis >= 'a' && axis <= 'z') {
+        TVM_FFI_ICHECK(axis_cnt[axis - 'a' + 'A'])
+            << "Invalid layout " << name << ": missing axis " << 
std::toupper(axis);
+        TVM_FFI_ICHECK(axis_cnt[axis] == 1)
+            << "Invalid layout " << name << ": found more than one subordinate 
"
+            << std::toupper(axis);
+      }
     }
   }
+
   data_ = std::move(node);
 }
 
@@ -159,27 +227,46 @@ Layout Layout::SubLayout(size_t pos, size_t len) const {
   return Layout(new_layout);
 }
 
-Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t 
factor) const {
-  if (!defined()) return Layout::Undef();
-  const std::string& name = operator->()->name;
-  const auto axes = operator->()->axes;
-  TVM_FFI_ICHECK(target_pos <= this->ndim())
-      << "Invalid split position " << target_pos << " for layout " << name;
-  TVM_FFI_ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << 
axis;
-  TVM_FFI_ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist 
in " << name;
-  TVM_FFI_ICHECK(!this->Contains(axis.ToSubordinate()))
-      << "Axis " << axis << " has already been split in " << name;
-  TVM_FFI_ICHECK(factor > 0) << "Invalid split size " << factor;
-  ffi::Array<IterVar> new_layout;
-  for (size_t i = 0; i <= this->ndim(); ++i) {
-    if (i == target_pos) {
-      new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)),
-                                   Var(axis.ToSubordinate().name()), 
tir::kDataPar));
+ffi::Array<IterVar> Layout::UnpackIterVar(IterVar packed_iter) {
+  ffi::Array<IterVar> result;
+  int64_t factor = 0, final_factor = 1;
+
+  std::string name(packed_iter->var->name_hint.c_str());
+  DataType dtype = packed_iter->var.dtype();
+
+  for (auto ch : name) {
+    if (ch >= '0' && ch <= '9') {
+      factor = factor * 10 + (ch - '0');
+    } else if (ch >= 'a' && ch <= 'z') {
+      TVM_FFI_ICHECK(factor != 0) << "Invalid Factor Size";
+      result.push_back(IterVar(Range(IntImm(dtype, 0), IntImm(dtype, factor)),
+                               Var(std::string(1, ch), dtype), tir::kDataPar));
+      final_factor *= factor;
+      factor = 0;
+    } else if (ch >= 'A' && ch <= 'Z') {
+      TVM_FFI_ICHECK(factor == 0) << "Can't have non-zero factors for primal 
axis";
+      result.push_back(IterVar(Range(IntImm(dtype, 0), Var(std::string(1, ch), 
dtype)),
+                               Var(std::string(1, ch), dtype), tir::kDataPar));
     }
-    if (i == this->ndim()) break;
-    new_layout.push_back(axes[i]);
   }
-  return Layout(new_layout);
+
+  return result;
+}
+
+IterVar Layout::PackIterVar(ffi::Array<IterVar> iter_vars) {
+  std::stringstream name;
+  size_t extent = 1;
+
+  DataType dtype = iter_vars[0]->dom->extent.as<PrimExpr>().value()->dtype;
+  for (auto itvar : iter_vars) {
+    TVM_FFI_ICHECK(itvar->dom->extent.as<IntImm>())
+        << "Packed Axis can contain only Subordinate Axes";
+    name << itvar->dom->extent.as<IntImm>().value() << itvar->var->name_hint;
+    extent = extent * itvar->dom->extent.as<IntImm>().value()->value;
+  }
+
+  return IterVar(Range(IntImm(dtype, 0), IntImm(dtype, extent)), 
Var(name.str(), dtype),
+                 tir::kDataPar);
 }
 
 int32_t Layout::FactorOf(const LayoutAxis& axis) const {
@@ -188,12 +275,13 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
 
   int32_t factor = 1;
   bool has_sub = false;
-  for (const IterVar& itvar : operator->()->axes) {
-    if (sub == LayoutAxis::Get(itvar)) {
-      has_sub = true;
-      int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
-      TVM_FFI_ICHECK(val);
-      factor *= val;
+  for (const IterVar& packed_itvar : operator->()->axes) {
+    for (auto itvar : UnpackIterVar(packed_itvar)) {
+      if (sub == LayoutAxis::Get(itvar)) {
+        has_sub = true;
+        int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
+        factor *= val;
+      }
     }
   }
   factor = has_sub ? factor : -1;
@@ -218,63 +306,120 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>* 
index_rule, ffi::Array<PrimExpr>*
     return false;
   }
 
-  for (size_t i = 0; i < dst_layout.ndim(); ++i) {
-    const auto& store_axis = dst_layout[i];
-    const IterVar& store_axis_impl = dst_layout->axes[i];
-    PrimExpr index_store(0);
-
-    for (size_t j = 0; j < src_layout.ndim(); ++j) {
-      const auto& orig_axis = src_layout[j];
-      const IterVar& orig_axis_impl = src_layout->axes[j];
-      if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
-        if (orig_axis.IsPrimal()) {
-          PrimExpr orig_var = orig_axis_impl->var;
-          const int32_t factor = src_layout.FactorOf(orig_axis);
-          if (factor > 0) {
-            orig_var = orig_var * factor;
-          }
-          index_store = index_store + orig_var;
-        } else {
-          PrimExpr factor(1);
-          for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
-            if (LayoutAxis::Get(orig_axis_impl) == 
LayoutAxis::Get(src_layout->axes[k])) {
-              factor = factor * src_layout->axes[k]->dom->extent;
+  std::vector<bool> exists(128, false);
+  PrimExpr norm_indexes[128];
+  for (auto& it : norm_indexes) it = PrimExpr(0);
+
+  for (size_t i = 0; i < src_layout.ndim(); i++) {
+    auto factor = src_layout.PackedAxisAt(i)->dom->extent;
+    auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i));
+
+    if (src_unpacked_axes.size() == 1 && 
LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) {
+      const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]);
+      int64_t offset = src_layout.FactorOf(prim_axis);
+      if (offset == -1)
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] + 
src_layout.PackedAxisAt(i);
+      else
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] +
+            src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis);
+      exists[prim_axis.name()[0]] = true;
+    } else {
+      int64_t value = 1;
+      std::vector<int> index_divs(src_unpacked_axes.size());
+      for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+        index_divs[j] = value;
+        const auto* extent = 
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+        TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset 
Calculation";
+        index_divs.push_back(value);
+        value = value * extent->value;
+      }
+      std::reverse(index_divs.begin(), index_divs.end());
+
+      for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+        const int extent = 
src_unpacked_axes[j]->dom->extent.as<IntImmNode>()->value;
+        const LayoutAxis& store_axis_impl = 
LayoutAxis::Get(src_unpacked_axes[j]);
+        const LayoutAxis& sub_axis = store_axis_impl.ToSubordinate(); /* Not 
Needed */
+        const LayoutAxis& prim_axis = store_axis_impl.ToPrimal();
+
+        PrimExpr factor_ij = indexdiv(src_layout.PackedAxisAt(i), 
index_divs[j]);
+        if (j != 0) factor_ij = indexmod(factor_ij, extent);
+
+        for (size_t k = i; k < src_layout.ndim(); k++) {
+          size_t l = 0;
+          if (k == i) l = j + 1;
+
+          auto inter_unpacked_axes = 
Layout::UnpackIterVar(src_layout.PackedAxisAt(k));
+          for (; l < inter_unpacked_axes.size(); l++) {
+            const LayoutAxis& axis = LayoutAxis::Get(inter_unpacked_axes[l]);
+            if (axis == sub_axis) {
+              const auto* sub_extent = 
inter_unpacked_axes[l]->dom->extent.as<IntImmNode>();
+              TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for 
Offset Calculation";
+              factor_ij = factor_ij * IntImm(sub_extent->dtype, 
sub_extent->value);
             }
           }
-          index_store = index_store + orig_axis_impl->var * factor;
         }
+
+        norm_indexes[prim_axis.name()[0] - 'A'] =
+            norm_indexes[prim_axis.name()[0] - 'A'] + factor_ij;
       }
     }
-    if (tir::is_zero(index_store)) {
-      LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << 
dst_layout.name()
-                   << "' is not convertible.";
-      return false;
-    }
+  }
+
+  arith::Analyzer ana;
 
-    PrimExpr shape_store = index_store;
-    if (store_axis.IsPrimal()) {
-      const int32_t factor = dst_layout.FactorOf(store_axis);
-      if (factor > 0) {
-        shape_store = shapediv(index_store, PrimExpr(factor));
-        index_store = indexdiv(index_store, PrimExpr(factor));
+  for (size_t i = 0; i < dst_layout.ndim(); i++) {
+    const auto dst_unpacked_axes = 
Layout::UnpackIterVar(dst_layout.PackedAxisAt(i));
+
+    if (dst_unpacked_axes.size() == 1 && 
LayoutAxis::Get(dst_unpacked_axes[0]).IsPrimal()) {
+      const auto& prim_axis = LayoutAxis::Get(dst_unpacked_axes[0]);
+      if (!exists[prim_axis.name()[0]]) return false;
+      int64_t offset = dst_layout.FactorOf(prim_axis);
+      if (offset != -1) {
+        index_rule->push_back(
+            indexdiv(norm_indexes[prim_axis.name()[0] - 'A'], 
dst_layout.FactorOf(prim_axis)));
+        shape_rule->push_back(
+            indexdiv(norm_indexes[prim_axis.name()[0] - 'A'] + 
(dst_layout.FactorOf(prim_axis) - 1),
+                     dst_layout.FactorOf(prim_axis)));
+      } else {
+        index_rule->push_back(norm_indexes[prim_axis.name()[0] - 'A']);
+        shape_rule->push_back(norm_indexes[prim_axis.name()[0] - 'A']);
       }
     } else {
-      PrimExpr stride(1);
-      PrimExpr factor(1);
-      for (size_t j = i; j < dst_layout.ndim(); ++j) {
-        if (LayoutAxis::Get(store_axis_impl) == 
LayoutAxis::Get(dst_layout->axes[j])) {
-          stride = stride * dst_layout->axes[j]->dom->extent;
-          if (j > i) {
-            factor = factor * dst_layout->axes[j]->dom->extent;
+      PrimExpr factor(0);
+      for (size_t j = 0; j < dst_unpacked_axes.size(); j++) {
+        const auto& prim_axis = 
LayoutAxis::Get(dst_unpacked_axes[j]).ToPrimal();
+        const auto& sub_axis = 
LayoutAxis::Get(dst_unpacked_axes[j]).ToSubordinate();
+        const auto* extent = 
dst_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+        TVM_FFI_ICHECK(extent) << "Expected extent to be IntImmNode";
+
+        size_t divfactor = 1;
+        for (size_t k = i; k < dst_layout.ndim(); k++) {
+          size_t l = 0;
+          if (k == i) l = j + 1;
+
+          const auto inter_unpacked_axes = 
Layout::UnpackIterVar(dst_layout.PackedAxisAt(k));
+          for (; l < inter_unpacked_axes.size(); l++) {
+            const auto& axis = LayoutAxis::Get(inter_unpacked_axes[l]);
+            if (sub_axis == axis) {
+              const auto* sub_extent = 
inter_unpacked_axes[l]->dom->extent.as<IntImmNode>();
+              TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for 
Offset Calculation";
+              divfactor = divfactor * sub_extent->value;
+            }
           }
         }
+
+        factor = factor + indexmod(indexdiv(norm_indexes[prim_axis.name()[0] - 
'A'], divfactor),
+                                   extent->value);
+        for (size_t k = j + 1; k < dst_unpacked_axes.size(); k++) {
+          factor = factor * 
dst_unpacked_axes[k]->dom->extent.as<IntImm>().value();
+        }
       }
-      shape_store = indexdiv(indexmod(index_store, stride), factor);
-      index_store = indexdiv(indexmod(index_store, stride), factor);
+      ana.Simplify(factor);
+      index_rule->push_back(factor);
+      shape_rule->push_back(factor);
     }
-
-    index_rule->push_back(index_store);
-    shape_rule->push_back(shape_store);
   }
 
   std::stringstream ss;
@@ -289,7 +434,7 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>* index_rule, 
ffi::Array<PrimExpr>*
     ss << r << ", ";
   }
   ss << "]" << std::endl;
-  VLOG(1) << std::endl << ss.str();
+  VLOG(1) << ss.str() << std::endl;
 
   return true;
 }
@@ -341,7 +486,8 @@ inline ffi::Array<PrimExpr> TransformShape(const 
ffi::Array<PrimExpr>& src_shape
   for (size_t i = 0; i < src_shape.size(); ++i) {
     PrimExpr orig_shape = src_shape[i];
     IterVar orig_axis = src_axis[i];
-    if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
+    auto layout = Layout::UnpackIterVar(orig_axis);
+    if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) {
       if (orig_shape.defined()) {
         const auto* orig_shape_const = orig_shape.as<IntImmNode>();
         const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
@@ -366,7 +512,8 @@ inline ffi::Array<PrimExpr> TransformShape(const 
ffi::Array<PrimExpr>& src_shape
   for (size_t i = 0; i < transform_rule.size(); ++i) {
     PrimExpr rule = transform_rule[i];
     IterVar axis = target_axis[i];
-    if (!LayoutAxis::Get(axis).IsPrimal()) {
+    auto layout = Layout::UnpackIterVar(axis);
+    if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) {
       result.push_back(axis->dom->extent);
     } else {
       result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
@@ -435,9 +582,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef()
       .def("s_tir.Layout", [](std::string name, DataType dtype) { return 
Layout(name, dtype); })
       .def("s_tir.LayoutIndexOf",
-           [](Layout layout, std::string axis) -> int {
-             return layout.IndexOf(LayoutAxis::Get(axis));
-           })
+           [](Layout layout, std::string axis) -> int { return 
layout.IndexOf(axis); })
       .def("s_tir.LayoutFactorOf",
            [](Layout layout, std::string axis) -> int {
              return layout.FactorOf(LayoutAxis::Get(axis));
@@ -445,8 +590,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def("s_tir.LayoutNdim", [](Layout layout) -> int { return 
layout.ndim(); })
       .def("s_tir.LayoutGetItem",
            [](Layout layout, int idx) -> std::string {
-             const LayoutAxis& axis = layout[idx];
-             return axis.name();
+             const auto& axis = layout.PackedAxisAt(idx);
+             return axis->var->name_hint;
            })
       .def("s_tir.BijectiveLayout",
            [](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
diff --git a/tests/python/s_tir/base/test_tir_data_layout.py 
b/tests/python/s_tir/base/test_tir_data_layout.py
index 8f57a99b3e..32c7f9c9d1 100644
--- a/tests/python/s_tir/base/test_tir_data_layout.py
+++ b/tests/python/s_tir/base/test_tir_data_layout.py
@@ -19,7 +19,8 @@
 import pytest
 
 import tvm
-import tvm.error
+import tvm.testing
+from tvm.error import InternalError
 from tvm.topi.utils import get_const_tuple
 
 
@@ -36,7 +37,7 @@ def test_layout():
     assert layout.index_of("C") == 1
     assert layout.index_of("H") == 2
     assert layout.index_of("W") == 3
-    assert layout.index_of("c") == 4
+    assert layout.index_of("16c") == 4
     assert layout.index_of("O") == -1
 
     assert "N" in layout
@@ -50,8 +51,50 @@ def test_layout():
     assert layout[1] == "C"
     assert layout[2] == "H"
     assert layout[3] == "W"
-    assert layout[4] == "c"
-    assert layout[-1] == "c"
+    assert layout[4] == "16c"
+
+    layout = tvm.s_tir.layout("OIHW[4o4i]")
+    assert layout is not None
+    assert isinstance(layout, tvm.s_tir.Layout)
+
+    assert layout.factor_of("o") == 4
+    assert layout.factor_of("i") == 4
+    assert layout.factor_of("H") == -1
+    assert layout.factor_of("W") == -1
+    assert layout.factor_of("N") == -1
+
+    assert layout.index_of("O") == 0
+    assert layout.index_of("I") == 1
+    assert layout.index_of("H") == 2
+    assert layout.index_of("W") == 3
+    assert layout.index_of("4o4i") == 4
+    assert layout.index_of("i") == -1
+    assert layout.index_of("o") == -1
+
+    assert "O" in layout
+    assert "I" in layout
+    assert "H" in layout
+    assert "W" in layout
+    assert "4o4i" in layout
+    assert "i" in layout
+    assert "o" in layout
+
+    assert layout[0] == "O"
+    assert layout[1] == "I"
+    assert layout[2] == "H"
+    assert layout[3] == "W"
+    assert layout[4] == "4o4i"
+
+    with pytest.raises(InternalError):
+        layout = tvm.s_tir.layout("[N4o]C")
+    with pytest.raises(InternalError):
+        layout = tvm.s_tir.layout("[O4o]")
+    with pytest.raises(InternalError):
+        layout = tvm.s_tir.layout("C4o")
+    with pytest.raises(InternalError):
+        layout = tvm.s_tir.layout("OI[4o4i][]")
+    with pytest.raises(InternalError):
+        layout = tvm.s_tir.layout("C4c[4c]")
 
 
 def test_layout_dtype():
@@ -85,6 +128,8 @@ def test_bilayout_convertible():
     assert tvm.s_tir.bijective_layout("__undef__", "__undef__") is None
     assert tvm.s_tir.bijective_layout("", "NCHW") is None
     assert tvm.s_tir.bijective_layout("NCHW", "") is None
+    assert tvm.s_tir.bijective_layout("OIHW", "OIHW[4o4i]") is not None
+    assert tvm.s_tir.bijective_layout("OIHW[2o4i]", "OIHW") is not None
     assert tvm.s_tir.bijective_layout("", "") is None
     # convertible
     assert tvm.s_tir.bijective_layout("NCHW", "NCHW16c") is not None
@@ -100,6 +145,14 @@ def test_bilayout_shape():
     src_shape = bilayout.backward_shape(dst_shape)
     assert get_const_tuple(src_shape) == (1, 32, 7, 7)
 
+    bilayout = tvm.s_tir.bijective_layout("OIHW", "OIHW[4o4i]")
+
+    dst_shape = bilayout.forward_shape((64, 28, 7, 7))
+    assert get_const_tuple(dst_shape) == (16, 7, 7, 7, 16)
+
+    src_shape = bilayout.backward_shape((2, 11, 4, 4, 16))
+    assert get_const_tuple(src_shape) == (8, 44, 4, 4)
+
 
 def test_bilayout_index():
     bilayout = tvm.s_tir.bijective_layout("NCHW", "NCHW16c")
@@ -110,10 +163,14 @@ def test_bilayout_index():
     src_index = bilayout.backward_index([0, 1, 6, 6, 2])
     assert get_const_tuple(src_index) == (0, 18, 6, 6)
 
+    bilayout = tvm.s_tir.bijective_layout("OIHW", "OIHW[4o4i]")
+
+    dst_index = bilayout.forward_index((63, 29, 7, 7))
+    assert get_const_tuple(dst_index) == (15, 7, 7, 7, 13)
+
+    src_index = bilayout.backward_index((4, 7, 4, 4, 13))
+    assert get_const_tuple(src_index) == (19, 29, 4, 4)
+
 
 if __name__ == "__main__":
-    test_layout()
-    test_layout_dtype()
-    test_bilayout_convertible()
-    test_bilayout_shape()
-    test_bilayout_index()
+    tvm.testing.main()

Reply via email to