This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 64a24c4104 [RELAX][LAYOUT] Support multiple axis paching (#18869)
64a24c4104 is described below
commit 64a24c4104110c83c15ac507effbc18845b2d530
Author: Siva <[email protected]>
AuthorDate: Sun Mar 15 11:49:43 2026 +0530
[RELAX][LAYOUT] Support multiple axis paching (#18869)
Like OIHW[4o4i] where we can pack multiple axis.
Helpful while handling complex target layouts.
This PR covers layout representation and transforms for these.
---
include/tvm/s_tir/data_layout.h | 79 +++++-
python/tvm/s_tir/data_layout.py | 5 +-
src/s_tir/data_layout.cc | 357 +++++++++++++++++-------
tests/python/s_tir/base/test_tir_data_layout.py | 75 ++++-
4 files changed, 386 insertions(+), 130 deletions(-)
diff --git a/include/tvm/s_tir/data_layout.h b/include/tvm/s_tir/data_layout.h
index 8d1ad0ca4c..e57ed9bc96 100644
--- a/include/tvm/s_tir/data_layout.h
+++ b/include/tvm/s_tir/data_layout.h
@@ -35,6 +35,8 @@
#include <utility>
#include <vector>
+#include "tvm/tir/var.h"
+
namespace tvm {
namespace tir {
@@ -158,6 +160,22 @@ class Layout : public ObjectRef {
return undef;
}
+ /*!
+ * \brief Packs the Given Array of IterVars into a Single IterVar. Each
IterVar in the Array
+ * should represent either a single primal axis or one or more
subordinate axis
+ * \param iters Array of iter vars to be packed
+ * \return A packed iter var
+ */
+ static IterVar PackIterVar(ffi::Array<IterVar> iters);
+
+ /*!
+ * \brief Unpacks a Packed IterVar into its constituents
+ * \param packed_iter A Packed IterVar containing a single primal axis or
one or more subordinate
+ * axis
+ * \return Constituent IterVars
+ */
+ static ffi::Array<IterVar> UnpackIterVar(IterVar packed_iter);
+
/*!
* \brief Returns a sub-layout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
@@ -187,9 +205,12 @@ class Layout : public ObjectRef {
inline size_t ndim_primal() const {
if (!defined()) return 0;
size_t ct = 0;
- for (auto x : operator->()->axes) {
- if (LayoutAxis::Get(x).IsPrimal()) {
- ct++;
+ for (auto px : operator->()->axes) {
+ auto iter_vars = UnpackIterVar(px);
+ for (auto x : iter_vars) {
+ if (LayoutAxis::Get(x).IsPrimal()) {
+ ct++;
+ }
}
}
return ct;
@@ -204,10 +225,13 @@ class Layout : public ObjectRef {
Layout new_src_layout;
// 1) Find the axis which are missing in the current layout. Make them the
prefix.
std::string new_src_layout_str = "";
- for (auto dst_axis : dst_layout->axes) {
- if (LayoutAxis::Get(dst_axis).IsPrimal()) {
- if (!this->Contains(LayoutAxis::Get(dst_axis))) {
- new_src_layout_str += dst_axis->var->name_hint;
+ for (auto packed_axis : dst_layout->axes) {
+ auto iter_vars = UnpackIterVar(packed_axis);
+ for (auto dst_axis : iter_vars) {
+ if (LayoutAxis::Get(dst_axis).IsPrimal()) {
+ if (!this->Contains(LayoutAxis::Get(dst_axis))) {
+ new_src_layout_str += dst_axis->var->name_hint;
+ }
}
}
}
@@ -221,18 +245,36 @@ class Layout : public ObjectRef {
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
* return -1.
- * \param axis the input axis.
+ * \param axis The input axis either a layout axis, or a packed axis
* \return the index or -1 if not found.
*/
- inline int32_t IndexOf(const LayoutAxis& axis) const {
+ inline int32_t IndexOf(const std::string& axis) const {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
- if (axes[i]->var->name_hint == axis.name()) return
static_cast<int32_t>(i);
+ if (axes[i]->var->name_hint == axis) return static_cast<int32_t>(i);
}
return -1;
}
+ /*!
+ * \brief return the index of the input axis.
+ * If it is not found in the layout or the layout is undefined,
+ * return -1.
+ * \param axis the input layout axis.
+ * \return the index or -1 if not found.
+ */
+ inline int32_t IndexOf(const LayoutAxis& axis) const { return
IndexOf(axis.name()); }
+
+ /*!
+ * \brief return the index of the input axis.
+ * If it is not found in the layout or the layout is undefined,
+ * return -1.
+ * \param iter the input iter var.
+ * \return the index or -1 if not found.
+ */
+ inline int32_t IndexOf(const tir::IterVar& iter) const { return
IndexOf(iter->var->name_hint); }
+
/*!
* \brief Get the factor size of the subordinate axis.
* \param axis the input primal-axis or subordinate-axis.
@@ -249,9 +291,12 @@ class Layout : public ObjectRef {
*/
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
- for (const tir::IterVar var : operator->()->axes) {
- if (var->var->name_hint == axis.name()) {
- return true;
+ for (const tir::IterVar packed_var : operator->()->axes) {
+ auto iter_vars = UnpackIterVar(packed_var);
+ for (auto var : iter_vars) {
+ if (var->var->name_hint == axis.name()) {
+ return true;
+ }
}
}
return false;
@@ -265,6 +310,14 @@ class Layout : public ObjectRef {
return LayoutAxis::Get(axis);
}
+ IterVar PackedAxisAt(int32_t i) const {
+ TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined
layout.";
+ int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
+ TVM_FFI_ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) <<
"Invalid index " << i;
+ const tir::IterVar axis = operator->()->axes[index];
+ return axis;
+ }
+
/*! \return the string description of the layout */
inline std::string name() const {
if (!defined()) return "__undef__";
diff --git a/python/tvm/s_tir/data_layout.py b/python/tvm/s_tir/data_layout.py
index 01dd88e3b0..00d6f0ebb0 100644
--- a/python/tvm/s_tir/data_layout.py
+++ b/python/tvm/s_tir/data_layout.py
@@ -41,7 +41,8 @@ class Layout(Object):
return _ffi_api.LayoutNdim(self) # type: ignore
def __contains__(self, axis):
- return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
+ # Note: We do a weaker check for packed axis assuming layout is valid
+ return not any(bkt in axis for bkt in "[]") and axis in self.name
def __getitem__(self, index):
if index >= len(self):
@@ -54,7 +55,7 @@ class Layout(Object):
Parameters
----------
axis : str
- The axis name, need to be [a-z,A-Z]
+ The axis name, needs to be [a-z,A-Z] or a packed axis
Returns
-------
diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc
index 267368975d..fb64ee00ff 100644
--- a/src/s_tir/data_layout.cc
+++ b/src/s_tir/data_layout.cc
@@ -24,9 +24,17 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/object.h>
#include <tvm/s_tir/data_layout.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/var.h>
+#include <algorithm>
#include <cctype>
namespace tvm {
@@ -78,17 +86,30 @@ Layout::Layout(const ffi::Array<IterVar>& axes) {
auto node = ffi::make_object<LayoutNode>();
node->axes = axes;
std::ostringstream repr;
- for (const IterVar& axis : axes) {
- if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
- TVM_FFI_ICHECK_GT(factor->value, 0);
- repr << factor->value;
+
+ for (const IterVar& packed_axis : axes) {
+ auto unpacked_axes = UnpackIterVar(packed_axis);
+ bool is_grouped = unpacked_axes.size() > 1;
+
+ if (is_grouped) repr << "[";
+ for (const IterVar& axis : unpacked_axes) {
+ if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
+ TVM_FFI_ICHECK_GT(factor->value, 0);
+ repr << factor->value;
+ } else {
+ TVM_FFI_ICHECK(!is_grouped)
+ << "Only Subordinate Axes with extent is allowed within a packed
dim";
+ }
+ TVM_FFI_ICHECK_EQ(axis->var.get()->name_hint.size(), 1)
+ << "Invalid layout axis " << axis->var.get()->name_hint;
+ char c = axis->var.get()->name_hint.operator std::string()[0];
+ TVM_FFI_ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
+ << "Invalid layout axis " << c;
+ repr << axis->var.get()->name_hint;
}
- TVM_FFI_ICHECK_EQ(axis->var.get()->name_hint.size(), 1)
- << "Invalid layout axis " << axis->var.get()->name_hint;
- char c = axis->var.get()->name_hint.operator std::string()[0];
- TVM_FFI_ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) <<
"Invalid layout axis " << c;
- repr << axis->var.get()->name_hint;
+ if (is_grouped) repr << "]";
}
+
node->name = repr.str();
data_ = std::move(node);
}
@@ -104,46 +125,93 @@ Layout::Layout(const std::string& name, DataType dtype) {
// NOLINT(*)
// parse layout string
int32_t factor = 0;
+ bool in_packing = false;
+ std::vector<IterVar> unpacked_axes;
+
for (char c : name) {
if (c >= 'A' && c <= 'Z') {
TVM_FFI_ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid
factor size "
<< factor << " before dimension " << c;
- std::string shape_name("_shape");
- shape_name.insert(0, 1, c);
- IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)),
Var(std::string(1, c), dtype),
- tir::kDataPar);
- node->axes.push_back(axis);
+ IterVar axis(Range(IntImm(dtype, 0), Var(std::string(1, c), dtype)),
+ Var(std::string(1, c), dtype), tir::kDataPar);
+ if (!in_packing) {
+ node->axes.push_back(axis);
+ } else {
+ unpacked_axes.push_back(axis);
+ }
} else if (c >= 'a' && c <= 'z') {
TVM_FFI_ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid
factor size "
<< factor << " for dimension " << c;
- IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)),
Var(std::string(1, c), dtype),
+ std::stringstream name;
+ name << factor << c;
+ IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)),
Var(name.str(), dtype),
tir::kDataPar);
- node->axes.push_back(axis);
+ if (!in_packing) {
+ node->axes.push_back(axis);
+ } else {
+ unpacked_axes.push_back(axis);
+ }
factor = 0;
} else if (c >= '0' && c <= '9') {
TVM_FFI_ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is
adjacent to a number.";
factor = factor * 10 + c - '0';
+ } else if (c == '[') {
+ TVM_FFI_ICHECK(!in_packing) << "Invalid layout " << name << ": can't do
nested packing";
+ in_packing = true;
+ } else if (c == ']') {
+ TVM_FFI_ICHECK(in_packing) << "Invalid layout " << name
+ << ": encountered ] without matching bracket";
+ TVM_FFI_ICHECK(unpacked_axes.size() > 1)
+ << "Invalid layout " << name << ": found empty/single packed axis";
+ std::stringstream ss;
+ int64_t extent = 1;
+ for (auto& axis : unpacked_axes) {
+ TVM_FFI_ICHECK(axis->dom->extent.as<IntImmNode>())
+ << "Invalid Layout " << name << ": can't have variable sized node("
+ << axis->var->name_hint << ") within a packed axis";
+ auto axis_name = axis->var->name_hint.operator std::string();
+ auto factor = axis->dom->extent.as<IntImm>().value();
+ ss << axis_name;
+ extent = extent * factor->value;
+ }
+ std::string grouped_name = ss.str();
+ IterVar grouped_axis(Range(IntImm(dtype, 0), IntImm(dtype, extent)),
Var(grouped_name, dtype),
+ tir::kDataPar);
+ node->axes.push_back(grouped_axis);
+
+ in_packing = false;
+ unpacked_axes.clear();
} else {
TVM_FFI_THROW(InternalError) << "Invalid layout " << name;
}
}
+ TVM_FFI_ICHECK(in_packing == false)
+ << "Invalid Layout " << name << ": haven't terminated the packing
sequence";
// validate layout
- std::vector<bool> exist_axis(256, false);
- for (const IterVar& v : node->axes) {
- auto axis_str = v->var.get()->name_hint.operator std::string();
- TVM_FFI_ICHECK_EQ(axis_str.size(), 1);
- char axis = axis_str[0];
- TVM_FFI_ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <=
'Z'));
- exist_axis[axis] = true;
+ std::vector<int> axis_cnt(256, 0);
+ for (const IterVar& pv : node->axes) {
+ for (const IterVar& v : UnpackIterVar(pv)) {
+ auto axis_str = v->var.get()->name_hint.operator std::string();
+ TVM_FFI_ICHECK_EQ(axis_str.size(), 1);
+ char axis = axis_str[0];
+ TVM_FFI_ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <=
'Z'));
+ axis_cnt[axis] += 1;
+ }
}
- for (const IterVar& v : node->axes) {
- char axis = v->var.get()->name_hint.operator std::string()[0];
- if (axis >= 'a' && axis <= 'z') {
- TVM_FFI_ICHECK(exist_axis[axis - 'a' + 'A'])
- << "Invalid layout " << name << ": missing axis " <<
std::toupper(axis);
+ for (const IterVar& pv : node->axes) {
+ for (const IterVar& v : UnpackIterVar(pv)) {
+ char axis = v->var.get()->name_hint.operator std::string()[0];
+ if (axis >= 'a' && axis <= 'z') {
+ TVM_FFI_ICHECK(axis_cnt[axis - 'a' + 'A'])
+ << "Invalid layout " << name << ": missing axis " <<
std::toupper(axis);
+ TVM_FFI_ICHECK(axis_cnt[axis] == 1)
+ << "Invalid layout " << name << ": found more than one subordinate
"
+ << std::toupper(axis);
+ }
}
}
+
data_ = std::move(node);
}
@@ -159,27 +227,46 @@ Layout Layout::SubLayout(size_t pos, size_t len) const {
return Layout(new_layout);
}
-Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t
factor) const {
- if (!defined()) return Layout::Undef();
- const std::string& name = operator->()->name;
- const auto axes = operator->()->axes;
- TVM_FFI_ICHECK(target_pos <= this->ndim())
- << "Invalid split position " << target_pos << " for layout " << name;
- TVM_FFI_ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " <<
axis;
- TVM_FFI_ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist
in " << name;
- TVM_FFI_ICHECK(!this->Contains(axis.ToSubordinate()))
- << "Axis " << axis << " has already been split in " << name;
- TVM_FFI_ICHECK(factor > 0) << "Invalid split size " << factor;
- ffi::Array<IterVar> new_layout;
- for (size_t i = 0; i <= this->ndim(); ++i) {
- if (i == target_pos) {
- new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)),
- Var(axis.ToSubordinate().name()),
tir::kDataPar));
+ffi::Array<IterVar> Layout::UnpackIterVar(IterVar packed_iter) {
+ ffi::Array<IterVar> result;
+ int64_t factor = 0, final_factor = 1;
+
+ std::string name(packed_iter->var->name_hint.c_str());
+ DataType dtype = packed_iter->var.dtype();
+
+ for (auto ch : name) {
+ if (ch >= '0' && ch <= '9') {
+ factor = factor * 10 + (ch - '0');
+ } else if (ch >= 'a' && ch <= 'z') {
+ TVM_FFI_ICHECK(factor != 0) << "Invalid Factor Size";
+ result.push_back(IterVar(Range(IntImm(dtype, 0), IntImm(dtype, factor)),
+ Var(std::string(1, ch), dtype), tir::kDataPar));
+ final_factor *= factor;
+ factor = 0;
+ } else if (ch >= 'A' && ch <= 'Z') {
+ TVM_FFI_ICHECK(factor == 0) << "Can't have non-zero factors for primal
axis";
+ result.push_back(IterVar(Range(IntImm(dtype, 0), Var(std::string(1, ch),
dtype)),
+ Var(std::string(1, ch), dtype), tir::kDataPar));
}
- if (i == this->ndim()) break;
- new_layout.push_back(axes[i]);
}
- return Layout(new_layout);
+
+ return result;
+}
+
+IterVar Layout::PackIterVar(ffi::Array<IterVar> iter_vars) {
+ std::stringstream name;
+ size_t extent = 1;
+
+ DataType dtype = iter_vars[0]->dom->extent.as<PrimExpr>().value()->dtype;
+ for (auto itvar : iter_vars) {
+ TVM_FFI_ICHECK(itvar->dom->extent.as<IntImm>())
+ << "Packed Axis can contain only Subordinate Axes";
+ name << itvar->dom->extent.as<IntImm>().value() << itvar->var->name_hint;
+ extent = extent * itvar->dom->extent.as<IntImm>().value()->value;
+ }
+
+ return IterVar(Range(IntImm(dtype, 0), IntImm(dtype, extent)),
Var(name.str(), dtype),
+ tir::kDataPar);
}
int32_t Layout::FactorOf(const LayoutAxis& axis) const {
@@ -188,12 +275,13 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
int32_t factor = 1;
bool has_sub = false;
- for (const IterVar& itvar : operator->()->axes) {
- if (sub == LayoutAxis::Get(itvar)) {
- has_sub = true;
- int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
- TVM_FFI_ICHECK(val);
- factor *= val;
+ for (const IterVar& packed_itvar : operator->()->axes) {
+ for (auto itvar : UnpackIterVar(packed_itvar)) {
+ if (sub == LayoutAxis::Get(itvar)) {
+ has_sub = true;
+ int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
+ factor *= val;
+ }
}
}
factor = has_sub ? factor : -1;
@@ -218,63 +306,120 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>*
index_rule, ffi::Array<PrimExpr>*
return false;
}
- for (size_t i = 0; i < dst_layout.ndim(); ++i) {
- const auto& store_axis = dst_layout[i];
- const IterVar& store_axis_impl = dst_layout->axes[i];
- PrimExpr index_store(0);
-
- for (size_t j = 0; j < src_layout.ndim(); ++j) {
- const auto& orig_axis = src_layout[j];
- const IterVar& orig_axis_impl = src_layout->axes[j];
- if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
- if (orig_axis.IsPrimal()) {
- PrimExpr orig_var = orig_axis_impl->var;
- const int32_t factor = src_layout.FactorOf(orig_axis);
- if (factor > 0) {
- orig_var = orig_var * factor;
- }
- index_store = index_store + orig_var;
- } else {
- PrimExpr factor(1);
- for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
- if (LayoutAxis::Get(orig_axis_impl) ==
LayoutAxis::Get(src_layout->axes[k])) {
- factor = factor * src_layout->axes[k]->dom->extent;
+ std::vector<bool> exists(128, false);
+ PrimExpr norm_indexes[128];
+ for (auto& it : norm_indexes) it = PrimExpr(0);
+
+ for (size_t i = 0; i < src_layout.ndim(); i++) {
+ auto factor = src_layout.PackedAxisAt(i)->dom->extent;
+ auto src_unpacked_axes = Layout::UnpackIterVar(src_layout.PackedAxisAt(i));
+
+ if (src_unpacked_axes.size() == 1 &&
LayoutAxis::Get(src_unpacked_axes[0]).IsPrimal()) {
+ const auto& prim_axis = LayoutAxis::Get(src_unpacked_axes[0]);
+ int64_t offset = src_layout.FactorOf(prim_axis);
+ if (offset == -1)
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] +
src_layout.PackedAxisAt(i);
+ else
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] +
+ src_layout.PackedAxisAt(i) * src_layout.FactorOf(prim_axis);
+ exists[prim_axis.name()[0]] = true;
+ } else {
+ int64_t value = 1;
+ std::vector<int> index_divs(src_unpacked_axes.size());
+ for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+ index_divs[j] = value;
+ const auto* extent =
src_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(extent) << "Expected Integer Extents for Offset
Calculation";
+ index_divs.push_back(value);
+ value = value * extent->value;
+ }
+ std::reverse(index_divs.begin(), index_divs.end());
+
+ for (size_t j = 0; j < src_unpacked_axes.size(); j++) {
+ const int extent =
src_unpacked_axes[j]->dom->extent.as<IntImmNode>()->value;
+ const LayoutAxis& store_axis_impl =
LayoutAxis::Get(src_unpacked_axes[j]);
+ const LayoutAxis& sub_axis = store_axis_impl.ToSubordinate(); /* Not
Needed */
+ const LayoutAxis& prim_axis = store_axis_impl.ToPrimal();
+
+ PrimExpr factor_ij = indexdiv(src_layout.PackedAxisAt(i),
index_divs[j]);
+ if (j != 0) factor_ij = indexmod(factor_ij, extent);
+
+ for (size_t k = i; k < src_layout.ndim(); k++) {
+ size_t l = 0;
+ if (k == i) l = j + 1;
+
+ auto inter_unpacked_axes =
Layout::UnpackIterVar(src_layout.PackedAxisAt(k));
+ for (; l < inter_unpacked_axes.size(); l++) {
+ const LayoutAxis& axis = LayoutAxis::Get(inter_unpacked_axes[l]);
+ if (axis == sub_axis) {
+ const auto* sub_extent =
inter_unpacked_axes[l]->dom->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for
Offset Calculation";
+ factor_ij = factor_ij * IntImm(sub_extent->dtype,
sub_extent->value);
}
}
- index_store = index_store + orig_axis_impl->var * factor;
}
+
+ norm_indexes[prim_axis.name()[0] - 'A'] =
+ norm_indexes[prim_axis.name()[0] - 'A'] + factor_ij;
}
}
- if (tir::is_zero(index_store)) {
- LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" <<
dst_layout.name()
- << "' is not convertible.";
- return false;
- }
+ }
+
+ arith::Analyzer ana;
- PrimExpr shape_store = index_store;
- if (store_axis.IsPrimal()) {
- const int32_t factor = dst_layout.FactorOf(store_axis);
- if (factor > 0) {
- shape_store = shapediv(index_store, PrimExpr(factor));
- index_store = indexdiv(index_store, PrimExpr(factor));
+ for (size_t i = 0; i < dst_layout.ndim(); i++) {
+ const auto dst_unpacked_axes =
Layout::UnpackIterVar(dst_layout.PackedAxisAt(i));
+
+ if (dst_unpacked_axes.size() == 1 &&
LayoutAxis::Get(dst_unpacked_axes[0]).IsPrimal()) {
+ const auto& prim_axis = LayoutAxis::Get(dst_unpacked_axes[0]);
+ if (!exists[prim_axis.name()[0]]) return false;
+ int64_t offset = dst_layout.FactorOf(prim_axis);
+ if (offset != -1) {
+ index_rule->push_back(
+ indexdiv(norm_indexes[prim_axis.name()[0] - 'A'],
dst_layout.FactorOf(prim_axis)));
+ shape_rule->push_back(
+ indexdiv(norm_indexes[prim_axis.name()[0] - 'A'] +
(dst_layout.FactorOf(prim_axis) - 1),
+ dst_layout.FactorOf(prim_axis)));
+ } else {
+ index_rule->push_back(norm_indexes[prim_axis.name()[0] - 'A']);
+ shape_rule->push_back(norm_indexes[prim_axis.name()[0] - 'A']);
}
} else {
- PrimExpr stride(1);
- PrimExpr factor(1);
- for (size_t j = i; j < dst_layout.ndim(); ++j) {
- if (LayoutAxis::Get(store_axis_impl) ==
LayoutAxis::Get(dst_layout->axes[j])) {
- stride = stride * dst_layout->axes[j]->dom->extent;
- if (j > i) {
- factor = factor * dst_layout->axes[j]->dom->extent;
+ PrimExpr factor(0);
+ for (size_t j = 0; j < dst_unpacked_axes.size(); j++) {
+ const auto& prim_axis =
LayoutAxis::Get(dst_unpacked_axes[j]).ToPrimal();
+ const auto& sub_axis =
LayoutAxis::Get(dst_unpacked_axes[j]).ToSubordinate();
+ const auto* extent =
dst_unpacked_axes[j]->dom->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(extent) << "Expected extent to be IntImmNode";
+
+ size_t divfactor = 1;
+ for (size_t k = i; k < dst_layout.ndim(); k++) {
+ size_t l = 0;
+ if (k == i) l = j + 1;
+
+ const auto inter_unpacked_axes =
Layout::UnpackIterVar(dst_layout.PackedAxisAt(k));
+ for (; l < inter_unpacked_axes.size(); l++) {
+ const auto& axis = LayoutAxis::Get(inter_unpacked_axes[l]);
+ if (sub_axis == axis) {
+ const auto* sub_extent =
inter_unpacked_axes[l]->dom->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for
Offset Calculation";
+ divfactor = divfactor * sub_extent->value;
+ }
}
}
+
+ factor = factor + indexmod(indexdiv(norm_indexes[prim_axis.name()[0] -
'A'], divfactor),
+ extent->value);
+ for (size_t k = j + 1; k < dst_unpacked_axes.size(); k++) {
+ factor = factor *
dst_unpacked_axes[k]->dom->extent.as<IntImm>().value();
+ }
}
- shape_store = indexdiv(indexmod(index_store, stride), factor);
- index_store = indexdiv(indexmod(index_store, stride), factor);
+ ana.Simplify(factor);
+ index_rule->push_back(factor);
+ shape_rule->push_back(factor);
}
-
- index_rule->push_back(index_store);
- shape_rule->push_back(shape_store);
}
std::stringstream ss;
@@ -289,7 +434,7 @@ inline bool GetStoreRule(ffi::Array<PrimExpr>* index_rule,
ffi::Array<PrimExpr>*
ss << r << ", ";
}
ss << "]" << std::endl;
- VLOG(1) << std::endl << ss.str();
+ VLOG(1) << ss.str() << std::endl;
return true;
}
@@ -341,7 +486,8 @@ inline ffi::Array<PrimExpr> TransformShape(const
ffi::Array<PrimExpr>& src_shape
for (size_t i = 0; i < src_shape.size(); ++i) {
PrimExpr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
- if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
+ auto layout = Layout::UnpackIterVar(orig_axis);
+ if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const = orig_shape.as<IntImmNode>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
@@ -366,7 +512,8 @@ inline ffi::Array<PrimExpr> TransformShape(const
ffi::Array<PrimExpr>& src_shape
for (size_t i = 0; i < transform_rule.size(); ++i) {
PrimExpr rule = transform_rule[i];
IterVar axis = target_axis[i];
- if (!LayoutAxis::Get(axis).IsPrimal()) {
+ auto layout = Layout::UnpackIterVar(axis);
+ if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) {
result.push_back(axis->dom->extent);
} else {
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
@@ -435,9 +582,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef()
.def("s_tir.Layout", [](std::string name, DataType dtype) { return
Layout(name, dtype); })
.def("s_tir.LayoutIndexOf",
- [](Layout layout, std::string axis) -> int {
- return layout.IndexOf(LayoutAxis::Get(axis));
- })
+ [](Layout layout, std::string axis) -> int { return
layout.IndexOf(axis); })
.def("s_tir.LayoutFactorOf",
[](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::Get(axis));
@@ -445,8 +590,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("s_tir.LayoutNdim", [](Layout layout) -> int { return
layout.ndim(); })
.def("s_tir.LayoutGetItem",
[](Layout layout, int idx) -> std::string {
- const LayoutAxis& axis = layout[idx];
- return axis.name();
+ const auto& axis = layout.PackedAxisAt(idx);
+ return axis->var->name_hint;
})
.def("s_tir.BijectiveLayout",
[](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
diff --git a/tests/python/s_tir/base/test_tir_data_layout.py
b/tests/python/s_tir/base/test_tir_data_layout.py
index 8f57a99b3e..32c7f9c9d1 100644
--- a/tests/python/s_tir/base/test_tir_data_layout.py
+++ b/tests/python/s_tir/base/test_tir_data_layout.py
@@ -19,7 +19,8 @@
import pytest
import tvm
-import tvm.error
+import tvm.testing
+from tvm.error import InternalError
from tvm.topi.utils import get_const_tuple
@@ -36,7 +37,7 @@ def test_layout():
assert layout.index_of("C") == 1
assert layout.index_of("H") == 2
assert layout.index_of("W") == 3
- assert layout.index_of("c") == 4
+ assert layout.index_of("16c") == 4
assert layout.index_of("O") == -1
assert "N" in layout
@@ -50,8 +51,50 @@ def test_layout():
assert layout[1] == "C"
assert layout[2] == "H"
assert layout[3] == "W"
- assert layout[4] == "c"
- assert layout[-1] == "c"
+ assert layout[4] == "16c"
+
+ layout = tvm.s_tir.layout("OIHW[4o4i]")
+ assert layout is not None
+ assert isinstance(layout, tvm.s_tir.Layout)
+
+ assert layout.factor_of("o") == 4
+ assert layout.factor_of("i") == 4
+ assert layout.factor_of("H") == -1
+ assert layout.factor_of("W") == -1
+ assert layout.factor_of("N") == -1
+
+ assert layout.index_of("O") == 0
+ assert layout.index_of("I") == 1
+ assert layout.index_of("H") == 2
+ assert layout.index_of("W") == 3
+ assert layout.index_of("4o4i") == 4
+ assert layout.index_of("i") == -1
+ assert layout.index_of("o") == -1
+
+ assert "O" in layout
+ assert "I" in layout
+ assert "H" in layout
+ assert "W" in layout
+ assert "4o4i" in layout
+ assert "i" in layout
+ assert "o" in layout
+
+ assert layout[0] == "O"
+ assert layout[1] == "I"
+ assert layout[2] == "H"
+ assert layout[3] == "W"
+ assert layout[4] == "4o4i"
+
+ with pytest.raises(InternalError):
+ layout = tvm.s_tir.layout("[N4o]C")
+ with pytest.raises(InternalError):
+ layout = tvm.s_tir.layout("[O4o]")
+ with pytest.raises(InternalError):
+ layout = tvm.s_tir.layout("C4o")
+ with pytest.raises(InternalError):
+ layout = tvm.s_tir.layout("OI[4o4i][]")
+ with pytest.raises(InternalError):
+ layout = tvm.s_tir.layout("C4c[4c]")
def test_layout_dtype():
@@ -85,6 +128,8 @@ def test_bilayout_convertible():
assert tvm.s_tir.bijective_layout("__undef__", "__undef__") is None
assert tvm.s_tir.bijective_layout("", "NCHW") is None
assert tvm.s_tir.bijective_layout("NCHW", "") is None
+ assert tvm.s_tir.bijective_layout("OIHW", "OIHW[4o4i]") is not None
+ assert tvm.s_tir.bijective_layout("OIHW[2o4i]", "OIHW") is not None
assert tvm.s_tir.bijective_layout("", "") is None
# convertible
assert tvm.s_tir.bijective_layout("NCHW", "NCHW16c") is not None
@@ -100,6 +145,14 @@ def test_bilayout_shape():
src_shape = bilayout.backward_shape(dst_shape)
assert get_const_tuple(src_shape) == (1, 32, 7, 7)
+ bilayout = tvm.s_tir.bijective_layout("OIHW", "OIHW[4o4i]")
+
+ dst_shape = bilayout.forward_shape((64, 28, 7, 7))
+ assert get_const_tuple(dst_shape) == (16, 7, 7, 7, 16)
+
+ src_shape = bilayout.backward_shape((2, 11, 4, 4, 16))
+ assert get_const_tuple(src_shape) == (8, 44, 4, 4)
+
def test_bilayout_index():
bilayout = tvm.s_tir.bijective_layout("NCHW", "NCHW16c")
@@ -110,10 +163,14 @@ def test_bilayout_index():
src_index = bilayout.backward_index([0, 1, 6, 6, 2])
assert get_const_tuple(src_index) == (0, 18, 6, 6)
+ bilayout = tvm.s_tir.bijective_layout("OIHW", "OIHW[4o4i]")
+
+ dst_index = bilayout.forward_index((63, 29, 7, 7))
+ assert get_const_tuple(dst_index) == (15, 7, 7, 7, 13)
+
+ src_index = bilayout.backward_index((4, 7, 4, 4, 13))
+ assert get_const_tuple(src_index) == (19, 29, 4, 4)
+
if __name__ == "__main__":
- test_layout()
- test_layout_dtype()
- test_bilayout_convertible()
- test_bilayout_shape()
- test_bilayout_index()
+ tvm.testing.main()