This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a86a6cbb61 [Unity] Fix Unary Op Legalization (#14789)
a86a6cbb61 is described below
commit a86a6cbb618605ad9be0c9c78b205611944c8ac8
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon May 8 10:28:33 2023 +0800
[Unity] Fix Unary Op Legalization (#14789)
This PR adds support for unary ops legalization, which is missing currently.
And clean up the test cases.
---
python/tvm/relax/transform/legalize_ops/unary.py | 13 +-
python/tvm/script/parser/core/parser.py | 15 +
.../relax/test_transform_legalize_ops_unary.py | 1251 +-------------------
3 files changed, 86 insertions(+), 1193 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/unary.py
b/python/tvm/relax/transform/legalize_ops/unary.py
index 9d19abe16e..104a874679 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -21,18 +21,27 @@ from .common import _call_topi_without_attr,
register_legalize
# To avoid conflict of IRModule function name and libc function name, we add
# "tir_" as the prefix of the generated PrimFunc name.
register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs"))
+register_legalize("relax.acos", _call_topi_without_attr(topi.acos, "tir_acos"))
+register_legalize("relax.acosh", _call_topi_without_attr(topi.acosh,
"tir_acosh"))
+register_legalize("relax.asin", _call_topi_without_attr(topi.asin, "tir_asin"))
+register_legalize("relax.asinh", _call_topi_without_attr(topi.asinh,
"tir_asinh"))
+register_legalize("relax.atan", _call_topi_without_attr(topi.atan, "tir_atan"))
+register_legalize("relax.atanh", _call_topi_without_attr(topi.atanh,
"tir_atanh"))
register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil"))
register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos"))
-register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
+register_legalize("relax.cosh", _call_topi_without_attr(topi.cosh, "tir_cosh"))
register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp"))
register_legalize("relax.floor", _call_topi_without_attr(topi.floor,
"tir_floor"))
+register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
register_legalize("relax.negative", _call_topi_without_attr(topi.negative,
"tir_negative"))
register_legalize("relax.round", _call_topi_without_attr(topi.round,
"tir_round"))
register_legalize("relax.rsqrt", _call_topi_without_attr(topi.rsqrt,
"tir_rsqrt"))
register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid,
"tir_sigmoid"))
register_legalize("relax.sign", _call_topi_without_attr(topi.sign, "tir_sign"))
-register_legalize("relax.sinh", _call_topi_without_attr(topi.sinh, "tir_sinh"))
register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin"))
+register_legalize("relax.sinh", _call_topi_without_attr(topi.sinh, "tir_sinh"))
+register_legalize("relax.square", _call_topi_without_attr(lambda x: x * x,
"tir_square"))
register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
+register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan"))
register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip"))
diff --git a/python/tvm/script/parser/core/parser.py
b/python/tvm/script/parser/core/parser.py
index 6c08680a6c..380b428c4d 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -687,3 +687,18 @@ class Parser(doc.NodeVisitor):
The visiting result.
"""
return _dispatch(self, "Return")(self, node)
+
+ def visit_Nonlocal(self, node: doc.Nonlocal) -> Any: # pylint:
disable=invalid-name
+ """The general nonlocal visiting method.
+
+ Parameters
+ ----------
+ node : doc.Nonlocal
+ The doc AST nonlocal node.
+
+ Returns
+ -------
+ res : Any
+ The visiting result.
+ """
+ return _dispatch(self, "Nonlocal")(self, node)
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py
b/tests/python/relax/test_transform_legalize_ops_unary.py
index 398fffdbdb..7d8a0fe9d7 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -15,1229 +15,98 @@
# specific language governing permissions and limitations
# under the License.
+from typing import Callable
+
+import pytest
import tvm
+from tvm import topi
import tvm.testing
from tvm.relax.transform import LegalizeOps
+import tvm.script
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
-def test_abs():
- # fmt: off
- @tvm.script.ir_module
- class Abs:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.abs(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_abs(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1],
dtype="float32")
- # fmt: on
-
- mod = LegalizeOps()(Abs)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_abs_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Abs:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.abs(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_abs, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1],
dtype="float32")
- # fmt: on
-
- mod = LegalizeOps()(Abs)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_ceil():
- # fmt: off
- @tvm.script.ir_module
- class Ceil:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.ceil(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2,
3), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Ceil)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_ceil_int():
- # fmt: off
- @tvm.script.ir_module
- class Ceil:
- @R.function
- def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
- gv: R.Tensor((2, 3), "int32") = R.ceil(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
- gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((2,
3), dtype="int32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Ceil)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_ceil_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Ceil:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.ceil(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_ceil(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
- compute = T.match_buffer(var_compute, (m, n))
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
-
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_ceil, (x,), out_sinfo=R.Tensor((m,
n), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Ceil)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_cos():
- # fmt: off
+def _test_static_shape(name: str, relax_op: Callable, te_func: Callable,
dtype: str):
@tvm.script.ir_module
- class Cos:
+ class Before:
@R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.cos(x)
+ def main(x: R.Tensor((2, 3), dtype)):
+ nonlocal dtype
+ gv = relax_op(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((2, 3),
dtype="float32"))
+ def main(x: R.Tensor((2, 3), dtype)):
+ nonlocal dtype
+ gv = R.emit_te(te_func, x, primfunc_name_hint=f"tir_{name}")
return gv
- @T.prim_func
- def tir_cos(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Cos)
+ mod = LegalizeOps()(Before)
tvm.ir.assert_structural_equal(mod, Expected)
-def test_cos_symbolic():
- # fmt: off
+def _test_symbolic_shape(name: str, relax_op: Callable, te_func: Callable,
dtype: str):
@tvm.script.ir_module
- class Cos:
+ class Before:
@R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.cos(x)
+ def main(x: R.Tensor(("m", "n"), dtype)):
+ nonlocal dtype
+ gv = relax_op(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_cos, (x,), R.Tensor((m, n),
dtype="float32"))
+ def main(x: R.Tensor(("m", "n"), dtype)):
+ nonlocal dtype
+ gv = R.emit_te(te_func, x, primfunc_name_hint=f"tir_{name}")
return gv
- @T.prim_func
- def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Cos)
+ mod = LegalizeOps()(Before)
tvm.ir.assert_structural_equal(mod, Expected)
-def test_exp():
- # fmt: off
- @tvm.script.ir_module
- class Exp:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.exp(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1],
dtype="float32")
- # fmt: on
-
- mod = LegalizeOps()(Exp)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_exp_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Exp:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.exp(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_exp, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1],
dtype="float32")
- # fmt: on
-
- mod = LegalizeOps()(Exp)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_floor():
- # fmt: off
- @tvm.script.ir_module
- class Floor:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.floor(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2,
3), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Floor)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_floor_int():
- # fmt: off
- @tvm.script.ir_module
- class Floor:
- @R.function
- def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
- gv: R.Tensor((2, 3), "int32") = R.floor(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
- gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((2,
3), dtype="int32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Floor)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_floor_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Floor:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.floor(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_floor(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
- compute = T.match_buffer(var_compute, (m, n))
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
-
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_floor, (x,), out_sinfo=R.Tensor((m,
n), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Floor)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_log():
- # fmt: off
- @tvm.script.ir_module
- class Log:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.log(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_log(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Log)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_log_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Log:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.log(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_log, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Log)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_negative():
- # fmt: off
- @tvm.script.ir_module
- class Negative:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.negative(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_negative(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] *
T.float32(-1)
- # fmt: on
-
- mod = LegalizeOps()(Negative)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_negative_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Negative:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.negative(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_negative, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] *
T.float32(-1)
- # fmt: on
-
- mod = LegalizeOps()(Negative)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_round():
- # fmt: off
- @tvm.script.ir_module
- class Round:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.round(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @T.prim_func
- def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2,
3), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Round)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_round_int():
- # fmt: off
- @tvm.script.ir_module
- class Round:
- @R.function
- def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
- gv: R.Tensor((2, 3), "int32") = R.round(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @T.prim_func
- def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
- gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((2,
3), dtype="int32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Round)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_round_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Round:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.round(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @T.prim_func
- def tir_round(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
- compute = T.match_buffer(var_compute, (m, n))
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
-
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_round, (x,), out_sinfo=R.Tensor((m,
n), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Round)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_rsqrt():
- # fmt: off
- @tvm.script.ir_module
- class Rsqrt:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.rsqrt(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_rsqrt, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_rsqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.rsqrt(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Rsqrt)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_rsqrt_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Rsqrt:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.rsqrt(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_rsqrt, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_rsqrt(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.rsqrt(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Rsqrt)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sigmoid():
- # fmt: off
- @tvm.script.ir_module
- class Sigmoid:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.sigmoid(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sigmoid(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Sigmoid)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sigmoid_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Sigmoid:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.sigmoid(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_sigmoid, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Sigmoid)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sign():
- # fmt: off
- @tvm.script.ir_module
- class Sign:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.sign(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("T_sign"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(rxplaceholder[v_ax0, v_ax1])
- T.writes(T_sign[v_ax0, v_ax1])
- T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) <
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1]
< T.float32(0), T.float32(-1), T.float32(0)))
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2,
3), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Sign)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sign_int():
- # fmt: off
- @tvm.script.ir_module
- class Sign:
- @R.function
- def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
- gv: R.Tensor((2, 3), "int32") = R.sign(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "int32")):
- T.func_attr({"tir.noalias": True})
- for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("T_sign"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(rxplaceholder[v_ax0, v_ax1])
- T.writes(T_sign[v_ax0, v_ax1])
- T_sign[v_ax0, v_ax1] = T.Select(T.int64(0) <
T.Cast("int64", rxplaceholder[v_ax0, v_ax1]), 1, T.Select(T.Cast("int64",
rxplaceholder[v_ax0, v_ax1]) < T.int64(0), -1, 0))
-
- @R.function
- def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
- gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((2,
3), dtype="int32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Sign)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sign_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Sign:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.sign(x)
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func
- def tir_sign(var_rxplaceholder: T.handle, var_T_sign: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
- T_sign = T.match_buffer(var_T_sign, (m, n))
- for ax0, ax1 in T.grid(m, n):
- with T.block("T_sign"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(rxplaceholder[v_ax0, v_ax1])
- T.writes(T_sign[v_ax0, v_ax1])
- T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) <
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1]
< T.float32(0), T.float32(-1), T.float32(0)))
-
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_sign, (x,), out_sinfo=R.Tensor((m,
n), dtype="float32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Sign)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sin():
- # fmt: off
- @tvm.script.ir_module
- class Sin:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.sin(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Sin)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sin_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Sin:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.sin(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_sin, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Sin)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sinh():
- # fmt: off
- @tvm.script.ir_module
- class Sinh:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.sinh(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
- cls = Expected
- gv = R.call_tir(cls.tir_sinh, (x,), out_sinfo=R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sinh(
- rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"),
- compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),
- ):
- T.func_attr({"tir.noalias": T.bool(True)})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.sinh(rxplaceholder[v_i0, v_i1])
- # fmt: on
-
- mod = LegalizeOps()(Sinh)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sinh_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Sinh:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.sinh(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(
- x: R.Tensor(("m", "n"), dtype="float32")
- ) -> R.Tensor(("m", "n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- cls = Expected
- gv = R.call_tir(cls.tir_sinh, (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sinh(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": T.bool(True)})
- m, n = T.int64(), T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
- compute = T.match_buffer(var_compute, (m, n))
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[v_i0, v_i1])
- T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.sinh(rxplaceholder[v_i0, v_i1])
- # fmt: on
-
- mod = LegalizeOps()(Sinh)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sqrt():
- # fmt: off
- @tvm.script.ir_module
- class Sqrt:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.sqrt(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Sqrt)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_sqrt_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Sqrt:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.sqrt(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_sqrt, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Sqrt)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_tanh():
- # fmt: off
- @tvm.script.ir_module
- class Tanh:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv: R.Tensor((2, 3), "float32") = R.tanh(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
- gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((2, 3),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_tanh(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
- for i0, i1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Tanh)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_tanh_symbolic():
- # fmt: off
- @tvm.script.ir_module
- class Tanh:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.tanh(x)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_tanh, (x,), R.Tensor((m, n),
dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1])
- # fmt: on
-
- mod = LegalizeOps()(Tanh)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_clip_symbolic():
- @tvm.script.ir_module
- class Clip:
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
- m = T.int64()
- n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8)
- return gv
-
- @tvm.script.ir_module
- class Expected:
- @R.function
- def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
- m = T.int64()
- n = T.int64()
- gv = R.call_tir(Expected.tir_clip, (x,), out_sinfo=R.Tensor((m,
n), dtype="float32"))
- return gv
-
- @T.prim_func
- def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle):
- T.func_attr({"tir.noalias": True})
- m = T.int64()
- n = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n],
dtype="float32")
- compute = T.match_buffer(var_compute, [m, n], dtype="float32")
- for i0, i1 in T.grid(m, n):
- with T.block("compute"):
- v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- compute[v_i0, v_i1] = T.max(
- T.min(rxplaceholder[v_i0, v_i1], T.float32(8)),
T.float32(5)
- )
-
- mod = LegalizeOps()(Clip)
- tvm.ir.assert_structural_equal(mod, Expected)
[email protected](
+ "name, relax_op, te_func, dtype",
+ [
+ ("abs", R.abs, topi.abs, "float32"),
+ ("acos", R.acos, topi.acos, "float32"),
+ ("acosh", R.acosh, topi.acosh, "float32"),
+ ("asin", R.asin, topi.asin, "float32"),
+ ("asinh", R.asinh, topi.asinh, "float32"),
+ ("atan", R.atan, topi.atan, "float32"),
+ ("atanh", R.atanh, topi.atanh, "float32"),
+ ("ceil", R.ceil, topi.ceil, "float32"),
+ ("ceil", R.ceil, topi.identity, "int32"),
+ ("cos", R.cos, topi.cos, "float32"),
+ ("cosh", R.cosh, topi.cosh, "float32"),
+ ("exp", R.exp, topi.exp, "float32"),
+ ("floor", R.floor, topi.floor, "float32"),
+ ("floor", R.floor, topi.identity, "int32"),
+ ("log", R.log, topi.log, "float32"),
+ ("negative", R.negative, topi.negative, "float32"),
+ ("round", R.round, topi.round, "float32"),
+ ("round", R.round, topi.identity, "int32"),
+ ("rsqrt", R.rsqrt, topi.rsqrt, "float32"),
+ ("sigmoid", R.sigmoid, topi.sigmoid, "float32"),
+ ("sign", R.sign, topi.sign, "float32"),
+ ("sign", R.sign, topi.sign, "int32"),
+ ("sin", R.sin, topi.sin, "float32"),
+ ("sinh", R.sinh, topi.sinh, "float32"),
+ ("sqrt", R.sqrt, topi.sqrt, "float32"),
+ ("square", R.square, lambda x: topi.multiply(x, x), "float32"),
+ ("tan", R.tan, topi.tan, "float32"),
+ ("tanh", R.tanh, topi.tanh, "float32"),
+ ("clip", lambda x: R.clip(x, 5, 8), lambda x: topi.clip(x, 5, 8),
"float32"),
+ ],
+)
+def test_unary_ops(name: str, relax_op: Callable, te_func: Callable, dtype:
str):
+ _test_static_shape(name, relax_op, te_func, dtype)
+ _test_symbolic_shape(name, relax_op, te_func, dtype)
if __name__ == "__main__":