This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7c28c86f7d [Relax][PyTorch] Support binary, statistical and search ops 
for ExportedProgram importer (#17424)
7c28c86f7d is described below

commit 7c28c86f7d3121ce2adc179475fdb1922c86b942
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Sep 28 22:30:15 2024 +0900

    [Relax][PyTorch] Support binary, statistical and search ops for 
ExportedProgram importer (#17424)
    
    * support binary ops
    
    * support mean
    
    * support sum
    
    * support argmax and argmin
---
 .../frontend/torch/base_fx_graph_translator.py     |  62 +++
 .../frontend/torch/exported_program_translator.py  |  25 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  62 ---
 .../relax/test_frontend_from_exported_program.py   | 512 +++++++++++++++++++++
 4 files changed, 599 insertions(+), 62 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d52b3d598f..a41b9b6d4f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -185,6 +185,39 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return convert
 
+    ########## Binary Ops ##########
+
+    def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> 
Callable:
+        from torch import fx
+
+        def convert(node: fx.Node) -> relax.Var:
+            def promote_binary_op_args(lhs, rhs):
+                if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+                    return lhs, rhs
+                elif isinstance(lhs, relax.Expr):
+                    assert isinstance(lhs.struct_info, relax.TensorStructInfo)
+                    return lhs, relax.const(rhs, lhs.struct_info.dtype)
+                elif isinstance(rhs, relax.Expr):
+                    assert isinstance(rhs.struct_info, relax.TensorStructInfo)
+                    return relax.const(lhs, rhs.struct_info.dtype), rhs
+                else:
+                    assert False
+
+            def call_binary_op(op, lhs, rhs):
+                lhs, rhs = promote_binary_op_args(lhs, rhs)
+                return self.block_builder.emit(op(lhs, rhs))
+
+            lhs, rhs = self.retrieve_args(node)
+            if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+                return call_binary_op(relax_op, lhs, rhs)
+            elif isinstance(lhs, relax.expr.Constant):
+                return call_binary_op(relax_op, lhs, relax.const(rhs, 
dtype=lhs.struct_info.dtype))
+            elif isinstance(rhs, relax.expr.Constant):
+                return call_binary_op(relax_op, relax.const(lhs, 
dtype=rhs.struct_info.dtype), rhs)
+            return intrinsic_op(lhs, rhs)
+
+        return convert
+
     ########## Neural Network ##########
 
     def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
@@ -283,6 +316,35 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    ########## Statistical ##########
+
+    def _mean(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
+
+    def _sum(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
+        if len(args) == 1:
+            return self.block_builder.emit(relax.op.sum(args[0], 
keepdims=keepdim))
+        return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+
+    ########## Search ##########
+
+    def _argmax_argmin(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.Node):
+            x = self.env[node.args[0]]
+            dim = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("dim", None)
+            keepdim = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+            return self.block_builder.emit(op(x, dim, keepdim))
+
+        return convert
+
     ########## Manipulation ##########
 
     def _reshape(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1ceddad7d7..11594690cd 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -19,6 +19,7 @@
 # pylint: disable=import-outside-toplevel
 """PyTorch ExportedProgram of Relax."""
 from collections import ChainMap, OrderedDict
+from functools import partial
 from typing import Callable, Dict, List, Tuple
 
 import torch
@@ -76,6 +77,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
     def create_convert_map(
         self,
     ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
+        import operator
+
         return {
             # unary
             "acos.default": self._unary_op(relax.op.acos),
@@ -109,11 +112,33 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "tanh.default": self._unary_op(relax.op.tanh),
             "tril.default": self._tril_triu(relax.op.tril),
             "triu.default": self._tril_triu(relax.op.triu),
+            # binary
+            "add.Tensor": self._binary_op(relax.op.add, operator.add),
+            "div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
+            "eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
+            "eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
+            "floor_divide.default": self._binary_op(relax.op.floor_divide, 
operator.floordiv),
+            "lt.Scalar": self._binary_op(relax.op.less, operator.lt),
+            "lt.Tensor": self._binary_op(relax.op.less, operator.lt),
+            "matmul.default": self._binary_op(
+                partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
+            ),
+            "max.other": self._binary_op(relax.op.maximum, max),
+            "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
+            "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
+            "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
+            "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
             # neural network
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
             "conv2d.default": self._conv2d,
             "linear.default": self._linear,
             "max_pool2d.default": self._max_pool2d,
+            # statistical
+            "mean.dim": self._mean,
+            "sum.dim_IntList": self._sum,
+            # search
+            "argmax.default": self._argmax_argmin(relax.op.argmax),
+            "argmin.default": self._argmax_argmin(relax.op.argmin),
             # tensor manipulation
             "view.default": self._reshape,
         }
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6f7c6fa2c5..dc6ebc2eb3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -96,39 +96,6 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return convert
 
-    ########## Binary Ops ##########
-
-    def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> 
Callable:
-        from torch import fx
-
-        def convert(node: fx.Node) -> relax.Var:
-            def promote_binary_op_args(lhs, rhs):
-                if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
-                    return lhs, rhs
-                elif isinstance(lhs, relax.Expr):
-                    assert isinstance(lhs.struct_info, relax.TensorStructInfo)
-                    return lhs, relax.const(rhs, lhs.struct_info.dtype)
-                elif isinstance(rhs, relax.Expr):
-                    assert isinstance(rhs.struct_info, relax.TensorStructInfo)
-                    return relax.const(lhs, rhs.struct_info.dtype), rhs
-                else:
-                    assert False
-
-            def call_binary_op(op, lhs, rhs):
-                lhs, rhs = promote_binary_op_args(lhs, rhs)
-                return self.block_builder.emit(op(lhs, rhs))
-
-            lhs, rhs = self.retrieve_args(node)
-            if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
-                return call_binary_op(relax_op, lhs, rhs)
-            elif isinstance(lhs, relax.expr.Constant):
-                return call_binary_op(relax_op, lhs, relax.const(rhs, 
dtype=lhs.struct_info.dtype))
-            elif isinstance(rhs, relax.expr.Constant):
-                return call_binary_op(relax_op, relax.const(lhs, 
dtype=rhs.struct_info.dtype), rhs)
-            return intrinsic_op(lhs, rhs)
-
-        return convert
-
     ########## Neural Network ##########
 
     def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
@@ -794,35 +761,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             ret.append(self.block_builder.emit(relax.op.squeeze(split[i], 
axis=dim)))
         return self.block_builder.emit(relax.Tuple(ret))
 
-    ########## Statistical ##########
-
-    def _mean(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
-        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
-        return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
-
-    def _sum(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
-        if len(args) == 1:
-            return self.block_builder.emit(relax.op.sum(args[0], 
keepdims=keepdim))
-        return self.block_builder.emit(relax.op.sum(args[0], args[1]))
-
-    ########## Search ##########
-
-    def _argmax_argmin(self, op: Callable) -> Callable:
-        from torch import fx
-
-        def convert(node: fx.Node):
-            x = self.env[node.args[0]]
-            dim = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("dim", None)
-            keepdim = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
-            return self.block_builder.emit(op(x, dim, keepdim))
-
-        return convert
-
     ########## Manipulation ##########
 
     def _cat(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 6c17d96004..25e6dbfae3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -790,6 +790,372 @@ def test_tril_triu():
     verify_model(Triu(), example_args, {}, expected_triu)
 
 
+def test_binary():
+    example_args1 = (
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+    )
+    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
+
+    # Add
+    class Add1(Module):
+        def forward(self, lhs, rhs):
+            return lhs + rhs
+
+    @tvm.script.ir_module
+    class expected_add1:
+        @R.function
+        def main(
+            lhs: R.Tensor((10, 10), dtype="float32"),
+            rhs: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs, rhs)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class Add2(Module):
+        def forward(self, lhs):
+            return lhs + 1.0
+
+    @tvm.script.ir_module
+    class expected_add2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Add1(), example_args1, {}, expected_add1)
+    verify_model(Add2(), example_args2, {}, expected_add2)
+
+    # True div
+    class TrueDiv1(Module):
+        def forward(self, lhs, rhs):
+            return lhs / rhs
+
+    @tvm.script.ir_module
+    class expected_truediv1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, 
rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class TrueDiv2(Module):
+        def forward(self, lhs):
+            return lhs / 1.0
+
+    @tvm.script.ir_module
+    class expected_truediv2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(TrueDiv1(), example_args1, {}, expected_truediv1)
+    verify_model(TrueDiv2(), example_args2, {}, expected_truediv2)
+
+    # EQ
+    class EQ1(Module):
+        def forward(self, lhs, rhs):
+            return lhs == rhs
+
+    @tvm.script.ir_module
+    class expected_eq1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class EQ2(Module):
+        def forward(self, lhs):
+            return lhs == 1.0
+
+    @tvm.script.ir_module
+    class expected_eq2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(EQ1(), example_args1, {}, expected_eq1)
+    verify_model(EQ2(), example_args2, {}, expected_eq2)
+
+    # Floor div
+    class FloorDiv1(Module):
+        def forward(self, lhs, rhs):
+            return lhs // rhs
+
+    @tvm.script.ir_module
+    class expected_floordiv1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = 
R.floor_divide(lhs_1, rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class FloorDiv2(Module):
+        def forward(self, lhs):
+            return lhs // 1.0
+
+    @tvm.script.ir_module
+    class expected_floordiv2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = 
R.floor_divide(lhs_1, R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(FloorDiv1(), example_args1, {}, expected_floordiv1)
+    verify_model(FloorDiv2(), example_args2, {}, expected_floordiv2)
+
+    # LT
+    class LT1(Module):
+        def forward(self, lhs, rhs):
+            return lhs < rhs
+
+    @tvm.script.ir_module
+    class expected_lt1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class LT2(Module):
+        def forward(self, lhs):
+            return lhs < 1.0
+
+    @tvm.script.ir_module
+    class expected_lt2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(LT1(), example_args1, {}, expected_lt1)
+    verify_model(LT2(), example_args2, {}, expected_lt2)
+
+    # MatMul
+    class MatMul1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return torch.matmul(x, y)
+
+    @tvm.script.ir_module
+    class expected_matmul1:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32"),
+            input_2: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(
+                    input_1, input_2, out_dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(MatMul1(), example_args1, {}, expected_matmul1)
+
+    # Max
+    class Max1(Module):
+        def forward(self, x, y):
+            return torch.max(x, y)
+
+    @I.ir_module
+    class expected_max1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((10, 10), dtype="float32"),
+            inp_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.maximum(inp_0, 
inp_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Max1(), example_args1, {}, expected_max1)
+
+    # Mul
+    class Mul1(Module):
+        def forward(self, lhs, rhs):
+            return lhs * rhs
+
+    @tvm.script.ir_module
+    class expected_mul1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, 
rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class Mul2(Module):
+        def forward(self, lhs):
+            return lhs * 1.0
+
+    @tvm.script.ir_module
+    class expected_mul2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Mul1(), example_args1, {}, expected_mul1)
+    verify_model(Mul2(), example_args2, {}, expected_mul2)
+
+    # Power
+    class Power1(Module):
+        def forward(self, lhs, rhs):
+            return lhs**rhs
+
+    @tvm.script.ir_module
+    class expected_power1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class Power2(Module):
+        def forward(self, lhs):
+            return lhs**1.0
+
+    @tvm.script.ir_module
+    class expected_power2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Power1(), example_args1, {}, expected_power1)
+    verify_model(Power2(), example_args2, {}, expected_power2)
+
+    # Sub
+    class Sub1(Module):
+        def forward(self, lhs, rhs):
+            return lhs - rhs
+
+    @tvm.script.ir_module
+    class expected_sub1:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            rhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, 
rhs_1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class Sub2(Module):
+        def forward(self, lhs):
+            return lhs - 1.0
+
+    @tvm.script.ir_module
+    class expected_sub2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Sub1(), example_args1, {}, expected_sub1)
+    verify_model(Sub2(), example_args2, {}, expected_sub2)
+
+
 def test_adaptive_avgpool2d():
     class AdaptiveAvgPool2d0(Module):
         def __init__(self):
@@ -1094,6 +1460,152 @@ def test_maxpool2d():
     verify_model(MaxPool2d3(), example_args, {}, expected3)
 
 
+def test_mean():
+    class Mean(Module):
+        def forward(self, input):
+            return input.mean(-1)
+
+    class MeanKeepDim(Module):
+        def forward(self, input: torch.Tensor):
+            return input.mean(-1, keepdim=True)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((256,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, 
axis=[-1], keepdims=False)
+                gv: R.Tuple(R.Tensor((256,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, 
axis=[-1], keepdims=True)
+                gv: R.Tuple(R.Tensor((256, 1), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(256, 256, dtype=torch.float32),)
+    verify_model(Mean(), example_args, {}, Expected1)
+    verify_model(MeanKeepDim(), example_args, {}, Expected2)
+
+
+def test_sum():
+    class Sum(Module):
+        def forward(self, x):
+            return torch.sum(x, (2, 1))
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 
1], keepdims=False)
+                gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(Sum(), example_args, {}, expected1)
+
+
+def test_argmax_argmin():
+    example_args = (torch.randn(256, 256, dtype=torch.float32),)
+
+    class Argmax1(Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+        def forward(self, input):
+            return torch.argmax(input, dim=-1)
+
+    class Argmax2(Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+        def forward(self, input):
+            return torch.argmax(input, dim=-1, keepdim=True)
+
+    @tvm.script.ir_module
+    class expected_argmax1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((256,), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_argmax2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, 
axis=-1, keepdims=True)
+                gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Argmax1(), example_args, {}, expected_argmax1)
+    verify_model(Argmax2(), example_args, {}, expected_argmax2)
+
+    class Argmin1(Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+        def forward(self, input):
+            return torch.argmin(input)
+
+    class Argmin2(Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+        def forward(self, input):
+            return torch.argmin(input, keepdim=True)
+
+    @tvm.script.ir_module
+    class expected_argmin1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_argmin2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, 
axis=None, keepdims=True)
+                gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Argmin1(), example_args, {}, expected_argmin1)
+    verify_model(Argmin2(), example_args, {}, expected_argmin2)
+
+
 def test_view():
     class View(Module):
         def forward(self, x):

Reply via email to