This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dfd525bda5 Revert "[TensorIR][Visitor] Visit buffer members in
`match_buffer`'s in block visitor functions (#15153) (#15816)
dfd525bda5 is described below
commit dfd525bda5acdecb148220e32a7028e48797b10d
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Sep 26 05:26:31 2023 -0700
Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in
block visitor functions (#15153) (#15816)
* Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in
block visitor functions (#15153)"
---
src/tir/ir/stmt_functor.cc | 32 ++--------------
.../test_tir_transform_unify_thread_binding.py | 43 ----------------------
2 files changed, 4 insertions(+), 71 deletions(-)
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 7d1fe9f8dd..1c15f95826 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -135,15 +135,8 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) {
VisitArray(op->reads, fvisit_buffer_region);
VisitArray(op->writes, fvisit_buffer_region);
VisitArray(op->match_buffers,
- [this, fvisit_buffer_region](const MatchBufferRegion&
match_buffer_region) {
+ [fvisit_buffer_region](const MatchBufferRegion&
match_buffer_region) {
fvisit_buffer_region(match_buffer_region->source);
- this->VisitExpr(match_buffer_region->buffer->elem_offset);
- VisitArray(match_buffer_region->buffer->strides,
- [this](const PrimExpr& e) { this->VisitExpr(e); });
- VisitArray(match_buffer_region->buffer->shape,
- [this](const PrimExpr& e) { this->VisitExpr(e); });
- VisitArray(match_buffer_region->buffer->axis_separators,
- [this](const IntImm& e) { this->VisitExpr(e); });
});
if (op->init.defined()) {
this->VisitStmt(op->init.value());
@@ -245,28 +238,11 @@ class StmtMutator::Internal {
static Array<MatchBufferRegion> Mutate(StmtMutator* self, const
Array<MatchBufferRegion>& arr) {
auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
- const Buffer& buffer = match_buffer_region->buffer;
Array<Range> region = Mutate(self, match_buffer_region->source->region);
- PrimExpr elem_offset = self->VisitExpr(buffer->elem_offset);
- Array<PrimExpr> strides = Mutate(self, buffer->strides);
- Array<PrimExpr> shape = Mutate(self, buffer->shape);
- Array<IntImm> axis_separators =
- MutateArray(self, buffer->axis_separators,
- [self](const IntImm& e) { return
Downcast<IntImm>(self->VisitExpr(e)); });
-
- if (elem_offset.same_as(buffer->elem_offset) &&
strides.same_as(buffer->strides) &&
- shape.same_as(buffer->shape) &&
axis_separators.same_as(buffer->axis_separators)) {
- if (region.same_as(match_buffer_region->source->region)) {
- return match_buffer_region;
- } else {
- return MatchBufferRegion(buffer,
-
BufferRegion(match_buffer_region->source->buffer, region));
- }
+ if (region.same_as(match_buffer_region->source->region)) {
+ return match_buffer_region;
} else {
- Buffer new_buffer(buffer->data, buffer->dtype, shape, strides,
elem_offset, buffer->name,
- buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type,
- axis_separators, buffer->span);
- return MatchBufferRegion(new_buffer,
+ return MatchBufferRegion(match_buffer_region->buffer,
BufferRegion(match_buffer_region->source->buffer, region));
}
};
diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
index d42adfcee4..9ee8643312 100644
--- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
+++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
@@ -258,45 +258,6 @@ def unified_element_wise_implicit_block(a: T.handle, b:
T.handle, c: T.handle) -
)
[email protected]_func
-def match_buffer_with_elem_offset(
- A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset:
T.int32
-) -> None:
- for i in T.thread_binding(0, 4, "blockIdx.x"):
- for j in range(2):
- with T.block():
- T.writes(A[I[i], offset, j * 4 : j * 4 + 4])
- sub_A = T.match_buffer(
- A[I[i], offset, j * 4 : j * 4 + 4],
- (4),
- elem_offset=I[i] * 80 + offset * 8 + j * 4,
- )
- for ji in range(0, 4):
- sub_A[j * 4 + ji] = 1
-
-
[email protected]_func
-def unified_match_buffer_with_elem_offset(
- A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset:
T.int32
-) -> None:
- for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
- for j in range(2):
- with T.block(""):
- T.reads(I[blockIdx_x])
- T.writes(A[I[blockIdx_x], offset, j * 4 : j * 4 + 4])
- sub_A = T.match_buffer(
- A[I[blockIdx_x], offset, j * 4 : j * 4 + 4],
- (4,),
- elem_offset=I[blockIdx_x] * 80 + offset * 8 + j * 4,
- )
- for ji in range(4):
- i = T.int32()
- sub_A_1 = T.Buffer(
- (4,), data=sub_A.data, elem_offset=I[i] * 80 + offset
* 8 + j * 4
- )
- sub_A_1[j * 4 + ji] = T.float32(1)
-
-
def test_thread_x():
_check(element_wise_thread_x, unified_element_wise_thread_x)
@@ -327,10 +288,6 @@ def test_implicit_block():
_check(element_wise_implicit_block, unified_element_wise_implicit_block)
-def test_match_buffer_with_elem_offset():
- _check(match_buffer_with_elem_offset,
unified_match_buffer_with_elem_offset)
-
-
def test_inner_binding_with_annotation():
@T.prim_func
def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B:
T.Buffer((64,), "float32")):