This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dfd525bda5 Revert "[TensorIR][Visitor] Visit buffer members in 
`match_buffer`'s in block visitor functions (#15153) (#15816)
dfd525bda5 is described below

commit dfd525bda5acdecb148220e32a7028e48797b10d
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Sep 26 05:26:31 2023 -0700

    Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in 
block visitor functions (#15153) (#15816)
    
    * Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in 
block visitor functions (#15153)"
---
 src/tir/ir/stmt_functor.cc                         | 32 ++--------------
 .../test_tir_transform_unify_thread_binding.py     | 43 ----------------------
 2 files changed, 4 insertions(+), 71 deletions(-)

diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 7d1fe9f8dd..1c15f95826 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -135,15 +135,8 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) {
   VisitArray(op->reads, fvisit_buffer_region);
   VisitArray(op->writes, fvisit_buffer_region);
   VisitArray(op->match_buffers,
-             [this, fvisit_buffer_region](const MatchBufferRegion& 
match_buffer_region) {
+             [fvisit_buffer_region](const MatchBufferRegion& 
match_buffer_region) {
                fvisit_buffer_region(match_buffer_region->source);
-               this->VisitExpr(match_buffer_region->buffer->elem_offset);
-               VisitArray(match_buffer_region->buffer->strides,
-                          [this](const PrimExpr& e) { this->VisitExpr(e); });
-               VisitArray(match_buffer_region->buffer->shape,
-                          [this](const PrimExpr& e) { this->VisitExpr(e); });
-               VisitArray(match_buffer_region->buffer->axis_separators,
-                          [this](const IntImm& e) { this->VisitExpr(e); });
              });
   if (op->init.defined()) {
     this->VisitStmt(op->init.value());
@@ -245,28 +238,11 @@ class StmtMutator::Internal {
 
   static Array<MatchBufferRegion> Mutate(StmtMutator* self, const 
Array<MatchBufferRegion>& arr) {
     auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
-      const Buffer& buffer = match_buffer_region->buffer;
       Array<Range> region = Mutate(self, match_buffer_region->source->region);
-      PrimExpr elem_offset = self->VisitExpr(buffer->elem_offset);
-      Array<PrimExpr> strides = Mutate(self, buffer->strides);
-      Array<PrimExpr> shape = Mutate(self, buffer->shape);
-      Array<IntImm> axis_separators =
-          MutateArray(self, buffer->axis_separators,
-                      [self](const IntImm& e) { return 
Downcast<IntImm>(self->VisitExpr(e)); });
-
-      if (elem_offset.same_as(buffer->elem_offset) && 
strides.same_as(buffer->strides) &&
-          shape.same_as(buffer->shape) && 
axis_separators.same_as(buffer->axis_separators)) {
-        if (region.same_as(match_buffer_region->source->region)) {
-          return match_buffer_region;
-        } else {
-          return MatchBufferRegion(buffer,
-                                   
BufferRegion(match_buffer_region->source->buffer, region));
-        }
+      if (region.same_as(match_buffer_region->source->region)) {
+        return match_buffer_region;
       } else {
-        Buffer new_buffer(buffer->data, buffer->dtype, shape, strides, 
elem_offset, buffer->name,
-                          buffer->data_alignment, buffer->offset_factor, 
buffer->buffer_type,
-                          axis_separators, buffer->span);
-        return MatchBufferRegion(new_buffer,
+        return MatchBufferRegion(match_buffer_region->buffer,
                                  
BufferRegion(match_buffer_region->source->buffer, region));
       }
     };
diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py 
b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
index d42adfcee4..9ee8643312 100644
--- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
+++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
@@ -258,45 +258,6 @@ def unified_element_wise_implicit_block(a: T.handle, b: 
T.handle, c: T.handle) -
                     )
 
 
[email protected]_func
-def match_buffer_with_elem_offset(
-    A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: 
T.int32
-) -> None:
-    for i in T.thread_binding(0, 4, "blockIdx.x"):
-        for j in range(2):
-            with T.block():
-                T.writes(A[I[i], offset, j * 4 : j * 4 + 4])
-                sub_A = T.match_buffer(
-                    A[I[i], offset, j * 4 : j * 4 + 4],
-                    (4),
-                    elem_offset=I[i] * 80 + offset * 8 + j * 4,
-                )
-                for ji in range(0, 4):
-                    sub_A[j * 4 + ji] = 1
-
-
[email protected]_func
-def unified_match_buffer_with_elem_offset(
-    A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: 
T.int32
-) -> None:
-    for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
-        for j in range(2):
-            with T.block(""):
-                T.reads(I[blockIdx_x])
-                T.writes(A[I[blockIdx_x], offset, j * 4 : j * 4 + 4])
-                sub_A = T.match_buffer(
-                    A[I[blockIdx_x], offset, j * 4 : j * 4 + 4],
-                    (4,),
-                    elem_offset=I[blockIdx_x] * 80 + offset * 8 + j * 4,
-                )
-                for ji in range(4):
-                    i = T.int32()
-                    sub_A_1 = T.Buffer(
-                        (4,), data=sub_A.data, elem_offset=I[i] * 80 + offset 
* 8 + j * 4
-                    )
-                    sub_A_1[j * 4 + ji] = T.float32(1)
-
-
 def test_thread_x():
     _check(element_wise_thread_x, unified_element_wise_thread_x)
 
@@ -327,10 +288,6 @@ def test_implicit_block():
     _check(element_wise_implicit_block, unified_element_wise_implicit_block)
 
 
-def test_match_buffer_with_elem_offset():
-    _check(match_buffer_with_elem_offset, 
unified_match_buffer_with_elem_offset)
-
-
 def test_inner_binding_with_annotation():
     @T.prim_func
     def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: 
T.Buffer((64,), "float32")):

Reply via email to