This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a86a6cbb61 [Unity] Fix Unary Op Legalization (#14789)
a86a6cbb61 is described below

commit a86a6cbb618605ad9be0c9c78b205611944c8ac8
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon May 8 10:28:33 2023 +0800

    [Unity] Fix Unary Op Legalization (#14789)
    
    This PR adds support for unary ops legalization, which is missing currently.
    And clean up the test cases.
---
 python/tvm/relax/transform/legalize_ops/unary.py   |   13 +-
 python/tvm/script/parser/core/parser.py            |   15 +
 .../relax/test_transform_legalize_ops_unary.py     | 1251 +-------------------
 3 files changed, 86 insertions(+), 1193 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/unary.py 
b/python/tvm/relax/transform/legalize_ops/unary.py
index 9d19abe16e..104a874679 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -21,18 +21,27 @@ from .common import _call_topi_without_attr, 
register_legalize
 # To avoid conflict of IRModule function name and libc function name, we add
 # "tir_" as the prefix of the generated PrimFunc name.
 register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs"))
+register_legalize("relax.acos", _call_topi_without_attr(topi.acos, "tir_acos"))
+register_legalize("relax.acosh", _call_topi_without_attr(topi.acosh, 
"tir_acosh"))
+register_legalize("relax.asin", _call_topi_without_attr(topi.asin, "tir_asin"))
+register_legalize("relax.asinh", _call_topi_without_attr(topi.asinh, 
"tir_asinh"))
+register_legalize("relax.atan", _call_topi_without_attr(topi.atan, "tir_atan"))
+register_legalize("relax.atanh", _call_topi_without_attr(topi.atanh, 
"tir_atanh"))
 register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil"))
 register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos"))
-register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
+register_legalize("relax.cosh", _call_topi_without_attr(topi.cosh, "tir_cosh"))
 register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp"))
 register_legalize("relax.floor", _call_topi_without_attr(topi.floor, 
"tir_floor"))
+register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
 register_legalize("relax.negative", _call_topi_without_attr(topi.negative, 
"tir_negative"))
 register_legalize("relax.round", _call_topi_without_attr(topi.round, 
"tir_round"))
 register_legalize("relax.rsqrt", _call_topi_without_attr(topi.rsqrt, 
"tir_rsqrt"))
 register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid, 
"tir_sigmoid"))
 register_legalize("relax.sign", _call_topi_without_attr(topi.sign, "tir_sign"))
-register_legalize("relax.sinh", _call_topi_without_attr(topi.sinh, "tir_sinh"))
 register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin"))
+register_legalize("relax.sinh", _call_topi_without_attr(topi.sinh, "tir_sinh"))
+register_legalize("relax.square", _call_topi_without_attr(lambda x: x * x, 
"tir_square"))
 register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
+register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan"))
 register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
 register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip"))
