This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new db96ee80e7 [Unity] Add More Ops For FX Translator (#14348)
db96ee80e7 is described below
commit db96ee80e72281742becd14a6edacd19b2f8a881
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Mar 22 11:00:20 2023 -0400
[Unity] Add More Ops For FX Translator (#14348)
This PR makes 2 changes:
1. Add Relax Op Maximum and Minimum
2. Add translation function for torch function/method silu, to, ones, full,
masked_fill_, mean, rsqrt, neg, max in fx translator
---
python/tvm/relax/frontend/torch/fx_translator.py | 131 +++++++++-
python/tvm/relax/op/binary.py | 37 +++
python/tvm/relax/transform/legalize_ops/binary.py | 3 +
python/tvm/script/ir_builder/relax/ir.py | 4 +
src/relax/op/tensor/binary.cc | 5 +
src/relax/op/tensor/binary.h | 8 +
tests/python/relax/test_frontend_dynamo.py | 137 +++++++++-
tests/python/relax/test_frontend_from_fx.py | 163 +++++++++++-
tests/python/relax/test_op_binary.py | 2 +
.../relax/test_transform_legalize_ops_binary.py | 280 +++++++++++++++++++++
.../relax/test_tvmscript_parser_op_arith_cmp.py | 2 +
11 files changed, 765 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a2e2afe668..ef6793cc67 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -75,6 +75,8 @@ class TorchFXImporter:
return "int64"
elif input_type in ["int32", "torch.int32", torch.int32]:
return "int32"
+ elif input_type in ["bool", "torch.bool", torch.bool]:
+ return "bool"
else:
raise NotImplementedError("input_type {} is not handled
yet".format(input_type))
@@ -151,6 +153,15 @@ class TorchFXImporter:
arg = relax.const(arg, "float32")
return self.block_builder.emit(relax.op.sqrt(arg))
+ def _rsqrt(self, node: fx.node.Node) -> relax.Expr:
+ arg = self.env[node.args[0]]
+ if isinstance(arg, (int, float)):
+ arg = relax.const(arg, "float32")
+ sqrt = self.block_builder.emit(relax.op.sqrt(arg))
+ return self.block_builder.emit(
+ relax.op.divide(relax.const(1, sqrt.struct_info.dtype), sqrt)
+ )
+
def _round(self, node: fx.node.Node) -> relax.Expr:
if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
raise ValueError("specifying decimals for round is not supported
yet")
@@ -161,8 +172,21 @@ class TorchFXImporter:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.add, lhs, rhs)
+ elif isinstance(lhs, relax.expr.Constant):
+ return self._call_binary_op(
+ relax.op.add, lhs, relax.const(rhs,
dtype=lhs.struct_info.dtype)
+ )
+ elif isinstance(rhs, relax.expr.Constant):
+ return self._call_binary_op(
+ relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype),
rhs
+ )
return lhs + rhs
+ def _max(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.maximum, lhs, rhs)
+
def _floordiv(self, node: fx.node.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -181,6 +205,10 @@ class TorchFXImporter:
return self._call_binary_op(relax.op.power, lhs, rhs)
return lhs**rhs
+ def _neg(self, node: fx.node.Node) -> relax.Expr:
+ x = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.negative(x))
+
def _sub(self, node: fx.node.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -279,7 +307,7 @@ class TorchFXImporter:
def _tensor(self, node: fx.node.Node) -> relax.Var:
dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
if isinstance(node.args[0], float):
- return relax.const(node.args[0], dtype if dtype is not None else
"float64")
+ return relax.const(node.args[0], dtype if dtype is not None else
"float32")
elif isinstance(node.args[0], int):
return relax.const(node.args[0], dtype if dtype is not None else
"int64")
raise ValueError("torch.tensor with value not a float or int is not
accepted")
@@ -324,14 +352,65 @@ class TorchFXImporter:
)
)
+ def _ones(self, node: fx.node.Node) -> relax.Var:
+ import torch
+
+ args = self.retrieve_args(node)
+ size = args[0]
+ if not isinstance(size, (list, tuple)):
+ size = (size,)
+ size = relax.ShapeExpr(size)
+ dtype = (
+ TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ if "dtype" in node.kwargs
+ else TorchFXImporter._convert_data_type(torch.get_default_dtype(),
self.env)
+ )
+ return self.block_builder.emit(
+ relax.op.full(
+ size,
+ relax.const(1, dtype),
+ dtype,
+ )
+ )
+
+ def _full(self, node: fx.node.Node) -> relax.Var:
+ import torch
+
+ args = self.retrieve_args(node)
+ size = args[0]
+ if not isinstance(size, (list, tuple)):
+ size = (size,)
+ size = relax.ShapeExpr(size)
+ dtype = (
+ TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ if "dtype" in node.kwargs
+ else TorchFXImporter._convert_data_type(torch.get_default_dtype(),
self.env)
+ )
+ value = args[1] if isinstance(args[1], relax.expr.Constant) else
relax.const(args[1], dtype)
+ return self.block_builder.emit(
+ relax.op.full(
+ size,
+ value,
+ dtype,
+ )
+ )
+
########## Statistical ##########
def _sum(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
+ keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
- return self.block_builder.emit(relax.op.sum(args[0]))
+ return self.block_builder.emit(relax.op.sum(args[0],
keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+ def _mean(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
+ if len(args) == 1:
+ return self.block_builder.emit(relax.op.mean(args[0],
keepdims=keepdim))
+ return self.block_builder.emit(relax.op.mean(args[0], args[1],
keepdims=keepdim))
+
########## DataType ##########
def _float(self, node: fx.node.Node) -> relax.Var:
@@ -345,6 +424,19 @@ class TorchFXImporter:
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
+ def _to(self, node: fx.node.Node) -> relax.Var:
+ import torch
+
+ x = self.env[node.args[0]]
+ if len(node.args) == 2:
+ if isinstance(node.args[1], torch.dtype):
+ dtype = TorchFXImporter._convert_data_type(node.args[1],
self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ elif "dtype" in node.kwargs:
+ dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"],
self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ return x
+
########## Linear Algebra ##########
def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
@@ -500,6 +592,16 @@ class TorchFXImporter:
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
return self.block_builder.emit(relax.op.where(mask, values, x))
+ def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ mask = self.env[node.args[1]]
+ value = node.args[2]
+ rx_value = relax.const(value)
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+ output = self.block_builder.emit(relax.op.where(mask, values, x))
+ self.env[node.args[0]] = output
+ return output
+
########## Search ##########
def _argmax_argmin(self, op: Callable) -> Callable:
@@ -847,6 +949,10 @@ class TorchFXImporter:
expand_dim = []
i = 0
shape = self.shape_of(x)
+ non_ellipsis_cnt = 0
+ for index in node.args[1]:
+ if isinstance(index, (int, slice)):
+ non_ellipsis_cnt += 1
for index in node.args[1]:
if isinstance(index, int):
begin.append(index)
@@ -862,6 +968,13 @@ class TorchFXImporter:
i = i + 1
elif index is None:
expand_dim.append(len(axes) + len(expand_dim))
+ elif index is Ellipsis:
+ for _ in range(len(shape) - non_ellipsis_cnt):
+ begin.append(0)
+ end.append(shape[i])
+ stride.append(1)
+ axes.append(i)
+ i += 1
else:
raise ValueError("Unsupported index type: " +
str(type(index)))
while i < len(shape):
@@ -869,7 +982,7 @@ class TorchFXImporter:
end.append(shape[i])
stride.append(1)
axes.append(i)
- i = i + 1
+ i += 1
sliced = self.block_builder.emit(relax.op.strided_slice(x, axes,
begin, end, stride))
sliced_shape = list(self.shape_of(sliced))
for i in expand_dim:
@@ -957,17 +1070,25 @@ class TorchFXImporter:
"clamp": self._clamp,
"relu": lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
"gelu": lambda node:
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+ "silu": lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
"tanh": lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"interpolate": self._interpolate,
"size": self._size,
"getattr": self._getattr,
"getitem": self._getitem,
"contiguous": lambda node: self.env[node.args[0]],
- "to": lambda node: self.env[node.args[0]],
+ "to": self._to,
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
"layer_norm": self._layer_norm,
"index_select": self._index_select,
"masked_fill": self._masked_fill,
+ "ones": self._ones,
+ "full": self._full,
+ "masked_fill_": self._inplace_masked_fill,
+ "mean": self._mean,
+ "rsqrt": self._rsqrt,
+ "neg": self._neg,
+ "max": self._max,
}
def from_fx(
@@ -1029,7 +1150,7 @@ class TorchFXImporter:
assert len(args) == 1
if (
unwrap_unit_return_tuple
- and isinstance(args[0], (tuple, relax.Tuple))
+ and isinstance(args[0], (tuple, list, relax.Tuple))
and len(args[0]) == 1
):
output = self.block_builder.emit_output(args[0][0])
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index ead59cdf7b..09a0c30f19 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -248,3 +248,40 @@ def not_equal(x1: Expr, x2: Expr) -> Expr:
The computed result.
"""
return _ffi_api.not_equal(x1, x2) # type: ignore
+
+
+###################### Comparison operators ######################
+
+
+def maximum(x1: Expr, x2: Expr) -> Expr:
+ """Element-wise maximum
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.maximum(x1, x2)
+
+
+def minimum(x1: Expr, x2: Expr) -> Expr:
+ """Element-wise minimum
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.minimum(x1, x2)
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index ffda767233..897b676518 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -54,3 +54,6 @@ register_legalize("relax.greater_equal",
_binary(topi.greater_equal))
register_legalize("relax.less", _binary(topi.less))
register_legalize("relax.less_equal", _binary(topi.less_equal))
register_legalize("relax.not_equal", _binary(topi.not_equal))
+
+register_legalize("relax.maximum", _binary(topi.maximum))
+register_legalize("relax.minimum", _binary(topi.minimum))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index ae0918a082..d344891609 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -82,9 +82,11 @@ from tvm.relax.op import (
make_closure,
matmul,
max,
+ maximum,
mean,
memory,
min,
+ minimum,
multiply,
negative,
not_equal,
@@ -596,9 +598,11 @@ __all__ = [
"make_closure",
"matmul",
"max",
+ "maximum",
"mean",
"memory",
"min",
+ "minimum",
"multiply",
"negative",
"not_equal",
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 30cd748308..96d1f01e8a 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -118,5 +118,10 @@ RELAX_REGISTER_CMP_OP_AND_IMPL(less);
RELAX_REGISTER_CMP_OP_AND_IMPL(less_equal);
RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal);
+/***************** Min/Max operators *****************/
+
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(minimum);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(maximum);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 086e37f883..e386f9019f 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -98,6 +98,14 @@ Expr less_equal(Expr x1, Expr x2);
/*! \brief Broadcasted element-wise test for (lhs != rhs). */
Expr not_equal(Expr x1, Expr x2);
+/***************** Min/Max *****************/
+
+/*! \brief Element-wise minimum */
+Expr minimum(Expr x1, Expr x2);
+
+/*! \brief Element-wise maximum */
+Expr maximum(Expr x1, Expr x2);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index 765ca9b6f0..72ea193a02 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -25,7 +25,9 @@ import tvm.testing
import torch
import torch._dynamo as dynamo
from tvm.relax.frontend.torch import relax_dynamo
-from tvm.script.parser import relax as R, tir as T
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
def test_relax_dynamo():
@@ -230,5 +232,138 @@ def test_subgraph_capture():
tvm.ir.assert_structural_equal(mod, expected)
+def verify_dynamo_model(torch_model, input_info, binding, expected):
+ import torch
+ import torch._dynamo as dynamo
+ from tvm.relax.frontend.torch import from_fx
+
+ args = []
+ for info in input_info:
+ args.append(torch.zeros(*info[0], dtype=_convert_data_type(info[1])))
+ graph_model = dynamo.export(torch_model, *args)[0]
+ mod = from_fx(graph_model, input_info, unwrap_unit_return_tuple=True)
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+ expected = relax.transform.BindParams("main", binding)(expected)
+ tvm.ir.assert_structural_equal(mod, expected)
+
+
+def _convert_data_type(input_type):
+ """converts the PyTorch scalar type input_type to a TVM dtype."""
+ import torch # type: ignore
+
+ input_type = input_type.lower() if isinstance(input_type, str) else
input_type
+ if input_type == "float32":
+ return torch.float32
+ elif input_type == "float16":
+ return torch.float16
+ elif input_type == "int64":
+ return torch.int64
+ elif input_type == "int32":
+ return torch.int32
+ elif input_type == "bool":
+ return torch.bool
+ else:
+ raise NotImplementedError("input_type {} is not handled
yet".format(input_type))
+
+
[email protected]_gpu
+def test_ones():
+ import torch
+ from torch.nn import Module
+
+ class Ones(Module):
+ def forward(self, input):
+ return torch.ones((10, 10), dtype=torch.float32)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.full(
+ R.shape([10, 10]), R.const(1, "float32"), dtype="float32"
+ )
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_dynamo_model(
+ Ones(),
+ [([256, 256], "float32")],
+ {},
+ Expected1,
+ )
+
+
[email protected]_gpu
+def test_full():
+ import torch
+ from torch.nn import Module
+
+ class Full(Module):
+ def forward(self, input):
+ return torch.full((10, 10), 1, dtype=torch.float32)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.full(
+ R.shape([10, 10]), R.const(1, "float32"), dtype="float32"
+ )
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_dynamo_model(
+ Full(),
+ [([256, 256], "float32")],
+ {},
+ Expected1,
+ )
+
+
[email protected]_gpu
+def test_masked_fill():
+ import torch
+ from torch.nn import Module
+
+ class MaskedFill(Module):
+ def forward(self, mask, input):
+ return input.masked_fill(mask, 0)
+
+ class InplaceMaskedFill(Module):
+ def forward(self, mask, input):
+ input.masked_fill_(mask, 0)
+ return input
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="bool"), inp_1: R.Tensor((256,
256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = R.full_like(
+ inp_1, R.const(0, "int32"), dtype="void"
+ )
+ lv1: R.Tensor((256, 256), dtype="float32") = R.where(inp_0,
lv, inp_1)
+ gv: R.Tensor((256, 256), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_dynamo_model(
+ MaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], {},
Expected1
+ )
+ verify_dynamo_model(
+ InplaceMaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")],
{}, Expected1
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2e69795d51..d201cb111c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -19,10 +19,13 @@ import pytest
import tvm
from tvm import relax
import tvm.testing
-from tvm.script.parser import ir as I, relax as R, tir as T
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
def verify_model(torch_model, input_info, binding, expected):
+ import torch
from torch import fx
from tvm.relax.frontend.torch import from_fx
@@ -831,6 +834,10 @@ def test_silu():
def forward(self, input):
return self.silu(input)
+ class SiLU2(Module):
+ def forward(self, input):
+ return torch.nn.functional.silu(input)
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -845,6 +852,7 @@ def test_silu():
return gv
verify_model(SiLU(), input_info, {}, expected1)
+ verify_model(SiLU2(), input_info, {}, expected1)
@tvm.testing.requires_gpu
@@ -2496,5 +2504,158 @@ def test_argmin():
verify_model(Argmin2(), [([256, 256], "float32")], {}, Expected2)
[email protected]_gpu
+def test_to():
+ import torch
+ from torch.nn import Module
+
+ class To1(Module):
+ def forward(self, input):
+ return input.to(torch.float16)
+
+ class To2(Module):
+ def forward(self, input):
+ return input.to("cpu")
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float16") = R.astype(inp_0,
dtype="float16")
+ gv: R.Tensor((256, 256), dtype="float16") = lv
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((256, 256), dtype="float32") = inp_0
+ R.output(gv)
+ return gv
+
+ verify_model(To1(), [([256, 256], "float32")], {}, Expected1)
+ verify_model(To2(), [([256, 256], "float32")], {}, Expected2)
+
+
[email protected]_gpu
+def test_mean():
+ import torch
+ from torch.nn import Module
+
+ class Mean(Module):
+ def forward(self, input):
+ return input.mean(-1)
+
+ class MeanKeepDim(Module):
+ def forward(self, input):
+ return input.mean(-1, keepdim=True)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(inp_0: R.Tensor((256, 256), dtype="float32")) ->
R.Tensor((256,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0,
axis=[-1], keepdims=False)
+ gv: R.Tensor((256,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 1), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0,
axis=[-1], keepdims=True)
+ gv: R.Tensor((256, 1), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Mean(), [([256, 256], "float32")], {}, Expected1)
+ verify_model(MeanKeepDim(), [([256, 256], "float32")], {}, Expected2)
+
+
[email protected]_gpu
+def test_rsqrt():
+ import torch
+ from torch.nn import Module
+
+ class Rsqrt(Module):
+ def forward(self, input):
+ return torch.rsqrt(input)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = R.sqrt(inp_0)
+ lv1: R.Tensor((256, 256), dtype="float32") =
R.divide(R.const(1, "float32"), lv)
+ gv: R.Tensor((256, 256), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1)
+
+
[email protected]_gpu
+def test_neg():
+ import torch
+ from torch.nn import Module
+
+ class Neg(Module):
+ def forward(self, input):
+ return -input
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = R.negative(inp_0)
+ gv: R.Tensor((256, 256), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Neg(), [([256, 256], "float32")], {}, Expected1)
+
+
[email protected]_gpu
+def test_max():
+ import torch
+ from torch.nn import Module
+
+ class Max(Module):
+ def forward(self, x, y):
+ return torch.max(x, y)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32"),
+ inp_1: R.Tensor((256, 256), dtype="float32"),
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = R.maximum(inp_0,
inp_1)
+ gv: R.Tensor((256, 256), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")],
{}, Expected1)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 56263bc4ee..809fe7e98f 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -54,6 +54,8 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
(relax.op.multiply,),
(relax.op.power,),
(relax.op.subtract,),
+ (relax.op.maximum,),
+ (relax.op.minimum,),
)
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py
b/tests/python/relax/test_transform_legalize_ops_binary.py
index 5847413713..dc14a0c3fd 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -1327,5 +1327,285 @@ def test_not_equal_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_maximum():
+ # fmt: off
+ @tvm.script.ir_module
+ class Maximum:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1),
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+ gv: R.Tensor((4, 3, 2, 3), "float32") = R.maximum(x, y)
+ return gv
+
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1),
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+ gv = R.call_tir(Expected.maximum, (x, y), R.Tensor((4, 3, 2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def maximum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2),
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3),
T.int64(2), T.int64(1)), "float32"), T_maximum: T.Buffer((T.int64(4),
T.int64(3), T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2),
T.int64(3)):
+ with T.block("T_maximum"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_maximum[ax0, ax1, ax2, ax3])
+ T_maximum[ax0, ax1, ax2, ax3] =
T.max(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2,
T.int64(0)])
+ # fmt: on
+
+ mod = LegalizeOps()(Maximum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_maximum_with_arg0_constant_scalar():
+ # fmt: off
+ @tvm.script.ir_module
+ class Maximum:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), dtype="float32") = R.maximum(x, R.const(1,
"float32"))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv = R.call_tir(Expected.maximum, (x,), R.Tensor((2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_maximum"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[ax0, ax1])
+ T.writes(T_maximum[ax0, ax1])
+ T_maximum[ax0, ax1] = T.max(rxplaceholder[ax0, ax1],
T.float32(1))
+ # fmt: on
+
+ mod = LegalizeOps()(Maximum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_maximum_with_arg1_constant_scalar():
+ # fmt: off
+ @tvm.script.ir_module
+ class Maximum:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), dtype="float32") = R.maximum(R.const(1,
"float32"), x)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv = R.call_tir(Expected.maximum, (x,), R.Tensor((2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_maximum"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[ax0, ax1])
+ T.writes(T_maximum[ax0, ax1])
+ T_maximum[ax0, ax1] = T.max(T.float32(1),
rxplaceholder[ax0, ax1])
+ # fmt: on
+
+ mod = LegalizeOps()(Maximum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_maximum_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Maximum:
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b",
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv: R.Tensor((a, b, c, d), "float32") = R.maximum(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b",
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv = R.call_tir(Expected.maximum, (x, y), R.Tensor((a, b, c, d),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def maximum(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_T_maximum: T.handle):
+ T.func_attr({"tir.noalias": True})
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c,
d], dtype="float32")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c,
T.int64(1)], dtype="float32")
+ T_maximum = T.match_buffer(var_T_maximum, [a, b, c, d],
dtype="float32")
+ for i0, i1, i2, i3 in T.grid(a, b, c, d):
+ with T.block("T_maximum"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_maximum[ax0, ax1, ax2, ax3])
+ T_maximum[ax0, ax1, ax2, ax3] =
T.max(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2,
T.int64(0)])
+ # fmt: on
+
+ mod = LegalizeOps()(Maximum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum():
+ # fmt: off
+ @tvm.script.ir_module
+ class Minimum:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1),
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+ gv: R.Tensor((4, 3, 2, 3), "float32") = R.minimum(x, y)
+ return gv
+
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1),
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+ gv = R.call_tir(Expected.minimum, (x, y), R.Tensor((4, 3, 2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def minimum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2),
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3),
T.int64(2), T.int64(1)), "float32"), T_minimum: T.Buffer((T.int64(4),
T.int64(3), T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2),
T.int64(3)):
+ with T.block("T_minimum"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_minimum[ax0, ax1, ax2, ax3])
+ T_minimum[ax0, ax1, ax2, ax3] =
T.min(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2,
T.int64(0)])
+ # fmt: on
+
+ mod = LegalizeOps()(Minimum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum_with_arg0_constant_scalar():
+ # fmt: off
+ @tvm.script.ir_module
+ class Minimum:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), dtype="float32") = R.minimum(x, R.const(1,
"float32"))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv = R.call_tir(Expected.minimum, (x,), R.Tensor((2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_minimum"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[ax0, ax1])
+ T.writes(T_minimum[ax0, ax1])
+ T_minimum[ax0, ax1] = T.min(rxplaceholder[ax0, ax1],
T.float32(1))
+ # fmt: on
+
+ mod = LegalizeOps()(Minimum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum_with_arg1_constant_scalar():
+ # fmt: off
+ @tvm.script.ir_module
+ class Minimum:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), dtype="float32") = R.minimum(R.const(1,
"float32"), x)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv = R.call_tir(Expected.minimum, (x,), R.Tensor((2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_minimum"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[ax0, ax1])
+ T.writes(T_minimum[ax0, ax1])
+ T_minimum[ax0, ax1] = T.min(T.float32(1),
rxplaceholder[ax0, ax1])
+ # fmt: on
+
+ mod = LegalizeOps()(Minimum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_minimum_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Minimum:
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b",
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv: R.Tensor((a, b, c, d), "float32") = R.minimum(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b",
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv = R.call_tir(Expected.minimum, (x, y), R.Tensor((a, b, c, d),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def minimum(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_T_minimum: T.handle):
+ T.func_attr({"tir.noalias": True})
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c,
d], dtype="float32")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c,
T.int64(1)], dtype="float32")
+ T_minimum = T.match_buffer(var_T_minimum, [a, b, c, d],
dtype="float32")
+ for i0, i1, i2, i3 in T.grid(a, b, c, d):
+ with T.block("T_minimum"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3],
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_minimum[ax0, ax1, ax2, ax3])
+ T_minimum[ax0, ax1, ax2, ax3] =
T.min(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2,
T.int64(0)])
+ # fmt: on
+
+ mod = LegalizeOps()(Minimum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
index 7fdd109cca..d43e9a626b 100644
--- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
+++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
@@ -106,6 +106,8 @@ def test_unary_check(unary_check_op: Callable):
(relax.op.multiply,),
(relax.op.power,),
(relax.op.subtract,),
+ (relax.op.maximum,),
+ (relax.op.minimum,),
)