junrushao commented on code in PR #13801:
URL: https://github.com/apache/tvm/pull/13801#discussion_r1080543155


##########
src/script/printer/tir/function.cc:
##########
@@ -34,16 +36,54 @@ String FindFunctionName(const IRDocsifier& d, const 
tir::PrimFunc& f) {
   return "main";
 }
 
+bool IsSimpleBuffer(const tir::Buffer& buf) {
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  for (const PrimExpr& shp_i : buf->shape) {
+    if (!tir::UndefinedVars(shp_i).empty()) {
+      return false;
+    }
+  }
+  for (const PrimExpr& stride_i : buf->strides) {
+    if (!tir::UndefinedVars(stride_i).empty()) {
+      return false;
+    }
+  }
+  if (!tir::UndefinedVars(buf->elem_offset).empty()) {
+    return false;
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  return buf.scope() == "global" && buf->data_alignment == 
runtime::kAllocAlignment &&
+         buf->offset_factor == 1 && buf->buffer_type == 
tir::BufferType::kDefault &&
+         !buf->axis_separators.size();
+}
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, 
IRDocsifier d) -> Doc {
       With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
       int n_args = func->params.size();
       // Step 1. Handle `func->params`
       Array<AssignDoc> args;
       args.reserve(n_args);
+      std::unordered_set<const tir::BufferNode*> buffer_inlined;
       for (int i = 0; i < n_args; ++i) {
         tir::Var var = func->params[i];
         ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
+        if (func->buffer_map.count(var)) {
+          tir::Buffer buffer = func->buffer_map[var];
+          ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
+          if (IsSimpleBuffer(buffer)) {
+            args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
+                                     BufferAttn(buffer, buffer_p, *frame, d)));
+            buffer_inlined.insert(buffer.get());
+            continue;
+          }
+        }

Review Comment:
   It's likely that we will have to check the following conditions to make sure 
the syntactic sugars are correct (please add those corner cases into tests 
accordingly):
   - The TIR var `func->params[i]` is used only once in this PrimFunc, i.e. in 
`func->params[i]`
   - The buffer's data field `func->buffer_map[var]->data` is not shared with 
other buffers in `func->buffer_map`
   
   I know the existing TVMScript printer isn't carefully implemented to take 
this into consideration, but let's do this right this time :-)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to