This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7c28c86f7d [Relax][PyTorch] Support binary, statistical and search ops
for ExportedProgram importer (#17424)
7c28c86f7d is described below
commit 7c28c86f7d3121ce2adc179475fdb1922c86b942
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Sep 28 22:30:15 2024 +0900
[Relax][PyTorch] Support binary, statistical and search ops for
ExportedProgram importer (#17424)
* support binary ops
* support mean
* support sum
* support argmax and argmin
---
.../frontend/torch/base_fx_graph_translator.py | 62 +++
.../frontend/torch/exported_program_translator.py | 25 +
python/tvm/relax/frontend/torch/fx_translator.py | 62 ---
.../relax/test_frontend_from_exported_program.py | 512 +++++++++++++++++++++
4 files changed, 599 insertions(+), 62 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d52b3d598f..a41b9b6d4f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -185,6 +185,39 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ ########## Binary Ops ##########
+
+ def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) ->
Callable:
+ from torch import fx
+
+ def convert(node: fx.Node) -> relax.Var:
+ def promote_binary_op_args(lhs, rhs):
+ if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+ return lhs, rhs
+ elif isinstance(lhs, relax.Expr):
+ assert isinstance(lhs.struct_info, relax.TensorStructInfo)
+ return lhs, relax.const(rhs, lhs.struct_info.dtype)
+ elif isinstance(rhs, relax.Expr):
+ assert isinstance(rhs.struct_info, relax.TensorStructInfo)
+ return relax.const(lhs, rhs.struct_info.dtype), rhs
+ else:
+ assert False
+
+ def call_binary_op(op, lhs, rhs):
+ lhs, rhs = promote_binary_op_args(lhs, rhs)
+ return self.block_builder.emit(op(lhs, rhs))
+
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return call_binary_op(relax_op, lhs, rhs)
+ elif isinstance(lhs, relax.expr.Constant):
+ return call_binary_op(relax_op, lhs, relax.const(rhs,
dtype=lhs.struct_info.dtype))
+ elif isinstance(rhs, relax.expr.Constant):
+ return call_binary_op(relax_op, relax.const(lhs,
dtype=rhs.struct_info.dtype), rhs)
+ return intrinsic_op(lhs, rhs)
+
+ return convert
+
########## Neural Network ##########
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
@@ -283,6 +316,35 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ ########## Statistical ##########
+
+ def _mean(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
+
+ def _sum(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
+ if len(args) == 1:
+ return self.block_builder.emit(relax.op.sum(args[0],
keepdims=keepdim))
+ return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+
+ ########## Search ##########
+
+ def _argmax_argmin(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.Node):
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else
node.kwargs.get("dim", None)
+ keepdim = node.args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(op(x, dim, keepdim))
+
+ return convert
+
########## Manipulation ##########
def _reshape(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1ceddad7d7..11594690cd 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel
"""PyTorch ExportedProgram of Relax."""
from collections import ChainMap, OrderedDict
+from functools import partial
from typing import Callable, Dict, List, Tuple
import torch
@@ -76,6 +77,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
+ import operator
+
return {
# unary
"acos.default": self._unary_op(relax.op.acos),
@@ -109,11 +112,33 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"tanh.default": self._unary_op(relax.op.tanh),
"tril.default": self._tril_triu(relax.op.tril),
"triu.default": self._tril_triu(relax.op.triu),
+ # binary
+ "add.Tensor": self._binary_op(relax.op.add, operator.add),
+ "div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
+ "eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
+ "eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
+ "floor_divide.default": self._binary_op(relax.op.floor_divide,
operator.floordiv),
+ "lt.Scalar": self._binary_op(relax.op.less, operator.lt),
+ "lt.Tensor": self._binary_op(relax.op.less, operator.lt),
+ "matmul.default": self._binary_op(
+ partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
+ ),
+ "max.other": self._binary_op(relax.op.maximum, max),
+ "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
+ "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
+ "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
+ "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
# neural network
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"conv2d.default": self._conv2d,
"linear.default": self._linear,
"max_pool2d.default": self._max_pool2d,
+ # statistical
+ "mean.dim": self._mean,
+ "sum.dim_IntList": self._sum,
+ # search
+ "argmax.default": self._argmax_argmin(relax.op.argmax),
+ "argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"view.default": self._reshape,
}
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6f7c6fa2c5..dc6ebc2eb3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -96,39 +96,6 @@ class TorchFXImporter(BaseFXGraphImporter):
return convert
- ########## Binary Ops ##########
-
- def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) ->
Callable:
- from torch import fx
-
- def convert(node: fx.Node) -> relax.Var:
- def promote_binary_op_args(lhs, rhs):
- if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
- return lhs, rhs
- elif isinstance(lhs, relax.Expr):
- assert isinstance(lhs.struct_info, relax.TensorStructInfo)
- return lhs, relax.const(rhs, lhs.struct_info.dtype)
- elif isinstance(rhs, relax.Expr):
- assert isinstance(rhs.struct_info, relax.TensorStructInfo)
- return relax.const(lhs, rhs.struct_info.dtype), rhs
- else:
- assert False
-
- def call_binary_op(op, lhs, rhs):
- lhs, rhs = promote_binary_op_args(lhs, rhs)
- return self.block_builder.emit(op(lhs, rhs))
-
- lhs, rhs = self.retrieve_args(node)
- if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
- return call_binary_op(relax_op, lhs, rhs)
- elif isinstance(lhs, relax.expr.Constant):
- return call_binary_op(relax_op, lhs, relax.const(rhs,
dtype=lhs.struct_info.dtype))
- elif isinstance(rhs, relax.expr.Constant):
- return call_binary_op(relax_op, relax.const(lhs,
dtype=rhs.struct_info.dtype), rhs)
- return intrinsic_op(lhs, rhs)
-
- return convert
-
########## Neural Network ##########
def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
@@ -794,35 +761,6 @@ class TorchFXImporter(BaseFXGraphImporter):
ret.append(self.block_builder.emit(relax.op.squeeze(split[i],
axis=dim)))
return self.block_builder.emit(relax.Tuple(ret))
- ########## Statistical ##########
-
- def _mean(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
- keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
- return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
-
- def _sum(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
- if len(args) == 1:
- return self.block_builder.emit(relax.op.sum(args[0],
keepdims=keepdim))
- return self.block_builder.emit(relax.op.sum(args[0], args[1]))
-
- ########## Search ##########
-
- def _argmax_argmin(self, op: Callable) -> Callable:
- from torch import fx
-
- def convert(node: fx.Node):
- x = self.env[node.args[0]]
- dim = node.args[1] if len(node.args) > 1 else
node.kwargs.get("dim", None)
- keepdim = node.args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
- return self.block_builder.emit(op(x, dim, keepdim))
-
- return convert
-
########## Manipulation ##########
def _cat(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 6c17d96004..25e6dbfae3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -790,6 +790,372 @@ def test_tril_triu():
verify_model(Triu(), example_args, {}, expected_triu)
+def test_binary():
+ example_args1 = (
+ torch.randn(10, 10, dtype=torch.float32),
+ torch.randn(10, 10, dtype=torch.float32),
+ )
+ example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
+
+ # Add
+ class Add1(Module):
+ def forward(self, lhs, rhs):
+ return lhs + rhs
+
+ @tvm.script.ir_module
+ class expected_add1:
+ @R.function
+ def main(
+ lhs: R.Tensor((10, 10), dtype="float32"),
+ rhs: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs, rhs)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class Add2(Module):
+ def forward(self, lhs):
+ return lhs + 1.0
+
+ @tvm.script.ir_module
+ class expected_add2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Add1(), example_args1, {}, expected_add1)
+ verify_model(Add2(), example_args2, {}, expected_add2)
+
+ # True div
+ class TrueDiv1(Module):
+ def forward(self, lhs, rhs):
+ return lhs / rhs
+
+ @tvm.script.ir_module
+ class expected_truediv1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1,
rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class TrueDiv2(Module):
+ def forward(self, lhs):
+ return lhs / 1.0
+
+ @tvm.script.ir_module
+ class expected_truediv2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(TrueDiv1(), example_args1, {}, expected_truediv1)
+ verify_model(TrueDiv2(), example_args2, {}, expected_truediv2)
+
+ # EQ
+ class EQ1(Module):
+ def forward(self, lhs, rhs):
+ return lhs == rhs
+
+ @tvm.script.ir_module
+ class expected_eq1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class EQ2(Module):
+ def forward(self, lhs):
+ return lhs == 1.0
+
+ @tvm.script.ir_module
+ class expected_eq2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(EQ1(), example_args1, {}, expected_eq1)
+ verify_model(EQ2(), example_args2, {}, expected_eq2)
+
+ # Floor div
+ class FloorDiv1(Module):
+ def forward(self, lhs, rhs):
+ return lhs // rhs
+
+ @tvm.script.ir_module
+ class expected_floordiv1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") =
R.floor_divide(lhs_1, rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class FloorDiv2(Module):
+ def forward(self, lhs):
+ return lhs // 1.0
+
+ @tvm.script.ir_module
+ class expected_floordiv2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") =
R.floor_divide(lhs_1, R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(FloorDiv1(), example_args1, {}, expected_floordiv1)
+ verify_model(FloorDiv2(), example_args2, {}, expected_floordiv2)
+
+ # LT
+ class LT1(Module):
+ def forward(self, lhs, rhs):
+ return lhs < rhs
+
+ @tvm.script.ir_module
+ class expected_lt1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class LT2(Module):
+ def forward(self, lhs):
+ return lhs < 1.0
+
+ @tvm.script.ir_module
+ class expected_lt2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(LT1(), example_args1, {}, expected_lt1)
+ verify_model(LT2(), example_args2, {}, expected_lt2)
+
+ # MatMul
+ class MatMul1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.matmul(x, y)
+
+ @tvm.script.ir_module
+ class expected_matmul1:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32"),
+ input_2: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(
+ input_1, input_2, out_dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(MatMul1(), example_args1, {}, expected_matmul1)
+
+ # Max
+ class Max1(Module):
+ def forward(self, x, y):
+ return torch.max(x, y)
+
+ @I.ir_module
+ class expected_max1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((10, 10), dtype="float32"),
+ inp_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.maximum(inp_0,
inp_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Max1(), example_args1, {}, expected_max1)
+
+ # Mul
+ class Mul1(Module):
+ def forward(self, lhs, rhs):
+ return lhs * rhs
+
+ @tvm.script.ir_module
+ class expected_mul1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1,
rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class Mul2(Module):
+ def forward(self, lhs):
+ return lhs * 1.0
+
+ @tvm.script.ir_module
+ class expected_mul2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Mul1(), example_args1, {}, expected_mul1)
+ verify_model(Mul2(), example_args2, {}, expected_mul2)
+
+ # Power
+ class Power1(Module):
+ def forward(self, lhs, rhs):
+ return lhs**rhs
+
+ @tvm.script.ir_module
+ class expected_power1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class Power2(Module):
+ def forward(self, lhs):
+ return lhs**1.0
+
+ @tvm.script.ir_module
+ class expected_power2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Power1(), example_args1, {}, expected_power1)
+ verify_model(Power2(), example_args2, {}, expected_power2)
+
+ # Sub
+ class Sub1(Module):
+ def forward(self, lhs, rhs):
+ return lhs - rhs
+
+ @tvm.script.ir_module
+ class expected_sub1:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ rhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1,
rhs_1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class Sub2(Module):
+ def forward(self, lhs):
+ return lhs - 1.0
+
+ @tvm.script.ir_module
+ class expected_sub2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1,
R.const(1.0))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Sub1(), example_args1, {}, expected_sub1)
+ verify_model(Sub2(), example_args2, {}, expected_sub2)
+
+
def test_adaptive_avgpool2d():
class AdaptiveAvgPool2d0(Module):
def __init__(self):
@@ -1094,6 +1460,152 @@ def test_maxpool2d():
verify_model(MaxPool2d3(), example_args, {}, expected3)
+def test_mean():
+ class Mean(Module):
+ def forward(self, input):
+ return input.mean(-1)
+
+ class MeanKeepDim(Module):
+ def forward(self, input: torch.Tensor):
+ return input.mean(-1, keepdim=True)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((256,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0,
axis=[-1], keepdims=False)
+ gv: R.Tuple(R.Tensor((256,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0,
axis=[-1], keepdims=True)
+ gv: R.Tuple(R.Tensor((256, 1), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(256, 256, dtype=torch.float32),)
+ verify_model(Mean(), example_args, {}, Expected1)
+ verify_model(MeanKeepDim(), example_args, {}, Expected2)
+
+
+def test_sum():
+ class Sum(Module):
+ def forward(self, x):
+ return torch.sum(x, (2, 1))
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2,
1], keepdims=False)
+ gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(Sum(), example_args, {}, expected1)
+
+
+def test_argmax_argmin():
+ example_args = (torch.randn(256, 256, dtype=torch.float32),)
+
+ class Argmax1(Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, input):
+ return torch.argmax(input, dim=-1)
+
+ class Argmax2(Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, input):
+ return torch.argmax(input, dim=-1, keepdim=True)
+
+ @tvm.script.ir_module
+ class expected_argmax1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((256,), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1,
keepdims=False)
+ gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_argmax2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0,
axis=-1, keepdims=True)
+ gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Argmax1(), example_args, {}, expected_argmax1)
+ verify_model(Argmax2(), example_args, {}, expected_argmax2)
+
+ class Argmin1(Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, input):
+ return torch.argmin(input)
+
+ class Argmin2(Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, input):
+ return torch.argmin(input, keepdim=True)
+
+ @tvm.script.ir_module
+ class expected_argmin1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None,
keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_argmin2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0,
axis=None, keepdims=True)
+ gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Argmin1(), example_args, {}, expected_argmin1)
+ verify_model(Argmin2(), example_args, {}, expected_argmin2)
+
+
def test_view():
class View(Module):
def forward(self, x):