gemini-code-assist[bot] commented on code in PR #18417:
URL: https://github.com/apache/tvm/pull/18417#discussion_r2485351378


##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4580,47 +4580,51 @@ def forward(self, x, y):
     class Expected0:
         @R.function
         def main(
-            inp_0: R.Tensor((2, 3), dtype="float32"),
-            inp_1: R.Tensor((2, 3), dtype="float32"),
+            x: R.Tensor((2, 3), dtype="float32"),
+            y: R.Tensor((2, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=0)
-                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), 
axis=0)
+                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, 
R.shape([2, 2, 3]))
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The decomposition of `torch.stack` with `axis=0` into `concat` and `reshape` 
is valid for contiguous tensors, but it's not as canonical as using 
`expand_dims` and `concat`. The decomposition for `axis=-1` in `Expected3` uses 
`expand_dims` and `concat`, which is a more direct translation of the `stack` 
operation's semantics of inserting a new dimension. For consistency and 
clarity, it would be better to use the same `expand_dims` and `concat` approach 
for `axis=0`. This would make the decomposition logic more robust and easier to 
understand across different axes.
   
   For example:
   ```python
   lv: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(x, axis=0)
   lv1: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(y, axis=0)
   lv2: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=0)
   gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv2,)
   ```



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4580,47 +4580,51 @@ def forward(self, x, y):
     class Expected0:
         @R.function
         def main(
-            inp_0: R.Tensor((2, 3), dtype="float32"),
-            inp_1: R.Tensor((2, 3), dtype="float32"),
+            x: R.Tensor((2, 3), dtype="float32"),
+            y: R.Tensor((2, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=0)
-                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), 
axis=0)
+                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, 
R.shape([2, 2, 3]))
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
     @I.ir_module
     class Expected1:
         @R.function
         def main(
-            inp_0: R.Tensor((2, 3), dtype="float32"),
-            inp_1: R.Tensor((2, 3), dtype="float32"),
+            x: R.Tensor((2, 3), dtype="float32"),
+            y: R.Tensor((2, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=1)
-                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y), 
axis=1)
+                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, 
R.shape([2, 2, 3]))
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Similar to the `axis=0` case, the decomposition of `torch.stack` with 
`axis=1` into `concat` and `reshape` is less canonical than using `expand_dims` 
and `concat`. Using a consistent decomposition strategy across all axes would 
improve the robustness and readability of the translated IR.
   
   For example:
   ```python
   lv: R.Tensor((2, 1, 3), dtype="float32") = R.expand_dims(x, axis=1)
   lv1: R.Tensor((2, 1, 3), dtype="float32") = R.expand_dims(y, axis=1)
   lv2: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=1)
   gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv2,)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to