This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6c2d485a01 [TVMScript] `T.match_buffer` syntax sugar in arguments for
TVMScript printer (#13801)
6c2d485a01 is described below
commit 6c2d485a011fbfbd426353c6fc1254f3385d826e
Author: Yaxing Cai <[email protected]>
AuthorDate: Thu Jan 19 18:27:26 2023 -0800
[TVMScript] `T.match_buffer` syntax sugar in arguments for TVMScript
printer (#13801)
This PR implements the syntax sugar of `T.match_buffer` for new TVMScript
printer. This syntax sugar will replace the `T.handle` in `T.prim_func`
arguments, with matched simple buffer. For example, it will change
```python
@T.prim_func
def func(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [128], dtype="float32")
B = T.match_buffer(b, [128, 128], dtype="int32")
C = T.match_buffer(c, [128, 128, 128], dtype="uint8")
```
into
```python
@T.prim_func
def main(A: T.Buffer[(128,)], B: T.Buffer[(128, 128), "int32"], C:
T.Buffer[(128, 128, 128), "uint8"]):
T.evaluate(0)
```
Co-authored-by: Junru Shao <[email protected]>
---
src/script/printer/tir/buffer.cc | 8 ++
src/script/printer/tir/function.cc | 105 +++++++++++++++++++++
src/script/printer/tir/utils.h | 11 +++
.../python/unittest/test_tvmscript_printer_tir.py | 52 +++++++++-
4 files changed, 174 insertions(+), 2 deletions(-)
diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc
index 5400328fe2..126a6e5827 100644
--- a/src/script/printer/tir/buffer.cc
+++ b/src/script/printer/tir/buffer.cc
@@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const
String& method, const Array<
/*args=*/args);
}
+ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const
Frame& frame,
+ const IRDocsifier& d) {
+ Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
+ ExprDoc shape = attrs.Get("shape").value();
+ ExprDoc dtype =
attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
+ return TIR("Buffer")->Call({shape, dtype}, {}, {});
+}
+
Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
const IRDocsifier& d) {
int n = indices.size();
diff --git a/src/script/printer/tir/function.cc
b/src/script/printer/tir/function.cc
index f0f84e81d5..6094eefb65 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/stmt_functor.h>
+
#include "./utils.h"
namespace tvm {
@@ -34,16 +37,115 @@ String FindFunctionName(const IRDocsifier& d, const
tir::PrimFunc& f) {
return "main";
}
+bool IsSimpleBuffer(const tir::Buffer& buf) {
+ if (!buf->strides.empty()) {
+ return false;
+ }
+ for (const PrimExpr& shp_i : buf->shape) {
+ if (!tir::UndefinedVars(shp_i).empty()) {
+ return false;
+ }
+ }
+ for (const PrimExpr& stride_i : buf->strides) {
+ if (!tir::UndefinedVars(stride_i).empty()) {
+ return false;
+ }
+ }
+ if (!tir::UndefinedVars(buf->elem_offset).empty()) {
+ return false;
+ } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+ IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+ if (elem_offset->value != 0) {
+ return false;
+ }
+ }
+ return buf.scope() == "global" && buf->data_alignment ==
runtime::kAllocAlignment &&
+ buf->offset_factor == 1 && buf->buffer_type ==
tir::BufferType::kDefault &&
+ !buf->axis_separators.size();
+}
+
+int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
+ class OccurrenceCounter : public tir::StmtExprVisitor {
+ public:
+ int count = 0;
+ const tir::VarNode* v = nullptr;
+
+ void VisitExpr_(const tir::VarNode* op) final {
+ if (op == v) {
+ ++count;
+ }
+ tir::StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const tir::BufferStoreNode* op) final {
+ VisitBuffer(op->buffer.get());
+ tir::StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const tir::BufferLoadNode* op) final {
+ VisitBuffer(op->buffer.get());
+ tir::StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const tir::DeclBufferNode* op) final {
+ VisitBuffer(op->buffer.get());
+ tir::StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitBuffer(const tir::BufferNode* buffer) {
+ VisitExpr(buffer->data);
+ for (const PrimExpr& shape_i : buffer->shape) {
+ VisitExpr(shape_i);
+ }
+ for (const PrimExpr& stride_i : buffer->strides) {
+ VisitExpr(stride_i);
+ }
+ VisitExpr(buffer->elem_offset);
+ }
+ };
+
+ OccurrenceCounter counter;
+ counter.v = v.get();
+ counter(f->body);
+ for (const tir::Var& v : f->params) {
+ counter(v);
+ }
+ for (const auto& pair : f->buffer_map) {
+ counter(pair.first);
+ counter.VisitBuffer(pair.second.get());
+ }
+ return counter.count;
+}
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p,
IRDocsifier d) -> Doc {
With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
+ std::unordered_map<const tir::VarNode*, int> buffer_data_counter;
+ for (const auto& pair : func->buffer_map) {
+ const tir::VarNode* data_var = pair.second->data.get();
+ if (!buffer_data_counter.count(data_var)) {
+ buffer_data_counter.insert({data_var, 0});
+ }
+ ++buffer_data_counter.at(data_var);
+ }
// Step 1. Handle `func->params`
Array<AssignDoc> args;
args.reserve(n_args);
+ std::unordered_set<const tir::BufferNode*> buffer_inlined;
for (int i = 0; i < n_args; ++i) {
tir::Var var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
+ if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var))
{
+ tir::Buffer buffer = func->buffer_map[var];
+ if (IsSimpleBuffer(buffer) &&
buffer_data_counter.at(buffer->data.get()) == 1) {
+ ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
+ args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
+ BufferAttn(buffer, buffer_p, *frame, d)));
+ buffer_inlined.insert(buffer.get());
+ continue;
+ }
+ }
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation,
var_p->Attr("type_annotation"));
args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
}
@@ -58,6 +160,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
tir::Var param = func->params[i];
if (func->buffer_map.count(param)) {
tir::Buffer buffer = func->buffer_map[param];
+ if (buffer_inlined.count(buffer.get())) {
+ continue;
+ }
ExprDoc param = args[i]->lhs;
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
ExprDoc lhs =
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 047513dcb3..183400d974 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj,
ReprPrinter* p) {
ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const
Array<ExprDoc>& args,
const ObjectPath& p, const Frame& frame, const IRDocsifier&
d);
+/*!
+ * \brief Declare and define a buffer as annotation
+ * \param buffer The buffer to be defined
+ * \param p The object path
+ * \param f The frame
+ * \param d The IRDocsifier
+ * \return The ExprDoc corresponding to the buffer declaration
+ */
+ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const
Frame& frame,
+ const IRDocsifier& d);
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index d62a1cd12c..201428b74c 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -57,10 +57,56 @@ def test_prim_func():
func,
expected="""
@T.prim_func
+def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256),
"float32")):
+ T.evaluate(0)""",
+ )
+
+
+def test_prim_func_no_sugar_inlined_buffer():
+ a = tir.Var("a", "handle")
+ b = tir.Var("b", "handle")
+ func = tir.PrimFunc(
+ params=[a, b],
+ ret_type=None,
+ buffer_map={
+ a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
+ b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
+ },
+ body=tir.Evaluate(a),
+ )
+ _assert_print(
+ func,
+ expected="""
[email protected]_func
+def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
+ A = T.match_buffer(a, (128, 128))
+ T.evaluate(a)
+""",
+ )
+
+
+def test_prim_func_no_sugar_shared_buffer_data():
+ a = tir.Var("a", "handle")
+ b = tir.Var("b", "handle")
+ buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32",
name="A").data
+ func = tir.PrimFunc(
+ params=[a, b],
+ ret_type=None,
+ buffer_map={
+ a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A",
data=buffer_data),
+ b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B",
data=buffer_data),
+ },
+ body=tir.Evaluate(0),
+ )
+ _assert_print(
+ func,
+ expected="""
[email protected]_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (256, 256))
- T.evaluate(0)""",
+ B = T.match_buffer(b, (256, 256), data=A.data)
+ T.evaluate(0)
+""",
)
@@ -641,6 +687,8 @@ def main():
if __name__ == "__main__":
test_prim_func()
+ test_prim_func_no_sugar_inlined_buffer()
+ test_prim_func_no_sugar_shared_buffer_data()
test_block_realize()
test_block()
test_buffer()