This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b073bbefe4 [Unity][Op] Full support of Relax op `power` (#14171)
b073bbefe4 is described below

commit b073bbefe4a539485cfcd4c8c0d3ac099b702813
Author: Chaofan Lin <[email protected]>
AuthorDate: Fri Mar 3 02:25:44 2023 +0800

    [Unity][Op] Full support of Relax op `power` (#14171)
    
    This PR provides a full support of `R.power` including op registering, 
legalization, overloading `__power__` for Expr and torch fx frontend.
---
 python/tvm/relax/expr.py                           |  6 ++
 python/tvm/relax/frontend/torch/fx_translator.py   |  7 ++
 python/tvm/relax/op/binary.py                      | 18 +++++
 python/tvm/relax/transform/legalize_ops/binary.py  |  1 +
 python/tvm/script/ir_builder/relax/ir.py           |  4 +-
 src/relax/op/tensor/binary.cc                      |  1 +
 src/relax/op/tensor/binary.h                       |  3 +
 tests/python/relax/test_frontend_from_fx.py        | 47 +++++++++++--
 tests/python/relax/test_op_binary.py               |  2 +
 tests/python/relax/test_op_misc.py                 |  1 +
 .../relax/test_transform_legalize_ops_binary.py    | 80 ++++++++++++++++++++++
 tests/python/relax/test_tvmscript_parser.py        |  6 +-
 .../relax/test_tvmscript_parser_op_arith_cmp.py    |  1 +
 13 files changed, 170 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index a20181e6fc..ab332eed61 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -188,6 +188,12 @@ class ExprWithOp(Expr, Scriptable):
     def __rmod__(self, other: Expr) -> "ExprWithOp":
         return _binary_rhs_helper(other)
 
+    def __pow__(self, other: Expr) -> "ExprWithOp":
+        return _binary_op_helper(self, other, _op_ffi_api.power)  # type: 
ignore
+
+    def __rpow__(self, other: Expr) -> "ExprWithOp":
+        return _binary_rhs_helper(other)
+
     def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] = 
None) -> "ExprWithOp":
         """Call the variable (if it represents a function).
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4acad61855..e80f73096c 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -158,6 +158,12 @@ class TorchFXImporter:
             return self._call_binary_op(relax.op.multiply, lhs, rhs)
         return lhs * rhs
 
+    def _pow(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.power, lhs, rhs)
+        return lhs**rhs
+
     def _sub(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -641,6 +647,7 @@ class TorchFXImporter:
             "floordiv": self._floordiv,
             "mul": self._mul,
             "sub": self._sub,
+            "pow": self._pow,
             "sqrt": self._sqrt,
             "lt": self._lt,
             "truediv": self._truediv,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 4042f9bbc9..ead59cdf7b 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -103,6 +103,24 @@ def multiply(x1: Expr, x2: Expr) -> Expr:
     return _ffi_api.multiply(x1, x2)  # type: ignore
 
 
+def power(x1: Expr, x2: Expr):
+    """Power with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.power(x1, x2)  # type: ignore
+
+
 def subtract(x1: Expr, x2: Expr) -> Expr:
     """Subtraction with numpy-style broadcasting.
 
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index 55b832021a..ffda767233 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -45,6 +45,7 @@ register_legalize("relax.add", _binary(topi.add))
 register_legalize("relax.divide", _binary(topi.divide))
 register_legalize("relax.floor_divide", _binary(topi.floor_divide))
 register_legalize("relax.multiply", _binary(topi.multiply))
+register_legalize("relax.power", _binary(topi.power))
 register_legalize("relax.subtract", _binary(topi.subtract))
 register_legalize("relax.equal", _binary(topi.equal))
 
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 045fe9ddd9..8fa55f7495 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -83,9 +83,10 @@ from tvm.relax.op import (
     negative,
     not_equal,
     null_value,
-    permute_dims,
     ones,
     ones_like,
+    permute_dims,
+    power,
     print,
     prod,
     reshape,
@@ -585,6 +586,7 @@ __all__ = [
     "ones_like",
     "output",
     "permute_dims",
+    "power",
     "prim_value",
     "print",
     "prod",
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index b7a07c5202..7e8480ee16 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -84,6 +84,7 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
 
 /***************** Comparison operators *****************/
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b565b159bb..0a48e727e6 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -70,6 +70,9 @@ Expr floor_divide(Expr x1, Expr x2);
 /*! \brief Multiplication with numpy-style broadcasting. */
 Expr multiply(Expr x1, Expr x2);
 
+/*! \brief Power with numpy-style broadcasting. */
+Expr power(Expr x1, Expr x2);
+
 /*! \brief Subtraction with numpy-style broadcasting. */
 Expr subtract(Expr x1, Expr x2);
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index e216010667..137713869e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -982,13 +982,52 @@ def test_binary():
     verify_model(FloorDiv1(), input_info1, {}, expected9)
     verify_model(FloorDiv2(), input_info2, {}, expected10)
 
+    # Power
+    class Power1(Module):
+        def forward(self, lhs, rhs):
+            return lhs**rhs
+
+    @tvm.script.ir_module
+    class expected11:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1, 
rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Power2(Module):
+        def forward(self, lhs):
+            return lhs**1.0
+
+    @tvm.script.ir_module
+    class expected12:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1, 
R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Power1(), input_info1, {}, expected11)
+    verify_model(Power2(), input_info2, {}, expected12)
+
     # LT
     class LT1(Module):
         def forward(self, lhs, rhs):
             return lhs < rhs
 
     @tvm.script.ir_module
-    class expected11:
+    class expected13:
         @R.function
         def main(
             lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
@@ -1006,7 +1045,7 @@ def test_binary():
             return lhs < 1.0
 
     @tvm.script.ir_module
-    class expected12:
+    class expected14:
         @R.function
         def main(
             lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
@@ -1018,8 +1057,8 @@ def test_binary():
                 R.output(gv)
             return gv
 
-    verify_model(LT1(), input_info1, {}, expected11)
-    verify_model(LT2(), input_info2, {}, expected12)
+    verify_model(LT1(), input_info1, {}, expected13)
+    verify_model(LT2(), input_info2, {}, expected14)
 
 
 @tvm.testing.requires_gpu
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index a4ae8ce31a..56263bc4ee 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -31,6 +31,7 @@ def test_op_correctness():
     assert relax.op.divide(x, y).op == Op.get("relax.divide")
     assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide")
     assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
+    assert relax.op.power(x, y).op == Op.get("relax.power")
     assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
 
     assert relax.op.equal(x, y).op == Op.get("relax.equal")
@@ -51,6 +52,7 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     (relax.op.divide,),
     (relax.op.floor_divide,),
     (relax.op.multiply,),
+    (relax.op.power,),
     (relax.op.subtract,),
 )
 
diff --git a/tests/python/relax/test_op_misc.py 
b/tests/python/relax/test_op_misc.py
index 65772baadf..523a628fa9 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -68,6 +68,7 @@ def test_implicit_op():
     _check_call(x * y, "multiply")
     _check_call(x / y, "divide")
     _check_call(x // y, "floor_divide")
+    _check_call(x**y, "power")
     # _check_call(x % y, "mod") <= relax.mod is not implemented yet
 
     # Cast
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py 
b/tests/python/relax/test_transform_legalize_ops_binary.py
index c99fb885c4..79b4e8a8a7 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -519,6 +519,86 @@ def test_multiply_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_power():
+    # fmt: off
+    @tvm.script.ir_module
+    class Power:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), 
"float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
+            gv: R.Tensor((4, 3, 2, 3), "float32") = R.power(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def power(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), 
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), 
T.int64(2), T.int64(1)), "float32"), T_power: T.Buffer((T.int64(4), T.int64(3), 
T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), 
T.int64(2), T.int64(3)):
+                with T.block("T_power"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], 
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+                    T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_power[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.pow(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, 
v_ax2, T.int64(0)])
+
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 
2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"):
+            gv = R.call_tir(power, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3), 
dtype="float32"))
+            return gv
+
+    # fmt: on
+
+    mod = LegalizeOps()(Power)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_power_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Power:
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", 
"c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv: R.Tensor((a, b, c, d), "float32") = R.power(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def power(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, 
var_T_power: T.handle):
+            T.func_attr({"tir.noalias": True})
+            c = T.int64()
+            d = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c, 
d))
+            a = T.int64()
+            b = T.int64()
+            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c, 
T.int64(1)))
+            T_power = T.match_buffer(var_T_power, (a, b, c, d))
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
+                with T.block("T_power"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], 
rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
+                    T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_power[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.pow(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, 
v_ax2, T.int64(0)])
+
+        @R.function
+        def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: 
R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c", 
"d"), dtype="float32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            d = T.int64()
+            gv = R.call_tir(power, (x, y), out_sinfo=R.Tensor((a, b, c, d), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Expected)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_subtract():
     # fmt: off
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 9636a98b41..b885697c73 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1023,6 +1023,7 @@ def test_arith_operators():
         a3 = x * y
         a4 = x / y
         a5 = x // y
+        a6 = x**y
 
         c0 = x > y
         c1 = x < y
@@ -1033,7 +1034,7 @@ def test_arith_operators():
         t0 = tuple_expr[0]
         t1 = tuple_expr[1]
         t2 = tuple_expr[0][0]  # <= Will normalize to two bindings
-        return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2
+        return a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2
 
     m = tir.Var("m", "int64")
     n = tir.Var("n", "int64")
@@ -1047,6 +1048,7 @@ def test_arith_operators():
         a3 = bb.emit(relax.op.multiply(x, y))
         a4 = bb.emit(relax.op.divide(x, y))
         a5 = bb.emit(relax.op.floor_divide(x, y))
+        a6 = bb.emit(relax.op.power(x, y))
 
         c0 = bb.emit(relax.op.greater(x, y))
         c1 = bb.emit(relax.op.less(x, y))
@@ -1058,7 +1060,7 @@ def test_arith_operators():
         t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1))
         tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0))
         t2 = bb.emit(relax.TupleGetItem(tmp, 0))
-        bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2, 
c3, t0, t1, t2)))
+        bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, a6, c0, c1, 
c2, c3, t0, t1, t2)))
 
     _check(foo, bb.get()["foo"])
 
diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py 
b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
index ffb8576b27..7fdd109cca 100644
--- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
+++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
@@ -104,6 +104,7 @@ def test_unary_check(unary_check_op: Callable):
     (relax.op.divide,),
     (relax.op.floor_divide,),
     (relax.op.multiply,),
+    (relax.op.power,),
     (relax.op.subtract,),
 )
 

Reply via email to