slyubomirsky commented on code in PR #15878:
URL: https://github.com/apache/tvm/pull/15878#discussion_r1346695352


##########
tests/python/relax/test_relax_operators.py:
##########
@@ -248,6 +250,89 @@ def pure_copy(x: R.Tensor((3, 4), "float32")):
     assert (copy_found.numpy() == arr).all()
 
 
+def test_op_call_inplace_packed():
+    # in this case we can use the same test as above
+    @tvm.script.ir_module
+    class CallInplaceTest:
+        @R.function
+        def pure_copy(x: R.Tensor((3, 4), "float32")):
+            z = R.call_inplace_packed(
+                "vm.builtin.copy",
+                x,
+                inplace_indices=0,
+                sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+            )
+            return z
+
+    @tvm.register_func("test.inplace.add")
+    def inplace_add(a, b):
+        arr_a = a.numpy()
+        arr_b = b.numpy()
+        for i in range(len(arr_a)):
+            for j in range(len(arr_a[i])):
+                arr_a[i][j] = arr_a[i][j] + arr_b[i][j]
+        a.copyfrom(arr_a)
+        return a
+
+    @tvm.script.ir_module
+    class CallInplaceAddTest:
+        @R.function
+        def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), 
"float32")):
+            z = R.call_inplace_packed(
+                "test.inplace.add",
+                x,
+                y,
+                inplace_indices=0,
+                sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+            )
+            return z
+
+    np.random.seed(1)  # to avoid flakiness
+    arr_a = np.random.rand(3, 4).astype("float32")
+    arr_b = np.random.rand(3, 4).astype("float32")
+    sum = arr_a + arr_b
+    tvm_arr_a = tvm.nd.array(arr_a)
+    result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a, 
tvm.nd.array(arr_b))
+    assert result == tvm_arr_a
+    assert (result.numpy() == sum).all()
+
+    @tvm.register_func("test.inplace.tuple_add")
+    def inplace_tuple_add(a, b):
+        arr_a = a.numpy()
+        arr_b = b.numpy()
+        c = tvm.nd.array(arr_a + arr_b)
+        for i in range(len(arr_a)):
+            for j in range(len(arr_a[i])):
+                arr_a[i][j] = arr_a[i][j] + arr_b[i][j]
+        a.copyfrom(arr_a)
+        return tvm.runtime.container.ADT(0, [a, c])
+
+    @tvm.script.ir_module
+    class CallInplaceTuple:
+        @R.function
+        def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), 
"float32")):
+            z = R.call_inplace_packed(
+                "test.inplace.tuple_add",
+                x,
+                y,
+                inplace_indices=[0, -1],
+                sinfo_args=(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 
4), dtype="float32")),
+            )
+            return z
+
+    np.random.seed(2)  # to avoid flakiness
+    arr_a = np.random.rand(3, 4).astype("float32")
+    arr_b = np.random.rand(3, 4).astype("float32")
+    sum = arr_a + arr_b
+    tvm_arr_a = tvm.nd.array(arr_a)
+    tvm_arr_b = tvm.nd.array(arr_b)
+    result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b)
+    assert result[0] == tvm_arr_a
+    assert (result[0].numpy() == sum).all()
+    assert result[1] != tvm_arr_a and result[1] != tvm_arr_b
+    assert (result[1].numpy() == sum).all()

Review Comment:
   If we want, we can check reference equality/disequality at run time as part 
of the implementation of this op. I'm not sure it would be worth it, though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to