diff --git a/python/tvm/script/parser/core/parser.py 
b/python/tvm/script/parser/core/parser.py
index 6c08680a6c..380b428c4d 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -687,3 +687,18 @@ class Parser(doc.NodeVisitor):
             The visiting result.
         """
         return _dispatch(self, "Return")(self, node)
+
+    def visit_Nonlocal(self, node: doc.Nonlocal) -> Any:  # pylint: 
disable=invalid-name
+        """The general nonlocal visiting method.
+
+        Parameters
+        ----------
+        node : doc.Nonlocal
+            The doc AST nonlocal node.
+
+        Returns
+        -------
+        res : Any
+            The visiting result.
+        """
+        return _dispatch(self, "Nonlocal")(self, node)
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py 
b/tests/python/relax/test_transform_legalize_ops_unary.py
index 398fffdbdb..7d8a0fe9d7 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -15,1229 +15,98 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from typing import Callable
+
+import pytest
 import tvm
+from tvm import topi
 import tvm.testing
 from tvm.relax.transform import LegalizeOps
+import tvm.script
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
 
 
-def test_abs():
-    # fmt: off
-    @tvm.script.ir_module
-    class Abs:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.abs(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_abs(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], 
dtype="float32")
-    # fmt: on
-
-    mod = LegalizeOps()(Abs)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_abs_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Abs:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.abs(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], 
dtype="float32")
-    # fmt: on
-
-    mod = LegalizeOps()(Abs)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_ceil():
-    # fmt: off
-    @tvm.script.ir_module
-    class Ceil:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.ceil(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2, 
3), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Ceil)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_ceil_int():
-    # fmt: off
-    @tvm.script.ir_module
-    class Ceil:
-        @R.function
-        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
-            gv: R.Tensor((2, 3), "int32") = R.ceil(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
-            gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2, 
3), dtype="int32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Ceil)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_ceil_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Ceil:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.ceil(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_ceil(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
-            compute = T.match_buffer(var_compute, (m, n))
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
-
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((m, 
n), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Ceil)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_cos():
-    # fmt: off
+def _test_static_shape(name: str, relax_op: Callable, te_func: Callable, 
dtype: str):
     @tvm.script.ir_module
-    class Cos:
+    class Before:
         @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.cos(x)
+        def main(x: R.Tensor((2, 3), dtype)):
+            nonlocal dtype
+            gv = relax_op(x)
             return gv
 
     @tvm.script.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((2, 3), 
dtype="float32"))
+        def main(x: R.Tensor((2, 3), dtype)):
+            nonlocal dtype
+            gv = R.emit_te(te_func, x, primfunc_name_hint=f"tir_{name}")
             return gv
 
-        @T.prim_func
-        def tir_cos(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Cos)
+    mod = LegalizeOps()(Before)
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_cos_symbolic():
-    # fmt: off
+def _test_symbolic_shape(name: str, relax_op: Callable, te_func: Callable, 
dtype: str):
     @tvm.script.ir_module
-    class Cos:
+    class Before:
         @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.cos(x)
+        def main(x: R.Tensor(("m", "n"), dtype)):
+            nonlocal dtype
+            gv = relax_op(x)
             return gv
 
     @tvm.script.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((m, n), 
dtype="float32"))
+        def main(x: R.Tensor(("m", "n"), dtype)):
+            nonlocal dtype
+            gv = R.emit_te(te_func, x, primfunc_name_hint=f"tir_{name}")
             return gv
 
-        @T.prim_func
-        def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Cos)
+    mod = LegalizeOps()(Before)
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_exp():
-    # fmt: off
-    @tvm.script.ir_module
-    class Exp:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.exp(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], 
dtype="float32")
-    # fmt: on
-
-    mod = LegalizeOps()(Exp)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_exp_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Exp:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.exp(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], 
dtype="float32")
-    # fmt: on
-
-    mod = LegalizeOps()(Exp)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_floor():
-    # fmt: off
-    @tvm.script.ir_module
-    class Floor:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.floor(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2, 
3), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Floor)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_floor_int():
-    # fmt: off
-    @tvm.script.ir_module
-    class Floor:
-        @R.function
-        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
-            gv: R.Tensor((2, 3), "int32") = R.floor(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
-            gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2, 
3), dtype="int32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Floor)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_floor_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Floor:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.floor(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_floor(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
-            compute = T.match_buffer(var_compute, (m, n))
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
-
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((m, 
n), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Floor)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_log():
-    # fmt: off
-    @tvm.script.ir_module
-    class Log:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.log(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_log(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Log)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_log_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Log:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.log(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Log)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_negative():
-    # fmt: off
-    @tvm.script.ir_module
-    class Negative:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.negative(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_negative(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * 
T.float32(-1)
-    # fmt: on
-
-    mod = LegalizeOps()(Negative)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_negative_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Negative:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.negative(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * 
T.float32(-1)
-    # fmt: on
-
-    mod = LegalizeOps()(Negative)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_round():
-    # fmt: off
-    @tvm.script.ir_module
-    class Round:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.round(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2, 
3), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Round)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_round_int():
-    # fmt: off
-    @tvm.script.ir_module
-    class Round:
-        @R.function
-        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
-            gv: R.Tensor((2, 3), "int32") = R.round(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
-            gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2, 
3), dtype="int32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Round)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_round_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Round:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.round(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_round(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
-            compute = T.match_buffer(var_compute, (m, n))
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
-
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((m, 
n), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Round)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_rsqrt():
-    # fmt: off
-    @tvm.script.ir_module
-    class Rsqrt:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.rsqrt(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_rsqrt, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_rsqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.rsqrt(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Rsqrt)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_rsqrt_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Rsqrt:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.rsqrt(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_rsqrt, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_rsqrt(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.rsqrt(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Rsqrt)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sigmoid():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sigmoid:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.sigmoid(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sigmoid(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sigmoid)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sigmoid_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sigmoid:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.sigmoid(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sigmoid)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sign():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sign:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.sign(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("T_sign"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(rxplaceholder[v_ax0, v_ax1])
-                    T.writes(T_sign[v_ax0, v_ax1])
-                    T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < 
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] 
< T.float32(0), T.float32(-1), T.float32(0)))
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2, 
3), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Sign)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sign_int():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sign:
-        @R.function
-        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
-            gv: R.Tensor((2, 3), "int32") = R.sign(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "int32")):
-            T.func_attr({"tir.noalias": True})
-            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("T_sign"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(rxplaceholder[v_ax0, v_ax1])
-                    T.writes(T_sign[v_ax0, v_ax1])
-                    T_sign[v_ax0, v_ax1] = T.Select(T.int64(0) < 
T.Cast("int64", rxplaceholder[v_ax0, v_ax1]), 1, T.Select(T.Cast("int64", 
rxplaceholder[v_ax0, v_ax1]) < T.int64(0), -1, 0))
-
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
-            gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2, 
3), dtype="int32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Sign)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sign_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sign:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.sign(x)
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func
-        def tir_sign(var_rxplaceholder: T.handle, var_T_sign: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
-            T_sign = T.match_buffer(var_T_sign, (m, n))
-            for ax0, ax1 in T.grid(m, n):
-                with T.block("T_sign"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(rxplaceholder[v_ax0, v_ax1])
-                    T.writes(T_sign[v_ax0, v_ax1])
-                    T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < 
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] 
< T.float32(0), T.float32(-1), T.float32(0)))
-
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((m, 
n), dtype="float32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Sign)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sin():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sin:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.sin(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sin)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sin_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sin:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.sin(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sin)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sinh():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sinh:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.sinh(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
-            cls = Expected
-            gv = R.call_tir(cls.tir_sinh, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sinh(
-            rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"),
-            compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),
-        ):
-            T.func_attr({"tir.noalias": T.bool(True)})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.sinh(rxplaceholder[v_i0, v_i1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sinh)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sinh_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sinh:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.sinh(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(
-            x: R.Tensor(("m", "n"), dtype="float32")
-        ) -> R.Tensor(("m", "n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            cls = Expected
-            gv = R.call_tir(cls.tir_sinh, (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sinh(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": T.bool(True)})
-            m, n = T.int64(), T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
-            compute = T.match_buffer(var_compute, (m, n))
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[v_i0, v_i1])
-                    T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.sinh(rxplaceholder[v_i0, v_i1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sinh)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sqrt():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sqrt:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.sqrt(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sqrt)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sqrt_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Sqrt:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.sqrt(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Sqrt)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_tanh():
-    # fmt: off
-    @tvm.script.ir_module
-    class Tanh:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv: R.Tensor((2, 3), "float32") = R.tanh(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
-            gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((2, 3), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_tanh(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Tanh)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_tanh_symbolic():
-    # fmt: off
-    @tvm.script.ir_module
-    class Tanh:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.tanh(x)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((m, n), 
dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1])
-    # fmt: on
-
-    mod = LegalizeOps()(Tanh)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_clip_symbolic():
-    @tvm.script.ir_module
-    class Clip:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
-            m = T.int64()
-            n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8)
-            return gv
-
-    @tvm.script.ir_module
-    class Expected:
-        @R.function
-        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
-            m = T.int64()
-            n = T.int64()
-            gv = R.call_tir(Expected.tir_clip, (x,), out_sinfo=R.Tensor((m, 
n), dtype="float32"))
-            return gv
-
-        @T.prim_func
-        def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle):
-            T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
-                with T.block("compute"):
-                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    compute[v_i0, v_i1] = T.max(
-                        T.min(rxplaceholder[v_i0, v_i1], T.float32(8)), 
T.float32(5)
-                    )
-
-    mod = LegalizeOps()(Clip)
-    tvm.ir.assert_structural_equal(mod, Expected)
[email protected](
+    "name, relax_op, te_func, dtype",
+    [
+        ("abs", R.abs, topi.abs, "float32"),
+        ("acos", R.acos, topi.acos, "float32"),
+        ("acosh", R.acosh, topi.acosh, "float32"),
+        ("asin", R.asin, topi.asin, "float32"),
+        ("asinh", R.asinh, topi.asinh, "float32"),
+        ("atan", R.atan, topi.atan, "float32"),
+        ("atanh", R.atanh, topi.atanh, "float32"),
+        ("ceil", R.ceil, topi.ceil, "float32"),
+        ("ceil", R.ceil, topi.identity, "int32"),
+        ("cos", R.cos, topi.cos, "float32"),
+        ("cosh", R.cosh, topi.cosh, "float32"),
+        ("exp", R.exp, topi.exp, "float32"),
+        ("floor", R.floor, topi.floor, "float32"),
+        ("floor", R.floor, topi.identity, "int32"),
+        ("log", R.log, topi.log, "float32"),
+        ("negative", R.negative, topi.negative, "float32"),
+        ("round", R.round, topi.round, "float32"),
+        ("round", R.round, topi.identity, "int32"),
+        ("rsqrt", R.rsqrt, topi.rsqrt, "float32"),
+        ("sigmoid", R.sigmoid, topi.sigmoid, "float32"),
+        ("sign", R.sign, topi.sign, "float32"),
+        ("sign", R.sign, topi.sign, "int32"),
+        ("sin", R.sin, topi.sin, "float32"),
+        ("sinh", R.sinh, topi.sinh, "float32"),
+        ("sqrt", R.sqrt, topi.sqrt, "float32"),
+        ("square", R.square, lambda x: topi.multiply(x, x), "float32"),
+        ("tan", R.tan, topi.tan, "float32"),
+        ("tanh", R.tanh, topi.tanh, "float32"),
+        ("clip", lambda x: R.clip(x, 5, 8), lambda x: topi.clip(x, 5, 8), 
"float32"),
+    ],
+)
+def test_unary_ops(name: str, relax_op: Callable, te_func: Callable, dtype: 
str):
+    _test_static_shape(name, relax_op, te_func, dtype)
+    _test_symbolic_shape(name, relax_op, te_func, dtype)
 
 
 if __name__ == "__main__":

Reply via email to