Lunderberg commented on code in PR #16120:
URL: https://github.com/apache/tvm/pull/16120#discussion_r1432197350


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -719,64 +697,46 @@ class FusedTIRConstructor : public ExprVisitor {
   }
 
   /*!
-   * \brief Create an TIR func params and buffers with specified relax type 
and shape
+   * \brief Collect TIR func params and buffers with specified relax type and 
shape
    * \param struct_info The struct info
    * \param name_hint The name hint for params and buffers
-   * \param index The index used for unique name_hint if type is Tuple.
-   *              -1 means no need to add postfix since the relax param is not 
a Tuple.
-   * \return The created TIR func params and buffers
+   * \param out The vector into which to collect the params/buffers
    */
-  static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers(
-      StructInfo struct_info, const String& name_hint, int index = -1) {
-    Array<tir::Var> params;
-    Array<tir::Buffer> buffers;
-    // The symbolic shape params must be defined at the end of the param list.
-    bool symbolic_shape_param_started = false;
+  static void CollectPrimFuncParams(const Var& relax_param,
+                                    std::vector<Variant<tir::Var, 
tir::Buffer>>* out) {
+    auto struct_info = GetStructInfo(relax_param);
+
+    CHECK(!struct_info.as<TupleStructInfoNode>())
+        << "InternalError: "
+        << "All tuple parameters should be expanded before this point in 
FuseTIR.  "
+        << "However, parameter " << relax_param << " has struct info " << 
struct_info;
+
+    auto name_hint = relax_param->name_hint();
+
     if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
-      // Case 1. the relax param is a Tensor, we directly create a tir var and 
buffer
+      // Case 1. The relax param is a Tensor, we directly create a tir var and 
buffer
       const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
-      ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with 
symbolic shape.";
-      CHECK(!symbolic_shape_param_started)
-          << "The symbolic shape params must be defined at the end of the 
param "
-             "list.";
-      String name = index == -1 ? name_hint : name_hint + "_" + 
std::to_string(index);
+      ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a 
known shape.";
       DataType dtype = tensor->dtype;
-      tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name);
-      // Differentiate buffer name and param name by adding prefix `v_` to 
param
-      // Every symbol should be unique in TVMScript, and Buffer is used more 
than param
-      // So we decide to make sure buffer names have better readability.
-      tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle()));
-      params.push_back(std::move(param));
-      buffers.push_back(std::move(buffer));
-    } else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
-      // Case 2. the relax param is a Tuple, we recursively visit each field 
until it's a Tensor
-      // Enable postfix
-      CHECK(!symbolic_shape_param_started)
-          << "The symbolic shape params must be defined at the end of the 
param "
-             "list.";
-      if (index == -1) index = 0;
-      for (size_t i = 0; i < tuple->fields.size(); ++i) {
-        auto [ret_params, ret_buffers] = 
CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
-        ICHECK_EQ(ret_params.size(), ret_buffers.size());
-        // Adding tuple field results to the end of params and buffers.
-        params.insert(params.end(), ret_params.begin(), ret_params.end());
-        buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
-        index += ret_params.size();
-      }
+      tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, 
name_hint);
+      out->push_back(std::move(buffer));
+
+    } else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
+      // Case 2. The relax param is a scalar, we directly create a tir var
+      ICHECK(prim_value->value->IsInstance<tir::VarNode>());
+      out->push_back(Downcast<tir::Var>(prim_value->value));
+
     } else if (const auto* shape_expr = struct_info.as<ShapeStructInfoNode>()) 
{
-      // Case 3. the relax param is a scalar, we directly create a tir var
-      symbolic_shape_param_started = true;
-      ICHECK(index == -1) << "TypeError: The ShapeExprNode should not be in a 
Tuple field.";
+      // Case 3. The relax param is a tuple of scalars, each represented as a 
tir var
       for (const auto& var : shape_expr->values.value()) {
         ICHECK(var->IsInstance<tir::VarNode>());
-        params.push_back(Downcast<tir::Var>(var));
+        out->push_back(Downcast<tir::Var>(var));
       }
     } else {
       ICHECK(false) << "TypeError: The param type of PrimFunc is expected to 
be Tensor, Tuple or "

Review Comment:
   Good catch, and updated to remove `Tuple` from the list, and add `PrimValue` 
as it is now handled.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to