This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b073bbefe4 [Unity][Op] Full support of Relax op `power` (#14171)
b073bbefe4 is described below
commit b073bbefe4a539485cfcd4c8c0d3ac099b702813
Author: Chaofan Lin <[email protected]>
AuthorDate: Fri Mar 3 02:25:44 2023 +0800
[Unity][Op] Full support of Relax op `power` (#14171)
This PR provides a full support of `R.power` including op registering,
legalization, overloading `__power__` for Expr and torch fx frontend.
---
python/tvm/relax/expr.py | 6 ++
python/tvm/relax/frontend/torch/fx_translator.py | 7 ++
python/tvm/relax/op/binary.py | 18 +++++
python/tvm/relax/transform/legalize_ops/binary.py | 1 +
python/tvm/script/ir_builder/relax/ir.py | 4 +-
src/relax/op/tensor/binary.cc | 1 +
src/relax/op/tensor/binary.h | 3 +
tests/python/relax/test_frontend_from_fx.py | 47 +++++++++++--
tests/python/relax/test_op_binary.py | 2 +
tests/python/relax/test_op_misc.py | 1 +
.../relax/test_transform_legalize_ops_binary.py | 80 ++++++++++++++++++++++
tests/python/relax/test_tvmscript_parser.py | 6 +-
.../relax/test_tvmscript_parser_op_arith_cmp.py | 1 +
13 files changed, 170 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index a20181e6fc..ab332eed61 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -188,6 +188,12 @@ class ExprWithOp(Expr, Scriptable):
def __rmod__(self, other: Expr) -> "ExprWithOp":
return _binary_rhs_helper(other)
+ def __pow__(self, other: Expr) -> "ExprWithOp":
+ return _binary_op_helper(self, other, _op_ffi_api.power) # type:
ignore
+
+ def __rpow__(self, other: Expr) -> "ExprWithOp":
+ return _binary_rhs_helper(other)
+
def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] =
None) -> "ExprWithOp":
"""Call the variable (if it represents a function).
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4acad61855..e80f73096c 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -158,6 +158,12 @@ class TorchFXImporter:
return self._call_binary_op(relax.op.multiply, lhs, rhs)
return lhs * rhs
+ def _pow(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.power, lhs, rhs)
+ return lhs**rhs
+
def _sub(self, node: fx.node.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -641,6 +647,7 @@ class TorchFXImporter:
"floordiv": self._floordiv,
"mul": self._mul,
"sub": self._sub,
+ "pow": self._pow,
"sqrt": self._sqrt,
"lt": self._lt,
"truediv": self._truediv,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 4042f9bbc9..ead59cdf7b 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -103,6 +103,24 @@ def multiply(x1: Expr, x2: Expr) -> Expr:
return _ffi_api.multiply(x1, x2) # type: ignore
+def power(x1: Expr, x2: Expr):
+ """Power with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.power(x1, x2) # type: ignore
+
+
def subtract(x1: Expr, x2: Expr) -> Expr:
"""Subtraction with numpy-style broadcasting.
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index 55b832021a..ffda767233 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -45,6 +45,7 @@ register_legalize("relax.add", _binary(topi.add))
register_legalize("relax.divide", _binary(topi.divide))
register_legalize("relax.floor_divide", _binary(topi.floor_divide))
register_legalize("relax.multiply", _binary(topi.multiply))
+register_legalize("relax.power", _binary(topi.power))
register_legalize("relax.subtract", _binary(topi.subtract))
register_legalize("relax.equal", _binary(topi.equal))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 045fe9ddd9..8fa55f7495 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -83,9 +83,10 @@ from tvm.relax.op import (
negative,
not_equal,
null_value,
- permute_dims,
ones,
ones_like,
+ permute_dims,
+ power,
print,
prod,
reshape,
@@ -585,6 +586,7 @@ __all__ = [
"ones_like",
"output",
"permute_dims",
+ "power",
"prim_value",
"print",
"prod",
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index b7a07c5202..7e8480ee16 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -84,6 +84,7 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
/***************** Comparison operators *****************/
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b565b159bb..0a48e727e6 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -70,6 +70,9 @@ Expr floor_divide(Expr x1, Expr x2);
/*! \brief Multiplication with numpy-style broadcasting. */
Expr multiply(Expr x1, Expr x2);
+/*! \brief Power with numpy-style broadcasting. */
+Expr power(Expr x1, Expr x2);
+
/*! \brief Subtraction with numpy-style broadcasting. */
Expr subtract(Expr x1, Expr x2);
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index e216010667..137713869e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -982,13 +982,52 @@ def test_binary():
verify_model(FloorDiv1(), input_info1, {}, expected9)
verify_model(FloorDiv2(), input_info2, {}, expected10)
+ # Power
+ class Power1(Module):
+ def forward(self, lhs, rhs):
+ return lhs**rhs
+
+ @tvm.script.ir_module
+ class expected11:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1,
rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Power2(Module):
+ def forward(self, lhs):
+ return lhs**1.0
+
+ @tvm.script.ir_module
+ class expected12:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1,
R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Power1(), input_info1, {}, expected11)
+ verify_model(Power2(), input_info2, {}, expected12)
+
# LT
class LT1(Module):
def forward(self, lhs, rhs):
return lhs < rhs
@tvm.script.ir_module
- class expected11:
+ class expected13:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
@@ -1006,7 +1045,7 @@ def test_binary():
return lhs < 1.0
@tvm.script.ir_module
- class expected12:
+ class expected14:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
@@ -1018,8 +1057,8 @@ def test_binary():
R.output(gv)
return gv
- verify_model(LT1(), input_info1, {}, expected11)
- verify_model(LT2(), input_info2, {}, expected12)
+ verify_model(LT1(), input_info1, {}, expected13)
+ verify_model(LT2(), input_info2, {}, expected14)
@tvm.testing.requires_gpu
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index a4ae8ce31a..56263bc4ee 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -31,6 +31,7 @@ def test_op_correctness():
assert relax.op.divide(x, y).op == Op.get("relax.divide")
assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide")
assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
+ assert relax.op.power(x, y).op == Op.get("relax.power")
assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
assert relax.op.equal(x, y).op == Op.get("relax.equal")
@@ -51,6 +52,7 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
(relax.op.divide,),
(relax.op.floor_divide,),
(relax.op.multiply,),
+ (relax.op.power,),
(relax.op.subtract,),
)
diff --git a/tests/python/relax/test_op_misc.py
b/tests/python/relax/test_op_misc.py
index 65772baadf..523a628fa9 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -68,6 +68,7 @@ def test_implicit_op():
_check_call(x * y, "multiply")
_check_call(x / y, "divide")
_check_call(x // y, "floor_divide")
+ _check_call(x**y, "power")
# _check_call(x % y, "mod") <= relax.mod is not implemented yet
# Cast
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py
b/tests/python/relax/test_transform_legalize_ops_binary.py
index c99fb885c4..79b4e8a8a7 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -519,6 +519,86 @@ def test_multiply_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_power():
+ # fmt: off
+ @tvm.script.ir_module
+ class Power:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1),
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+ gv: R.Tensor((4, 3, 2, 3), "float32") = R.power(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def power(rxplaceholder: T.Buffer((T.int64(1), T.int64(2),
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3),
T.int64(2), T.int64(1)), "float32"), T_power: T.Buffer((T.int64(4), T.int64(3),
T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3),
T.int64(2), T.int64(3)):
+ with T.block("T_power"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3],
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+ T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_power[v_ax0, v_ax1, v_ax2, v_ax3] =
T.pow(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1,
v_ax2, T.int64(0)])
+
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3,
2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"):
+ gv = R.call_tir(power, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3),
dtype="float32"))
+ return gv
+
+ # fmt: on
+
+ mod = LegalizeOps()(Power)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_power_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Power:
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b",
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv: R.Tensor((a, b, c, d), "float32") = R.power(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def power(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle,
var_T_power: T.handle):
+ T.func_attr({"tir.noalias": True})
+ c = T.int64()
+ d = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c,
d))
+ a = T.int64()
+ b = T.int64()
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c,
T.int64(1)))
+ T_power = T.match_buffer(var_T_power, (a, b, c, d))
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
+ with T.block("T_power"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3],
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+ T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_power[v_ax0, v_ax1, v_ax2, v_ax3] =
T.pow(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1,
v_ax2, T.int64(0)])
+
+ @R.function
+ def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y:
R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c",
"d"), dtype="float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ gv = R.call_tir(power, (x, y), out_sinfo=R.Tensor((a, b, c, d),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Expected)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_subtract():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 9636a98b41..b885697c73 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1023,6 +1023,7 @@ def test_arith_operators():
a3 = x * y
a4 = x / y
a5 = x // y
+ a6 = x**y
c0 = x > y
c1 = x < y
@@ -1033,7 +1034,7 @@ def test_arith_operators():
t0 = tuple_expr[0]
t1 = tuple_expr[1]
t2 = tuple_expr[0][0] # <= Will normalize to two bindings
- return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2
+ return a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
@@ -1047,6 +1048,7 @@ def test_arith_operators():
a3 = bb.emit(relax.op.multiply(x, y))
a4 = bb.emit(relax.op.divide(x, y))
a5 = bb.emit(relax.op.floor_divide(x, y))
+ a6 = bb.emit(relax.op.power(x, y))
c0 = bb.emit(relax.op.greater(x, y))
c1 = bb.emit(relax.op.less(x, y))
@@ -1058,7 +1060,7 @@ def test_arith_operators():
t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1))
tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0))
t2 = bb.emit(relax.TupleGetItem(tmp, 0))
- bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2,
c3, t0, t1, t2)))
+ bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, a6, c0, c1,
c2, c3, t0, t1, t2)))
_check(foo, bb.get()["foo"])
diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
index ffb8576b27..7fdd109cca 100644
--- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
+++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
@@ -104,6 +104,7 @@ def test_unary_check(unary_check_op: Callable):
(relax.op.divide,),
(relax.op.floor_divide,),
(relax.op.multiply,),
+ (relax.op.power,),
(relax.op.subtract,),
)