This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 96aca9d85fd178bdfcd68ddd909d24624d286b1b Author: Siyuan Feng <[email protected]> AuthorDate: Mon May 8 10:28:33 2023 +0800 [Unity] Fix Unary Op Legalization (#14789) This PR adds support for unary ops legalization, which is missing currently. And clean up the test cases. --- python/tvm/relax/transform/legalize_ops/unary.py | 13 +- python/tvm/script/parser/core/parser.py | 15 + .../relax/test_transform_legalize_ops_unary.py | 1251 +------------------- 3 files changed, 86 insertions(+), 1193 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py index 9d19abe16e..104a874679 100644 --- a/python/tvm/relax/transform/legalize_ops/unary.py +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -21,18 +21,27 @@ from .common import _call_topi_without_attr, register_legalize # To avoid conflict of IRModule function name and libc function name, we add # "tir_" as the prefix of the generated PrimFunc name. register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs")) +register_legalize("relax.acos", _call_topi_without_attr(topi.acos, "tir_acos")) +register_legalize("relax.acosh", _call_topi_without_attr(topi.acosh, "tir_acosh")) +register_legalize("relax.asin", _call_topi_without_attr(topi.asin, "tir_asin")) +register_legalize("relax.asinh", _call_topi_without_attr(topi.asinh, "tir_asinh")) +register_legalize("relax.atan", _call_topi_without_attr(topi.atan, "tir_atan")) +register_legalize("relax.atanh", _call_topi_without_attr(topi.atanh, "tir_atanh")) register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil")) register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos")) -register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log")) +register_legalize("relax.cosh", _call_topi_without_attr(topi.cosh, "tir_cosh")) register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp")) register_legalize("relax.floor", _call_topi_without_attr(topi.floor, "tir_floor")) +register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log")) register_legalize("relax.negative", _call_topi_without_attr(topi.negative, "tir_negative")) register_legalize("relax.round", _call_topi_without_attr(topi.round, "tir_round")) register_legalize("relax.rsqrt", _call_topi_without_attr(topi.rsqrt, "tir_rsqrt")) register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid, "tir_sigmoid")) register_legalize("relax.sign", _call_topi_without_attr(topi.sign, "tir_sign")) -register_legalize("relax.sinh", _call_topi_without_attr(topi.sinh, "tir_sinh")) register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin")) +register_legalize("relax.sinh", _call_topi_without_attr(topi.sinh, "tir_sinh")) +register_legalize("relax.square", _call_topi_without_attr(lambda x: x * x, "tir_square")) register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt")) +register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan")) register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh")) register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip")) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 6c08680a6c..380b428c4d 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -687,3 +687,18 @@ class Parser(doc.NodeVisitor): The visiting result. """ return _dispatch(self, "Return")(self, node) + + def visit_Nonlocal(self, node: doc.Nonlocal) -> Any: # pylint: disable=invalid-name + """The general nonlocal visiting method. + + Parameters + ---------- + node : doc.Nonlocal + The doc AST nonlocal node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Nonlocal")(self, node) diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py index 398fffdbdb..7d8a0fe9d7 100644 --- a/tests/python/relax/test_transform_legalize_ops_unary.py +++ b/tests/python/relax/test_transform_legalize_ops_unary.py @@ -15,1229 +15,98 @@ # specific language governing permissions and limitations # under the License. +from typing import Callable + +import pytest import tvm +from tvm import topi import tvm.testing from tvm.relax.transform import LegalizeOps +import tvm.script from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -def test_abs(): - # fmt: off - @tvm.script.ir_module - class Abs: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.abs(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_abs(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], dtype="float32") - # fmt: on - - mod = LegalizeOps()(Abs) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_abs_symbolic(): - # fmt: off - @tvm.script.ir_module - class Abs: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.abs(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], dtype="float32") - # fmt: on - - mod = LegalizeOps()(Abs) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_ceil(): - # fmt: off - @tvm.script.ir_module - class Ceil: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.ceil(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1]) - - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Ceil) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_ceil_int(): - # fmt: off - @tvm.script.ir_module - class Ceil: - @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): - gv: R.Tensor((2, 3), "int32") = R.ceil(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1] - - @R.function - def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): - gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) - return gv - # fmt: on - - mod = LegalizeOps()(Ceil) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_ceil_symbolic(): - # fmt: off - @tvm.script.ir_module - class Ceil: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.ceil(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_ceil(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) - compute = T.match_buffer(var_compute, (m, n)) - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1]) - - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Ceil) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_cos(): - # fmt: off +def _test_static_shape(name: str, relax_op: Callable, te_func: Callable, dtype: str): @tvm.script.ir_module - class Cos: + class Before: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.cos(x) + def main(x: R.Tensor((2, 3), dtype)): + nonlocal dtype + gv = relax_op(x) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((2, 3), dtype="float32")) + def main(x: R.Tensor((2, 3), dtype)): + nonlocal dtype + gv = R.emit_te(te_func, x, primfunc_name_hint=f"tir_{name}") return gv - @T.prim_func - def tir_cos(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Cos) + mod = LegalizeOps()(Before) tvm.ir.assert_structural_equal(mod, Expected) -def test_cos_symbolic(): - # fmt: off +def _test_symbolic_shape(name: str, relax_op: Callable, te_func: Callable, dtype: str): @tvm.script.ir_module - class Cos: + class Before: @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.cos(x) + def main(x: R.Tensor(("m", "n"), dtype)): + nonlocal dtype + gv = relax_op(x) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((m, n), dtype="float32")) + def main(x: R.Tensor(("m", "n"), dtype)): + nonlocal dtype + gv = R.emit_te(te_func, x, primfunc_name_hint=f"tir_{name}") return gv - @T.prim_func - def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Cos) + mod = LegalizeOps()(Before) tvm.ir.assert_structural_equal(mod, Expected) -def test_exp(): - # fmt: off - @tvm.script.ir_module - class Exp: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.exp(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], dtype="float32") - # fmt: on - - mod = LegalizeOps()(Exp) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_exp_symbolic(): - # fmt: off - @tvm.script.ir_module - class Exp: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.exp(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], dtype="float32") - # fmt: on - - mod = LegalizeOps()(Exp) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_floor(): - # fmt: off - @tvm.script.ir_module - class Floor: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.floor(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1]) - - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Floor) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_floor_int(): - # fmt: off - @tvm.script.ir_module - class Floor: - @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): - gv: R.Tensor((2, 3), "int32") = R.floor(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1] - - @R.function - def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): - gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) - return gv - # fmt: on - - mod = LegalizeOps()(Floor) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_floor_symbolic(): - # fmt: off - @tvm.script.ir_module - class Floor: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.floor(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_floor(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) - compute = T.match_buffer(var_compute, (m, n)) - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1]) - - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Floor) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_log(): - # fmt: off - @tvm.script.ir_module - class Log: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.log(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_log(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Log) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_log_symbolic(): - # fmt: off - @tvm.script.ir_module - class Log: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.log(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Log) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_negative(): - # fmt: off - @tvm.script.ir_module - class Negative: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.negative(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_negative(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * T.float32(-1) - # fmt: on - - mod = LegalizeOps()(Negative) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_negative_symbolic(): - # fmt: off - @tvm.script.ir_module - class Negative: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.negative(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * T.float32(-1) - # fmt: on - - mod = LegalizeOps()(Negative) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_round(): - # fmt: off - @tvm.script.ir_module - class Round: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.round(x) - return gv - - @tvm.script.ir_module - class Expected: - @T.prim_func - def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1]) - - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Round) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_round_int(): - # fmt: off - @tvm.script.ir_module - class Round: - @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): - gv: R.Tensor((2, 3), "int32") = R.round(x) - return gv - - @tvm.script.ir_module - class Expected: - @T.prim_func - def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1] - - @R.function - def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): - gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) - return gv - # fmt: on - - mod = LegalizeOps()(Round) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_round_symbolic(): - # fmt: off - @tvm.script.ir_module - class Round: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.round(x) - return gv - - @tvm.script.ir_module - class Expected: - @T.prim_func - def tir_round(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) - compute = T.match_buffer(var_compute, (m, n)) - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1]) - - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Round) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_rsqrt(): - # fmt: off - @tvm.script.ir_module - class Rsqrt: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.rsqrt(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_rsqrt, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_rsqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.rsqrt(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Rsqrt) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_rsqrt_symbolic(): - # fmt: off - @tvm.script.ir_module - class Rsqrt: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.rsqrt(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_rsqrt, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_rsqrt(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.rsqrt(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Rsqrt) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sigmoid(): - # fmt: off - @tvm.script.ir_module - class Sigmoid: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.sigmoid(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_sigmoid(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Sigmoid) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sigmoid_symbolic(): - # fmt: off - @tvm.script.ir_module - class Sigmoid: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.sigmoid(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Sigmoid) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sign(): - # fmt: off - @tvm.script.ir_module - class Sign: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.sign(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_sign"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax0, v_ax1]) - T.writes(T_sign[v_ax0, v_ax1]) - T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] < T.float32(0), T.float32(-1), T.float32(0))) - - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Sign) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sign_int(): - # fmt: off - @tvm.script.ir_module - class Sign: - @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): - gv: R.Tensor((2, 3), "int32") = R.sign(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "int32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) - for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_sign"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax0, v_ax1]) - T.writes(T_sign[v_ax0, v_ax1]) - T_sign[v_ax0, v_ax1] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[v_ax0, v_ax1]), 1, T.Select(T.Cast("int64", rxplaceholder[v_ax0, v_ax1]) < T.int64(0), -1, 0)) - - @R.function - def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"): - gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2, 3), dtype="int32")) - return gv - # fmt: on - - mod = LegalizeOps()(Sign) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sign_symbolic(): - # fmt: off - @tvm.script.ir_module - class Sign: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.sign(x) - return gv - - @I.ir_module - class Expected: - @T.prim_func - def tir_sign(var_rxplaceholder: T.handle, var_T_sign: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) - T_sign = T.match_buffer(var_T_sign, (m, n)) - for ax0, ax1 in T.grid(m, n): - with T.block("T_sign"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax0, v_ax1]) - T.writes(T_sign[v_ax0, v_ax1]) - T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] < T.float32(0), T.float32(-1), T.float32(0))) - - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) - return gv - # fmt: on - - mod = LegalizeOps()(Sign) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sin(): - # fmt: off - @tvm.script.ir_module - class Sin: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.sin(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_sin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Sin) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sin_symbolic(): - # fmt: off - @tvm.script.ir_module - class Sin: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.sin(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Sin) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sinh(): - # fmt: off - @tvm.script.ir_module - class Sinh: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.sinh(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - cls = Expected - gv = R.call_tir(cls.tir_sinh, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_sinh( - rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), - compute: T.Buffer((T.int64(2), T.int64(3)), "float32"), - ): - T.func_attr({"tir.noalias": T.bool(True)}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.sinh(rxplaceholder[v_i0, v_i1]) - # fmt: on - - mod = LegalizeOps()(Sinh) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sinh_symbolic(): - # fmt: off - @tvm.script.ir_module - class Sinh: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.sinh(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main( - x: R.Tensor(("m", "n"), dtype="float32") - ) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - cls = Expected - gv = R.call_tir(cls.tir_sinh, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_sinh(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - m, n = T.int64(), T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n)) - compute = T.match_buffer(var_compute, (m, n)) - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.sinh(rxplaceholder[v_i0, v_i1]) - # fmt: on - - mod = LegalizeOps()(Sinh) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sqrt(): - # fmt: off - @tvm.script.ir_module - class Sqrt: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.sqrt(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_sqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Sqrt) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_sqrt_symbolic(): - # fmt: off - @tvm.script.ir_module - class Sqrt: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.sqrt(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Sqrt) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_tanh(): - # fmt: off - @tvm.script.ir_module - class Tanh: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.tanh(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((2, 3), dtype="float32")) - return gv - - @T.prim_func - def tir_tanh(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Tanh) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_tanh_symbolic(): - # fmt: off - @tvm.script.ir_module - class Tanh: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.tanh(x) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1]) - # fmt: on - - mod = LegalizeOps()(Tanh) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_clip_symbolic(): - @tvm.script.ir_module - class Clip: - @R.function - def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.int64() - n = T.int64() - gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8) - return gv - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.tir_clip, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) - return gv - - @T.prim_func - def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") - for i0, i1 in T.grid(m, n): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - compute[v_i0, v_i1] = T.max( - T.min(rxplaceholder[v_i0, v_i1], T.float32(8)), T.float32(5) - ) - - mod = LegalizeOps()(Clip) - tvm.ir.assert_structural_equal(mod, Expected) [email protected]( + "name, relax_op, te_func, dtype", + [ + ("abs", R.abs, topi.abs, "float32"), + ("acos", R.acos, topi.acos, "float32"), + ("acosh", R.acosh, topi.acosh, "float32"), + ("asin", R.asin, topi.asin, "float32"), + ("asinh", R.asinh, topi.asinh, "float32"), + ("atan", R.atan, topi.atan, "float32"), + ("atanh", R.atanh, topi.atanh, "float32"), + ("ceil", R.ceil, topi.ceil, "float32"), + ("ceil", R.ceil, topi.identity, "int32"), + ("cos", R.cos, topi.cos, "float32"), + ("cosh", R.cosh, topi.cosh, "float32"), + ("exp", R.exp, topi.exp, "float32"), + ("floor", R.floor, topi.floor, "float32"), + ("floor", R.floor, topi.identity, "int32"), + ("log", R.log, topi.log, "float32"), + ("negative", R.negative, topi.negative, "float32"), + ("round", R.round, topi.round, "float32"), + ("round", R.round, topi.identity, "int32"), + ("rsqrt", R.rsqrt, topi.rsqrt, "float32"), + ("sigmoid", R.sigmoid, topi.sigmoid, "float32"), + ("sign", R.sign, topi.sign, "float32"), + ("sign", R.sign, topi.sign, "int32"), + ("sin", R.sin, topi.sin, "float32"), + ("sinh", R.sinh, topi.sinh, "float32"), + ("sqrt", R.sqrt, topi.sqrt, "float32"), + ("square", R.square, lambda x: topi.multiply(x, x), "float32"), + ("tan", R.tan, topi.tan, "float32"), + ("tanh", R.tanh, topi.tanh, "float32"), + ("clip", lambda x: R.clip(x, 5, 8), lambda x: topi.clip(x, 5, 8), "float32"), + ], +) +def test_unary_ops(name: str, relax_op: Callable, te_func: Callable, dtype: str): + _test_static_shape(name, relax_op, te_func, dtype) + _test_symbolic_shape(name, relax_op, te_func, dtype) if __name__ == "__main__":
