This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new db96ee80e7 [Unity] Add More Ops For FX Translator (#14348)
db96ee80e7 is described below

commit db96ee80e72281742becd14a6edacd19b2f8a881
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Mar 22 11:00:20 2023 -0400

    [Unity] Add More Ops For FX Translator (#14348)
    
    This PR makes 2 changes:
    1. Add Relax Op Maximum and Minimum
    2. Add translation function for torch function/method silu, to, ones, full, 
masked_fill_, mean, rsqrt, neg, max in fx translator
---
 python/tvm/relax/frontend/torch/fx_translator.py   | 131 +++++++++-
 python/tvm/relax/op/binary.py                      |  37 +++
 python/tvm/relax/transform/legalize_ops/binary.py  |   3 +
 python/tvm/script/ir_builder/relax/ir.py           |   4 +
 src/relax/op/tensor/binary.cc                      |   5 +
 src/relax/op/tensor/binary.h                       |   8 +
 tests/python/relax/test_frontend_dynamo.py         | 137 +++++++++-
 tests/python/relax/test_frontend_from_fx.py        | 163 +++++++++++-
 tests/python/relax/test_op_binary.py               |   2 +
 .../relax/test_transform_legalize_ops_binary.py    | 280 +++++++++++++++++++++
 .../relax/test_tvmscript_parser_op_arith_cmp.py    |   2 +
 11 files changed, 765 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a2e2afe668..ef6793cc67 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -75,6 +75,8 @@ class TorchFXImporter:
             return "int64"
         elif input_type in ["int32", "torch.int32", torch.int32]:
             return "int32"
+        elif input_type in ["bool", "torch.bool", torch.bool]:
+            return "bool"
         else:
             raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
 
@@ -151,6 +153,15 @@ class TorchFXImporter:
             arg = relax.const(arg, "float32")
         return self.block_builder.emit(relax.op.sqrt(arg))
 
+    def _rsqrt(self, node: fx.node.Node) -> relax.Expr:
+        arg = self.env[node.args[0]]
+        if isinstance(arg, (int, float)):
+            arg = relax.const(arg, "float32")
+        sqrt = self.block_builder.emit(relax.op.sqrt(arg))
+        return self.block_builder.emit(
+            relax.op.divide(relax.const(1, sqrt.struct_info.dtype), sqrt)
+        )
+
     def _round(self, node: fx.node.Node) -> relax.Expr:
         if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
             raise ValueError("specifying decimals for round is not supported 
yet")
@@ -161,8 +172,21 @@ class TorchFXImporter:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.add, lhs, rhs)
+        elif isinstance(lhs, relax.expr.Constant):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, 
dtype=lhs.struct_info.dtype)
+            )
+        elif isinstance(rhs, relax.expr.Constant):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), 
rhs
+            )
         return lhs + rhs
 
+    def _max(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.maximum, lhs, rhs)
+
     def _floordiv(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -181,6 +205,10 @@ class TorchFXImporter:
             return self._call_binary_op(relax.op.power, lhs, rhs)
         return lhs**rhs
 
+    def _neg(self, node: fx.node.Node) -> relax.Expr:
+        x = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.negative(x))
+
     def _sub(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -279,7 +307,7 @@ class TorchFXImporter:
     def _tensor(self, node: fx.node.Node) -> relax.Var:
         dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
         if isinstance(node.args[0], float):
-            return relax.const(node.args[0], dtype if dtype is not None else 
"float64")
+            return relax.const(node.args[0], dtype if dtype is not None else 
"float32")
         elif isinstance(node.args[0], int):
             return relax.const(node.args[0], dtype if dtype is not None else 
"int64")
         raise ValueError("torch.tensor with value not a float or int is not 
accepted")
@@ -324,14 +352,65 @@ class TorchFXImporter:
             )
         )
 
