This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit e9a6e49f3a552d67d958375663d9a908735ed9bd Author: Chaosfan <[email protected]> AuthorDate: Sun Feb 19 12:29:53 2023 +0800 [Unity][TVMScript] Overload `__neg__` for relax expr (#14045) This PR overloads `__neg__` given that `relax.negative` is now supported. Besides, it adds `test_op_misc.py` and brings tests for calling overloaded operators. --- python/tvm/relax/expr.py | 2 +- tests/python/relax/test_op_misc.py | 98 +++++++++++++++++++++++++++++ tests/python/relax/test_tvmscript_parser.py | 49 +++++++++++++++ 3 files changed, 148 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index f1cf815d8e..a20181e6fc 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -135,7 +135,7 @@ class ExprWithOp(Expr, Scriptable): return _op_ffi_api.astype(self, dtype) # type: ignore def __neg__(self) -> "ExprWithOp": - raise ValueError("relax.negative is not supported yet.") + return _op_ffi_api.negative(self) # type: ignore def __lt__(self, other: Expr) -> "ExprWithOp": return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py new file mode 100644 index 0000000000..65772baadf --- /dev/null +++ b/tests/python/relax/test_op_misc.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax as rx +from tvm.script import relax as R +from tvm.script import tir as T + + [email protected]_func("test.op.identity", override=True) +def identity_packed(a): + return tvm.nd.array(a.asnumpy()) + + [email protected]_func +def identity_tir(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [54, 96]) + B = T.match_buffer(b, [54, 96]) + + for i, j in T.grid(54, 96): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + +def test_call_tir() -> None: + v0 = rx.Var("v0", R.Tensor([54, 96], "float32")) + v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) + v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32")) + + +def test_implicit_op(): + m, n = tvm.tir.Var("m", "int64"), tvm.tir.Var("n", "int64") + x = rx.Var("x", R.Tensor([m, n], "float32")) + y = rx.Var("y", R.Tensor([m, n], "float32")) + + def _check_call(expr, op_name: str): + assert isinstance(expr, rx.Call) + if not op_name.startswith("relax."): + op_name = "relax." + op_name + op = tvm.ir.Op.get(op_name) + assert expr.op == op + + # Comparison operators + _check_call(x > y, "greater") + _check_call(x >= y, "greater_equal") + _check_call(x < y, "less") + _check_call(x <= y, "less_equal") + + # Arithmetic operators + _check_call(-x, "negative") + _check_call(x + y, "add") + _check_call(x - y, "subtract") + _check_call(x * y, "multiply") + _check_call(x / y, "divide") + _check_call(x // y, "floor_divide") + # _check_call(x % y, "mod") <= relax.mod is not implemented yet + + # Cast + _check_call(x.astype("float32"), "astype") + + # Call + call_expr = x(y)(y) + assert isinstance(call_expr.op, rx.Call) + assert call_expr.op.op == x + + # GetTupleItem + ## Eager get item for tuple + tuple_expr = rx.Tuple((x, y)) + assert tuple_expr[0] == x + assert tuple_expr[1] == y + + ## Eager get item for ShapeExpr + shape_expr = rx.ShapeExpr((1, 2)) + assert shape_expr[0] == 1 + assert shape_expr[1] == 2 + + ## Create TupleGetItem for other expr + assert isinstance(x[0], rx.TupleGetItem) + assert isinstance(x[1][0], rx.TupleGetItem) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 8df125ac72..b458b290ec 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -972,6 +972,55 @@ def test_symbolic_shape_computing(): return z +def test_arith_operators(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): + a0 = -x + a1 = x + y + a2 = x - y + a3 = x * y + a4 = x / y + a5 = x // y + + c0 = x > y + c1 = x < y + c2 = x >= y + c3 = x <= y + + tuple_expr = ((x, x), y) + t0 = tuple_expr[0] + t1 = tuple_expr[1] + t2 = tuple_expr[0][0] # <= Will normalize to two bindings + return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2 + + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", relax.TensorStructInfo([m, n], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + a0 = bb.emit(relax.op.negative(x)) + a1 = bb.emit(relax.op.add(x, y)) + a2 = bb.emit(relax.op.subtract(x, y)) + a3 = bb.emit(relax.op.multiply(x, y)) + a4 = bb.emit(relax.op.divide(x, y)) + a5 = bb.emit(relax.op.floor_divide(x, y)) + + c0 = bb.emit(relax.op.greater(x, y)) + c1 = bb.emit(relax.op.less(x, y)) + c2 = bb.emit(relax.op.greater_equal(x, y)) + c3 = bb.emit(relax.op.less_equal(x, y)) + + tuple_expr = bb.emit(relax.Tuple((relax.Tuple((x, x)), y))) + t0 = bb.emit(relax.TupleGetItem(tuple_expr, 0)) + t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1)) + tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0)) + t2 = bb.emit(relax.TupleGetItem(tmp, 0)) + bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2))) + + _check(foo, bb.get()["foo"]) + + # TODO(relax-team): enable this when vm ops are ready @pytest.mark.xfail def test_vm_ops():
