This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6248b5db43 [Relax][Torch] Fixed issues related to sum op when without 
dim and keep dim (#18583)
6248b5db43 is described below

commit 6248b5db43505fbcfb13cc289d11877d5d2649e8
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Sat Dec 13 14:29:23 2025 +0700

    [Relax][Torch] Fixed issues related to sum op when without dim and keep dim 
(#18583)
    
    ## Issue 1: Without Dim
    ### Summary:
    In _sum function (BaseFXGraphImporter), after retrieve_args, args[1] =
    [] and still pass into relax.op.sum so the result is incorrect.
    ### Steps to Reproduce
    - Module
    ```
    class SumWithoutDim(nn.Module):
        def forward(self, x):
            return torch.sum(x)
    ```
    ```
    class Module:
        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 
3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3), dtype="float32") = R.sum(x, axis=[], 
keepdims=False)
                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    ```
    - Result:
    
    Input: tensor([[1., 1., 1.], [1., 1., 1.]])
    Torch output: tensor(6.)
    Torch output shape: torch.Size([])
    TVM output: [[1. 1. 1.]  [1. 1. 1.]]
    TVM output shape: (2, 3)
    ### Expected
    ```
    class Module:
        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, 
keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    ```
    - Result: TVM output: 6.0; TVM output shape: ()
    
    ## Issue 2: Keep Dim
    ### Summary:
    In _sum function (BaseFXGraphImporter), previously keepdim value get
    only from node.kwargs and no pass into relax.op.sum. Now keepdim get
    more from args[2] and pass into.
    ### Steps to Reproduce
    - Module
    ```
    class SumKeepDim(nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1, keepdim=True)
    ```
    ```
    class Module:
        def main(x: R.Tensor((2, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((2,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2,), dtype="float32") = R.sum(x, axis=[1], 
keepdims=False)
                gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    
    ```
    - Result:
    
    Input: tensor([[1., 1., 1.], [1., 1., 1.]])
    Torch output: tensor([[3.], [3.]])
    Torch output shape: torch.Size([2, 1])
    TVM VM output: [3. 3.]
    TVM VM output shape: (2,)
    ### Expected
    ```
    class Module:
        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 
1), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 1), dtype="float32") = R.sum(x, axis=[1], 
keepdims=True)
                gv: R.Tuple(R.Tensor((2, 1), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    ```
    - Result: TVM output: [[3.] [3.]] ;TVM output shape: (2, 1)
---
 .../frontend/torch/base_fx_graph_translator.py     | 10 +++--
 .../relax/test_frontend_from_exported_program.py   | 48 +++++++++++++++++++---
 2 files changed, 48 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 47eb666210..f7d54a6216 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1628,10 +1628,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     def _sum(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
-        keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
-        if len(args) == 1:
-            return self.block_builder.emit(relax.op.sum(args[0], 
keepdims=keepdim))
-        return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        if isinstance(dim, (list, tuple)) and len(dim) == 0:
+            dim = None
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
 
     def _var(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 01e16e7564..4a84b50cc9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4945,6 +4945,14 @@ def test_sum():
         def forward(self, x):
             return torch.sum(x, (2, 1))
 
+    class SumKeepDim(Module):
+        def forward(self, x):
+            return torch.sum(x, (2, 1), keepdim=True)
+
+    class SumWithoutDim(Module):
+        def forward(self, x):
+            return torch.sum(x)
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -4958,8 +4966,36 @@ def test_sum():
                 R.output(gv)
             return gv
 
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum(
+                    inp_0, axis=[2, 1], keepdims=True
+                )
+                gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
     verify_model(Sum(), example_args, {}, expected1)
+    verify_model(SumKeepDim(), example_args, {}, expected2)
+    verify_model(SumWithoutDim(), example_args, {}, expected3)
 
 
 def test_argmax_argmin():
@@ -7840,7 +7876,7 @@ def test_cross_entropy():
     @tvm.script.ir_module
     class Expected1:
         @R.function
-        def main(x: R.Tensor((4, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((4,), dtype="float32")):
+        def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32")):
             with R.dataflow():
                 lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, 
dtype="float32")
                 lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, 
axis=1)
@@ -7863,11 +7899,11 @@ def test_cross_entropy():
                 lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
                     R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
                 )
-                lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], 
keepdims=False)
-                lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, 
dtype="float32")
-                lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], 
keepdims=False)
-                lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
-                gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
+                lv13: R.Tensor((), dtype="bool") = R.sum(lv12, axis=None, 
keepdims=False)
+                lv14: R.Tensor((), dtype="float32") = R.astype(lv13, 
dtype="float32")
+                lv15: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, 
keepdims=False)
+                lv16: R.Tensor((), dtype="float32") = R.divide(lv15, lv14)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv16,)
                 R.output(gv)
             return gv
 

Reply via email to