This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit e9a6e49f3a552d67d958375663d9a908735ed9bd
Author: Chaosfan <[email protected]>
AuthorDate: Sun Feb 19 12:29:53 2023 +0800

    [Unity][TVMScript] Overload `__neg__` for relax expr (#14045)
    
    This PR overloads `__neg__` given that `relax.negative` is now supported. 
Besides, it adds `test_op_misc.py` and brings tests for calling overloaded 
operators.
---
 python/tvm/relax/expr.py                    |  2 +-
 tests/python/relax/test_op_misc.py          | 98 +++++++++++++++++++++++++++++
 tests/python/relax/test_tvmscript_parser.py | 49 +++++++++++++++
 3 files changed, 148 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index f1cf815d8e..a20181e6fc 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -135,7 +135,7 @@ class ExprWithOp(Expr, Scriptable):
         return _op_ffi_api.astype(self, dtype)  # type: ignore
 
     def __neg__(self) -> "ExprWithOp":
-        raise ValueError("relax.negative is not supported yet.")
+        return _op_ffi_api.negative(self)  # type: ignore
 
     def __lt__(self, other: Expr) -> "ExprWithOp":
         return _binary_op_helper(self, other, _op_ffi_api.less)  # type: ignore
diff --git a/tests/python/relax/test_op_misc.py 
b/tests/python/relax/test_op_misc.py
new file mode 100644
index 0000000000..65772baadf
--- /dev/null
+++ b/tests/python/relax/test_op_misc.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax as rx
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
[email protected]_func("test.op.identity", override=True)
+def identity_packed(a):
+    return tvm.nd.array(a.asnumpy())
+
+
[email protected]_func
+def identity_tir(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [54, 96])
+    B = T.match_buffer(b, [54, 96])
+
+    for i, j in T.grid(54, 96):
+        with T.block("compute"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj]
+
+
+def test_call_tir() -> None:
+    v0 = rx.Var("v0", R.Tensor([54, 96], "float32"))
+    v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), 
"float32"))
+    v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32"))
+
+
+def test_implicit_op():
+    m, n = tvm.tir.Var("m", "int64"), tvm.tir.Var("n", "int64")
+    x = rx.Var("x", R.Tensor([m, n], "float32"))
+    y = rx.Var("y", R.Tensor([m, n], "float32"))
+
+    def _check_call(expr, op_name: str):
+        assert isinstance(expr, rx.Call)
+        if not op_name.startswith("relax."):
+            op_name = "relax." + op_name
+        op = tvm.ir.Op.get(op_name)
+        assert expr.op == op
+
+    # Comparison operators
+    _check_call(x > y, "greater")
+    _check_call(x >= y, "greater_equal")
+    _check_call(x < y, "less")
+    _check_call(x <= y, "less_equal")
+
+    # Arithmetic operators
+    _check_call(-x, "negative")
+    _check_call(x + y, "add")
+    _check_call(x - y, "subtract")
+    _check_call(x * y, "multiply")
+    _check_call(x / y, "divide")
+    _check_call(x // y, "floor_divide")
+    # _check_call(x % y, "mod") <= relax.mod is not implemented yet
+
+    # Cast
+    _check_call(x.astype("float32"), "astype")
+
+    # Call
+    call_expr = x(y)(y)
+    assert isinstance(call_expr.op, rx.Call)
+    assert call_expr.op.op == x
+
+    # GetTupleItem
+    ## Eager get item for tuple
+    tuple_expr = rx.Tuple((x, y))
+    assert tuple_expr[0] == x
+    assert tuple_expr[1] == y
+
+    ## Eager get item for ShapeExpr
+    shape_expr = rx.ShapeExpr((1, 2))
+    assert shape_expr[0] == 1
+    assert shape_expr[1] == 2
+
+    ## Create TupleGetItem for other expr
+    assert isinstance(x[0], rx.TupleGetItem)
+    assert isinstance(x[1][0], rx.TupleGetItem)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 8df125ac72..b458b290ec 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -972,6 +972,55 @@ def test_symbolic_shape_computing():
             return z
 
 
+def test_arith_operators():
+    @R.function
+    def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), 
"float32")):
+        a0 = -x
+        a1 = x + y
+        a2 = x - y
+        a3 = x * y
+        a4 = x / y
+        a5 = x // y
+
+        c0 = x > y
+        c1 = x < y
+        c2 = x >= y
+        c3 = x <= y
+
+        tuple_expr = ((x, x), y)
+        t0 = tuple_expr[0]
+        t1 = tuple_expr[1]
+        t2 = tuple_expr[0][0]  # <= Will normalize to two bindings
+        return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2
+
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", relax.TensorStructInfo([m, n], "float32"))
+    y = relax.Var("y", relax.TensorStructInfo([m, n], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", (x, y)):
+        a0 = bb.emit(relax.op.negative(x))
+        a1 = bb.emit(relax.op.add(x, y))
+        a2 = bb.emit(relax.op.subtract(x, y))
+        a3 = bb.emit(relax.op.multiply(x, y))
+        a4 = bb.emit(relax.op.divide(x, y))
+        a5 = bb.emit(relax.op.floor_divide(x, y))
+
+        c0 = bb.emit(relax.op.greater(x, y))
+        c1 = bb.emit(relax.op.less(x, y))
+        c2 = bb.emit(relax.op.greater_equal(x, y))
+        c3 = bb.emit(relax.op.less_equal(x, y))
+
+        tuple_expr = bb.emit(relax.Tuple((relax.Tuple((x, x)), y)))
+        t0 = bb.emit(relax.TupleGetItem(tuple_expr, 0))
+        t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1))
+        tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0))
+        t2 = bb.emit(relax.TupleGetItem(tmp, 0))
+        bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2, 
c3, t0, t1, t2)))
+
+    _check(foo, bb.get()["foo"])
+
+
 # TODO(relax-team): enable this when vm ops are ready
 @pytest.mark.xfail
 def test_vm_ops():

Reply via email to