gemini-code-assist[bot] commented on code in PR #18410:
URL: https://github.com/apache/tvm/pull/18410#discussion_r2480208310


##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -5890,12 +5890,18 @@ def main(x: R.Tensor((5, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((5, 3), dtype
                 lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
                     x, axis=1, descending=True, dtype="int32"
                 )
-                gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,)
+                lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, 
lv, axis=1)
+                lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 
3), dtype="int32")) = (
+                    lv1,
+                    lv,
+                )
+                lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
+                gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The generated IR for `argsort` is unnecessarily complex. It computes both 
sorted values (`lv1`) and indices (`lv`), then creates a tuple (`lv2`), only to 
extract the indices (`lv3`). The sorted values are computed via 
`gather_elements` but are never used for the final result. This seems to be a 
result of decomposing `torch.argsort` into `torch.sort` and then taking the 
indices.
   
   While a Dead Code Elimination (DCE) pass might clean this up, it would be 
more efficient to have a more direct translation for `argsort` that doesn't 
compute the sorted values if they are not needed.



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6216,13 +6266,24 @@ def main(
             input: R.Tensor((9, 9), dtype="float32")
         ) -> R.Tuple(R.Tensor((9,), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 
0.125, dtype="float32")
-                gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
+                lv: R.Tensor((9,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(9), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, 
"int64"))
+                lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, 
dtype="float32")
+                lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, 
R.const(0.125, "float32"))
+                lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, 
"float32"))
+                lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, 
"int64"), lv)
+                lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, 
dtype="float32")
+                lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, 
R.const(0.125, "float32"))
+                lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, 
"float32"), lv7)
+                lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
+                gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The generated IR for `linspace` is highly inefficient. It computes two 
expressions, `lv4` and `lv8`, which are mathematically equivalent for the given 
inputs (`i * 0.125` for `i` in `[0, 8]`). Then it uses `R.where` to select 
between these identical values. The entire `where` operation and the 
computation of `lv5` through `lv8` are redundant. The IR could be simplified to 
just compute `lv4` and use that as the result. This suggests an issue in the 
PyTorch decomposition logic for `linspace` that should be investigated.



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6187,21 +6220,38 @@ def forward(self, x):
     @tvm.script.ir_module
     class Expected1:
         @R.function
-        def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32")):
+        def main(x: R.Tensor((4, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((4,), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, 
axis=-1)
-                lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
-                    lv,
-                    targets=R.const([0, 1, 2, 1], dtype="int64"),
-                    reduction="mean",
-                    ignore_index=-100,
+                lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, 
axis=1)
+                lv1: R.Tensor((4,), dtype="bool") = R.not_equal(
+                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
+                )
+                lv2: R.Tensor((), dtype="int64") = R.const(0, "int64")
+                lv3: R.Tensor((4,), dtype="int64") = R.where(
+                    lv1, R.const([0, 1, 2, 1], dtype="int64"), lv2
                 )
-                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+                lv4: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv3, 
axis=[1])
+                lv5: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv, 
lv4, axis=1)
+                lv6: R.Tensor((4,), dtype="float32") = R.squeeze(lv5, axis=[1])
+                lv7: R.Tensor((4,), dtype="float32") = R.negative(lv6)
+                lv8: R.Tensor((4,), dtype="bool") = R.not_equal(
+                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
+                )
+                lv9: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+                lv10: R.Tensor((4,), dtype="float32") = R.where(lv8, lv7, lv9)
+                lv11: R.Tensor((4,), dtype="bool") = R.not_equal(
+                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
+                )
+                lv12: R.Tensor((4,), dtype="bool") = R.sum(lv11, axis=[], 
keepdims=False)
+                lv13: R.Tensor((4,), dtype="float32") = R.astype(lv12, 
dtype="float32")
+                lv14: R.Tensor((4,), dtype="float32") = R.sum(lv10, axis=[], 
keepdims=False)
+                lv15: R.Tensor((4,), dtype="float32") = R.divide(lv14, lv13)
+                gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv15,)
                 R.output(gv)
             return gv

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The expected IR for `test_cross_entropy` has incorrect shape annotations. 
`torch.nn.CrossEntropyLoss` with the default `reduction='mean'` should return a 
scalar tensor. However, the return type of the `main` function is annotated as 
`R.Tuple(R.Tensor((4,), dtype="float32"))`.
   
   Looking at the IR, `lv12` and `lv14` are results of `R.sum` with `axis=[]`, 
which should produce scalar tensors (shape `()`), but they are annotated with 
shape `(4,)`. Consequently, `lv15` (the final result) is also annotated with 
shape `(4,)` instead of `()`. The function signature and intermediate type 
annotations should be corrected to reflect that a scalar is being computed.



##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1722,6 +1722,9 @@ def _split(self, node: fx.Node) -> relax.Var:
     def _squeeze(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        # Support both "dim" and "dims" parameters
+        if dim is None:
+            dim = node.kwargs.get("dims", None)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic to get `dim` can be simplified into a single line by chaining 
`dict.get` calls. This makes the code more concise and easier to read.
   
   ```suggestion
           dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
node.kwargs.get("dims"))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to