+    def _ones(self, node: fx.node.Node) -> relax.Var:
+        import torch
+
+        args = self.retrieve_args(node)
+        size = args[0]
+        if not isinstance(size, (list, tuple)):
+            size = (size,)
+        size = relax.ShapeExpr(size)
+        dtype = (
+            TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+            if "dtype" in node.kwargs
+            else TorchFXImporter._convert_data_type(torch.get_default_dtype(), 
self.env)
+        )
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                relax.const(1, dtype),
+                dtype,
+            )
+        )
+
+    def _full(self, node: fx.node.Node) -> relax.Var:
+        import torch
+
+        args = self.retrieve_args(node)
+        size = args[0]
+        if not isinstance(size, (list, tuple)):
+            size = (size,)
+        size = relax.ShapeExpr(size)
+        dtype = (
+            TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+            if "dtype" in node.kwargs
+            else TorchFXImporter._convert_data_type(torch.get_default_dtype(), 
self.env)
+        )
+        value = args[1] if isinstance(args[1], relax.expr.Constant) else 
relax.const(args[1], dtype)
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                value,
+                dtype,
+            )
+        )
+
     ########## Statistical ##########
 
     def _sum(self, node: fx.node.Node) -> relax.Var:
         args = self.retrieve_args(node)
+        keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
         if len(args) == 1:
-            return self.block_builder.emit(relax.op.sum(args[0]))
+            return self.block_builder.emit(relax.op.sum(args[0], 
keepdims=keepdim))
         return self.block_builder.emit(relax.op.sum(args[0], args[1]))
 
+    def _mean(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
+        if len(args) == 1:
+            return self.block_builder.emit(relax.op.mean(args[0], 
keepdims=keepdim))
+        return self.block_builder.emit(relax.op.mean(args[0], args[1], 
keepdims=keepdim))
+
     ########## DataType ##########
 
     def _float(self, node: fx.node.Node) -> relax.Var:
@@ -345,6 +424,19 @@ class TorchFXImporter:
         dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
         return self.block_builder.emit(relax.op.astype(x, dtype))
 
+    def _to(self, node: fx.node.Node) -> relax.Var:
+        import torch
+
+        x = self.env[node.args[0]]
+        if len(node.args) == 2:
+            if isinstance(node.args[1], torch.dtype):
+                dtype = TorchFXImporter._convert_data_type(node.args[1], 
self.env)
+                return self.block_builder.emit(relax.op.astype(x, dtype))
+        elif "dtype" in node.kwargs:
+            dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], 
self.env)
+            return self.block_builder.emit(relax.op.astype(x, dtype))
+        return x
+
     ########## Linear Algebra ##########
 
     def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
@@ -500,6 +592,16 @@ class TorchFXImporter:
         values = self.block_builder.emit(relax.op.full_like(x, rx_value))
         return self.block_builder.emit(relax.op.where(mask, values, x))
 
+    def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        mask = self.env[node.args[1]]
+        value = node.args[2]
+        rx_value = relax.const(value)
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+        output = self.block_builder.emit(relax.op.where(mask, values, x))
+        self.env[node.args[0]] = output
+        return output
+
     ########## Search ##########
 
     def _argmax_argmin(self, op: Callable) -> Callable:
@@ -847,6 +949,10 @@ class TorchFXImporter:
             expand_dim = []
             i = 0
             shape = self.shape_of(x)
+            non_ellipsis_cnt = 0
+            for index in node.args[1]:
+                if isinstance(index, (int, slice)):
+                    non_ellipsis_cnt += 1
             for index in node.args[1]:
                 if isinstance(index, int):
                     begin.append(index)
@@ -862,6 +968,13 @@ class TorchFXImporter:
                     i = i + 1
                 elif index is None:
                     expand_dim.append(len(axes) + len(expand_dim))
+                elif index is Ellipsis:
+                    for _ in range(len(shape) - non_ellipsis_cnt):
+                        begin.append(0)
+                        end.append(shape[i])
+                        stride.append(1)
+                        axes.append(i)
+                        i += 1
                 else:
                     raise ValueError("Unsupported index type: " + 
str(type(index)))
             while i < len(shape):
@@ -869,7 +982,7 @@ class TorchFXImporter:
                 end.append(shape[i])
                 stride.append(1)
                 axes.append(i)
-                i = i + 1
+                i += 1
             sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, 
begin, end, stride))
             sliced_shape = list(self.shape_of(sliced))
             for i in expand_dim:
