This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6e08d90425 [REFACTOR][TIR] Phaseout BufferRealize (#18763)
6e08d90425 is described below
commit 6e08d904251f0c7b777ac6be5329a097e8f8aad0
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Feb 12 09:13:18 2026 -0500
[REFACTOR][TIR] Phaseout BufferRealize (#18763)
This PR Phases out BufferRealize which is a legacy node in TE schedule
and no longer needed here.
---
include/tvm/script/ir_builder/tir/frame.h | 46 --
include/tvm/script/ir_builder/tir/ir.h | 10 -
include/tvm/tir/stmt.h | 55 ---
include/tvm/tir/stmt_functor.h | 4 -
python/tvm/script/ir_builder/tir/frame.py | 5 -
python/tvm/script/ir_builder/tir/ir.py | 29 --
python/tvm/script/parser/core/evaluator.py | 2 +-
python/tvm/tir/__init__.py | 1 -
python/tvm/tir/functor.py | 37 --
python/tvm/tir/stmt.py | 41 --
python/tvm/tir/transform/transform.py | 11 -
src/relax/transform/dataflow_inplace.cc | 7 -
.../analysis/sblock_buffer_access_lca_detector.cc | 6 -
.../plan_update_buffer_allocation_location.cc | 5 -
src/script/ir_builder/tir/frame.cc | 9 -
src/script/ir_builder/tir/ir.cc | 10 -
src/script/printer/tir/stmt.cc | 54 ---
src/tir/backend/adreno/texture_flatten.cc | 44 --
src/tir/ir/py_functor.cc | 12 -
src/tir/ir/stmt.cc | 15 -
src/tir/ir/stmt_functor.cc | 25 -
src/tir/ir/tir_visitor_with_path.cc | 8 -
src/tir/ir/tir_visitor_with_path.h | 1 -
src/tir/transform/inject_rolling_buffer.cc | 327 -------------
src/tir/transform/unsupported_dtype_legalize.cc | 16 -
.../test_hexagon/test_2d_physical_buffers.py | 2 +-
.../test_tir_analysis_verify_well_formed.py | 31 --
tests/python/tir-base/test_tir_nodes.py | 3 -
.../test_tir_transform_inject_rolling_buffer.py | 179 -------
.../tvmscript/test_tvmscript_error_report.py | 10 +-
.../tvmscript/test_tvmscript_ir_builder_tir.py | 21 -
.../python/tvmscript/test_tvmscript_printer_tir.py | 16 -
tests/python/tvmscript/test_tvmscript_roundtrip.py | 515 ---------------------
33 files changed, 5 insertions(+), 1552 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/frame.h
b/include/tvm/script/ir_builder/tir/frame.h
index 224326c0c8..d626df38c2 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -425,52 +425,6 @@ class LaunchThreadFrame : public TIRFrame {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LaunchThreadFrame, TIRFrame,
LaunchThreadFrameNode);
};
-/*!
- * \brief A frame that represents realization.
- *
- * \sa RealizeFrame
- */
-class RealizeFrameNode : public TIRFrameNode {
- public:
- /*! \brief The region of buffer access. */
- tvm::tir::BufferRegion buffer_slice;
- /*! \brief The storage scope associated with this realization. */
- ffi::String storage_scope;
- /*! \brief The condition expression. */
- PrimExpr condition;
-
- static void RegisterReflection() {
- namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<RealizeFrameNode>()
- .def_ro("buffer_slice", &RealizeFrameNode::buffer_slice)
- .def_ro("storage_scope", &RealizeFrameNode::storage_scope)
- .def_ro("condition", &RealizeFrameNode::condition);
- }
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.RealizeFrame",
RealizeFrameNode,
- TIRFrameNode);
-
- public:
- /*!
- * \brief The method called when exiting RAII scope.
- * \sa tvm::support::With
- */
- void ExitWithScope() final;
-};
-
-/*!
- * \brief Managed reference to RealizeFrameNode.
- *
- * \sa RealizeFrameNode
- */
-class RealizeFrame : public TIRFrame {
- public:
- explicit RealizeFrame(ObjectPtr<RealizeFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
- TVM_FFI_ICHECK(data != nullptr);
- data_ = std::move(data);
- }
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RealizeFrame, TIRFrame,
RealizeFrameNode);
-};
-
/*!
* \brief A frame represents the allocate.
*
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 62aad84645..694bd57aab 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -303,16 +303,6 @@ AssertFrame Assert(PrimExpr condition, ffi::String
message);
LetFrame LetStmt(PrimExpr value, ffi::Optional<Type> type_annotation =
std::nullopt,
ffi::Optional<Var> var = std::nullopt);
-/*!
- * \brief The realization.
- * \param buffer_slice The region of buffer access.
- * \param storage_scope The storage scope associated with this realization.
- * \param condition The condition expression.
- * \return The result RealizeFrame.
- */
-RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String
storage_scope,
- PrimExpr condition);
-
/*!
* \brief The allocate node.
* \param extents The extents of the allocate.
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 86703346cb..4d0029803b 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -227,57 +227,6 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};
-/*!
- * \brief Annotate the region where the buffer need to
- * be read and write in the body.
- * We only need to allocate the space for the corresponding region.
- *
- * \note There should be at most one BufferRealize for each buffer.
- * BufferRealize is not necessary for external buffers,
- * since they are assumed to be fully allocated.
- *
- * \sa BufferLoad, BufferStore
- */
-class BufferRealizeNode : public StmtNode {
- public:
- /*! \brief The buffer variable. */
- Buffer buffer;
- /*! \brief Bounds to be realized */
- ffi::Array<Range> bounds;
- /*! \brief Only realize if condition holds. */
- PrimExpr condition;
- /*! \brief The body of realization. */
- Stmt body;
-
- static void RegisterReflection() {
- namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<BufferRealizeNode>()
- .def_ro("buffer", &BufferRealizeNode::buffer)
- .def_ro("bounds", &BufferRealizeNode::bounds)
- .def_ro("condition", &BufferRealizeNode::condition)
- .def_ro("body", &BufferRealizeNode::body);
- }
-
- BufferRealizeNode() = default;
- BufferRealizeNode(Buffer buffer, ffi::Array<Range> bounds, PrimExpr
condition, Stmt body,
- Span span = Span())
- : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition),
body(body) {}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRealize", BufferRealizeNode,
StmtNode);
-};
-
-/*!
- * \brief Managed reference to BufferRealizeNode.
- * \sa BufferRealizeNode
- */
-class BufferRealize : public Stmt {
- public:
- TVM_DLL explicit BufferRealize(Buffer buffer, ffi::Array<Range> bounds,
PrimExpr condition,
- Stmt body, Span span = Span());
-
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BufferRealize, Stmt,
BufferRealizeNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode);
-};
-
/*!
* \brief Allocate a buffer that can be used in body.
*/
@@ -989,8 +938,6 @@ constexpr const char* extern_scope = "extern_scope";
constexpr const char* compute_scope = "compute_scope";
/*! \brief Mark storage alignment requirement of buffers */
constexpr const char* storage_alignment = "storage_alignment";
-/*! \brief Mark storage scope of realization */
-constexpr const char* realize_scope = "realize_scope";
/*! \brief The allocation device for global malloc in host. */
constexpr const char* device_id = "device_id";
/*! \brief The device type. */
@@ -1034,8 +981,6 @@ constexpr const char* double_buffer_scope =
"double_buffer_scope";
* \brief Marks region used by double buffer write
*/
constexpr const char* double_buffer_write = "double_buffer_write";
-/*! \brief Mark realization for rolling buffer optimization */
-constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 431308f859..c7933c64f5 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -90,7 +90,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const DeclBufferNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const BufferRealizeNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
@@ -116,7 +115,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
- IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(SBlockNode);
IR_STMT_FUNCTOR_DISPATCH(SBlockRealizeNode);
vtable.Finalize();
@@ -153,7 +151,6 @@ class TVM_DLL StmtVisitor : protected
StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const DeclBufferNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
- void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
@@ -250,7 +247,6 @@ class TVM_DLL StmtMutator : protected
StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
- Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const EvaluateNode* op) override;
diff --git a/python/tvm/script/ir_builder/tir/frame.py
b/python/tvm/script/ir_builder/tir/frame.py
index fd9ae24407..785e43a454 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -62,11 +62,6 @@ class LetFrame(TIRFrame):
return self.var
-@_register_object("script.ir_builder.tir.RealizeFrame")
-class RealizeFrame(TIRFrame):
- ...
-
-
@_register_object("script.ir_builder.tir.AllocateFrame")
class AllocateFrame(TIRFrame):
def __enter__(self) -> Buffer:
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 24a8fce800..660f517c46 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1020,34 +1020,6 @@ def let(
return let_expr(v, value, body)
-def realize(
- buffer_slice: BufferRegion,
- storage_scope: str,
- condition: PrimExpr = True,
-) -> frame.RealizeFrame:
- """Create a realization.
-
- Parameters
- ----------
- buffer_slice : BufferRegion
- The region of buffer access.
-
- storage_scope : str
- The storage scope associated with this realization.
-
- condition: PrimExpr
- The condition expression, the default is True.
-
- Returns
- -------
- res : frame.RealizeFrame
- The result RealizeFrame.
- """
- return _ffi_api.Realize( # type: ignore[attr-defined] # pylint:
disable=no-member
- buffer_slice, storage_scope, condition
- )
-
-
def allocate(
extents: List[PrimExpr],
dtype: str,
@@ -2145,7 +2117,6 @@ __all__ = float_types + [
"thread_binding",
"grid",
"Assert",
- "realize",
"allocate",
"attr",
"While",
diff --git a/python/tvm/script/parser/core/evaluator.py
b/python/tvm/script/parser/core/evaluator.py
index 7668fa99e6..c69d2a328b 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -174,7 +174,7 @@ class ExprEvaluator:
if (
isinstance(node, doc.Call)
and hasattr(node.func, "attr")
- and node.func.attr not in ["reads", "writes", "match_buffer",
"realize"]
+ and node.func.attr not in ["reads", "writes", "match_buffer"]
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare,
doc.BoolOp, doc.IfExp)):
if isinstance(node, doc.BinOp):
args = [node.left, node.right]
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index d7053fc862..29a597980b 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -30,7 +30,6 @@ from .expr import Call, CallEffectKind, Let, IterVar,
CommReducer
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import (
BufferStore,
- BufferRealize,
Allocate,
AttrStmt,
DeclBuffer,
diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py
index 935ddcb2e2..1bdd01ccbc 100644
--- a/python/tvm/tir/functor.py
+++ b/python/tvm/tir/functor.py
@@ -64,7 +64,6 @@ from .stmt import (
AttrStmt,
SBlock,
SBlockRealize,
- BufferRealize,
BufferStore,
DeclBuffer,
Evaluate,
@@ -166,7 +165,6 @@ class _PyStmtExprVisitor(tvm_ffi.core.Object):
f_visit_allocate: Callable = None,
f_visit_decl_buffer: Callable = None,
f_visit_buffer_store: Callable = None,
- f_visit_buffer_realize: Callable = None,
f_visit_assert_stmt: Callable = None,
f_visit_seq_stmt: Callable = None,
f_visit_evaluate: Callable = None,
@@ -221,7 +219,6 @@ class _PyStmtExprVisitor(tvm_ffi.core.Object):
f_visit_allocate,
f_visit_decl_buffer,
f_visit_buffer_store,
- f_visit_buffer_realize,
f_visit_assert_stmt,
f_visit_seq_stmt,
f_visit_evaluate,
@@ -285,7 +282,6 @@ class PyStmtExprVisitor:
"visit_allocate_",
"visit_decl_buffer_",
"visit_buffer_store_",
- "visit_buffer_realize_",
"visit_assert_stmt_",
"visit_seq_stmt_",
"visit_evaluate_",
@@ -452,19 +448,6 @@ class PyStmtExprVisitor:
print("visit_buffer_store_", op)
_ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
- def visit_buffer_realize_(self, op: BufferRealize) -> None:
- """Visit BufferRealize.
- Users can customize this function to overwrite VisitStmt_(const
BufferRealizeNode* op)
- on the C++ side.
-
- Parameters
- ----------
- op : BufferRealize
- The BufferRealize to be visited.
- """
- print("visit_buffer_realize_", op)
- _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
-
def visit_assert_stmt_(self, op: AssertStmt) -> None:
"""Visit AssertStmt.
Users can customize this function to overwrite VisitStmt_(const
AssertStmtNode* op)
@@ -984,7 +967,6 @@ class _PyStmtExprMutator(tvm_ffi.core.Object):
f_visit_allocate: Callable = None,
f_visit_decl_buffer: Callable = None,
f_visit_buffer_store: Callable = None,
- f_visit_buffer_realize: Callable = None,
f_visit_assert_stmt: Callable = None,
f_visit_seq_stmt: Callable = None,
f_visit_evaluate: Callable = None,
@@ -1039,7 +1021,6 @@ class _PyStmtExprMutator(tvm_ffi.core.Object):
f_visit_allocate,
f_visit_decl_buffer,
f_visit_buffer_store,
- f_visit_buffer_realize,
f_visit_assert_stmt,
f_visit_seq_stmt,
f_visit_evaluate,
@@ -1103,7 +1084,6 @@ class PyStmtExprMutator:
"visit_allocate_",
"visit_decl_buffer_",
"visit_buffer_store_",
- "visit_buffer_realize_",
"visit_assert_stmt_",
"visit_seq_stmt_",
"visit_evaluate_",
@@ -1316,23 +1296,6 @@ class PyStmtExprMutator:
"""
return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
- def visit_buffer_realize_(self, op: BufferRealize) -> Stmt:
- """Visit BufferRealize.
- Users can customize this function to overwrite VisitStmt_(const
BufferRealizeNode* op)
- on the C++ side.
-
- Parameters
- ----------
- op : BufferRealize
- The BufferRealize to be visited.
-
- Returns
- -------
- result : Stmt
- The mutated Stmt.
- """
- return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
-
def visit_assert_stmt_(self, op: AssertStmt) -> Stmt:
"""Visit AssertStmt.
Users can customize this function to overwrite VisitStmt_(const
AssertStmtNode* op)
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 9dfcf1f18a..1574d86b4a 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -259,47 +259,6 @@ class BufferStore(Stmt):
)
-@tvm_ffi.register_object("tir.BufferRealize")
-class BufferRealize(Stmt):
- """Buffer realize node.
-
- Parameters
- ----------
- buffer : Buffer
- The buffer.
-
- bounds : List[Range]
- The value we to be stored.
-
- condition : PrimExpr
- The realize condition.
-
- body : Stmt
- The body of the statement.
-
- span : Optional[Span]
- The location of the stmt in the source code.
- """
-
- buffer: Buffer
- bounds: List[Range]
- condition: PrimExpr
- body: Stmt
- span: Optional[Span]
-
- def __init__(
- self,
- buffer: Buffer,
- bounds: List[Range],
- condition: PrimExpr,
- body: Stmt,
- span: Optional[Span] = None,
- ) -> None:
- self.__init_handle_by_constructor__(
- _ffi_api.BufferRealize, buffer, bounds, condition, body, span #
type: ignore
- )
-
-
@tvm_ffi.register_object("tir.Allocate")
class Allocate(Stmt):
"""Allocate node.
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 676e24031e..7de12d5301 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -67,17 +67,6 @@ def VectorizeLoop(enable_vectorize: bool = True):
return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore
-def InjectRollingBuffer():
- """Inject rolling buffer statements.
-
- Returns
- -------
- fpass : tvm.transform.Pass
- The result pass
- """
- return _ffi_api.InjectRollingBuffer() # type: ignore
-
-
def StorageRewrite():
"""Rewrite storage allocation pattern.
diff --git a/src/relax/transform/dataflow_inplace.cc
b/src/relax/transform/dataflow_inplace.cc
index 0b9eeb8341..cf6b690ae3 100644
--- a/src/relax/transform/dataflow_inplace.cc
+++ b/src/relax/transform/dataflow_inplace.cc
@@ -716,13 +716,6 @@ tir::Stmt RemapBuffers(const tir::Stmt& stmt,
return node;
}
- tir::Stmt VisitStmt_(const tir::BufferRealizeNode* op) final {
- auto node =
Downcast<tir::BufferRealize>(tir::StmtExprMutator::VisitStmt_(op));
- auto* node_cow = node.CopyOnWrite();
- node_cow->buffer = AttemptRemap(node->buffer);
- return node;
- }
-
tir::Stmt VisitStmt_(const tir::DeclBufferNode* op) final {
auto node =
Downcast<tir::DeclBuffer>(tir::StmtExprMutator::VisitStmt_(op));
auto* node_cow = node.CopyOnWrite();
diff --git a/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
index 67ee0dbe69..da0dab4a97 100644
--- a/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
+++ b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
@@ -261,12 +261,6 @@ class LCADetector : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}
- void VisitStmt_(const BufferRealizeNode* op) final {
- buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
- UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
- StmtExprVisitor::VisitStmt_(op);
- }
-
// Works for Load/Store and opaque access.
void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); }
diff --git a/src/s_tir/transform/plan_update_buffer_allocation_location.cc
b/src/s_tir/transform/plan_update_buffer_allocation_location.cc
index 1e0bed7c98..528f43bade 100644
--- a/src/s_tir/transform/plan_update_buffer_allocation_location.cc
+++ b/src/s_tir/transform/plan_update_buffer_allocation_location.cc
@@ -202,11 +202,6 @@ class BufferAllocationLocator : public StmtExprMutator {
return Stmt(n);
}
- Stmt VisitStmt_(const BufferRealizeNode* op) final {
- ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in
TensorIR.";
- throw;
- }
-
Stmt InjectOpaqueBlock(Stmt body, const ffi::Array<Buffer>& alloc_buffers) {
ICHECK(!alloc_buffers.empty());
SBlock opaque_block(/*iter_vars=*/{},
diff --git a/src/script/ir_builder/tir/frame.cc
b/src/script/ir_builder/tir/frame.cc
index e6c8b5d0b1..abc981de1f 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -37,7 +37,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
AssertFrameNode::RegisterReflection();
LetFrameNode::RegisterReflection();
LaunchThreadFrameNode::RegisterReflection();
- RealizeFrameNode::RegisterReflection();
AllocateFrameNode::RegisterReflection();
AttrFrameNode::RegisterReflection();
WhileFrameNode::RegisterReflection();
@@ -135,14 +134,6 @@ void LetFrameNode::ExitWithScope() {
AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
}
-void RealizeFrameNode::ExitWithScope() {
- TIRFrameNode::ExitWithScope();
- AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
- tvm::tir::StringImm(storage_scope),
- tvm::tir::BufferRealize(buffer_slice->buffer,
buffer_slice->region,
- condition,
AsStmt(stmts))));
-}
-
void LaunchThreadFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 9f11743c05..cd8838f701 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -499,15 +499,6 @@ LaunchThreadFrame LaunchThread(ffi::String thread_tag,
PrimExpr extent) {
return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
}
-RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String
storage_scope,
- PrimExpr condition) {
- ObjectPtr<RealizeFrameNode> n = ffi::make_object<RealizeFrameNode>();
- n->buffer_slice = buffer_slice;
- n->storage_scope = storage_scope;
- n->condition = condition;
- return RealizeFrame(n);
-}
-
AllocateFrame Allocate(ffi::Array<PrimExpr> extents, DataType dtype,
ffi::String storage_scope,
ffi::Optional<PrimExpr> condition,
ffi::Optional<ffi::Map<ffi::String, Any>> annotations) {
@@ -735,7 +726,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("script.ir_builder.tir.LetStmt", LetStmt)
.def("script.ir_builder.tir.LegacyLetStmt", LegacyLetStmt)
.def("script.ir_builder.tir.Allocate", Allocate)
- .def("script.ir_builder.tir.Realize", Realize)
.def("script.ir_builder.tir.Attr", Attr)
.def("script.ir_builder.tir.While", While)
.def("script.ir_builder.tir.If", If)
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 9df4764dd7..bf9d2253ce 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -251,35 +251,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
});
-ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt,
ffi::Optional<ExprDoc> value, //
- AccessPath p, IRDocsifier d) {
- ExprDoc buffer = d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer"));
- {
- ffi::Array<Doc> bounds;
- bounds.reserve(stmt->bounds.size());
- for (int i = 0, n = stmt->bounds.size(); i < n; ++i) {
- Range range = stmt->bounds[i];
- AccessPath range_p = p->Attr("bounds")->ArrayItem(i);
- bounds.push_back(
- SliceDoc(d->AsDoc<ExprDoc>(range->min, range_p->Attr("min")),
- d->AsDoc<ExprDoc>(range->min + range->extent,
range_p->Attr("extent")), //
- std::nullopt));
- }
- buffer = buffer[bounds];
- }
- ffi::Array<ExprDoc> args{buffer};
- ffi::Array<ffi::String> kwargs_keys;
- ffi::Array<ExprDoc> kwargs_values;
- if (value.defined()) {
- args.push_back(value.value());
- }
- if (!tir::is_one(stmt->condition)) {
- kwargs_keys.push_back("condition");
- kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->condition,
p->Attr("condition")));
- }
- return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
-}
-
void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath&
iter_var_p,
const IRDocsifier& d) {
Frame f = FindLowestVarDef(iter_var->var, d).value();
@@ -313,16 +284,6 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt&
attr_stmt, const AccessPath& at
});
}
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::BufferRealize>( //
- "", [](tir::BufferRealize stmt, AccessPath p, IRDocsifier d) -> Doc {
- bool concise = AllowConciseScoping(d, stmt);
- ExprDoc rhs = DocsifyBufferRealize(stmt.get(), std::nullopt, p, d);
- With<TIRFrame> f(d, stmt);
- AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
- return DoConciseScoping(std::nullopt, rhs, &(*f)->stmts, concise);
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc {
@@ -332,19 +293,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ffi::Optional<tir::Var> define_var = std::nullopt;
tir::Stmt body = stmt->body;
AccessPath body_p = stmt_p->Attr("body");
- if (stmt->attr_key == "realize_scope") {
- if (const auto* realize = stmt->body.as<tir::BufferRealizeNode>())
{
- // TODO(tqchen): add any.same_as(ObjectRef)
- if (realize->buffer.same_as(stmt->node.cast<ObjectRef>())) {
- rhs = DocsifyBufferRealize(
- realize,
- /*value=*/d->AsDoc<ExprDoc>(stmt->value,
stmt_p->Attr("value")),
- /*p=*/stmt_p->Attr("body"), d);
- body = realize->body;
- body_p = stmt_p->Attr("body")->Attr("body");
- }
- }
- }
if (stmt->attr_key == "thread_extent" || stmt->attr_key ==
"virtual_thread") {
if (stmt->node.as<tir::IterVarNode>()) {
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
@@ -374,8 +322,6 @@ TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tir::BufferRealizeNode, ReprPrintTIR);
-
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/tir/backend/adreno/texture_flatten.cc
b/src/tir/backend/adreno/texture_flatten.cc
index 5e8a922c30..f7e5b16fba 100644
--- a/src/tir/backend/adreno/texture_flatten.cc
+++ b/src/tir/backend/adreno/texture_flatten.cc
@@ -91,50 +91,6 @@ class TextureFlattener : public TextureLoweringBase {
IRVisitorWithAnalyzer* bound_analyzer)
: TextureLoweringBase(extern_buffer_map, bound_analyzer) {}
- Stmt VisitStmt_(const BufferRealizeNode* op) final {
- if (extern_buf_.count(op->buffer)) {
- return this->VisitStmt(op->body);
- }
-
- std::string storage_scope = GetStorageScope(op->buffer);
- Var buffer_var(op->buffer->data->name_hint,
- PointerType(PrimType(op->buffer->dtype),
ffi::String(storage_scope)));
- let_binding_.insert({op->buffer->data, buffer_var});
-
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<BufferRealizeNode>();
-
- // Rewrite any buffer realizations with storage scope to 2d texture
allocations
- if (IsTextureStorage(storage_scope)) {
- Stmt body = this->VisitStmt(op->body);
- ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently
supported";
- const int bits = op->buffer->dtype.bits(),
- lanes =
static_cast<int>(op->bounds.back()->extent.as<IntImmNode>()->value);
- const int channel_size = bits * lanes;
- ICHECK(channel_size == 128 || channel_size == 64)
- << "Invalid Channel Size: " << channel_size << " bits";
-
- struct ShapeFromRange {
- const ffi::Array<Range>& bounds;
- PrimExpr operator[](size_t i) const { return bounds[i]->extent; }
- };
- size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(),
storage_scope);
- auto texture =
- ApplyTexture2DFlattening<PrimExpr>(ShapeFromRange{op->bounds},
op->bounds.size(), axis);
- ffi::Array<PrimExpr> args;
- args.push_back(StringImm(storage_scope));
- args.push_back(IntImm(DataType::Int(64), 3)); // 2D-Array
- args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(),
- {texture.width, texture.height, texture.depth}));
- args.push_back(IntImm(DataType::Int(64), channel_size));
-
- stmt = LetStmt(buffer_var, Call(buffer_var.dtype(),
builtin::nd_mem_alloc_with_scope(), args),
- body);
- }
-
- return stmt;
- }
-
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferStoreNode>();
diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc
index 9ab37903bb..9efb4557d7 100644
--- a/src/tir/ir/py_functor.cc
+++ b/src/tir/ir/py_functor.cc
@@ -186,8 +186,6 @@ class PyStmtExprVisitorNode : public Object, public
StmtExprVisitor {
ffi::Function f_visit_decl_buffer{nullptr};
/*! \brief The packed function to the `VisitStmt_(const BufferStoreNode*
op)` function. */
ffi::Function f_visit_buffer_store{nullptr};
- /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode*
op)` function. */
- ffi::Function f_visit_buffer_realize{nullptr};
/*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)`
function. */
ffi::Function f_visit_assert_stmt{nullptr};
/*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)`
function. */
@@ -230,7 +228,6 @@ class PyStmtExprVisitorNode : public Object, public
StmtExprVisitor {
PY_STMT_VISITOR_DISPATCH(AllocateNode, f_visit_allocate);
PY_STMT_VISITOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer);
PY_STMT_VISITOR_DISPATCH(BufferStoreNode, f_visit_buffer_store);
- PY_STMT_VISITOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize);
PY_STMT_VISITOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt);
PY_STMT_VISITOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt);
PY_STMT_VISITOR_DISPATCH(EvaluateNode, f_visit_evaluate);
@@ -322,7 +319,6 @@ class PyStmtExprVisitorNode : public Object, public
StmtExprVisitor {
PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateNode);
PY_STMT_VISITOR_DEFAULT_DISPATCH(DeclBufferNode);
PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferStoreNode);
- PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferRealizeNode);
PY_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode);
PY_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode);
PY_STMT_VISITOR_DEFAULT_DISPATCH(EvaluateNode);
@@ -352,7 +348,6 @@ class PyStmtExprVisitor : public ObjectRef {
ffi::Function
f_visit_allocate, //
ffi::Function
f_visit_decl_buffer, //
ffi::Function
f_visit_buffer_store, //
- ffi::Function
f_visit_buffer_realize, //
ffi::Function
f_visit_assert_stmt, //
ffi::Function
f_visit_seq_stmt, //
ffi::Function
f_visit_evaluate, //
@@ -403,7 +398,6 @@ class PyStmtExprVisitor : public ObjectRef {
n->f_visit_allocate = std::move(f_visit_allocate);
n->f_visit_decl_buffer = std::move(f_visit_decl_buffer);
n->f_visit_buffer_store = std::move(f_visit_buffer_store);
- n->f_visit_buffer_realize = std::move(f_visit_buffer_realize);
n->f_visit_assert_stmt = std::move(f_visit_assert_stmt);
n->f_visit_seq_stmt = std::move(f_visit_seq_stmt);
n->f_visit_evaluate = std::move(f_visit_evaluate);
@@ -547,8 +541,6 @@ class PyStmtExprMutatorNode : public Object, public
StmtExprMutator {
ffi::Function f_visit_decl_buffer{nullptr};
/*! \brief The packed function to the `VisitStmt_(const BufferStoreNode*
op)` function. */
ffi::Function f_visit_buffer_store{nullptr};
- /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode*
op)` function. */
- ffi::Function f_visit_buffer_realize{nullptr};
/*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)`
function. */
ffi::Function f_visit_assert_stmt{nullptr};
/*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)`
function. */
@@ -591,7 +583,6 @@ class PyStmtExprMutatorNode : public Object, public
StmtExprMutator {
PY_STMT_MUTATOR_DISPATCH(AllocateNode, f_visit_allocate);
PY_STMT_MUTATOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer);
PY_STMT_MUTATOR_DISPATCH(BufferStoreNode, f_visit_buffer_store);
- PY_STMT_MUTATOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize);
PY_STMT_MUTATOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt);
PY_STMT_MUTATOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt);
PY_STMT_MUTATOR_DISPATCH(EvaluateNode, f_visit_evaluate);
@@ -683,7 +674,6 @@ class PyStmtExprMutatorNode : public Object, public
StmtExprMutator {
PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateNode);
PY_STMT_MUTATOR_DEFAULT_DISPATCH(DeclBufferNode);
PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferStoreNode);
- PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferRealizeNode);
PY_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode);
PY_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode);
PY_STMT_MUTATOR_DEFAULT_DISPATCH(EvaluateNode);
@@ -714,7 +704,6 @@ class PyStmtExprMutator : public ObjectRef {
ffi::Function
f_visit_allocate, //
ffi::Function
f_visit_decl_buffer, //
ffi::Function
f_visit_buffer_store, //
- ffi::Function
f_visit_buffer_realize, //
ffi::Function
f_visit_assert_stmt, //
ffi::Function
f_visit_seq_stmt, //
ffi::Function
f_visit_evaluate, //
@@ -765,7 +754,6 @@ class PyStmtExprMutator : public ObjectRef {
n->f_visit_allocate = std::move(f_visit_allocate);
n->f_visit_decl_buffer = std::move(f_visit_decl_buffer);
n->f_visit_buffer_store = std::move(f_visit_buffer_store);
- n->f_visit_buffer_realize = std::move(f_visit_buffer_realize);
n->f_visit_assert_stmt = std::move(f_visit_assert_stmt);
n->f_visit_seq_stmt = std::move(f_visit_seq_stmt);
n->f_visit_evaluate = std::move(f_visit_evaluate);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index b7f28e0aaf..6f6b2f7e14 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -38,7 +38,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
AttrStmtNode::RegisterReflection();
AssertStmtNode::RegisterReflection();
BufferStoreNode::RegisterReflection();
- BufferRealizeNode::RegisterReflection();
AllocateNode::RegisterReflection();
DeclBufferNode::RegisterReflection();
SeqStmtNode::RegisterReflection();
@@ -467,20 +466,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
}
-// BufferRealize
-BufferRealize::BufferRealize(Buffer buffer, ffi::Array<Range> bounds, PrimExpr
condition, Stmt body,
- Span span) {
- data_ = ffi::make_object<BufferRealizeNode>(buffer, bounds, condition, body,
span);
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.BufferRealize", [](Buffer buffer,
ffi::Array<Range> bounds,
- PrimExpr condition, Stmt body,
Span span) {
- return BufferRealize(buffer, bounds, condition, body, span);
- });
-}
-
// BufferRegion
PrimExpr BufferRegionNode::ToPrimExpr() const {
// Auto convert to PrimExpr if it is a single point load
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 06db54af5d..ef91f128bc 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -70,15 +70,6 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
-void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
- VisitArray(op->bounds, [this](const Range& r) {
- this->VisitExpr(r->min);
- this->VisitExpr(r->extent);
- });
- this->VisitExpr(op->condition);
- this->VisitStmt(op->body);
-}
-
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
@@ -350,22 +341,6 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
}
}
-Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
- Region bounds = Internal::Mutate(this, op->bounds);
- PrimExpr condition = this->VisitExpr(op->condition);
- Stmt body = this->VisitStmt(op->body);
-
- if (bounds.same_as(op->bounds) && condition.same_as(op->condition) &&
body.same_as(op->body)) {
- return ffi::GetRef<Stmt>(op);
- } else {
- auto n = CopyOnWrite(op);
- n->bounds = std::move(bounds);
- n->condition = std::move(condition);
- n->body = std::move(body);
- return Stmt(n);
- }
-}
-
Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
ffi::Array<Stmt> seq = Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index ba79ab856d..3aa91484b9 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -243,14 +243,6 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode*
op, AccessPath path)
Visit(op->indices, path->Attr("indices"));
}
-void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, AccessPath
path) {
- Visit(op->condition, path->Attr("condition"));
- Visit(op->bounds, path->Attr("bounds"));
- auto context = WithDefIfUndefined(op->buffer->data,
path->Attr("buffer")->Attr("data"));
- Visit(op->buffer, path->Attr("buffer"));
- Visit(op->body, path->Attr("body"));
-}
-
void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, AccessPath path)
{
Visit(op->condition, path->Attr("condition"));
Visit(op->then_case, path->Attr("then_case"));
diff --git a/src/tir/ir/tir_visitor_with_path.h
b/src/tir/ir/tir_visitor_with_path.h
index df28015495..a68a5e2852 100644
--- a/src/tir/ir/tir_visitor_with_path.h
+++ b/src/tir/ir/tir_visitor_with_path.h
@@ -108,7 +108,6 @@ class TIRVisitorWithPath
void VisitStmt_(const AllocateNode* op, ffi::reflection::AccessPath path)
override;
void VisitStmt_(const DeclBufferNode* op, ffi::reflection::AccessPath path)
override;
void VisitStmt_(const BufferStoreNode* op, ffi::reflection::AccessPath path)
override;
- void VisitStmt_(const BufferRealizeNode* op, ffi::reflection::AccessPath
path) override;
void VisitStmt_(const AssertStmtNode* op, ffi::reflection::AccessPath path)
override;
void VisitStmt_(const SeqStmtNode* op, ffi::reflection::AccessPath path)
override;
void VisitStmt_(const EvaluateNode* op, ffi::reflection::AccessPath path)
override;
diff --git a/src/tir/transform/inject_rolling_buffer.cc
b/src/tir/transform/inject_rolling_buffer.cc
deleted file mode 100644
index c3b41e0589..0000000000
--- a/src/tir/transform/inject_rolling_buffer.cc
+++ /dev/null
@@ -1,327 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file inject_rolling_buffer.cc
- * \brief Inject rolling buffer statements.
-
- Rolling buffers are buffers where one of the dimensions has been made into
- a circular buffer. Two optimizations are implemented in order to accomplish
- this: sliding window and storage folding. In particular, the sliding window
- optimization is applied to the entire buffer (to avoid recomputing
elements)
- and storage folding is then applied to just the rolling dimension.
-
- Rolling buffers must be inside a loop with only part of the buffer used per
- iteration. The outermost axis will be rolled over.
-
- For more information, see the RFC:
-
https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
- */
-#include <tvm/arith/analyzer.h>
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
-
-#include "ir_utils.h"
-
-namespace tvm {
-namespace tir {
-
-using arith::IntSet;
-
-struct RollingBufferInfo {
- int rolling_axis;
- int rolling_extent;
- std::vector<int> axis_overlaps;
- std::vector<ffi::Optional<Var>> axis_iter_vars;
-};
-
-class RollingBufferInjector : public StmtExprMutator {
- std::vector<For> for_loops{};
- std::set<Buffer> rolling_buffers{};
- std::map<Buffer, BufferRealize> buffer_to_buffer_realize{};
- std::map<Buffer, std::vector<AttrStmt>> buffer_to_attrs{};
- std::map<Buffer, RollingBufferInfo> rolling_buffer_to_info{};
- // The actual key type is Var, ObjectRef has been used because
- // of the ambiguous overload for 'operator<'
- std::map<ObjectRef, std::vector<BufferRealize>> hoist_buffer_to_for{};
-
- public:
- RollingBufferInjector() {}
-
- Stmt Inject(Stmt stmt) { return ConvertSSA(operator()(std::move(stmt))); }
-
- Stmt VisitStmt_(const ForNode* op) final {
- // Manage the stack of iter_vars
- for_loops.push_back(ffi::GetRef<For>(op));
-
- auto stmt{StmtExprMutator::VisitStmt_(op)};
- op = stmt.as<ForNode>();
-
- // Manage the stack of iter_vars
- for_loops.pop_back();
-
- auto it{hoist_buffer_to_for.find(op->loop_var)};
- if (it != hoist_buffer_to_for.end()) {
- // If the loop corresponds to an iter_var that needs a BufferRealize
- // hoisting to its scope, perform the hoisting
- Stmt body{ffi::GetRef<For>(op)};
- for (auto realise : it->second) {
- auto attrs{buffer_to_attrs[realise->buffer]};
- Stmt new_realize{BufferRealize(realise->buffer, realise->bounds,
realise->condition, body,
- realise->span)};
- // The attributes attached to the BufferRealize need hoisting too
- for (auto attr : attrs) {
- if (attr->attr_key == attr::rolling_buffer_scope) {
- continue;
- }
- new_realize = AttrStmt(attr->node, attr->attr_key, attr->value,
new_realize, attr->span);
- }
- body = new_realize;
- }
- return body;
- } else {
- return stmt;
- }
- }
-
- Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (auto opt = op->node.as<Buffer>()) {
- auto buffer = opt.value();
- // Keep a dictionary associating attribute statements with the buffers
- // they reference. We'll need this if the buffer gets hoisted and we
- // need to hoist all of its attributes at the same time.
- buffer_to_attrs[buffer].push_back(ffi::GetRef<AttrStmt>(op));
-
- if (op->attr_key == attr::rolling_buffer_scope &&
Downcast<IntImm>(op->value)->value) {
- // If the attribute is indicating that a buffer should be a rolling
- // buffer, then update the rolling_buffers set to include the buffer
- rolling_buffers.insert(buffer);
-
- auto it{buffer_to_buffer_realize.find(buffer)};
- ICHECK(it != buffer_to_buffer_realize.end())
- << "Rolling buffer injection failed: no BufferRealize found";
- BufferRealize buffer_realize = it->second;
-
- // If a BufferRealize has been identified as needing to be made into
- // a rolling buffer, begin the analysis.
- std::vector<ffi::Optional<Var>> bound_iter_vars{};
- std::vector<int> bound_overlaps{};
- // We use the bound information of the BufferRealize to calculate
- // how we can legally roll
- auto stride{0};
- auto divisor{1};
- ffi::Optional<Var> iter_var{};
- for (auto bound : buffer_realize->bounds) {
- divisor = 1;
- if (auto floor_div = bound->min.as<FloorDivNode>()) {
- // Handle the case of fractional strides
- // They take this form: floordiv(hh.outer, 2)
- // Strip the floordiv and keep track of the divisor
- divisor = Downcast<IntImm>(floor_div->b)->value;
- bound = Range::FromMinExtent(floor_div->a, bound->extent,
bound->span);
- }
- if (bound->min.as<IntImmNode>()) {
- // If the bound is an int, we can't roll over it
- iter_var = nullptr;
- } else if (auto var = bound->min.as<VarNode>()) {
- // If the bound is just a Var, that implies the stride is 1
- iter_var = ffi::GetRef<Var>(var);
- stride = 1;
- } else {
- // Otherwise, it's the iter var multiplied by the stride
- // If not we're in unknown behaviour, so assert
- auto mul = bound->min.as<MulNode>();
- ICHECK(mul) << "Rolling buffer injection failed: the buffer
striding is unsupported";
- auto a = mul->a.as<VarNode>();
- ICHECK(a) << "Rolling buffer injection failed: the buffer striding
is unsupported";
- auto b = mul->b.as<IntImmNode>();
- ICHECK(b) << "Rolling buffer injection failed: the buffer striding
is unsupported";
- iter_var = ffi::GetRef<Var>(a);
- stride = b->value;
- }
- stride = std::ceil(static_cast<float>(stride) / divisor);
- bound_iter_vars.push_back(iter_var);
- if (iter_var) {
- bound_overlaps.push_back(Downcast<IntImm>(bound->extent)->value -
stride);
- } else {
- bound_overlaps.push_back(0);
- }
- }
- // Pick the outermost iter_var that's mentioned in the bounds
- // to be the rolling axis
- ffi::Optional<Var> roll_iter_var{};
- int roll_axis{1};
- for (auto loop : for_loops) {
- auto loop_var{loop->loop_var};
- iter_var = loop_var;
-
- auto it{std::find_if(
- bound_iter_vars.begin(), bound_iter_vars.end(),
- [&](ffi::Optional<Var> var) { return var && (var.get() ==
loop_var.get()); })};
-
- if (it != bound_iter_vars.end()) {
- auto i{std::distance(bound_iter_vars.begin(), it)};
- roll_iter_var = loop_var;
- roll_axis = i;
- break;
- }
- }
- // We must have found an axis to roll over
- ICHECK(roll_iter_var) << "Rolling buffer injection failed: no rolling
axis found";
- ICHECK(roll_axis != -1) << "Rolling buffer injection failed: no
rolling axis found";
-
- RollingBufferInfo rolling_buffer_info = {
- roll_axis,
-
static_cast<int>(Downcast<IntImm>(buffer_realize->bounds[roll_axis]->extent)->value),
- bound_overlaps,
- bound_iter_vars,
- };
- rolling_buffer_to_info[buffer] = rolling_buffer_info;
- ffi::Array<Range> new_bounds{};
- auto shape{buffer->shape};
- for (size_t i{0}; i < shape.size(); ++i) {
- auto extent{shape[i]};
- if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) {
- new_bounds.push_back(Range(0, rolling_buffer_info.rolling_extent));
- } else {
- new_bounds.push_back(Range(0, extent));
- }
- }
- BufferRealize new_realize{BufferRealize(buffer, new_bounds,
buffer_realize->condition,
- buffer_realize->body,
buffer_realize->span)};
- hoist_buffer_to_for[iter_var.value()].push_back(new_realize);
- }
- }
-
- auto stmt{StmtExprMutator::VisitStmt_(op)};
- op = stmt.as<AttrStmtNode>();
-
- if (auto opt = op->node.as<Buffer>(); opt &&
rolling_buffers.count(opt.value())) {
- // Remove the attribute statements attached to rolling buffers
- // because they will have been hoisted to the relevant rolling
- // scope
- return op->body;
- } else {
- return stmt;
- }
- }
-
- Stmt VisitStmt_(const BufferRealizeNode* op) final {
- buffer_to_buffer_realize.insert({op->buffer,
ffi::GetRef<BufferRealize>(op)});
-
- auto stmt{StmtExprMutator::VisitStmt_(op)};
- op = stmt.as<BufferRealizeNode>();
-
- if (rolling_buffers.count(op->buffer)) {
- // Remove the original BufferRealize for rolling buffers
- // because they will have been hoisted to the relevant rolling
- // scope
- return op->body;
- } else {
- return stmt;
- }
- }
-
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- auto stmt{StmtExprMutator::VisitStmt_(op)};
- op = stmt.as<BufferStoreNode>();
-
- auto it{rolling_buffer_to_info.find(op->buffer)};
- if (it != rolling_buffer_to_info.end()) {
- auto rolling_buffer_info{it->second};
- std::vector<PrimExpr> indices{};
- // First modify the access indices to use modulo arithmetic
- // for the rolling axis
- for (size_t i{0}; i < op->indices.size(); ++i) {
- auto index{op->indices[i]};
- if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) {
- indices.push_back(FloorMod(index,
rolling_buffer_info.rolling_extent));
- } else {
- indices.push_back(index);
- }
- }
- ICHECK(!op->predicate.defined()) << "Predicated buffer store is not
currently supported in "
- "the inject rolling buffer pass.";
- Stmt buffer_store = BufferStore(op->buffer, op->value, indices,
op->predicate, op->span);
- // Then wrap the BufferStores in some Ifs to avoid recomputing elements
- for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) {
- auto iter_var{rolling_buffer_info.axis_iter_vars[i]};
- if (iter_var && rolling_buffer_info.axis_overlaps[i] > 0) {
- Var var{iter_var.value()};
- const ffi::Map<Var, IntSet> dmap{std::make_pair(var,
IntSet::Interval(0, 0))};
- auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()};
- auto condition = Or(LT(var, 1), GE(term_2,
rolling_buffer_info.axis_overlaps[i]));
- buffer_store = IfThenElse(likely(condition), buffer_store);
- }
- }
- return buffer_store;
- } else {
- return stmt;
- }
- }
-
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- auto expr{StmtExprMutator::VisitExpr_(op)};
- op = expr.as<BufferLoadNode>();
-
- auto it{rolling_buffer_to_info.find(op->buffer)};
- if (it != rolling_buffer_to_info.end()) {
- auto rolling_buffer_info{it->second};
- std::vector<PrimExpr> indices{};
- // Modify the access indices to use modulo arithmetic
- // for the rolling axis
- for (size_t i{0}; i < op->indices.size(); ++i) {
- auto index{op->indices[i]};
- if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) {
- indices.push_back(FloorMod(index,
rolling_buffer_info.rolling_extent));
- } else {
- indices.push_back(index);
- }
- }
- ICHECK(!op->predicate.defined())
- << "Predicated buffer load is not currently supported in inject
rolling buffer pass.";
- return BufferLoad(op->buffer, indices, op->predicate, op->span);
- } else {
- return expr;
- }
- }
-}; // namespace tir
-
-namespace transform {
-
-Pass InjectRollingBuffer() {
- auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- auto* n = f.CopyOnWrite();
- n->body = RollingBufferInjector().Inject(std::move(n->body));
- return f;
- };
- return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.InjectRollingBuffer",
InjectRollingBuffer);
-}
-
-} // namespace transform
-
-} // namespace tir
-} // namespace tvm
diff --git a/src/tir/transform/unsupported_dtype_legalize.cc
b/src/tir/transform/unsupported_dtype_legalize.cc
index 74a69dfbc3..0ae17b5484 100644
--- a/src/tir/transform/unsupported_dtype_legalize.cc
+++ b/src/tir/transform/unsupported_dtype_legalize.cc
@@ -391,18 +391,6 @@ class ComputeLegalizer : public StmtExprMutator {
return ret;
}
- Stmt VisitStmt_(const BufferRealizeNode* op) final {
- Stmt ret = StmtExprMutator::VisitStmt_(op);
- op = ret.as<BufferRealizeNode>();
-
- Buffer new_buf = GetRemappedBuffer(op->buffer);
- if (new_buf.same_as(op->buffer)) {
- return ret;
- } else {
- return BufferRealize(new_buf, op->bounds, op->condition, op->body);
- }
- }
-
Stmt VisitStmt_(const DeclBufferNode* op) final {
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<DeclBufferNode>();
@@ -627,10 +615,6 @@ class StorageLegalizer : public StmtExprMutator {
return ret;
}
- Stmt VisitStmt_(const BufferRealizeNode* op) final {
- LOG(FATAL) << "Do not expect buffer realize";
- }
-
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<BufferLoadNode>();
diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
index f513bacdbb..7cae20fbb0 100644
--- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
+++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
@@ -170,7 +170,7 @@ def extract_buffers(stmt):
buffers = []
def visitor(node):
- if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore,
tvm.tir.BufferRealize)):
+ if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore)):
buffers.append(node.buffer)
post_order_visit(stmt, visitor)
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
index 1297cec61f..8f7b0232fa 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
@@ -314,37 +314,6 @@ def test_block_match_buffer_defines_symbolic_variables():
tvm.tir.analysis.verify_well_formed(mod)
-def test_buffer_realize_on_external_buffer_is_annotation():
- """A T.realize statement on an existing buffer annotates the region used"""
-
- @I.ir_module
- class mod:
- @T.prim_func
- def func(A: T.Buffer(256, "int32")):
- T.realize(A[0:16], "global")
-
- for i in range(16):
- A[i] = 1
-
- tvm.tir.analysis.verify_well_formed(mod)
-
-
-def test_buffer_realize_is_allocation():
- """A T.realize statement on an fresh buffer allocates the buffer"""
-
- @I.ir_module
- class mod:
- @T.prim_func
- def func():
- A = T.Buffer(256, "int32")
- T.realize(A[0:16], "global")
-
- for i in range(16):
- A[i] = 1
-
- tvm.tir.analysis.verify_well_formed(mod)
-
-
def test_error_message_without_previous_definition_location():
"""Test case 1: Error message without 'It was first defined at'
diff --git a/tests/python/tir-base/test_tir_nodes.py
b/tests/python/tir-base/test_tir_nodes.py
index 85cd726dda..b0bd5ac891 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -368,9 +368,6 @@ def test_buffer_load_store():
s = tvm.tir.BufferStore(b, 0.1, [0])
assert isinstance(s, tvm.tir.BufferStore)
- s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)], True,
tvm.tir.Evaluate(0))
- assert isinstance(s, tvm.tir.BufferRealize)
-
def test_intimm_cond():
x = tvm.runtime.convert(1)
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
deleted file mode 100644
index 4dd1380c8f..0000000000
--- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import numpy as np
-import pytest
-import tvm
-import tvm.script
-from tvm import te, topi
-from tvm.script import tir as T
-
-
[email protected]_module
-class PreRollingBuffer:
- @T.prim_func
- def main(
- A: T.handle,
- tensor: T.handle,
- tensor_2: T.Buffer(
- [1, 10, 12, 16],
- dtype="int8",
- elem_offset=0,
- align=64,
- offset_factor=1,
- ),
- ) -> None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
- A_1 = T.match_buffer(
- A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64,
offset_factor=1
- )
- tensor_1 = T.match_buffer(
- tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=64,
offset_factor=1
- )
- # body
- T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "")
- for ax1_outer in T.serial(0, 2):
- T.realize(tensor_2[0:1, (ax1_outer * 4) : ((ax1_outer * 4) + 6),
0:12, 0:16], "")
- T.attr(tensor_2, "rolling_buffer_scope", True)
- for ax1 in T.serial(0, 6):
- for ax2 in T.serial(0, 12):
- for ax3 in T.serial(0, 16):
- tensor_2[0, (ax1 + (ax1_outer * 4)), ax2, ax3] =
T.int8(0)
- for dh in T.serial(0, 3):
- for dw in T.serial(0, 3):
- tensor_2[0, (ax1 + (ax1_outer * 4)), ax2, ax3]
= T.max(
- tensor_2[0, (ax1 + (ax1_outer * 4)), ax2,
ax3],
- A_1[0, ((ax1 + (ax1_outer * 4)) + dh),
(ax2 + dw), ax3],
- )
- for ax1_inner in T.serial(0, 4):
- for ax2_inner in T.serial(0, 8):
- for ax3_inner in T.serial(0, 16):
- tensor_1[
- 0,
- (ax1_inner + (ax1_outer * 4)),
- ax2_inner,
- ax3_inner,
- ] = T.int8(0)
- for dh_1 in T.serial(0, 3):
- for dw_1 in T.serial(0, 5):
- tensor_1[
- 0,
- (ax1_inner + (ax1_outer * 4)),
- ax2_inner,
- ax3_inner,
- ] = T.max(
- tensor_1[
- 0,
- (ax1_inner + (ax1_outer * 4)),
- ax2_inner,
- ax3_inner,
- ],
- tensor_2[
- 0,
- ((ax1_inner + (ax1_outer * 4)) + dh_1),
- (ax2_inner + dw_1),
- ax3_inner,
- ],
- )
-
-
[email protected]_module
-class PostRollingBuffer:
- @T.prim_func
- def main(
- A: T.handle,
- tensor: T.handle,
- tensor_2: T.Buffer(
- [1, 10, 12, 16],
- dtype="int8",
- elem_offset=0,
- align=64,
- offset_factor=1,
- ),
- ) -> None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
- A_1 = T.match_buffer(
- A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64,
offset_factor=1
- )
- tensor_1 = T.match_buffer(
- tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=64,
offset_factor=1
- )
- # body
- T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "")
- T.realize(tensor_2[0:1, 0:6, 0:12, 0:16], "")
- for ax1_outer in T.serial(0, 2):
- for ax1 in T.serial(0, 6):
- for ax2 in T.serial(0, 12):
- for ax3 in T.serial(0, 16):
- if T.likely(((ax1_outer < 1) or (ax1 >= 2)),
dtype="bool"):
- tensor_2[
- 0,
- T.floormod((ax1 + (ax1_outer * 4)), 6),
- ax2,
- ax3,
- ] = T.int8(0)
- for dh in T.serial(0, 3):
- for dw in T.serial(0, 3):
- if T.likely(((ax1_outer < 1) or (ax1 >= 2)),
dtype="bool"):
- tensor_2[
- 0, T.floormod((ax1 + (ax1_outer * 4)),
6), ax2, ax3
- ] = T.max(
- tensor_2[
- 0, T.floormod((ax1 + (ax1_outer *
4)), 6), ax2, ax3
- ],
- A_1[0, ((ax1 + (ax1_outer * 4)) + dh),
(ax2 + dw), ax3],
- )
- for ax1_inner in T.serial(0, 4):
- for ax2_inner in T.serial(0, 8):
- for ax3_inner in T.serial(0, 16):
- tensor_1[
- 0,
- (ax1_inner + (ax1_outer * 4)),
- ax2_inner,
- ax3_inner,
- ] = T.int8(0)
- for dh_1 in T.serial(0, 3):
- for dw_1 in T.serial(0, 5):
- tensor_1[
- 0,
- (ax1_inner + (ax1_outer * 4)),
- ax2_inner,
- ax3_inner,
- ] = T.max(
- tensor_1[
- 0, (ax1_inner + (ax1_outer * 4)),
ax2_inner, ax3_inner
- ],
- tensor_2[
- 0,
- T.floormod(((ax1_inner + (ax1_outer *
4)) + dh_1), 6),
- (ax2_inner + dw_1),
- ax3_inner,
- ],
- )
-
-
-def test_rolling_buffer_ir_transform():
- mod = PreRollingBuffer
- mod = tvm.tir.transform.InjectRollingBuffer()(mod)
- script = mod.script()
- mod = tvm.script.from_source(script)
- tvm.ir.assert_structural_equal(mod["main"], PostRollingBuffer["main"],
True)
-
-
-if __name__ == "__main__":
- tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py
b/tests/python/tvmscript/test_tvmscript_error_report.py
index c1ed2425cd..3b48f06eb2 100644
--- a/tests/python/tvmscript/test_tvmscript_error_report.py
+++ b/tests/python/tvmscript/test_tvmscript_error_report.py
@@ -76,26 +76,22 @@ def test_undefined_buffer():
def undefined_buffer(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
- T.attr(A, "realize_scope", "")
- T.realize(C[0:16, 0:16], "") # error
for i in T.serial(16):
for j in T.serial(0, 16):
- A[i, j] = 0.0
+ C[i, j] = 0.0 # error
- check_error(undefined_buffer, 5)
+ check_error(undefined_buffer, 6)
def test_unsupported_function_call():
def unsupported_function_call(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
- T.attr(A, "realize_scope", "")
- T.realize(A[0:16, 0:16], "")
for i in T.const_range(16): # error
for j in T.serial(0, 16):
A[i, j] = 0.0
- check_error(unsupported_function_call, 6)
+ check_error(unsupported_function_call, 4)
def test_missing_type_annotation():
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index 9f98e10e7b..7395e15333 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -315,27 +315,6 @@ def test_ir_builder_tir_let():
assert_structural_equal(let_actual, let_expected, map_free_vars=True)
-def test_ir_builder_tir_realize():
- buffer_a = T.Buffer((128, 128), "float32")
- with IRBuilder() as ib:
- with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
- T.evaluate(0)
-
- # the buffer realization generated by IRBuilder
- realize_actual = ib.get()
-
- # the expected buffer realization
- buffer_realize = tir.BufferRealize(
- buffer_a, [tvm.ir.Range(0, 128), tvm.ir.Range(0, 128)], True,
tir.Evaluate(0)
- )
- expected_realize = tir.AttrStmt(
- buffer_a, "realize_scope", tir.StringImm("test_storage_scope"),
buffer_realize
- )
-
- # Check if the generated ir is expected
- assert_structural_equal(realize_actual, expected_realize,
map_free_vars=True)
-
-
def test_ir_builder_tir_thread():
with IRBuilder() as ib:
with T.prim_func():
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index f8eea3aee3..64057e9ca0 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -429,22 +429,6 @@ T.evaluate(0)
)
-def test_buffer_realize():
- with IRBuilder() as ib:
- a = tir.decl_buffer((128, 128), "float32", name="A")
- with T.realize(a[0:128, 0:128], "test_storage_scope", True):
- T.evaluate(0)
- obj = ib.get()
- _assert_print(
- obj,
- """
-A = T.Buffer((128, 128))
-with T.realize(A[0:128, 0:128], "test_storage_scope"):
- T.evaluate(0)
-""",
- )
-
-
def test_var():
a = tir.Var("a", "float32")
_assert_print(
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 1ea5bcb24c..8313bf7e75 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -26,64 +26,6 @@ from tvm.script import tir as T, ir as I, relax as R
import numpy as np
-def opt_gemm_normalize():
- @tvm.script.ir_module(check_well_formed=False)
- class Module:
- # packedB is treated as undefined
- @T.prim_func
- def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- # buffer definition
- C_global = T.Buffer([1024, 1024], elem_offset=0, align=64,
offset_factor=1)
- packedB = T.Buffer([32, 1024, 32], elem_offset=0, align=64,
offset_factor=1)
- A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=64,
offset_factor=1)
- B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=64,
offset_factor=1)
- C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=64,
offset_factor=1)
- # body
- T.realize(packedB[0:32, 0:1024, 0:32], "")
- for x in T.parallel(0, 32):
- for y in T.serial(0, 1024):
- for z in T.vectorized(0, 32):
- packedB[x, y, z] = B_1[y, ((x * 32) + z)]
- T.realize(C_1[0:1024, 0:1024], "")
- for x_outer in T.parallel(0, 32):
- for y_outer in T.serial(0, 32):
- T.realize(
- C_global[
- (x_outer * 32) : ((x_outer * 32) + 32),
- (y_outer * 32) : ((y_outer * 32) + 32),
- ],
- "global",
- )
- for x_c_init in T.serial(0, 32):
- for y_c_init in T.vectorized(0, 32):
- C_global[
- (x_c_init + (x_outer * 32)), (y_c_init +
(y_outer * 32))
- ] = T.float32(0)
- for k_outer in T.serial(0, 256):
- for x_c in T.serial(0, 32):
- for k_inner in T.unroll(0, 4):
- for y_c in T.vectorized(0, 32):
- C_global[
- (x_c + (x_outer * 32)), (y_c +
(y_outer * 32))
- ] = C_global[(x_c + (x_outer * 32)), (y_c
+ (y_outer * 32))] + (
- A_1[(x_c + (x_outer * 32)), (k_inner +
(k_outer * 4))]
- * packedB[
- T.floordiv((y_c + (y_outer * 32)),
32),
- (k_inner + (k_outer * 4)),
- T.floormod((y_c + (y_outer * 32)),
32),
- ]
- )
- for x_inner in T.serial(0, 32):
- for y_inner in T.serial(0, 32):
- C_1[(x_inner + (x_outer * 32)), (y_inner +
(y_outer * 32))] = C_global[
- (x_inner + (x_outer * 32)), (y_inner +
(y_outer * 32))
- ]
-
- return Module
-
-
def opt_gemm_lower():
@tvm.script.ir_module
class Module:
@@ -475,461 +417,6 @@ def opt_gemm_mod_host():
return Module
-def opt_conv_tensorcore_normalize():
- @T.prim_func(check_well_formed=False)
- def func(A: T.handle, W: T.handle, Conv: T.handle) -> None:
- # function attr dict
- T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
- # var definition
- bx = T.env_thread("blockIdx.x")
- by = T.env_thread("blockIdx.y")
- bz = T.env_thread("blockIdx.z")
- tx = T.env_thread("threadIdx.x")
- ty = T.env_thread("threadIdx.y")
- tz = T.env_thread("threadIdx.z")
- # buffer definition
- Apad_shared = T.Buffer(
- [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0,
align=64, offset_factor=1
- )
- Apad_shared_wmma_matrix_a = T.Buffer(
- [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0,
align=64, offset_factor=1
- )
- BA = T.Buffer([16, 16], dtype="float16", scope="wmma.matrix_a",
align=32, offset_factor=256)
- BB = T.Buffer([16, 16], dtype="float16", scope="wmma.matrix_b",
align=32, offset_factor=256)
- BC = T.Buffer([16, 16], scope="wmma.accumulator", align=32,
offset_factor=256)
- Conv_wmma_accumulator = T.Buffer(
- [16, 14, 14, 32, 16, 16], elem_offset=0, align=64, offset_factor=1
- )
- W_shared = T.Buffer(
- [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=64,
offset_factor=1
- )
- W_shared_wmma_matrix_b = T.Buffer(
- [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=64,
offset_factor=1
- )
- buffer = T.Buffer([16, 16], dtype="float16", scope="shared", align=32,
offset_factor=256)
- buffer_1 = T.Buffer(
- [16, 16], dtype="float16", scope="wmma.matrix_a", align=32,
offset_factor=256
- )
- buffer_2 = T.Buffer([16, 16], dtype="float16", scope="shared",
align=32, offset_factor=256)
- buffer_3 = T.Buffer(
- [16, 16], dtype="float16", scope="wmma.matrix_b", align=32,
offset_factor=256
- )
- buffer_4 = T.Buffer([16, 16], scope="wmma.accumulator", align=32,
offset_factor=256)
- buffer_5 = T.Buffer([16, 16], align=32, offset_factor=256)
- A_1 = T.match_buffer(
- A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0,
align=64, offset_factor=1
- )
- W_1 = T.match_buffer(
- W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0,
align=64, offset_factor=1
- )
- Conv_1 = T.match_buffer(
- Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=64,
offset_factor=1
- )
- # body
- T.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "")
- T.launch_thread(bz, 196)
- T.launch_thread(bx, 2)
- T.launch_thread(by, 4)
- T.launch_thread(ty, 4)
- T.launch_thread(tz, 2)
- T.realize(
- Conv_wmma_accumulator[
- ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2),
- T.floordiv(bz, 14) : (T.floordiv(bz, 14) + 1),
- T.floormod(bz, 14) : (T.floormod(bz, 14) + 1),
- ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4),
- 0:16,
- 0:16,
- ],
- "wmma.accumulator",
- )
- for n_c_init in T.serial(0, 2):
- for o_c_init in T.serial(0, 4):
- T.attr(
- [BC, Conv_wmma_accumulator],
- "buffer_bind_scope",
- T.tvm_tuple(
- (n_c_init + ((bx * 8) + (ty * 2))),
- 1,
- T.floordiv(bz, 14),
- 1,
- T.floormod(bz, 14),
- 1,
- (o_c_init + ((by * 8) + (tz * 4))),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.evaluate(
- T.tvm_fill_fragment(
- BC.data,
- 16,
- 16,
- 16,
- T.floordiv(BC.elem_offset, 256),
- T.float32(0),
- dtype="handle",
- )
- )
-
- for ic_outer in T.serial(0, 8):
- for kh in T.serial(0, 3):
- T.realize(
- Apad_shared[
- (bx * 8) : ((bx * 8) + 8),
- (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh)
+ 1),
- T.floormod(bz, 14) : (T.floormod(bz, 14) + 3),
- (ic_outer * 2) : ((ic_outer * 2) + 2),
- 0:16,
- 0:16,
- ],
- "shared",
- )
- for ax2 in T.serial(0, 3):
- for ax3 in T.serial(0, 2):
- for ax4_ax5_fused_outer in T.serial(0, 8):
- T.launch_thread(tx, 32)
- Apad_shared[
- ((tz + (ty * 2)) + (bx * 8)),
- (T.floordiv(bz, 14) + kh),
- (ax2 + T.floormod(bz, 14)),
- (ax3 + (ic_outer * 2)),
- T.floordiv((tx + (ax4_ax5_fused_outer * 32)),
16),
- T.floormod((tx + (ax4_ax5_fused_outer * 32)),
16),
- ] = T.if_then_else(
- (
- (
- (
- ((T.floordiv(bz, 14) + kh) >= 1)
- and (((T.floordiv(bz, 14) + kh) -
1) < 14)
- )
- and ((ax2 + T.floormod(bz, 14)) >= 1)
- )
- and (((ax2 + T.floormod(bz, 14)) - 1) < 14)
- ),
- A_1[
- ((tz + (ty * 2)) + (bx * 8)),
- ((T.floordiv(bz, 14) + kh) - 1),
- ((ax2 + T.floormod(bz, 14)) - 1),
- (ax3 + (ic_outer * 2)),
- T.floordiv((tx + (ax4_ax5_fused_outer *
32)), 16),
- T.floormod((tx + (ax4_ax5_fused_outer *
32)), 16),
- ],
- T.float16(0),
- dtype="float16",
- )
- T.realize(
- W_shared[
- kh : (kh + 1),
- 0:3,
- (ic_outer * 2) : ((ic_outer * 2) + 2),
- (by * 8) : ((by * 8) + 8),
- 0:16,
- 0:16,
- ],
- "shared",
- )
- for ax1 in T.serial(0, 3):
- for ax2_1 in T.serial(0, 2):
- T.launch_thread(tx, 32)
- for ax4_ax5_fused_inner in T.vectorized(0, 8):
- W_shared[
- kh,
- ax1,
- (ax2_1 + (ic_outer * 2)),
- ((tz + (ty * 2)) + (by * 8)),
- T.floordiv((ax4_ax5_fused_inner + (tx * 8)),
16),
- T.floormod((ax4_ax5_fused_inner + (tx * 8)),
16),
- ] = W_1[
- kh,
- ax1,
- (ax2_1 + (ic_outer * 2)),
- ((tz + (ty * 2)) + (by * 8)),
- T.floordiv((ax4_ax5_fused_inner + (tx * 8)),
16),
- T.floormod((ax4_ax5_fused_inner + (tx * 8)),
16),
- ]
- for ic_inner in T.serial(0, 2):
- for kw in T.serial(0, 3):
- T.realize(
- Apad_shared_wmma_matrix_a[
- ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2))
+ 2),
- (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz,
14) + kh) + 1),
- (kw + T.floormod(bz, 14)) : ((kw +
T.floormod(bz, 14)) + 1),
- ((ic_outer * 2) + ic_inner) : (((ic_outer * 2)
+ ic_inner) + 1),
- 0:16,
- 0:16,
- ],
- "wmma.matrix_a",
- )
- for ax0 in T.serial(0, 2):
- T.attr(
- [buffer, Apad_shared],
- "buffer_bind_scope",
- T.tvm_tuple(
- (ax0 + ((bx * 8) + (ty * 2))),
- 1,
- (T.floordiv(bz, 14) + kh),
- 1,
- (kw + T.floormod(bz, 14)),
- 1,
- ((ic_outer * 2) + ic_inner),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.attr(
- [buffer_1, Apad_shared_wmma_matrix_a],
- "buffer_bind_scope",
- T.tvm_tuple(
- (ax0 + ((bx * 8) + (ty * 2))),
- 1,
- (T.floordiv(bz, 14) + kh),
- 1,
- (kw + T.floormod(bz, 14)),
- 1,
- ((ic_outer * 2) + ic_inner),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.evaluate(
- T.tvm_load_matrix_sync(
- buffer_1.data,
- 16,
- 16,
- 16,
- T.floordiv(buffer_1.elem_offset, 256),
- T.tvm_access_ptr(
- T.type_annotation(dtype="float16"),
- buffer.data,
- buffer.elem_offset,
- 256,
- 1,
- dtype="handle",
- ),
- 16,
- "row_major",
- dtype="handle",
- )
- )
- T.realize(
- W_shared_wmma_matrix_b[
- kh : (kh + 1),
- kw : (kw + 1),
- ((ic_outer * 2) + ic_inner) : (((ic_outer * 2)
+ ic_inner) + 1),
- ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4))
+ 4),
- 0:16,
- 0:16,
- ],
- "wmma.matrix_b",
- )
- for ax3_1 in T.serial(0, 4):
- T.attr(
- [buffer_2, W_shared],
- "buffer_bind_scope",
- T.tvm_tuple(
- kh,
- 1,
- kw,
- 1,
- ((ic_outer * 2) + ic_inner),
- 1,
- (ax3_1 + ((by * 8) + (tz * 4))),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.attr(
- [buffer_3, W_shared_wmma_matrix_b],
- "buffer_bind_scope",
- T.tvm_tuple(
- kh,
- 1,
- kw,
- 1,
- ((ic_outer * 2) + ic_inner),
- 1,
- (ax3_1 + ((by * 8) + (tz * 4))),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.evaluate(
- T.tvm_load_matrix_sync(
- buffer_3.data,
- 16,
- 16,
- 16,
- T.floordiv(buffer_3.elem_offset, 256),
- T.tvm_access_ptr(
- T.type_annotation(dtype="float16"),
- buffer_2.data,
- buffer_2.elem_offset,
- 256,
- 1,
- dtype="handle",
- ),
- 16,
- "row_major",
- dtype="handle",
- )
- )
- for n_c in T.serial(0, 2):
- for o_c in T.serial(0, 4):
- T.attr(
- [BA, Apad_shared_wmma_matrix_a],
- "buffer_bind_scope",
- T.tvm_tuple(
- (n_c + ((bx * 8) + (ty * 2))),
- 1,
- (T.floordiv(bz, 14) + kh),
- 1,
- (T.floormod(bz, 14) + kw),
- 1,
- ((ic_outer * 2) + ic_inner),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.attr(
- [BB, W_shared_wmma_matrix_b],
- "buffer_bind_scope",
- T.tvm_tuple(
- kh,
- 1,
- kw,
- 1,
- ((ic_outer * 2) + ic_inner),
- 1,
- (o_c + ((by * 8) + (tz * 4))),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.attr(
- [BC, Conv_wmma_accumulator],
- "buffer_bind_scope",
- T.tvm_tuple(
- (n_c + ((bx * 8) + (ty * 2))),
- 1,
- T.floordiv(bz, 14),
- 1,
- T.floormod(bz, 14),
- 1,
- (o_c + ((by * 8) + (tz * 4))),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.evaluate(
- T.tvm_mma_sync(
- BC.data,
- T.floordiv(BC.elem_offset, 256),
- BA.data,
- T.floordiv(BA.elem_offset, 256),
- BB.data,
- T.floordiv(BB.elem_offset, 256),
- BC.data,
- T.floordiv(BC.elem_offset, 256),
- dtype="handle",
- )
- )
- for n_inner in T.serial(0, 2):
- for o_inner in T.serial(0, 4):
- T.attr(
- [buffer_4, Conv_wmma_accumulator],
- "buffer_bind_scope",
- T.tvm_tuple(
- ((((bx * 4) + ty) * 2) + n_inner),
- 1,
- T.floordiv(bz, 14),
- 1,
- T.floormod(bz, 14),
- 1,
- ((((by * 2) + tz) * 4) + o_inner),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.attr(
- [buffer_5, Conv_1],
- "buffer_bind_scope",
- T.tvm_tuple(
- ((((bx * 4) + ty) * 2) + n_inner),
- 1,
- T.floordiv(bz, 14),
- 1,
- T.floormod(bz, 14),
- 1,
- ((((by * 2) + tz) * 4) + o_inner),
- 1,
- 0,
- 16,
- 0,
- 16,
- dtype="handle",
- ),
- )
- T.evaluate(
- T.tvm_store_matrix_sync(
- buffer_4.data,
- 16,
- 16,
- 16,
- T.floordiv(buffer_4.elem_offset, 256),
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"),
- buffer_5.data,
- buffer_5.elem_offset,
- 256,
- 2,
- dtype="handle",
- ),
- 16,
- "row_major",
- dtype="handle",
- )
- )
-
- return func
-
-
def opt_conv_tensorcore_lower():
@T.prim_func
def func(
@@ -4105,10 +3592,8 @@ def relax_float_symbolic_var():
ir_generator = tvm.testing.parameter(
launch_env_thread,
- opt_gemm_normalize,
opt_gemm_lower,
opt_gemm_mod_host,
- opt_conv_tensorcore_normalize,
opt_conv_tensorcore_lower,
opt_conv_tensorcore_mod_host,
vthread_func,