junrushao commented on code in PR #13801:
URL: https://github.com/apache/tvm/pull/13801#discussion_r1080543155
##########
src/script/printer/tir/function.cc:
##########
@@ -34,16 +36,54 @@ String FindFunctionName(const IRDocsifier& d, const
tir::PrimFunc& f) {
return "main";
}
+bool IsSimpleBuffer(const tir::Buffer& buf) {
+ if (!buf->strides.empty()) {
+ return false;
+ }
+ for (const PrimExpr& shp_i : buf->shape) {
+ if (!tir::UndefinedVars(shp_i).empty()) {
+ return false;
+ }
+ }
+ for (const PrimExpr& stride_i : buf->strides) {
+ if (!tir::UndefinedVars(stride_i).empty()) {
+ return false;
+ }
+ }
+ if (!tir::UndefinedVars(buf->elem_offset).empty()) {
+ return false;
+ } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+ IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+ if (elem_offset->value != 0) {
+ return false;
+ }
+ }
+ return buf.scope() == "global" && buf->data_alignment ==
runtime::kAllocAlignment &&
+ buf->offset_factor == 1 && buf->buffer_type ==
tir::BufferType::kDefault &&
+ !buf->axis_separators.size();
+}
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p,
IRDocsifier d) -> Doc {
With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
// Step 1. Handle `func->params`
Array<AssignDoc> args;
args.reserve(n_args);
+ std::unordered_set<const tir::BufferNode*> buffer_inlined;
for (int i = 0; i < n_args; ++i) {
tir::Var var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
+ if (func->buffer_map.count(var)) {
+ tir::Buffer buffer = func->buffer_map[var];
+ ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
+ if (IsSimpleBuffer(buffer)) {
+ args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
+ BufferAttn(buffer, buffer_p, *frame, d)));
+ buffer_inlined.insert(buffer.get());
+ continue;
+ }
+ }
Review Comment:
It's likely that we will have to check the following conditions to make sure
the syntactic sugars are correct (please add those corner cases into tests
accordingly):
- The TIR var `func->params[i]` is used only once in this PrimFunc, i.e. in
`func->params[i]`
- The buffer's data field `func->buffer_map[var]->data` is not shared with
other buffers in `func->buffer_map`
I know the existing TVMScript printer isn't carefully implemented to take
this into consideration, but let's do this right this time :-)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]