This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c2d485a01 [TVMScript] `T.match_buffer` syntax sugar in arguments for 
TVMScript printer (#13801)
6c2d485a01 is described below

commit 6c2d485a011fbfbd426353c6fc1254f3385d826e
Author: Yaxing Cai <[email protected]>
AuthorDate: Thu Jan 19 18:27:26 2023 -0800

    [TVMScript] `T.match_buffer` syntax sugar in arguments for TVMScript 
printer (#13801)
    
    This PR implements the syntax sugar of `T.match_buffer` for new TVMScript 
printer. This syntax sugar will replace the `T.handle` in `T.prim_func` 
arguments, with matched simple buffer. For example, it will change
    ```python
    @T.prim_func
    def func(a: T.handle, b: T.handle, c: T.handle):
      A = T.match_buffer(a, [128], dtype="float32")
      B = T.match_buffer(b, [128, 128], dtype="int32")
      C = T.match_buffer(c, [128, 128, 128], dtype="uint8")
    ```
    into
    ```python
    @T.prim_func
    def main(A: T.Buffer[(128,)], B: T.Buffer[(128, 128), "int32"], C: 
T.Buffer[(128, 128, 128), "uint8"]):
      T.evaluate(0)
    ```
    
    Co-authored-by: Junru Shao <[email protected]>
---
 src/script/printer/tir/buffer.cc                   |   8 ++
 src/script/printer/tir/function.cc                 | 105 +++++++++++++++++++++
 src/script/printer/tir/utils.h                     |  11 +++
 .../python/unittest/test_tvmscript_printer_tir.py  |  52 +++++++++-
 4 files changed, 174 insertions(+), 2 deletions(-)

diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc
index 5400328fe2..126a6e5827 100644
--- a/src/script/printer/tir/buffer.cc
+++ b/src/script/printer/tir/buffer.cc
@@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const 
String& method, const Array<
                     /*args=*/args);
 }
 
+ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const 
Frame& frame,
+                   const IRDocsifier& d) {
+  Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
+  ExprDoc shape = attrs.Get("shape").value();
+  ExprDoc dtype = 
attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
+  return TIR("Buffer")->Call({shape, dtype}, {}, {});
+}
+
 Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
                          const IRDocsifier& d) {
   int n = indices.size();
diff --git a/src/script/printer/tir/function.cc 
b/src/script/printer/tir/function.cc
index f0f84e81d5..6094eefb65 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -16,6 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/stmt_functor.h>
+
 #include "./utils.h"
 
 namespace tvm {
@@ -34,16 +37,115 @@ String FindFunctionName(const IRDocsifier& d, const 
tir::PrimFunc& f) {
   return "main";
 }
 
+bool IsSimpleBuffer(const tir::Buffer& buf) {
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  for (const PrimExpr& shp_i : buf->shape) {
+    if (!tir::UndefinedVars(shp_i).empty()) {
+      return false;
+    }
+  }
+  for (const PrimExpr& stride_i : buf->strides) {
+    if (!tir::UndefinedVars(stride_i).empty()) {
+      return false;
+    }
+  }
+  if (!tir::UndefinedVars(buf->elem_offset).empty()) {
+    return false;
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  return buf.scope() == "global" && buf->data_alignment == 
runtime::kAllocAlignment &&
+         buf->offset_factor == 1 && buf->buffer_type == 
tir::BufferType::kDefault &&
+         !buf->axis_separators.size();
+}
+
+int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
+  class OccurrenceCounter : public tir::StmtExprVisitor {
+   public:
+    int count = 0;
+    const tir::VarNode* v = nullptr;
+
+    void VisitExpr_(const tir::VarNode* op) final {
+      if (op == v) {
+        ++count;
+      }
+      tir::StmtExprVisitor::VisitExpr_(op);
+    }
+
+    void VisitStmt_(const tir::BufferStoreNode* op) final {
+      VisitBuffer(op->buffer.get());
+      tir::StmtExprVisitor::VisitStmt_(op);
+    }
+
+    void VisitExpr_(const tir::BufferLoadNode* op) final {
+      VisitBuffer(op->buffer.get());
+      tir::StmtExprVisitor::VisitExpr_(op);
+    }
+
+    void VisitStmt_(const tir::DeclBufferNode* op) final {
+      VisitBuffer(op->buffer.get());
+      tir::StmtExprVisitor::VisitStmt_(op);
+    }
+
+    void VisitBuffer(const tir::BufferNode* buffer) {
+      VisitExpr(buffer->data);
+      for (const PrimExpr& shape_i : buffer->shape) {
+        VisitExpr(shape_i);
+      }
+      for (const PrimExpr& stride_i : buffer->strides) {
+        VisitExpr(stride_i);
+      }
+      VisitExpr(buffer->elem_offset);
+    }
+  };
+
+  OccurrenceCounter counter;
+  counter.v = v.get();
+  counter(f->body);
+  for (const tir::Var& v : f->params) {
+    counter(v);
+  }
+  for (const auto& pair : f->buffer_map) {
+    counter(pair.first);
+    counter.VisitBuffer(pair.second.get());
+  }
+  return counter.count;
+}
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, 
IRDocsifier d) -> Doc {
       With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
       int n_args = func->params.size();
+      std::unordered_map<const tir::VarNode*, int> buffer_data_counter;
+      for (const auto& pair : func->buffer_map) {
+        const tir::VarNode* data_var = pair.second->data.get();
+        if (!buffer_data_counter.count(data_var)) {
+          buffer_data_counter.insert({data_var, 0});
+        }
+        ++buffer_data_counter.at(data_var);
+      }
       // Step 1. Handle `func->params`
       Array<AssignDoc> args;
       args.reserve(n_args);
+      std::unordered_set<const tir::BufferNode*> buffer_inlined;
       for (int i = 0; i < n_args; ++i) {
         tir::Var var = func->params[i];
         ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
+        if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) 
{
+          tir::Buffer buffer = func->buffer_map[var];
+          if (IsSimpleBuffer(buffer) && 
buffer_data_counter.at(buffer->data.get()) == 1) {
+            ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
+            args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
+                                     BufferAttn(buffer, buffer_p, *frame, d)));
+            buffer_inlined.insert(buffer.get());
+            continue;
+          }
+        }
         ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, 
