gemini-code-assist[bot] commented on code in PR #18764:
URL: https://github.com/apache/tvm/pull/18764#discussion_r2796508268


##########
src/relax/transform/fold_constant.cc:
##########
@@ -137,9 +137,51 @@ class ConstantFolder : public ExprMutator {
    * of the program.
    */
   bool ShouldBeFolded(Expr expr) {
-    // TODO(prakalp): Implement a heuristic to check if folding this expr is 
actually useful or
-    // not.
-    return true;
+    // Skip folding for creation ops (no tensor inputs) that produce large 
outputs.
+    // These ops (e.g., zeros, ones, full, arange) are cheap to compute at 
runtime,
+    // and folding them would materialize large constants in the binary.
+    static constexpr int64_t kMaxFoldElements = 1024;
+
+    const auto* call = expr.as<CallNode>();
+    if (!call) return true;
+
+    const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
+    if (!tensor_sinfo) return true;
+
+    auto opt_shape = tensor_sinfo->GetShape();
+    if (!opt_shape) return true;
+
+    int64_t num_elements = 1;
+    for (const auto& dim : opt_shape.value()) {
+      const auto* int_dim = dim.as<IntImmNode>();
+      if (!int_dim) return true;
+      num_elements *= int_dim->value;

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The direct multiplication `num_elements *= int_dim->value;` can lead to an 
integer overflow if the tensor has a very large number of elements, which is 
undefined behavior in C++. It's safer to check for potential overflow before 
performing the multiplication. Additionally, you can optimize by breaking the 
loop as soon as `num_elements` exceeds `kMaxFoldElements`.
   
   ```suggestion
         int64_t d = int_dim->value;
         if (d < 0) {
           // Dimensions should be non-negative.
           return true;
         }
         if (d == 0) {
           num_elements = 0;
           break;
         }
         // Check for overflow and if we've already exceeded the threshold.
         if (num_elements > kMaxFoldElements / d) {
           num_elements = kMaxFoldElements + 1;
           break;
         }
         num_elements *= d;
   ```



##########
src/relax/transform/fold_constant.cc:
##########
@@ -137,9 +137,51 @@ class ConstantFolder : public ExprMutator {
    * of the program.
    */
   bool ShouldBeFolded(Expr expr) {
-    // TODO(prakalp): Implement a heuristic to check if folding this expr is 
actually useful or
-    // not.
-    return true;
+    // Skip folding for creation ops (no tensor inputs) that produce large 
outputs.
+    // These ops (e.g., zeros, ones, full, arange) are cheap to compute at 
runtime,
+    // and folding them would materialize large constants in the binary.
+    static constexpr int64_t kMaxFoldElements = 1024;
+
+    const auto* call = expr.as<CallNode>();
+    if (!call) return true;
+
+    const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
+    if (!tensor_sinfo) return true;
+
+    auto opt_shape = tensor_sinfo->GetShape();
+    if (!opt_shape) return true;
+
+    int64_t num_elements = 1;
+    for (const auto& dim : opt_shape.value()) {
+      const auto* int_dim = dim.as<IntImmNode>();
+      if (!int_dim) return true;
+      num_elements *= int_dim->value;
+    }
+
+    if (num_elements <= kMaxFoldElements) return true;
+
+    // Large output. Only skip if there are no tensor inputs,
+    // i.e., this is a pure creation op.
+    bool has_tensor_input = false;
+    for (const auto& arg : call->args) {
+      auto arg_sinfo = GetStructInfo(arg);
+      if (arg_sinfo.as<TensorStructInfoNode>()) {
+        has_tensor_input = true;
+        break;
+      }
+      // Check for tuple of tensors (e.g., call_tir args)
+      if (const auto* tuple = arg.as<TupleNode>()) {
+        for (const auto& field : tuple->fields) {
+          if (GetStructInfo(field).as<TensorStructInfoNode>()) {
+            has_tensor_input = true;
+            break;
+          }
+        }
+        if (has_tensor_input) break;
+      }
+    }
+
+    return has_tensor_input;

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current logic for detecting tensor inputs doesn't handle nested tuples 
(e.g., a tuple containing another tuple of tensors). This could lead to 
incorrect folding decisions for operations with such argument structures.
   
   A recursive helper function would make the logic cleaner, more robust, and 
correct for nested structures. I recommend adding a private static helper 
method to `ConstantFolder` like this:
   
   ```cpp
     static bool ExprContainsTensor(const Expr& expr) {
       if (GetStructInfo(expr).as<TensorStructInfoNode>()) {
         return true;
       }
       if (const auto* tuple = expr.as<TupleNode>()) {
         for (const auto& field : tuple->fields) {
           if (ExprContainsTensor(field)) {
             return true;
           }
         }
       }
       return false;
     }
   ```
   
   Using this helper simplifies the main logic and correctly handles 
arbitrarily nested tuples. It would also be beneficial to add a test case for 
this scenario.
   
   ```suggestion
       // Large output. Only skip if there are no tensor inputs,
       // i.e., this is a pure creation op.
       for (const auto& arg : call->args) {
         if (ExprContainsTensor(arg)) {
           return true;
         }
       }
   
       return false;
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to