This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5711c35ae0 [TIR Pass] decouple flatten buffer to lower opaque block
pass and flatten buffer. (#12172)
5711c35ae0 is described below
commit 5711c35ae01b1dec3527726149efc6fe1a4bc7c6
Author: Fred.Jia <[email protected]>
AuthorDate: Wed Jul 27 14:22:18 2022 +0800
[TIR Pass] decouple flatten buffer to lower opaque block pass and flatten
buffer. (#12172)
---
include/tvm/tir/transform.h | 11 +-
python/tvm/script/tir/scope_handler.py | 10 +-
python/tvm/tir/transform/transform.py | 16 +-
src/driver/driver_api.cc | 1 +
src/meta_schedule/postproc/verify_gpu_code.cc | 1 +
src/tir/transforms/flatten_buffer.cc | 135 ++---------
.../{flatten_buffer.cc => lower_opaque_block.cc} | 139 ++---------
tests/python/unittest/test_tir_buffer.py | 1 +
.../unittest/test_tir_transform_flatten_buffer.py | 261 +++++----------------
.../test_tir_transform_inject_ptx_async_copy.py | 2 +
.../unittest/test_tir_transform_loop_partition.py | 1 +
...py => test_tir_transform_lower_opaque_block.py} | 131 ++++-------
12 files changed, 177 insertions(+), 532 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 005bf84103..c758a00b3f 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -457,9 +457,14 @@ TVM_DLL Pass LegalizePackedCalls();
TVM_DLL Pass LowerMatchBuffer();
/*!
- * \brief Flatten the multi-dimensional BufferLoad and BufferStore
- * to single dimensional Load/Store. Also remove Block to
- * ensure that the flattened TIR can not be scheduled again.
+ * \brief Remove the block to ensure that the TIR can not be scheduled again.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerOpaqueBlock();
+
+/*!
+ * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single
dimensional
+ * BufferLoad/BufferStore for the TIR not contains opaque block.
* \return The pass.
*/
TVM_DLL Pass FlattenBuffer();
diff --git a/python/tvm/script/tir/scope_handler.py
b/python/tvm/script/tir/scope_handler.py
index 76fbf26eea..92aaf8b4d9 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -111,16 +111,10 @@ class Allocate(WithScopeHandler):
condition = tvm.runtime.convert(condition)
scope = tvm.runtime.convert(scope)
- # Currently, allocate nodes should only occur after buffer
- # flattening has been applied. This can be simplified in
- # the future by having the AllocateNode hold a buffer
- # object directly.
- flattened = self.buffer.get_flattened_buffer()
-
return tvm.tir.Allocate(
self.buffer.data,
- flattened.dtype,
- flattened.shape,
+ self.buffer.dtype,
+ self.buffer.shape,
condition,
self.body,
annotations=annotations,
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 2a4ff6618a..6cc7b2e1f8 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -769,10 +769,20 @@ def LowerMatchBuffer():
return _ffi_api.LowerMatchBuffer() # type: ignore
+def LowerOpaqueBlock():
+ """Remove the block to ensure that the TIR can not be scheduled again.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerOpaqueBlock() # type: ignore
+
+
def FlattenBuffer():
- """Flatten the multi-dimensional BufferLoad and BufferStore
- to single dimensional Load/Store. Also remove Block to
- ensure that the flattened TIR can not be scheduled again.
+ """Flatten the multi-dimensional BufferLoad and BufferStore to single
dimensional
+ BufferLoad/BufferStore for the TIR not contains opaque block.
Returns
-------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 0446347eca..6f4fb618d3 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -202,6 +202,7 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition) {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
+ pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::BF16Legalize());
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 857b732c98..dfe2c5a06a 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
+ pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/flatten_buffer.cc
index 21de191db0..dcc23a72b2 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -21,32 +21,17 @@
* \file flatten_buffer.cc
*/
-#include <tvm/tir/builtin.h>
-#include <tvm/tir/function.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include "../../support/utils.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
-PrimExpr BufferArea(const Buffer& buffer) {
- if (buffer->strides.size()) {
- ICHECK(buffer->shape.size() == buffer->strides.size());
- return buffer->strides[0] * buffer->shape[0];
- }
- PrimExpr area = Integer(1);
- for (const PrimExpr& dim : buffer->shape) {
- area = area * dim;
- }
- return area;
-}
-
/*!
* \brief Transform multi-dimension BufferLoad/BufferStore into
device-supported dimension
+ * for the TIR not contains opaque block.
*/
class BufferFlattener : public StmtExprMutator {
public:
@@ -68,76 +53,25 @@ class BufferFlattener : public StmtExprMutator {
}
}
- Stmt VisitStmt_(const BlockRealizeNode* op) final {
- // We have convert blocks into opaque blocks in previous passes.
- ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in
FlattenBuffer. Please "
- "call pass ConvertBlocksToOpaque
before.";
- // Step 1. Visit the body
- Block new_block = Downcast<Block>(this->VisitStmt(op->block));
- PrimExpr predicate = this->VisitExpr(op->predicate);
- // Step 2. Transform the `predicate` to if-then-else
- Stmt body = new_block->body;
- if (!is_one(predicate)) {
- body = IfThenElse(predicate, std::move(body));
- }
- // Step 3. Handle allocations in reverse order
- for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
- Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
- body = Allocate(buffer->data, buffer->dtype, buffer->shape,
const_true(), std::move(body));
- }
- return body;
- }
-
- Stmt VisitStmt_(const ForNode* op) final {
- // Step 1. Update unit loop info.
- PrimExpr min = this->VisitExpr(op->min);
- PrimExpr extent = this->VisitExpr(op->extent);
- if (is_one(extent) && op->annotations.empty()) {
- // handling unit loop
- unit_loop_vars_[op->loop_var] = min;
- }
- // Step 2. Visit recursively
- Stmt body = this->VisitStmt(op->body);
- // Step 3. Create new For loop accordingly
- if (op->kind == ForKind::kThreadBinding) {
- // Case 1. Thread binding
- ICHECK(op->thread_binding.defined());
- String thread_tag = op->thread_binding.value()->thread_tag;
- body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
- } else if (is_one(extent) && op->annotations.empty()) {
- // Case 2. Unit loop
- return body;
- } else {
- // Case 3. An ordinary loop
- body = For(op->loop_var, std::move(min), std::move(extent), op->kind,
std::move(body));
- }
- // Step 4. Handle annotations
- std::set<std::string> ordered_ann_keys;
- for (const auto& annotation : op->annotations) {
- ordered_ann_keys.insert(annotation.first);
- }
- for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend();
++it) {
- const std::string& ann_key = *it;
- const ObjectRef& ann_value = op->annotations.at(ann_key);
- if (attr::IsPragmaKey(ann_key)) {
- body =
- AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key,
ann_value), std::move(body));
- }
+ Stmt VisitStmt_(const AllocateNode* op) final {
+ Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+ // TODO(Lunderberg): Move the handling of boolean into a
+ // dedicated pass.
+ if (alloc->dtype == DataType::Bool()) {
+ auto writer = alloc.CopyOnWrite();
+ writer->dtype = DataType::Int(8);
}
- return body;
- }
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- Var var = GetRef<Var>(op);
- auto it = unit_loop_vars_.find(var);
- if (it == unit_loop_vars_.end()) {
- return std::move(var);
+ // Handle multi-dimension allocations
+ if (alloc->extents.size() == 1) {
+ return std::move(alloc);
} else {
- PrimExpr expr = it->second;
- if (expr.dtype() != var.dtype()) {
- expr = tvm::cast(var.dtype(), std::move(expr));
+ Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
+ for (size_t i = 0; i < alloc->extents.size(); i++) {
+ flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
}
- return expr;
+ auto n = alloc.CopyOnWrite();
+ n->extents = flat_extent;
+ return std::move(alloc);
}
}
@@ -146,7 +80,6 @@ class BufferFlattener : public StmtExprMutator {
if (it != buffer_remap_.end()) {
return it->second;
}
-
auto flattened = buf.GetFlattenedBuffer();
// TODO(Lunderberg): Move the handling of boolean into a
@@ -208,40 +141,6 @@ class BufferFlattener : public StmtExprMutator {
return node;
}
- static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String
thread_tag,
- Stmt body) {
- IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
- /*var=*/std::move(var),
- /*iter_type=*/IterVarType::kThreadIndex,
- /*thread_tag=*/thread_tag);
- String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
- thread_tag == "vthread.y" || thread_tag == "vthread.z")
- ? attr::virtual_thread
- : attr::thread_extent;
- return AttrStmt(/*node=*/std::move(iter_var),
- /*attr_key=*/std::move(attr_key),
- /*value=*/std::move(extent),
- /*body=*/std::move(body));
- }
-
- /*! \brief Convert attr value from annotation map into PrimExpr. */
- PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
- if (!obj.defined()) {
- return PrimExpr();
- } else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
- return GetRef<PrimExpr>(expr);
- } else if (const StringObj* str = obj.as<StringObj>()) {
- return std::move(StringImm(str->data));
- } else {
- LOG(FATAL) << "Illegal attribute of key " << key << ", value type " <<
obj->GetTypeKey()
- << " not supported";
- return PrimExpr();
- }
- }
-
- /*! \brief Record the loop_var and loop start value of unit loops, whose
extent is one. */
- std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
unit_loop_vars_;
-
/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_remap_;
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/lower_opaque_block.cc
similarity index 56%
copy from src/tir/transforms/flatten_buffer.cc
copy to src/tir/transforms/lower_opaque_block.cc
index 21de191db0..69d8787aa1 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -18,56 +18,22 @@
*/
/*!
- * \file flatten_buffer.cc
+ * \file lower_opaque_block.cc
*/
-#include <tvm/tir/builtin.h>
-#include <tvm/tir/function.h>
-#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include "../../support/utils.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
-PrimExpr BufferArea(const Buffer& buffer) {
- if (buffer->strides.size()) {
- ICHECK(buffer->shape.size() == buffer->strides.size());
- return buffer->strides[0] * buffer->shape[0];
- }
- PrimExpr area = Integer(1);
- for (const PrimExpr& dim : buffer->shape) {
- area = area * dim;
- }
- return area;
-}
-
/*!
- * \brief Transform multi-dimension BufferLoad/BufferStore into
device-supported dimension
+ * \brief Remove Block to ensure that the TIR can not be scheduled again.
*/
-class BufferFlattener : public StmtExprMutator {
- public:
- static PrimFunc Flatten(PrimFunc func) {
- Map<Var, Buffer> preflattened_buffer_map =
- Merge(func->buffer_map, func->preflattened_buffer_map);
- auto pass = BufferFlattener(func->buffer_map);
- auto writer = func.CopyOnWrite();
- writer->body = pass.VisitStmt(func->body);
- writer->preflattened_buffer_map = preflattened_buffer_map;
- writer->buffer_map = pass.updated_extern_buffer_map_;
- return func;
- }
-
+class OpaqueBlockLower : public StmtExprMutator {
private:
- explicit BufferFlattener(const Map<Var, Buffer>& extern_buffer_map) {
- for (const auto& kv : extern_buffer_map) {
- updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second));
- }
- }
-
Stmt VisitStmt_(const BlockRealizeNode* op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in
FlattenBuffer. Please "
@@ -82,8 +48,16 @@ class BufferFlattener : public StmtExprMutator {
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
- Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
- body = Allocate(buffer->data, buffer->dtype, buffer->shape,
const_true(), std::move(body));
+ const Buffer& buffer = new_block->alloc_buffers[i - 1];
+ Array<PrimExpr> new_shape = buffer->shape;
+ if (buffer->strides.size()) {
+ ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
+ for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
+ ICHECK(is_zero(floormod(buffer->strides[i - 1],
buffer->strides[i])));
+ new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
+ }
+ }
+ body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(),
std::move(body));
}
return body;
}
@@ -141,73 +115,6 @@ class BufferFlattener : public StmtExprMutator {
}
}
- Buffer GetFlattenedBuffer(Buffer buf) {
- auto it = buffer_remap_.find(buf);
- if (it != buffer_remap_.end()) {
- return it->second;
- }
-
- auto flattened = buf.GetFlattenedBuffer();
-
- // TODO(Lunderberg): Move the handling of boolean into a
- // dedicated pass.
- if (flattened->dtype == DataType::Bool()) {
- auto writer = flattened.CopyOnWrite();
- writer->dtype = DataType::Int(8);
- }
-
- buffer_remap_[buf] = flattened;
- return flattened;
- }
-
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
- bool store_returns_bool = (op->value.dtype() == DataType::Bool());
- store = VisitBufferAccess(store);
-
- // Handle casts from the value's dtype to the dtype of the
- // backing array.
- // TODO(Lunderberg): Move the handling of boolean into a
- // dedicated pass.
- if (store_returns_bool) {
- ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
- << "Expected int8 backing array for boolean tensor";
- auto writer = store.CopyOnWrite();
- writer->value = tvm::cast(DataType::Int(8), store->value);
- return store;
- }
- return store;
- }
-
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- bool load_returns_bool = (op->dtype == DataType::Bool());
- BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
- load = VisitBufferAccess(load);
- // Handle casts from dtype of the backing array to value's dtype.
- // TODO(Lunderberg): Move the handling of boolean into a
- // dedicated pass.
- if (load_returns_bool) {
- ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
- << "Expected int8 backing array for boolean tensor";
- load.CopyOnWrite()->dtype = DataType::Int(8);
- return tvm::cast(DataType::Bool(), load);
- } else {
- return std::move(load);
- }
- }
-
- template <typename Node>
- Node VisitBufferAccess(Node node) {
- ICHECK(node->buffer.defined());
- auto flattened_indices = node->buffer->ElemOffset(node->indices);
- Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);
-
- auto writer = node.CopyOnWrite();
- writer->buffer = flattened_buffer;
- writer->indices = flattened_indices;
- return node;
- }
-
static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String
thread_tag,
Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
@@ -241,18 +148,14 @@ class BufferFlattener : public StmtExprMutator {
/*! \brief Record the loop_var and loop start value of unit loops, whose
extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
unit_loop_vars_;
-
- /*! \brief Map of buffers being remapped. */
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_remap_;
-
- /*! \brief The updated external buffer map. */
- Map<Var, Buffer> updated_extern_buffer_map_;
};
-PrimFunc FlattenBuffer(PrimFunc f) {
+PrimFunc LowerOpaqueBlock(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
- return BufferFlattener::Flatten(f);
+ auto fptr = f.CopyOnWrite();
+ fptr->body = OpaqueBlockLower()(std::move(fptr->body));
+ return f;
} else {
return f;
}
@@ -260,14 +163,14 @@ PrimFunc FlattenBuffer(PrimFunc f) {
namespace transform {
-Pass FlattenBuffer() {
+Pass LowerOpaqueBlock() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- return FlattenBuffer(std::move(f));
+ return LowerOpaqueBlock(std::move(f));
};
- return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {});
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
+TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock);
} // namespace transform
} // namespace tir
diff --git a/tests/python/unittest/test_tir_buffer.py
b/tests/python/unittest/test_tir_buffer.py
index 10e827978c..d250fada6a 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -115,6 +115,7 @@ def test_buffer_vload_nullptr():
[
tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
tvm.tir.transform.CompactBufferAllocation(),
+ tvm.tir.transform.LowerOpaqueBlock(),
tvm.tir.transform.FlattenBuffer(),
]
)(mod)
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py
b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index f1a33a4fb2..ea9c604e71 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te, tir
+import tvm.testing
+from tvm import te
from tvm.script import tir as T
@@ -28,24 +29,15 @@ def _check(original, transformed):
@T.prim_func
-def compacted_elementwise_func(a: T.handle, c: T.handle) -> None:
+def elementwise_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
- for i in range(0, 16):
- with T.block():
- T.reads(A[i, 0:16])
- T.writes(C[i, 0:16])
- B = T.alloc_buffer([1, 16], "float32", scope="global")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i, j])
- T.writes(B[0, j])
- B[0, j] = A[i, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[0, j])
- T.writes(C[i, j])
- C[i, j] = B[0, j] * 2.0
+ for i in T.serial(0, 16):
+ B_new = T.allocate([1, 16], "float32", "global")
+ for j in T.serial(0, 16):
+ B_new[0, j] = A[i, j] + 1.0
+ for j in T.serial(0, 16):
+ C[i, j] = B_new[0, j] * 2.0
@T.prim_func
@@ -63,26 +55,22 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) ->
None:
@T.prim_func
-def compacted_gpu_func(a: T.handle, c: T.handle) -> None:
+def gpu_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
- for i0 in T.thread_binding(0, 4, thread="blockIdx.x"):
- for i1 in T.thread_binding(0, 2, thread="threadIdx.x"):
- for i2 in T.thread_binding(0, 2, thread="vthread"):
- with T.block():
- T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16])
- T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16])
- B = T.alloc_buffer([1, 16], "float32", scope="local")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i0 * 4 + i1 * 2 + i2, j])
- T.writes(B[0, j])
- B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[0, j])
- T.writes(C[i0 * 4 + i1 * 2 + i2, j])
- C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
+
+ i0 = T.env_thread("blockIdx.x")
+ i1 = T.env_thread("threadIdx.x")
+ i2 = T.env_thread("vthread")
+
+ T.launch_thread(i0, 4)
+ T.launch_thread(i1, 2)
+ T.launch_thread(i2, 2)
+ B = T.allocate([1, 16], "float32", "local")
+ for j in range(0, 16):
+ B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
+ for j in range(0, 16):
+ C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
@T.prim_func
@@ -107,25 +95,16 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
@T.prim_func
-def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32)
-> None:
+def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
A = T.match_buffer(a, (n, m), "float32")
C = T.match_buffer(c, (n, m), "float32")
for i in range(0, n):
- with T.block():
- T.reads(A[i, m])
- T.writes(C[i, m])
- B = T.alloc_buffer((m,), "float32", scope="global")
- for j in range(0, m):
- with T.block():
- T.reads(A[i, j])
- T.writes(B[j])
- B[j] = A[i, j] + 1.0
- for j in range(0, m):
- with T.block():
- T.reads(B[j])
- T.writes(C[i, j])
- C[i, j] = B[j] * 2.0
+ B = T.allocate([m], "float32", "global")
+ for j in range(0, m):
+ B[j] = A[i, j] + 1.0
+ for j in range(0, m):
+ C[i, j] = B[j] * 2.0
@T.prim_func
@@ -144,105 +123,44 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n:
T.int32, m: T.int32) ->
@T.prim_func
-def compacted_predicate_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (32), "float32")
- C = T.match_buffer(c, (32), "float32")
-
- for i, j in T.grid(5, 7):
- with T.block():
- T.reads(A[i * 7 + j])
- T.writes(C[i * 7 + j])
- T.where(i * 7 + j < 32)
- C[i * 7 + j] = A[i * 7 + j] + 1.0
-
-
[email protected]_func
-def flattened_predicate_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (32), "float32")
- C = T.match_buffer(c, (32), "float32")
- T.preflattened_buffer(A, (32), "float32", data=A.data)
- T.preflattened_buffer(C, (32), "float32", data=C.data)
-
- for i, j in T.grid(5, 7):
- if i * 7 + j < 32:
- C[i * 7 + j] = A[i * 7 + j] + 1.0
-
-
[email protected]_func
-def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (32), "float32")
- C = T.match_buffer(c, (32), "float32")
-
- for x, y, z in T.grid(4, 1, 8):
- with T.block():
- T.reads(A[x * 8 + y * 8 + z])
- T.writes(C[x * 8 + y * 8 + z])
- C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0
-
-
[email protected]_func
-def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (32), "float32")
- C = T.match_buffer(c, (32), "float32")
- T.preflattened_buffer(A, (32), "float32", data=A.data)
- T.preflattened_buffer(C, (32), "float32", data=C.data)
-
- for x, z in T.grid(4, 8):
- C[x * 8 + z] = A[x * 8 + z] + 1.0
-
-
[email protected]_func
-def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, (32), "float32")
- D = T.match_buffer(d, (32), "float32")
+def multi_alloc_func(a: T.handle, d: T.handle) -> None:
+ A = T.match_buffer(a, (4, 32), "float32")
+ D = T.match_buffer(d, (4, 32), "float32")
- for i in range(0, 32):
- with T.block():
- T.reads(A[i])
- T.writes(D[i])
- B = T.alloc_buffer((32,), scope="global")
- C = T.alloc_buffer((32,), scope="global")
- B[i] = A[i] + 1.0
- C[i] = A[i] + B[i]
- D[i] = C[i] * 2.0
+ for i, j in T.grid(4, 32):
+ B = T.allocate((4, 32), "float32", scope="global")
+ C = T.allocate((4, 32), "float32", scope="global")
+ B[i, j] = A[i, j] + 1.0
+ C[i, j] = A[i, j] + B[i, j]
+ D[i, j] = C[i, j] * 2.0
@T.prim_func
def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, (32), "float32")
- D = T.match_buffer(d, (32), "float32")
- T.preflattened_buffer(A, (32), "float32", data=A.data)
- T.preflattened_buffer(D, (32), "float32", data=D.data)
+ A = T.match_buffer(a, (128), "float32")
+ D = T.match_buffer(d, (128), "float32")
+ T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
+ T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
- for i in range(0, 32):
- B = T.allocate((32,), "float32", "global")
- C = T.allocate((32,), "float32", "global")
- B[i] = A[i] + 1.0
- C[i] = A[i] + B[i]
- D[i] = C[i] * 2.0
+ for i, j in T.grid(4, 32):
+ B = T.allocate((128), "float32", "global")
+ C = T.allocate((128), "float32", "global")
+ B[i * 32 + j] = A[i * 32 + j] + 1.0
+ C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
+ D[i * 32 + j] = C[i * 32 + j] * 2.0
@T.prim_func
-def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None:
+def strided_buffer_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
- for i0 in range(0, 4):
- with T.block():
- T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16])
- T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16])
- B = T.alloc_buffer([4, 16], "float32", strides=[17, 1],
scope="global")
- for i1 in range(0, 4):
- for j in range(0, 16):
- with T.block():
- T.reads(A[i0 * 4 + i1, j])
- T.writes(B[i1, j])
- B[i1, j] = A[i0 * 4 + i1, j] + 1.0
- for i1 in range(0, 4):
- for j in range(0, 16):
- with T.block():
- T.reads(B[i1, j])
- T.writes(C[i0 * 4 + i1, j])
- C[i0 * 4 + i1, j] = B[i1, j] * 2.0
+ for i0 in T.serial(4):
+ B = T.allocate([4, 17], "float32", "global")
+ B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data,
strides=[17, 1])
+ for i1, j in T.grid(4, 16):
+ B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
+ for i1, j in T.grid(4, 16):
+ C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
@T.prim_func
@@ -261,20 +179,10 @@ def flattened_strided_buffer_func(a: T.handle, c:
T.handle) -> None:
C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
[email protected]_func
-def annotated_loops(a: T.handle) -> None:
- A = T.match_buffer(a, (16,), "float32")
- for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1,
"pragma_3": 0.0}):
- A[i] = 0.0
-
-
@T.prim_func
def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"])
-> None:
for i0 in T.serial(10):
- with T.block("b"):
- T.reads(a[i0])
- T.writes(b[i0])
- b[i0] = a[i0]
+ b[i0] = a[i0]
@T.prim_func
@@ -286,41 +194,24 @@ def boolean_handling_after(a: T.Buffer[10, "int8"], b:
T.Buffer[10, "int8"]) ->
b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
[email protected]_func
-def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) ->
None:
- T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
- T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
- # body
- for i0 in T.serial(10):
- b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
def test_elementwise():
- _check(compacted_elementwise_func, flattened_elementwise_func)
+ _check(elementwise_func, flattened_elementwise_func)
def test_gpu_workload():
- _check(compacted_gpu_func, flattened_gpu_func)
+ _check(gpu_func, flattened_gpu_func)
def test_symbolic_shape():
- _check(compacted_symbolic_func, flattened_symbolic_func)
-
-
-def test_predicate():
- _check(compacted_predicate_func, flattened_predicate_func)
-
-
-def test_unit_loops():
- _check(compacted_unit_loop_func, flattened_unit_loop_func)
+ _check(symbolic_func, flattened_symbolic_func)
def test_multi_alloc():
- _check(compacted_multi_alloc_func, flattened_multi_alloc_func)
+ _check(multi_alloc_func, flattened_multi_alloc_func)
def test_strided_buffer():
- _check(compacted_strided_buffer_func, flattened_strided_buffer_func)
+ _check(strided_buffer_func, flattened_strided_buffer_func)
def test_lower_te():
@@ -332,35 +223,9 @@ def test_lower_te():
tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do
nothing on TE
-def test_annotated_loops():
- mod = tvm.IRModule.from_expr(annotated_loops)
- mod = tvm.tir.transform.FlattenBuffer()(mod)
- # _check(annotated_loops, compacted_annotated_loops)
- attr1 = mod["main"].body
- attr2 = attr1.body
- attr3 = attr2.body
- assert attr1.attr_key == "pragma_1" and attr1.value == "str_value"
- assert attr2.attr_key == "pragma_2"
- tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1))
- assert attr3.attr_key == "pragma_3"
- tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32",
0.0))
-
-
def test_boolean_handling():
_check(boolean_handling_before, boolean_handling_after)
- # mod = tvm.IRModule.from_expr(boolean_handling_before)
- # mod = tvm.tir.transform.FlattenBuffer()(mod)
- # print(mod.script())
if __name__ == "__main__":
- test_elementwise()
- test_gpu_workload()
- test_symbolic_shape()
- test_predicate()
- test_unit_loops()
- test_multi_alloc()
- test_strided_buffer()
- test_lower_te()
- test_annotated_loops()
- test_boolean_handling()
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index d7e13f40aa..1a906b2fb6 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -127,6 +127,7 @@ def test_inject_async_copy():
f = generate_global_to_shared_vectorized_copy(dtype, vec_size)
mod = tvm.IRModule.from_expr(f)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
if vec_size > 1:
mod = tvm.tir.transform.VectorizeLoop()(mod)
@@ -154,6 +155,7 @@ def test_inject_async_copy_shared_dyn():
f = ptx_global_to_shared_dyn_copy_fp16x8
mod = tvm.IRModule.from_expr(f)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.VectorizeLoop()(mod)
mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod)
diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py
b/tests/python/unittest/test_tir_transform_loop_partition.py
index 6cfe96664d..86f2b6696b 100644
--- a/tests/python/unittest/test_tir_transform_loop_partition.py
+++ b/tests/python/unittest/test_tir_transform_loop_partition.py
@@ -611,6 +611,7 @@ def concat_func_3(
def test_condition_mutually_exclusive():
mod = IRModule.from_expr(concat_func_3)
with tvm.transform.PassContext(config={"tir.LoopPartition":
{"partition_const_loop": True}}):
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py
b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
similarity index 67%
copy from tests/python/unittest/test_tir_transform_flatten_buffer.py
copy to tests/python/unittest/test_tir_transform_lower_opaque_block.py
index f1a33a4fb2..9b18c407c4 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
@@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te, tir
+import tvm.testing
+from tvm import te
from tvm.script import tir as T
def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func)
- mod = tvm.tir.transform.FlattenBuffer()(mod)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.Simplify()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed, True)
@@ -49,17 +50,15 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) ->
None:
@T.prim_func
-def flattened_elementwise_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, 256, "float32")
- C = T.match_buffer(c, 256, "float32")
- T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
- T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+def transformed_elementwise_func(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
for i in T.serial(0, 16):
- B_new = T.allocate([16], "float32", "global")
+ B_new = T.allocate([1, 16], "float32", "global")
for j in T.serial(0, 16):
- B_new[j] = A[((i * 16) + j)] + 1.0
+ B_new[0, j] = A[i, j] + 1.0
for j in T.serial(0, 16):
- C[((i * 16) + j)] = B_new[j] * 2.0
+ C[i, j] = B_new[0, j] * 2.0
@T.prim_func
@@ -86,11 +85,9 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None:
@T.prim_func
-def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, 256, "float32")
- C = T.match_buffer(c, 256, "float32")
- T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
- T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
i0 = T.env_thread("blockIdx.x")
i1 = T.env_thread("threadIdx.x")
@@ -99,11 +96,11 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
- B = T.allocate([16], "float32", "local")
+ B = T.allocate([1, 16], "float32", "local")
for j in range(0, 16):
- B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
+ B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
for j in range(0, 16):
- C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
+ C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
@T.prim_func
@@ -129,18 +126,16 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n:
T.int32, m: T.int32) ->
@T.prim_func
-def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32)
-> None:
- A = T.match_buffer(a, n * m, "float32")
- C = T.match_buffer(c, n * m, "float32")
- T.preflattened_buffer(A, (n, m), "float32", data=A.data)
- T.preflattened_buffer(C, (n, m), "float32", data=C.data)
+def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m:
T.int32) -> None:
+ A = T.match_buffer(a, (n, m), "float32")
+ C = T.match_buffer(c, (n, m), "float32")
for i in range(0, n):
B = T.allocate([m], "float32", "global")
for j in range(0, m):
- B[j] = A[i * m + j] + 1.0
+ B[j] = A[i, j] + 1.0
for j in range(0, m):
- C[i * m + j] = B[j] * 2.0
+ C[i, j] = B[j] * 2.0
@T.prim_func
@@ -157,11 +152,9 @@ def compacted_predicate_func(a: T.handle, c: T.handle) ->
None:
@T.prim_func
-def flattened_predicate_func(a: T.handle, c: T.handle) -> None:
+def transformed_predicate_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32), "float32")
C = T.match_buffer(c, (32), "float32")
- T.preflattened_buffer(A, (32), "float32", data=A.data)
- T.preflattened_buffer(C, (32), "float32", data=C.data)
for i, j in T.grid(5, 7):
if i * 7 + j < 32:
@@ -181,11 +174,9 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) ->
None:
@T.prim_func
-def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None:
+def transformed_unit_loop_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32), "float32")
C = T.match_buffer(c, (32), "float32")
- T.preflattened_buffer(A, (32), "float32", data=A.data)
- T.preflattened_buffer(C, (32), "float32", data=C.data)
for x, z in T.grid(4, 8):
C[x * 8 + z] = A[x * 8 + z] + 1.0
@@ -208,11 +199,9 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle)
-> None:
@T.prim_func
-def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
+def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (32), "float32")
D = T.match_buffer(d, (32), "float32")
- T.preflattened_buffer(A, (32), "float32", data=A.data)
- T.preflattened_buffer(D, (32), "float32", data=D.data)
for i in range(0, 32):
B = T.allocate((32,), "float32", "global")
@@ -246,19 +235,17 @@ def compacted_strided_buffer_func(a: T.handle, c:
T.handle) -> None:
@T.prim_func
-def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (256,), "float32")
- C = T.match_buffer(c, (256,), "float32")
- T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
- T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
- for i0 in T.serial(0, 4):
- B_new = T.allocate([68], "float32", "global")
- for i1 in T.serial(0, 4):
- for j in T.serial(0, 16):
- B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
- for i1 in T.serial(0, 4):
- for j in T.serial(0, 16):
- C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
+def transformed_strided_buffer_func(
+ A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
+) -> None:
+ # body
+ for i0 in T.serial(4):
+ B = T.allocate([4, 17], "float32", "global")
+ B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data,
strides=[17, 1])
+ for i1, j in T.grid(4, 16):
+ B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
+ for i1, j in T.grid(4, 16):
+ C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2)
@T.prim_func
@@ -278,49 +265,38 @@ def boolean_handling_before(a: T.Buffer[10, "bool"], b:
T.Buffer[10, "bool"]) ->
@T.prim_func
-def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"])
-> None:
- T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
- T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
- # body
- for i0 in T.serial(10):
- b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
[email protected]_func
-def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) ->
None:
- T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
- T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
+def boolean_handling_after(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"])
-> None:
# body
for i0 in T.serial(10):
- b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
+ b[i0] = a[i0]
def test_elementwise():
- _check(compacted_elementwise_func, flattened_elementwise_func)
+ _check(compacted_elementwise_func, transformed_elementwise_func)
def test_gpu_workload():
- _check(compacted_gpu_func, flattened_gpu_func)
+ _check(compacted_gpu_func, transformed_gpu_func)
def test_symbolic_shape():
- _check(compacted_symbolic_func, flattened_symbolic_func)
+ _check(compacted_symbolic_func, transformed_symbolic_func)
def test_predicate():
- _check(compacted_predicate_func, flattened_predicate_func)
+ _check(compacted_predicate_func, transformed_predicate_func)
def test_unit_loops():
- _check(compacted_unit_loop_func, flattened_unit_loop_func)
+ _check(compacted_unit_loop_func, transformed_unit_loop_func)
def test_multi_alloc():
- _check(compacted_multi_alloc_func, flattened_multi_alloc_func)
+ _check(compacted_multi_alloc_func, transformed_multi_alloc_func)
def test_strided_buffer():
- _check(compacted_strided_buffer_func, flattened_strided_buffer_func)
+ _check(compacted_strided_buffer_func, transformed_strided_buffer_func)
def test_lower_te():
@@ -328,14 +304,13 @@ def test_lower_te():
y = te.compute((1,), lambda i: x[i] + 2)
s = te.create_schedule(y.op)
orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
- mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
- tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do
nothing on TE
+ mod = tvm.tir.transform.LowerOpaqueBlock()(orig_mod)
+ tvm.ir.assert_structural_equal(mod, orig_mod) # LowerOpaqueBlock should
do nothing on TE
def test_annotated_loops():
mod = tvm.IRModule.from_expr(annotated_loops)
- mod = tvm.tir.transform.FlattenBuffer()(mod)
- # _check(annotated_loops, compacted_annotated_loops)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
attr1 = mod["main"].body
attr2 = attr1.body
attr3 = attr2.body
@@ -348,19 +323,7 @@ def test_annotated_loops():
def test_boolean_handling():
_check(boolean_handling_before, boolean_handling_after)
- # mod = tvm.IRModule.from_expr(boolean_handling_before)
- # mod = tvm.tir.transform.FlattenBuffer()(mod)
- # print(mod.script())
if __name__ == "__main__":
- test_elementwise()
- test_gpu_workload()
- test_symbolic_shape()
- test_predicate()
- test_unit_loops()
- test_multi_alloc()
- test_strided_buffer()
- test_lower_te()
- test_annotated_loops()
- test_boolean_handling()
+ tvm.testing.main()