var_p->Attr("type_annotation"));
         args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
       }
@@ -58,6 +160,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         tir::Var param = func->params[i];
         if (func->buffer_map.count(param)) {
           tir::Buffer buffer = func->buffer_map[param];
+          if (buffer_inlined.count(buffer.get())) {
+            continue;
+          }
           ExprDoc param = args[i]->lhs;
           ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
           ExprDoc lhs =
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 047513dcb3..183400d974 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, 
ReprPrinter* p) {
 ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const 
Array<ExprDoc>& args,
                    const ObjectPath& p, const Frame& frame, const IRDocsifier& 
d);
 
+/*!
+ * \brief Declare and define a buffer as annotation
+ * \param buffer The buffer to be defined
+ * \param p The object path
+ * \param f The frame
+ * \param d The IRDocsifier
+ * \return The ExprDoc corresponding to the buffer declaration
+ */
+ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const 
Frame& frame,
+                   const IRDocsifier& d);
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py 
b/tests/python/unittest/test_tvmscript_printer_tir.py
index d62a1cd12c..201428b74c 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -57,10 +57,56 @@ def test_prim_func():
         func,
         expected="""
 @T.prim_func
+def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), 
"float32")):
+    T.evaluate(0)""",
+    )
+
+
+def test_prim_func_no_sugar_inlined_buffer():
+    a = tir.Var("a", "handle")
+    b = tir.Var("b", "handle")
+    func = tir.PrimFunc(
+        params=[a, b],
+        ret_type=None,
+        buffer_map={
+            a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
+            b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
+        },
+        body=tir.Evaluate(a),
+    )
+    _assert_print(
+        func,
+        expected="""
[email protected]_func
+def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
+    A = T.match_buffer(a, (128, 128))
+    T.evaluate(a)
+""",
+    )
+
+
+def test_prim_func_no_sugar_shared_buffer_data():
+    a = tir.Var("a", "handle")
+    b = tir.Var("b", "handle")
+    buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", 
name="A").data
+    func = tir.PrimFunc(
+        params=[a, b],
+        ret_type=None,
+        buffer_map={
+            a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", 
data=buffer_data),
+            b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", 
data=buffer_data),
+        },
+        body=tir.Evaluate(0),
+    )
+    _assert_print(
+        func,
+        expected="""
[email protected]_func
 def main(a: T.handle, b: T.handle):
     A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (256, 256))
-    T.evaluate(0)""",
+    B = T.match_buffer(b, (256, 256), data=A.data)
+    T.evaluate(0)
+""",
     )
 
 
@@ -641,6 +687,8 @@ def main():
 
 if __name__ == "__main__":
     test_prim_func()
+    test_prim_func_no_sugar_inlined_buffer()
+    test_prim_func_no_sugar_shared_buffer_data()
     test_block_realize()
     test_block()
     test_buffer()

Reply via email to