This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bd7f1f8de0 [TIR] Validate tir::Buffer axis_separators on construction 
(#17219)
bd7f1f8de0 is described below

commit bd7f1f8de046d598bcf15ea6d7dffc596d5119a4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 5 01:27:37 2024 -0500

    [TIR] Validate tir::Buffer axis_separators on construction (#17219)
    
    * [TIR] Validate tir::Buffer axis_separators on construction
    
    Prior to this commit, the `axis_separators` field of a TIR buffer
    wasn't validated until the `tir.FlattenBuffer` legalization pass.
    Delaying the error until this point makes it difficult to determine
    where it invalid `axis_separators` were initially defined.
    
    This commit updates the `tir::Buffer` constructor to validate the
    `axis_separators` field immediately, allowing these invalid values to
    be caught on construction.
    
    Closes https://github.com/apache/tvm/issues/17215
    
    * Update metaschedule primitive to only set axis_separators of alloc
    
    * Allow axis separators to be increasing, rather than strictly increasing
---
 src/tir/ir/buffer.cc                               | 45 ++++++++++++++--------
 .../schedule/primitive/layout_transformation.cc    | 15 +++++---
 tests/python/tir-base/test_tir_buffer.py           | 12 ++++--
 .../test_tir_schedule_set_axis_separator.py        |  4 +-
 4 files changed, 51 insertions(+), 25 deletions(-)

diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 0256053331..b7c4eb1d42 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -334,24 +334,37 @@ inline Array<PrimExpr> BufferOffset(const BufferNode* n, 
Array<PrimExpr> index,
   return offsets;
 }
 
-Buffer Buffer::GetFlattenedBuffer() const {
-  auto self = operator->();
-
+static void ValidateAxisSeparators(const Array<IntImm>& axis_separators, 
size_t buffer_dim) {
   // These checks ensure that all output axes contain at least one
   // input axis.
-  for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
-    auto sep = self->axis_separators[i]->value;
-    auto next_sep = self->axis_separators[i + 1]->value;
-    ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly 
increasing order.";
-  }
-  if (self->axis_separators.size()) {
-    auto first_sep = self->axis_separators[0]->value;
-    ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater 
than 0, "
-                            << "so that first output axis contains at least 
one input axis";
-    auto last_sep = self->axis_separators[self->axis_separators.size() - 
1]->value;
-    ICHECK_LT(last_sep, self->shape.size())
-        << "Last output axis must contain at least one input axis.";
+  for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
+    auto sep = axis_separators[i]->value;
+    auto next_sep = axis_separators[i + 1]->value;
+    CHECK_LE(sep, next_sep) << "ValueError: "
+                            << "Axis separators must be in increasing order, "
+                            << "but axis_separators[" << i << "] = " << sep
+                            << " is greater than or equal to axis_separators[" 
<< (i + 1)
+                            << "] = " << next_sep << ".";
+  }
+  if (axis_separators.size()) {
+    auto first_sep = axis_separators[0]->value;
+    CHECK_GE(first_sep, 0) << "ValueError: "
+                           << "All axis separators must be non-negative.  "
+                           << "However, the axis_separators[0] = " << 
first_sep;
+    auto last_sep = axis_separators[axis_separators.size() - 1]->value;
+    CHECK_LE(last_sep, buffer_dim)
+        << "ValueError: "
+        << "All axis separators must be within the range "
+        << "0 <= sep <= buffer_dim.  "
+        << "However, the last axis_separators[" << (axis_separators.size() - 1)
+        << "] = " << last_sep << " is greater than the buffer's dimensionality 
of " << buffer_dim;
   }
+}
+
+Buffer Buffer::GetFlattenedBuffer() const {
+  auto self = operator->();
+
+  ValidateAxisSeparators(self->axis_separators, self->shape.size());
 
   Array<PrimExpr> output_shape;
   if (self->strides.size()) {
@@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> 
shape, Array<PrimExpr>
   
ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
       << "Variable " << data->name_hint << " does not point to a primitive.";
 
+  ValidateAxisSeparators(axis_separators, shape.size());
+
   auto n = make_object<BufferNode>();
   n->data = std::move(data);
   n->dtype = dtype;
diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index f1e9106a63..8b95e0dc62 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private 
ReplaceBufferMutator {
     if (it != buffer_var_map_.end()) {
       const Buffer& new_source_buffer = it->second;
       Buffer new_target_buffer = match_buffer->buffer;
-      new_target_buffer.CopyOnWrite()->axis_separators = 
new_source_buffer->axis_separators;
-      if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
-        LOG(WARNING)
-            << "Target buffer in match_buffer doesn't have the same 
dimensionality as its source "
-               "buffer. `axis_separators` for the target buffer might be 
incorrect.";
+
+      if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) {
+        new_target_buffer.CopyOnWrite()->axis_separators = 
new_source_buffer->axis_separators;
+      } else {
+        new_target_buffer.CopyOnWrite()->axis_separators =
+            Array<IntImm>(new_source_buffer->axis_separators.size(), 
IntImm(DataType::Int(32), 0));
+        LOG(WARNING) << "Buffer view " << new_target_buffer
+                     << " has different dimensionality than backing buffer " 
<< new_source_buffer
+                     << ".  The `axis_separators` for " << new_target_buffer 
<< "."
+                     << "`axis_separators` for the view might be incorrect.";
       }
       buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
       return MatchBufferRegion(new_target_buffer,
diff --git a/tests/python/tir-base/test_tir_buffer.py 
b/tests/python/tir-base/test_tir_buffer.py
index 1ab7662b0b..b4b773197b 100644
--- a/tests/python/tir-base/test_tir_buffer.py
+++ b/tests/python/tir-base/test_tir_buffer.py
@@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod():
     A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
 
     def assert_simplified_equal(index_simplified, index_direct):
-        tvm.ir.assert_structural_equal(
-            index_simplified, index_direct
-        ), "index_simplified=%s, index_direct=%s" % (index_simplified, 
index_direct)
+        (
+            tvm.ir.assert_structural_equal(index_simplified, index_direct),
+            "index_simplified=%s, index_direct=%s" % (index_simplified, 
index_direct),
+        )
 
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
@@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators():
     tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32])
 
 
+def test_invalid_axis_separators_raises_exception():
+    with pytest.raises(ValueError):
+        tvm.tir.decl_buffer([1], axis_separators=[1, 2])
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py 
b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
index 76a6ade42f..788e17e771 100644
--- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
@@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A: 
T.Buffer((128, 128), "flo
     for i, j in T.grid(128, 128):
         with T.block("B"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", 
offset_factor=1, axis_separators=[1])
+            B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", 
offset_factor=1, axis_separators=[0])
             B_subregion0[()] = A[vi, vj] * T.float32(2)
     for i, j in T.grid(128, 128):
         with T.block("C"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", 
offset_factor=1, axis_separators=[1])
+            B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", 
offset_factor=1, axis_separators=[0])
             C[vi, vj] = B_subregion1[()] + T.float32(1)
 
 

Reply via email to