@@ -957,17 +1070,25 @@ class TorchFXImporter:
             "clamp": self._clamp,
             "relu": lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
             "gelu": lambda node: 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+            "silu": lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
             "tanh": lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
             "interpolate": self._interpolate,
             "size": self._size,
             "getattr": self._getattr,
             "getitem": self._getitem,
             "contiguous": lambda node: self.env[node.args[0]],
-            "to": lambda node: self.env[node.args[0]],
+            "to": self._to,
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
             "layer_norm": self._layer_norm,
             "index_select": self._index_select,
             "masked_fill": self._masked_fill,
+            "ones": self._ones,
+            "full": self._full,
+            "masked_fill_": self._inplace_masked_fill,
+            "mean": self._mean,
+            "rsqrt": self._rsqrt,
+            "neg": self._neg,
+            "max": self._max,
         }
 
     def from_fx(
@@ -1029,7 +1150,7 @@ class TorchFXImporter:
                         assert len(args) == 1
                         if (
                             unwrap_unit_return_tuple
-                            and isinstance(args[0], (tuple, relax.Tuple))
+                            and isinstance(args[0], (tuple, list, relax.Tuple))
                             and len(args[0]) == 1
                         ):
                             output = self.block_builder.emit_output(args[0][0])
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index ead59cdf7b..09a0c30f19 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -248,3 +248,40 @@ def not_equal(x1: Expr, x2: Expr) -> Expr:
         The computed result.
     """
     return _ffi_api.not_equal(x1, x2)  # type: ignore
+
+
+###################### Comparison operators ######################
+
+
+def maximum(x1: Expr, x2: Expr) -> Expr:
+    """Element-wise maximum
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.maximum(x1, x2)
+
+
+def minimum(x1: Expr, x2: Expr) -> Expr:
+    """Element-wise minimum
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.minimum(x1, x2)
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index ffda767233..897b676518 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -54,3 +54,6 @@ register_legalize("relax.greater_equal", 
_binary(topi.greater_equal))
 register_legalize("relax.less", _binary(topi.less))
 register_legalize("relax.less_equal", _binary(topi.less_equal))
 register_legalize("relax.not_equal", _binary(topi.not_equal))
+
+register_legalize("relax.maximum", _binary(topi.maximum))
+register_legalize("relax.minimum", _binary(topi.minimum))
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index ae0918a082..d344891609 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -82,9 +82,11 @@ from tvm.relax.op import (
     make_closure,
     matmul,
     max,
+    maximum,
     mean,
     memory,
     min,
+    minimum,
     multiply,
     negative,
     not_equal,
@@ -596,9 +598,11 @@ __all__ = [
     "make_closure",
     "matmul",
     "max",
+    "maximum",
     "mean",
     "memory",
     "min",
+    "minimum",
     "multiply",
     "negative",
     "not_equal",
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 30cd748308..96d1f01e8a 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -118,5 +118,10 @@ RELAX_REGISTER_CMP_OP_AND_IMPL(less);
 RELAX_REGISTER_CMP_OP_AND_IMPL(less_equal);
 RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal);
 
+/***************** Min/Max operators *****************/
+
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(minimum);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(maximum);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 086e37f883..e386f9019f 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -98,6 +98,14 @@ Expr less_equal(Expr x1, Expr x2);
 /*! \brief Broadcasted element-wise test for (lhs != rhs). */
 Expr not_equal(Expr x1, Expr x2);
 
+/***************** Min/Max *****************/
+
+/*! \brief Element-wise minimum */
+Expr minimum(Expr x1, Expr x2);
+
+/*! \brief Element-wise maximum */
+Expr maximum(Expr x1, Expr x2);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index 765ca9b6f0..72ea193a02 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -25,7 +25,9 @@ import tvm.testing
 import torch
 import torch._dynamo as dynamo
 from tvm.relax.frontend.torch import relax_dynamo
-from tvm.script.parser import relax as R, tir as T
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 def test_relax_dynamo():
@@ -230,5 +232,138 @@ def test_subgraph_capture():
     tvm.ir.assert_structural_equal(mod, expected)
 
 
+def verify_dynamo_model(torch_model, input_info, binding, expected):
+    import torch
+    import torch._dynamo as dynamo
+    from tvm.relax.frontend.torch import from_fx
+
+    args = []
+    for info in input_info:
+        args.append(torch.zeros(*info[0], dtype=_convert_data_type(info[1])))
+    graph_model = dynamo.export(torch_model, *args)[0]
+    mod = from_fx(graph_model, input_info, unwrap_unit_return_tuple=True)
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+    expected = relax.transform.BindParams("main", binding)(expected)
+    tvm.ir.assert_structural_equal(mod, expected)
+
+
+def _convert_data_type(input_type):
+    """converts the PyTorch scalar type input_type to a TVM dtype."""
+    import torch  # type: ignore
+
+    input_type = input_type.lower() if isinstance(input_type, str) else 
input_type
+    if input_type == "float32":
+        return torch.float32
+    elif input_type == "float16":
+        return torch.float16
+    elif input_type == "int64":
+        return torch.int64
+    elif input_type == "int32":
+        return torch.int32
+    elif input_type == "bool":
+        return torch.bool
+    else:
+        raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
+
+
[email protected]_gpu
+def test_ones():
+    import torch
+    from torch.nn import Module
+
+    class Ones(Module):
+        def forward(self, input):
+            return torch.ones((10, 10), dtype=torch.float32)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.full(
+                    R.shape([10, 10]), R.const(1, "float32"), dtype="float32"
+                )
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_dynamo_model(
+        Ones(),
+        [([256, 256], "float32")],
+        {},
+        Expected1,
+    )
+
+
[email protected]_gpu
+def test_full():
+    import torch
+    from torch.nn import Module
+
+    class Full(Module):
+        def forward(self, input):
+            return torch.full((10, 10), 1, dtype=torch.float32)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.full(
+                    R.shape([10, 10]), R.const(1, "float32"), dtype="float32"
+                )
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_dynamo_model(
+        Full(),
+        [([256, 256], "float32")],
+        {},
+        Expected1,
+    )
+
+
[email protected]_gpu
+def test_masked_fill():
+    import torch
+    from torch.nn import Module
+
+    class MaskedFill(Module):
+        def forward(self, mask, input):
+            return input.masked_fill(mask, 0)
+
+    class InplaceMaskedFill(Module):
+        def forward(self, mask, input):
+            input.masked_fill_(mask, 0)
+            return input
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="bool"), inp_1: R.Tensor((256, 
256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float32") = R.full_like(
+                    inp_1, R.const(0, "int32"), dtype="void"
+                )
+                lv1: R.Tensor((256, 256), dtype="float32") = R.where(inp_0, 
lv, inp_1)
+                gv: R.Tensor((256, 256), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_dynamo_model(
+        MaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], {}, 
Expected1
+    )
+    verify_dynamo_model(
+        InplaceMaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], 
{}, Expected1
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 2e69795d51..d201cb111c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -19,10 +19,13 @@ import pytest
 import tvm
 from tvm import relax
 import tvm.testing
-from tvm.script.parser import ir as I, relax as R, tir as T
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 def verify_model(torch_model, input_info, binding, expected):
+    import torch
     from torch import fx
     from tvm.relax.frontend.torch import from_fx
 
@@ -831,6 +834,10 @@ def test_silu():
         def forward(self, input):
             return self.silu(input)
 
+    class SiLU2(Module):
+        def forward(self, input):
+            return torch.nn.functional.silu(input)
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -845,6 +852,7 @@ def test_silu():
             return gv
 
     verify_model(SiLU(), input_info, {}, expected1)
+    verify_model(SiLU2(), input_info, {}, expected1)
 
 
 @tvm.testing.requires_gpu
@@ -2496,5 +2504,158 @@ def test_argmin():
     verify_model(Argmin2(), [([256, 256], "float32")], {}, Expected2)
 
 
[email protected]_gpu
+def test_to():
+    import torch
+    from torch.nn import Module
+
+    class To1(Module):
+        def forward(self, input):
+            return input.to(torch.float16)
+
+    class To2(Module):
+        def forward(self, input):
+            return input.to("cpu")
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float16"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float16") = R.astype(inp_0, 
dtype="float16")
+                gv: R.Tensor((256, 256), dtype="float16") = lv
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((256, 256), dtype="float32") = inp_0
+                R.output(gv)
+            return gv
+
+    verify_model(To1(), [([256, 256], "float32")], {}, Expected1)
+    verify_model(To2(), [([256, 256], "float32")], {}, Expected2)
+
+
[email protected]_gpu
+def test_mean():
+    import torch
+    from torch.nn import Module
+
+    class Mean(Module):
+        def forward(self, input):
+            return input.mean(-1)
+
+    class MeanKeepDim(Module):
+        def forward(self, input):
+            return input.mean(-1, keepdim=True)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> 
R.Tensor((256,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, 
axis=[-1], keepdims=False)
+                gv: R.Tensor((256,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 1), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, 
axis=[-1], keepdims=True)
+                gv: R.Tensor((256, 1), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Mean(), [([256, 256], "float32")], {}, Expected1)
+    verify_model(MeanKeepDim(), [([256, 256], "float32")], {}, Expected2)
+
+
[email protected]_gpu
+def test_rsqrt():
+    import torch
+    from torch.nn import Module
+
+    class Rsqrt(Module):
+        def forward(self, input):
+            return torch.rsqrt(input)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float32") = R.sqrt(inp_0)
+                lv1: R.Tensor((256, 256), dtype="float32") = 
R.divide(R.const(1, "float32"), lv)
+                gv: R.Tensor((256, 256), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1)
+
+
[email protected]_gpu
+def test_neg():
+    import torch
+    from torch.nn import Module
+
+    class Neg(Module):
+        def forward(self, input):
+            return -input
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float32") = R.negative(inp_0)
+                gv: R.Tensor((256, 256), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Neg(), [([256, 256], "float32")], {}, Expected1)
+
+
[email protected]_gpu
+def test_max():
+    import torch
+    from torch.nn import Module
+
+    class Max(Module):
+        def forward(self, x, y):
+            return torch.max(x, y)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32"),
+            inp_1: R.Tensor((256, 256), dtype="float32"),
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float32") = R.maximum(inp_0, 
inp_1)
+                gv: R.Tensor((256, 256), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], 
{}, Expected1)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index 56263bc4ee..809fe7e98f 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -54,6 +54,8 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     (relax.op.multiply,),
     (relax.op.power,),
     (relax.op.subtract,),
+    (relax.op.maximum,),
+    (relax.op.minimum,),
 )
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py 
b/tests/python/relax/test_transform_legalize_ops_binary.py
index 5847413713..dc14a0c3fd 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -1327,5 +1327,285 @@ def test_not_equal_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_maximum():
+    # fmt: off
+    @tvm.script.ir_module
+    class Maximum:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), 
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+            gv: R.Tensor((4, 3, 2, 3), "float32") = R.maximum(x, y)
+            return gv
+
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), 
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+            gv = R.call_tir(Expected.maximum, (x, y), R.Tensor((4, 3, 2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def maximum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), 
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), 
T.int64(2), T.int64(1)), "float32"), T_maximum: T.Buffer((T.int64(4), 
T.int64(3), T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), 
T.int64(3)):
+                with T.block("T_maximum"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], 
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_maximum[ax0, ax1, ax2, ax3])
+                    T_maximum[ax0, ax1, ax2, ax3] = 
T.max(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, 
T.int64(0)])
+    # fmt: on
+
+    mod = LegalizeOps()(Maximum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_maximum_with_arg0_constant_scalar():
+    # fmt: off
+    @tvm.script.ir_module
+    class Maximum:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), dtype="float32") = R.maximum(x, R.const(1, 
"float32"))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv = R.call_tir(Expected.maximum, (x,), R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_maximum"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[ax0, ax1])
+                    T.writes(T_maximum[ax0, ax1])
+                    T_maximum[ax0, ax1] = T.max(rxplaceholder[ax0, ax1], 
T.float32(1))
+    # fmt: on
+
+    mod = LegalizeOps()(Maximum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_maximum_with_arg1_constant_scalar():
+    # fmt: off
+    @tvm.script.ir_module
+    class Maximum:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), dtype="float32") = R.maximum(R.const(1, 
"float32"), x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv = R.call_tir(Expected.maximum, (x,), R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_maximum"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[ax0, ax1])
+                    T.writes(T_maximum[ax0, ax1])
+                    T_maximum[ax0, ax1] = T.max(T.float32(1), 
rxplaceholder[ax0, ax1])
+    # fmt: on
+
+    mod = LegalizeOps()(Maximum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_maximum_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Maximum:
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv: R.Tensor((a, b, c, d), "float32") = R.maximum(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv = R.call_tir(Expected.maximum, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def maximum(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_maximum: T.handle):
+            T.func_attr({"tir.noalias": True})
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
+            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
+            T_maximum = T.match_buffer(var_T_maximum, [a, b, c, d], 
dtype="float32")
+            for i0, i1, i2, i3 in T.grid(a, b, c, d):
+                with T.block("T_maximum"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], 
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_maximum[ax0, ax1, ax2, ax3])
+                    T_maximum[ax0, ax1, ax2, ax3] = 
T.max(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, 
T.int64(0)])
+    # fmt: on
+
+    mod = LegalizeOps()(Maximum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum():
+    # fmt: off
+    @tvm.script.ir_module
+    class Minimum:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), 
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+            gv: R.Tensor((4, 3, 2, 3), "float32") = R.minimum(x, y)
+            return gv
+
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), 
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+            gv = R.call_tir(Expected.minimum, (x, y), R.Tensor((4, 3, 2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def minimum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), 
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), 
T.int64(2), T.int64(1)), "float32"), T_minimum: T.Buffer((T.int64(4), 
T.int64(3), T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), 
T.int64(3)):
+                with T.block("T_minimum"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], 
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_minimum[ax0, ax1, ax2, ax3])
+                    T_minimum[ax0, ax1, ax2, ax3] = 
T.min(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, 
T.int64(0)])
+    # fmt: on
+
+    mod = LegalizeOps()(Minimum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum_with_arg0_constant_scalar():
+    # fmt: off
+    @tvm.script.ir_module
+    class Minimum:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), dtype="float32") = R.minimum(x, R.const(1, 
"float32"))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv = R.call_tir(Expected.minimum, (x,), R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_minimum"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[ax0, ax1])
+                    T.writes(T_minimum[ax0, ax1])
+                    T_minimum[ax0, ax1] = T.min(rxplaceholder[ax0, ax1], 
T.float32(1))
+    # fmt: on
+
+    mod = LegalizeOps()(Minimum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum_with_arg1_constant_scalar():
+    # fmt: off
+    @tvm.script.ir_module
+    class Minimum:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), dtype="float32") = R.minimum(R.const(1, 
"float32"), x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv = R.call_tir(Expected.minimum, (x,), R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_minimum"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[ax0, ax1])
+                    T.writes(T_minimum[ax0, ax1])
+                    T_minimum[ax0, ax1] = T.min(T.float32(1), 
rxplaceholder[ax0, ax1])
+    # fmt: on
+
+    mod = LegalizeOps()(Minimum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Minimum:
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv: R.Tensor((a, b, c, d), "float32") = R.minimum(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv = R.call_tir(Expected.minimum, (x, y), R.Tensor((a, b, c, d), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def minimum(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_T_minimum: T.handle):
+            T.func_attr({"tir.noalias": True})
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, 
d], dtype="float32")
+            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, 
T.int64(1)], dtype="float32")
+            T_minimum = T.match_buffer(var_T_minimum, [a, b, c, d], 
dtype="float32")
+            for i0, i1, i2, i3 in T.grid(a, b, c, d):
+                with T.block("T_minimum"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], 
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_minimum[ax0, ax1, ax2, ax3])
+                    T_minimum[ax0, ax1, ax2, ax3] = 
T.min(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, 
T.int64(0)])
+    # fmt: on
+
+    mod = LegalizeOps()(Minimum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py 
b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
index 7fdd109cca..d43e9a626b 100644
--- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
+++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
@@ -106,6 +106,8 @@ def test_unary_check(unary_check_op: Callable):
     (relax.op.multiply,),
     (relax.op.power,),
     (relax.op.subtract,),
+    (relax.op.maximum,),
+    (relax.op.minimum,),
 )
 
 


Reply via email to