This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5821c1240e [TIR] Add DeclBuffer IR node and functors (#12300)
5821c1240e is described below
commit 5821c1240e18a137dfb0afe8810e97e50fd869ca
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 5 14:59:01 2022 -0700
[TIR] Add DeclBuffer IR node and functors (#12300)
* [TIR] Add DeclBuffer node
* [TIR] Add IR functors for DeclBuffer
* [TVMScript] Add printer and parser for DeclBuffer
* Update printer
* Update printer
* Add test case
* lint
* fix
---
include/tvm/tir/stmt.h | 34 ++++++++++
include/tvm/tir/stmt_functor.h | 4 ++
python/tvm/script/tir/__init__.pyi | 12 ++++
python/tvm/script/tir/scope_handler.py | 81 +++++++++++++++++++++++
python/tvm/tir/__init__.py | 1 +
python/tvm/tir/stmt.py | 20 ++++++
src/printer/text_printer.h | 1 +
src/printer/tir_text_printer.cc | 12 ++++
src/printer/tvmscript_printer.cc | 19 ++++++
src/target/source/codegen_c.cc | 2 +
src/target/source/codegen_c.h | 1 +
src/tir/ir/stmt.cc | 23 +++++++
src/tir/ir/stmt_functor.cc | 14 ++++
tests/cpp/ir_functor_test.cc | 6 +-
tests/python/unittest/test_tvmscript_roundtrip.py | 13 ++++
15 files changed, 241 insertions(+), 2 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 5dd4103e82..5be1b9626d 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -681,6 +681,40 @@ class AllocateConst : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
};
+/*! \brief Declare a buffer that can be used in the body */
+class DeclBufferNode : public StmtNode {
+ public:
+ /*! \brief The buffer being declared */
+ Buffer buffer;
+ /*! \brief The body to be executed */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("buffer", &buffer);
+ v->Visit("body", &body);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
+ return equal(buffer, other->buffer) && equal(body, other->body);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(buffer);
+ hash_reduce(body);
+ }
+
+ static constexpr const char* _type_key = "tir.DeclBuffer";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode);
+};
+
+/*! \brief Managed reference to DeclBufferNode */
+class DeclBuffer : public Stmt {
+ public:
+ TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
+ TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode);
+};
+
/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index fce2e1d671..49b1f28e5d 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -89,6 +89,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateConstNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const DeclBufferNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
@@ -116,6 +117,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
+ IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
@@ -159,6 +161,7 @@ class TVM_DLL StmtVisitor : protected
StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
+ void VisitStmt_(const DeclBufferNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
@@ -260,6 +263,7 @@ class TVM_DLL StmtMutator : protected
StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
+ Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
diff --git a/python/tvm/script/tir/__init__.pyi
b/python/tvm/script/tir/__init__.pyi
index f03c5c06da..a62fb102be 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -187,6 +187,18 @@ def match_buffer(
buffer_type: str = "default",
axis_separators: Optional[List[int]] = None,
) -> Buffer: ...
+def decl_buffer(
+ shape: Sequence[Union[PrimExpr, int]],
+ dtype: str = "float32",
+ data: Var = None,
+ strides: Optional[Sequence[int]] = None,
+ elem_offset: Optional[int] = None,
+ scope: str = "global",
+ align: int = -1,
+ offset_factor: int = 0,
+ buffer_type: str = "default",
+ axis_separators: Optional[List[int]] = None,
+) -> Buffer: ...
def buffer_decl(
shape: Sequence[Union[PrimExpr, int]],
dtype: str = "float32",
diff --git a/python/tvm/script/tir/scope_handler.py
b/python/tvm/script/tir/scope_handler.py
index 92aaf8b4d9..da7545c9a9 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -224,6 +224,87 @@ class AllocateConst(WithScopeHandler):
context.update_symbol(name, self.buffer, node)
+@register
+class DeclBuffer(WithScopeHandler):
+ """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type, axis_separators)
+ Example
+ -------
+ .. code-block:: python
+ A = T.decl_buffer((128, 128), dtype="float32")
+ """
+
+ def __init__(self):
+ def decl_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+ span=None,
+ ):
+ return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
+
+ super().__init__(decl_buffer, concise_scope=True, def_symbol=True)
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define buffer vars in symbol table
+ if isinstance(node, synr.ast.With):
+ vars = WithScopeHandler.get_optional_vars(node, context)
+ if len(vars) != 1:
+ context.report_error(f"Unexpected number of vars: 1 vs.
{len(vars)}", node.span)
+ name = vars[0].id.name
+ var_span = vars[0].id.span
+ elif isinstance(node, synr.ast.Assign):
+ if len(node.lhs) != 1:
+ context.report_error(f"Unexpected number of vars: 1 vs.
{len(node.lhs)}", node.span)
+ name = node.lhs[0].id.name
+ var_span = node.lhs[0].id.span
+ else:
+ raise Exception("Internal Bug")
+
+ def setup_buffer(
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ span: Span = None,
+ ):
+ self.buffer = tvm.tir.decl_buffer(
+ shape=shape,
+ dtype=dtype,
+ data=data,
+ strides=strides,
+ elem_offset=elem_offset,
+ scope=scope,
+ data_alignment=align,
+ offset_factor=offset_factor,
+ buffer_type=buffer_type,
+ axis_separators=axis_separators,
+ span=span,
+ )
+
+ setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
+ context.update_symbol(name, self.buffer, node)
+
+
@register
class LaunchThread(WithScopeHandler):
"""With scope handler T.launch_thread(env_var, extent)"""
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index a3798ccab4..c64b7dfe71 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -36,6 +36,7 @@ from .stmt import (
Allocate,
AllocateConst,
AttrStmt,
+ DeclBuffer,
)
from .stmt import ProducerRealize, SeqStmt
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 063439e068..4847e377de 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -377,6 +377,26 @@ class AllocateConst(Stmt):
)
+@tvm._ffi.register_object("tir.DeclBuffer")
+class DeclBuffer(Stmt):
+ """DeclBuffer node.
+
+ Parameters
+ ----------
+ buffer: Buffer
+ The buffer being declared.
+
+ body: Stmt
+ The body statement to be executed.
+
+ span: Optional[Span]
+ The location of this DeclBuffer in the source code.
+ """
+
+ def __init__(self, buffer, body, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body,
span)
+
+
@tvm._ffi.register_object("tir.AttrStmt")
class AttrStmt(Stmt):
"""AttrStmt node.
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index 05a00e3305..2dc0997f82 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -353,6 +353,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const ProducerRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const AllocateConstNode* op) override;
+ Doc VisitStmt_(const DeclBufferNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index fe829016b6..894a9cec1e 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -557,6 +557,18 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode*
op) {
return doc;
}
+Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
+ Doc doc;
+ doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data)
<< ", "
+ << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) <<
")" << Doc::NewLine();
+ if (op->body->IsInstance<SeqStmtNode>()) {
+ doc << PrintBody(op->body);
+ } else {
+ doc << ";" << Doc::NewLine() << Print(op->body);
+ }
+ return doc;
+}
+
Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index f2abf5c78d..6708922444 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -245,6 +245,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const AllocateConstNode* op) override;
+ Doc VisitStmt_(const DeclBufferNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
@@ -1161,6 +1162,24 @@ Doc TVMScriptPrinter::VisitStmt_(const
AllocateConstNode* alloc) {
return doc;
}
+Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) {
+ const Buffer& buffer = op->buffer;
+ buf_not_in_headers_.insert(buffer.get());
+ Doc buffer_name = Print(op->buffer);
+ Doc func_call;
+ func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) <<
")";
+
+ Doc doc;
+ if (current_num_ != num_child_ - 1) {
+ doc << "with " << func_call << " as " << buffer_name << ":";
+ doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+ } else {
+ doc << buffer_name << " = " << func_call << Doc::NewLine();
+ doc << PrintBody(op->body);
+ }
+ return doc;
+}
+
Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 3ad7882d79..3fe7fa50d3 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -661,6 +661,8 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
this->PrintStmt(op->body);
}
+void CodeGenC::VisitStmt_(const DeclBufferNode* op) {
this->PrintStmt(op->body); }
+
void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead.";
}
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 696ec62c58..0af24dfdc0 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -166,6 +166,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
+ void VisitStmt_(const DeclBufferNode* op) override;
/*!
* \brief Print expr representing the thread tag
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 2b337520a2..524204f3d3 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -508,6 +508,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->body);
});
+// DeclBuffer
+DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) {
+ ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>();
+ node->buffer = std::move(buffer);
+ node->body = std::move(body);
+ node->span = std::move(span);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt
body, Span span) {
+ return DeclBuffer(buffer, body, span);
+});
+
+TVM_REGISTER_NODE_TYPE(DeclBufferNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<DeclBufferNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const DeclBufferNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "decl_buffer " << op->buffer << "\n";
+ p->stream << op->body;
+ });
+
// ProducerRealize
ProducerRealize::ProducerRealize(DataProducer producer, Region bounds,
PrimExpr condition,
Stmt body, String storage_scope, Span span) {
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index c0abf953ee..c75eb52f92 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -63,6 +63,8 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) {
this->VisitStmt(op->body);
}
+void StmtVisitor::VisitStmt_(const DeclBufferNode* op) {
this->VisitStmt(op->body); }
+
void StmtVisitor::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use
BufferStoreNode instead.";
}
@@ -336,6 +338,18 @@ Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) {
}
}
+Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
+ Stmt body = this->VisitStmt(op->body);
+
+ if (body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else {
+ auto n = CopyOnWrite(op);
+ n->body = std::move(body);
+ return Stmt(n);
+ }
+}
+
Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index 33b145d3a4..2909915c32 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -181,6 +181,7 @@ TEST(IRF, StmtVisitor) {
DataType dtype = DataType::Float(32);
Var buf_var("b", PointerType(PrimType(dtype)));
Buffer buffer = decl_buffer({16});
+ body = DeclBuffer(buffer, std::move(body));
BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
@@ -309,6 +310,7 @@ TEST(IRF, StmtMutator) {
DataType dtype = DataType::Float(32);
Var buf_var("b", PointerType(PrimType(dtype)));
Buffer buffer = decl_buffer({16});
+ body = DeclBuffer(buffer, std::move(body));
BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
// construct block and block_realize
@@ -318,8 +320,8 @@ TEST(IRF, StmtMutator) {
body = v(std::move(block_realize));
// the body should be changed
Block new_block = body.as<BlockRealizeNode>()->block;
- ICHECK(new_block->body.as<AllocateNode>()->extents[1].same_as(x));
- ICHECK(new_block->init.as<AllocateNode>()->extents[1].same_as(x));
+
ICHECK(new_block->body.as<DeclBufferNode>()->body.as<AllocateNode>()->extents[1].same_as(x));
+
ICHECK(new_block->init.as<DeclBufferNode>()->body.as<AllocateNode>()->extents[1].same_as(x));
ICHECK(new_block->reads[0]->region[0]->min.same_as(x));
ICHECK(new_block->writes[0]->region[0]->min.same_as(x));
ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x));
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 528357339c..0a2cec6011 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3303,6 +3303,18 @@ def void_ptr():
return func
+def decl_buffer():
+ @T.prim_func
+ def func(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16),
"float32"]) -> None:
+ A_flattened = T.decl_buffer(data=A.data, shape=(256,), dtype="float32")
+ B_flattened = T.decl_buffer(data=B.data, shape=(256,), dtype="float32")
+ C_alias = T.decl_buffer(data=A_flattened.data, shape=(256,),
dtype="float32")
+ for i in range(256):
+ B_flattened[i] = A_flattened[i] + C_alias[i] + T.float32(1.0)
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
@@ -3342,6 +3354,7 @@ ir_generator = tvm.testing.parameter(
buffer_ramp_access_as_slice_index,
let_expression,
void_ptr,
+ decl_buffer,
)