slyubomirsky commented on code in PR #16120:
URL: https://github.com/apache/tvm/pull/16120#discussion_r1431894344


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -385,58 +385,45 @@ class FusedTIRConstructor : public ExprVisitor {
       : mod_(mod), func_name_(func_name) {}
 
   void VisitExpr_(const FunctionNode* func) final {
-    // Step 1. Create buffers for function params
-
-    // Record which fields in a tuple passed as a parameter are actually 
accessed by the function.
-    std::unordered_set<const Object*> tuple_param;
-    for (auto param : func->params) {
-      if (GetStructInfo(param)->IsInstance<TupleStructInfoNode>()) {
-        tuple_param.insert(param.get());
-      }
-    }
-
-    PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
-      if (auto tup_get = e.as<TupleGetItemNode>();
-          tup_get && tuple_param.count(tup_get->tuple.get())) {
-        
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
-      }
-    });
-
+    std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
     for (const Var& relax_param : func->params) {
-      auto sinfo = GetStructInfo(relax_param);
-      if (sinfo->IsInstance<ShapeStructInfoNode>()) {
-        // It's a symbolic shape var, no need to alloc Buffers.
-        continue;
-      }
-
-      auto [params, buffers] = [=]() {
-        if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
-          // Add only those tuple fields which are actually used by the 
function body into the
-          // function parameters.
-          int index = 0;
-          Array<tir::Var> params;
-          Array<tir::Buffer> buffers;
-          for (auto i : 
func_info_.used_tuple_field_indices[relax_param.get()]) {
-            auto [ret_params, ret_buffers] =
-                CreateParamsAndBuffers(tuple->fields[i], 
relax_param->name_hint(), index);
-            ICHECK_EQ(ret_params.size(), ret_buffers.size());
-            // Adding tuple field results to the end of params and buffers.
-            params.insert(params.end(), ret_params.begin(), ret_params.end());
-            buffers.insert(buffers.end(), ret_buffers.begin(), 
ret_buffers.end());
-            index += ret_params.size();
+      size_t size_before = prim_func_params.size();
+      CollectPrimFuncParams(relax_param, &prim_func_params);
+
+      auto param_buffers = [&]() -> Array<tir::Buffer> {
+        Array<tir::Buffer> out;
+        for (size_t i = size_before; i < prim_func_params.size(); i++) {
+          if (auto buf = prim_func_params[i].as<tir::Buffer>()) {
+            out.push_back(buf.value());
           }
-          return std::make_pair(params, buffers);
-        } else {
-          return CreateParamsAndBuffers(sinfo, relax_param->name_hint());
         }
+        return out;
       }();
 
-      ICHECK_EQ(params.size(), buffers.size());
-      for (size_t i = 0; i < params.size(); ++i) {
-        func_info_.buffer_map.Set(params[i], buffers[i]);
-        func_info_.params.push_back(params[i]);
+      func_info_.expr2buffers.Set(relax_param, param_buffers);
+    }
+
+    // Move all scalar params after buffer params.
+    std::stable_sort(prim_func_params.begin(), prim_func_params.end(),

Review Comment:
   I assume it's important for the relative ordering to be preserved (hence the 
stable sort), might be good to call that out.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to