This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5821c1240e [TIR] Add DeclBuffer IR node and functors (#12300)
5821c1240e is described below

commit 5821c1240e18a137dfb0afe8810e97e50fd869ca
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 5 14:59:01 2022 -0700

    [TIR] Add DeclBuffer IR node and functors (#12300)
    
    * [TIR] Add DeclBuffer node
    
    * [TIR] Add IR functors for DeclBuffer
    
    * [TVMScript] Add printer and parser for DeclBuffer
    
    * Update printer
    
    * Update printer
    
    * Add test case
    
    * lint
    
    * fix
---
 include/tvm/tir/stmt.h                            | 34 ++++++++++
 include/tvm/tir/stmt_functor.h                    |  4 ++
 python/tvm/script/tir/__init__.pyi                | 12 ++++
 python/tvm/script/tir/scope_handler.py            | 81 +++++++++++++++++++++++
 python/tvm/tir/__init__.py                        |  1 +
 python/tvm/tir/stmt.py                            | 20 ++++++
 src/printer/text_printer.h                        |  1 +
 src/printer/tir_text_printer.cc                   | 12 ++++
 src/printer/tvmscript_printer.cc                  | 19 ++++++
 src/target/source/codegen_c.cc                    |  2 +
 src/target/source/codegen_c.h                     |  1 +
 src/tir/ir/stmt.cc                                | 23 +++++++
 src/tir/ir/stmt_functor.cc                        | 14 ++++
 tests/cpp/ir_functor_test.cc                      |  6 +-
 tests/python/unittest/test_tvmscript_roundtrip.py | 13 ++++
 15 files changed, 241 insertions(+), 2 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 5dd4103e82..5be1b9626d 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -681,6 +681,40 @@ class AllocateConst : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
 };
 
+/*! \brief Declare a buffer that can be used in the body */
+class DeclBufferNode : public StmtNode {
+ public:
+  /*! \brief The buffer being declared */
+  Buffer buffer;
+  /*! \brief The body to be executed */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("buffer", &buffer);
+    v->Visit("body", &body);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
+    return equal(buffer, other->buffer) && equal(body, other->body);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(buffer);
+    hash_reduce(body);
+  }
+
+  static constexpr const char* _type_key = "tir.DeclBuffer";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode);
+};
+
+/*! \brief Managed reference to DeclBufferNode */
+class DeclBuffer : public Stmt {
+ public:
+  TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
+  TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode);
+};
+
 /*!
  * \brief The container of seq statement.
  *        Represent a sequence of statements.
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index fce2e1d671..49b1f28e5d 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -89,6 +89,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
   virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const AllocateNode* op, Args... args) 
STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const AllocateConstNode* op, Args... args) 
STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const DeclBufferNode* op, Args... args) 
STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const BufferStoreNode* op, Args... args) 
STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) 
STMT_FUNCTOR_DEFAULT;
@@ -116,6 +117,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
     IR_STMT_FUNCTOR_DISPATCH(WhileNode);
     IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
     IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
+    IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
     IR_STMT_FUNCTOR_DISPATCH(StoreNode);
     IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
     IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
@@ -159,6 +161,7 @@ class TVM_DLL StmtVisitor : protected 
StmtFunctor<void(const Stmt&)> {
   void VisitStmt_(const WhileNode* op) override;
   void VisitStmt_(const AllocateNode* op) override;
   void VisitStmt_(const AllocateConstNode* op) override;
+  void VisitStmt_(const DeclBufferNode* op) override;
   void VisitStmt_(const StoreNode* op) override;
   void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const BufferRealizeNode* op) override;
@@ -260,6 +263,7 @@ class TVM_DLL StmtMutator : protected 
StmtFunctor<Stmt(const Stmt&)> {
   Stmt VisitStmt_(const WhileNode* op) override;
   Stmt VisitStmt_(const AllocateNode* op) override;
   Stmt VisitStmt_(const AllocateConstNode* op) override;
+  Stmt VisitStmt_(const DeclBufferNode* op) override;
   Stmt VisitStmt_(const StoreNode* op) override;
   Stmt VisitStmt_(const BufferStoreNode* op) override;
   Stmt VisitStmt_(const BufferRealizeNode* op) override;
diff --git a/python/tvm/script/tir/__init__.pyi 
b/python/tvm/script/tir/__init__.pyi
index f03c5c06da..a62fb102be 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -187,6 +187,18 @@ def match_buffer(
     buffer_type: str = "default",
     axis_separators: Optional[List[int]] = None,
 ) -> Buffer: ...
+def decl_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+    axis_separators: Optional[List[int]] = None,
+) -> Buffer: ...
 def buffer_decl(
     shape: Sequence[Union[PrimExpr, int]],
     dtype: str = "float32",
diff --git a/python/tvm/script/tir/scope_handler.py 
b/python/tvm/script/tir/scope_handler.py
index 92aaf8b4d9..da7545c9a9 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -224,6 +224,87 @@ class AllocateConst(WithScopeHandler):
         context.update_symbol(name, self.buffer, node)
 
 
+@register
+class DeclBuffer(WithScopeHandler):
+    """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                offset_factor, buffer_type, axis_separators)
+    Example
+    -------
+    .. code-block:: python
+        A = T.decl_buffer((128, 128), dtype="float32")
+    """
+
+    def __init__(self):
+        def decl_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="global",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            axis_separators=None,
+            span=None,
+        ):
+            return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
+
+        super().__init__(decl_buffer, concise_scope=True, def_symbol=True)
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define buffer vars in symbol table
+        if isinstance(node, synr.ast.With):
+            vars = WithScopeHandler.get_optional_vars(node, context)
+            if len(vars) != 1:
+                context.report_error(f"Unexpected number of vars: 1 vs. 
{len(vars)}", node.span)
+            name = vars[0].id.name
+            var_span = vars[0].id.span
+        elif isinstance(node, synr.ast.Assign):
+            if len(node.lhs) != 1:
+                context.report_error(f"Unexpected number of vars: 1 vs. 
{len(node.lhs)}", node.span)
+            name = node.lhs[0].id.name
+            var_span = node.lhs[0].id.span
+        else:
+            raise Exception("Internal Bug")
+
+        def setup_buffer(
+            shape,
+            dtype,
+            data,
+            strides,
+            elem_offset,
+            scope,
+            align,
+            offset_factor,
+            buffer_type,
+            axis_separators,
+            span: Span = None,
+        ):
+            self.buffer = tvm.tir.decl_buffer(
+                shape=shape,
+                dtype=dtype,
+                data=data,
+                strides=strides,
+                elem_offset=elem_offset,
+                scope=scope,
+                data_alignment=align,
+                offset_factor=offset_factor,
+                buffer_type=buffer_type,
+                axis_separators=axis_separators,
+                span=span,
+            )
+
+        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
+        context.update_symbol(name, self.buffer, node)
+
+
 @register
 class LaunchThread(WithScopeHandler):
     """With scope handler T.launch_thread(env_var, extent)"""
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index a3798ccab4..c64b7dfe71 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -36,6 +36,7 @@ from .stmt import (
     Allocate,
     AllocateConst,
     AttrStmt,
+    DeclBuffer,
 )
 
 from .stmt import ProducerRealize, SeqStmt
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 063439e068..4847e377de 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -377,6 +377,26 @@ class AllocateConst(Stmt):
         )
 
 
+@tvm._ffi.register_object("tir.DeclBuffer")
+class DeclBuffer(Stmt):
+    """DeclBuffer node.
+
+    Parameters
+    ----------
+    buffer: Buffer
+        The buffer being declared.
+
+    body: Stmt
+        The body statement to be executed.
+
+    span: Optional[Span]
+        The location of this DeclBuffer in the source code.
+    """
+
+    def __init__(self, buffer, body, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, 
span)
+
+
 @tvm._ffi.register_object("tir.AttrStmt")
 class AttrStmt(Stmt):
     """AttrStmt node.
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index 05a00e3305..2dc0997f82 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -353,6 +353,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc VisitStmt_(const ProducerRealizeNode* op) override;
   Doc VisitStmt_(const AllocateNode* op) override;
   Doc VisitStmt_(const AllocateConstNode* op) override;
+  Doc VisitStmt_(const DeclBufferNode* op) override;
   Doc VisitStmt_(const IfThenElseNode* op) override;
   Doc VisitStmt_(const SeqStmtNode* op) override;
   Doc VisitStmt_(const EvaluateNode* op) override;
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index fe829016b6..894a9cec1e 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -557,6 +557,18 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* 
op) {
   return doc;
 }
 
+Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
+  Doc doc;
+  doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) 
<< ", "
+      << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << 
")" << Doc::NewLine();
+  if (op->body->IsInstance<SeqStmtNode>()) {
+    doc << PrintBody(op->body);
+  } else {
+    doc << ";" << Doc::NewLine() << Print(op->body);
+  }
+  return doc;
+}
+
 Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
   Doc doc;
   doc << "if " << Print(op->condition) << PrintBody(op->then_case);
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index f2abf5c78d..6708922444 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -245,6 +245,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const 
Stmt&)>,
   Doc VisitStmt_(const BufferRealizeNode* op) override;
   Doc VisitStmt_(const AllocateNode* op) override;
   Doc VisitStmt_(const AllocateConstNode* op) override;
+  Doc VisitStmt_(const DeclBufferNode* op) override;
   Doc VisitStmt_(const IfThenElseNode* op) override;
   Doc VisitStmt_(const SeqStmtNode* op) override;
   Doc VisitStmt_(const ForNode* op) override;
@@ -1161,6 +1162,24 @@ Doc TVMScriptPrinter::VisitStmt_(const 
AllocateConstNode* alloc) {
   return doc;
 }
 
+Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) {
+  const Buffer& buffer = op->buffer;
+  buf_not_in_headers_.insert(buffer.get());
+  Doc buffer_name = Print(op->buffer);
+  Doc func_call;
+  func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << 
")";
+
+  Doc doc;
+  if (current_num_ != num_child_ - 1) {
+    doc << "with " << func_call << " as " << buffer_name << ":";
+    doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+  } else {
+    doc << buffer_name << " = " << func_call << Doc::NewLine();
+    doc << PrintBody(op->body);
+  }
+  return doc;
+}
+
 Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
   Doc doc;
   doc << "if " << Print(op->condition) << ":";
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 3ad7882d79..3fe7fa50d3 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -661,6 +661,8 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
   this->PrintStmt(op->body);
 }
 
+void CodeGenC::VisitStmt_(const DeclBufferNode* op) { 
this->PrintStmt(op->body); }
+
 void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Unexpected deprecated LoadNode.  Use BufferLoadNode instead.";
 }
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 696ec62c58..0af24dfdc0 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -166,6 +166,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
   void VisitStmt_(const EvaluateNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
   void VisitStmt_(const AllocateConstNode* op) override;
+  void VisitStmt_(const DeclBufferNode* op) override;
 
   /*!
    * \brief Print expr representing the thread tag
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 2b337520a2..524204f3d3 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -508,6 +508,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->Print(op->body);
     });
 
+// DeclBuffer
+DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) {
+  ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>();
+  node->buffer = std::move(buffer);
+  node->body = std::move(body);
+  node->span = std::move(span);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt 
body, Span span) {
+  return DeclBuffer(buffer, body, span);
+});
+
+TVM_REGISTER_NODE_TYPE(DeclBufferNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<DeclBufferNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const DeclBufferNode*>(node.get());
+      p->PrintIndent();
+      p->stream << "decl_buffer " << op->buffer << "\n";
+      p->stream << op->body;
+    });
+
 // ProducerRealize
 ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, 
PrimExpr condition,
                                  Stmt body, String storage_scope, Span span) {
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index c0abf953ee..c75eb52f92 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -63,6 +63,8 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) {
   this->VisitStmt(op->body);
 }
 
+void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { 
this->VisitStmt(op->body); }
+
 void StmtVisitor::VisitStmt_(const StoreNode* op) {
   LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use 
BufferStoreNode instead.";
 }
@@ -336,6 +338,18 @@ Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) {
   }
 }
 
+Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
+  Stmt body = this->VisitStmt(op->body);
+
+  if (body.same_as(op->body)) {
+    return GetRef<Stmt>(op);
+  } else {
+    auto n = CopyOnWrite(op);
+    n->body = std::move(body);
+    return Stmt(n);
+  }
+}
+
 Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
   PrimExpr condition = this->VisitExpr(op->condition);
   Stmt then_case = this->VisitStmt(op->then_case);
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index 33b145d3a4..2909915c32 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -181,6 +181,7 @@ TEST(IRF, StmtVisitor) {
     DataType dtype = DataType::Float(32);
     Var buf_var("b", PointerType(PrimType(dtype)));
     Buffer buffer = decl_buffer({16});
+    body = DeclBuffer(buffer, std::move(body));
     BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
     MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
 
@@ -309,6 +310,7 @@ TEST(IRF, StmtMutator) {
     DataType dtype = DataType::Float(32);
     Var buf_var("b", PointerType(PrimType(dtype)));
     Buffer buffer = decl_buffer({16});
+    body = DeclBuffer(buffer, std::move(body));
     BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
     MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
     // construct block and block_realize
@@ -318,8 +320,8 @@ TEST(IRF, StmtMutator) {
     body = v(std::move(block_realize));
     // the body should be changed
     Block new_block = body.as<BlockRealizeNode>()->block;
-    ICHECK(new_block->body.as<AllocateNode>()->extents[1].same_as(x));
-    ICHECK(new_block->init.as<AllocateNode>()->extents[1].same_as(x));
+    
ICHECK(new_block->body.as<DeclBufferNode>()->body.as<AllocateNode>()->extents[1].same_as(x));
+    
ICHECK(new_block->init.as<DeclBufferNode>()->body.as<AllocateNode>()->extents[1].same_as(x));
     ICHECK(new_block->reads[0]->region[0]->min.same_as(x));
     ICHECK(new_block->writes[0]->region[0]->min.same_as(x));
     ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x));
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 528357339c..0a2cec6011 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3303,6 +3303,18 @@ def void_ptr():
     return func
 
 
+def decl_buffer():
+    @T.prim_func
+    def func(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), 
"float32"]) -> None:
+        A_flattened = T.decl_buffer(data=A.data, shape=(256,), dtype="float32")
+        B_flattened = T.decl_buffer(data=B.data, shape=(256,), dtype="float32")
+        C_alias = T.decl_buffer(data=A_flattened.data, shape=(256,), 
dtype="float32")
+        for i in range(256):
+            B_flattened[i] = A_flattened[i] + C_alias[i] + T.float32(1.0)
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     opt_gemm_normalize,
     opt_gemm_lower,
@@ -3342,6 +3354,7 @@ ir_generator = tvm.testing.parameter(
     buffer_ramp_access_as_slice_index,
     let_expression,
     void_ptr,
+    decl_buffer,
 )
 
 

Reply via email to