junrushao commented on code in PR #13005:
URL: https://github.com/apache/tvm/pull/13005#discussion_r1020503370


##########
tests/python/contrib/test_hexagon/test_async_dma_pipeline.py:
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Test different strategies for loading data into vtcm before running HVX 
workloads. """
+
+import numpy as np
+import tvm
+import pytest
+
+from tvm.script import tir as T
+from numpy.random import default_rng
+
+from tvm.tir.function import TensorIntrin
+
+VRMPY_SIZE_B = 128
+VRMPY_SIZE_INT32 = 32
+
+
+def conv_approximation(size_a, size_w):
+    a_shape = (size_a, VRMPY_SIZE_B)
+    w_shape = (size_w, VRMPY_SIZE_B)
+    out_shape = (size_a, VRMPY_SIZE_INT32)
+
+    @T.prim_func
+    def operator(a: T.handle, b: T.handle, c: T.handle) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = T.match_buffer(a, a_shape, dtype="uint8")
+        W = T.match_buffer(b, w_shape, dtype="uint8")
+        C = T.match_buffer(c, out_shape, dtype="int32")
+        for n, i in T.grid(size_a, size_w):
+            with T.block("C"):
+                vn, vi = T.axis.remap("SR", [n, i])
+                T.reads(A[vn, 0:VRMPY_SIZE_B], W[vi, 0:VRMPY_SIZE_B], C[vn, 
0:VRMPY_SIZE_INT32])
+                T.writes(C[vn, 0:VRMPY_SIZE_INT32])
+                with T.init():
+                    for x in T.serial(VRMPY_SIZE_INT32):
+                        C[vn, x] = 0
+                C[vn, T.ramp(0, 1, 32)] = T.call_llvm_intrin(
+                    
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
+                    T.uint32(3),
+                    C[vn, T.ramp(0, 1, 32)],
+                    T.reinterpret(A[vn, T.ramp(0, 1, 128)], dtype="int32x32"),
+                    T.reinterpret(W[vi, T.ramp(0, 1, 128)], dtype="int32x32"),
+                    dtype="int32x32",
+                )
+        # Currently async DMA lowering does not add any wait to the end of 
schedules so
+        # for timing purposes we are manually adding a wait to ensure that all 
copies
+        # are complete when the schedule exits.
+        T.evaluate(
+            T.tvm_call_packed(
+                "device_api.hexagon.dma_wait",
+                0,  # QueueId
+                0,  # Wait for 0 in flight
+                dtype="int32",
+            )
+        )
+
+    return tvm.tir.Schedule(operator)
+
+
+def evaluate(hexagon_session, sch, a, b, size_a, expected_output, 
use_async_copy=0):
+    target_hexagon = tvm.target.hexagon("v68", link_params=True)
+    with tvm.transform.PassContext(config={"tir.use_async_copy": 
use_async_copy}):
+        func_tir = tvm.build(
+            sch.mod["main"], target=tvm.target.Target(target_hexagon, 
host=target_hexagon)
+        )
+    module = hexagon_session.load_module(func_tir)
+
+    a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device)
+    b_hexagon = tvm.runtime.ndarray.array(b, device=hexagon_session.device)
+    c_hexagon = tvm.runtime.ndarray.array(
+        np.zeros((size_a, VRMPY_SIZE_INT32), dtype="int32"), 
device=hexagon_session.device
+    )
+
+    if tvm.testing.utils.IS_IN_CI:
+        # Run with reduced number and repeat for CI
+        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+    else:
+        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+
+    time = timer(a_hexagon, b_hexagon, c_hexagon)
+    tvm.testing.assert_allclose(c_hexagon.asnumpy(), expected_output)
+    return round(time.mean * 1000, 4)
+
+
[email protected]
+def input_a(size_a):
+    return default_rng().integers(0, 8, (size_a, VRMPY_SIZE_B), dtype="uint8")
+
+
[email protected]
+def input_w(size_w):
+    return default_rng().integers(0, 8, (size_w, VRMPY_SIZE_B), dtype="uint8")
+
+
[email protected]
+def expected_output(size_a, size_w, input_a, input_w):
+    if tvm.testing.utils.IS_IN_CI and (size_a > 1024 or size_w > 1):
+        pytest.skip("Skipping test since it takes too long in CI.")
+    expected_output = np.zeros((size_a, VRMPY_SIZE_INT32), dtype="int32")
+    for n in range(size_a):
+        for x in range(size_w):
+            for i in range(VRMPY_SIZE_INT32):
+                for r in range(4):
+                    expected_output[n, i] += np.uint32(input_a[n, i * 4 + r]) 
* np.uint32(
+                        input_w[x, i * 4 + r]
+                    )
+    return expected_output
+
+
+def get_single_dma_schedule(size_a, size_w):
+    a_shape = (size_a, VRMPY_SIZE_B)
+    w_shape = (size_w, VRMPY_SIZE_B)
+    out_shape = (size_a, VRMPY_SIZE_INT32)
+
+    a_bytes = size_a * VRMPY_SIZE_B
+    w_bytes = size_w * VRMPY_SIZE_B
+
+    @T.prim_func
+    def operator(a: T.handle, b: T.handle, c: T.handle) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = T.match_buffer(a, a_shape, dtype="uint8", mem_scope="global")
+        W = T.match_buffer(b, w_shape, dtype="uint8", mem_scope="global")
+        C = T.match_buffer(c, out_shape, dtype="int32", mem_scope="global")
+        A_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", 
mem_scope="global.vtcm")
+        W_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", 
mem_scope="global.vtcm")
+        C_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", 
mem_scope="global.vtcm")

Review Comment:
   @nverke just wanted to report an issue that i happened to detect using the 
new TVMScript parser
   
   `T.alloc_buffer` doesn't accept the parameter `mem_scope`, but instead it 
uses `scope`. Current generation of TVMScript parser doesn't report any error 
of such misuses (which is weird) but instead assumes `scope="global"`.
   
   To debug what's exactly happening, you may print out the method using 
`operator.show`, which is:
   
   ```python
   @T.prim_func
   def func(a_buffer: T.Buffer[(1024, 128), "uint8"], w_buffer: T.Buffer[(1, 
128), "uint8"], c_buffer: T.Buffer[(1024, 32), "int32"]):
       # function attr dict
       T.func_attr({"global_symbol": "main", "tir.noalias": True})
       # body
       # with T.block("root")
       a_global_vtcm = T.alloc_buffer([1024, 128], dtype="uint8") ## <==== note 
it's "global" rather than "global.vtcm"
       w_global_vtcm = T.alloc_buffer([1, 128], dtype="uint8")
       c_global_vtcm = T.alloc_buffer([1024, 32], dtype="int32")
       T.evaluate(T.tvm_call_packed("device_api.hexagon.mem_copy_DLTensor", 
T.tvm_stack_make_array(a_global_vtcm.data, T.tvm_stack_make_shape(1024, 128, 
dtype="handle"), 0, 2, "uint8", 0, dtype="handle"), 
T.tvm_stack_make_array(a_buffer.data, T.tvm_stack_make_shape(1024, 128, 
dtype="handle"), 0, 2, "uint8", 0, dtype="handle"), T.Cast("int32", 131072), 
dtype="int32"))
       T.evaluate(T.tvm_call_packed("device_api.hexagon.mem_copy_DLTensor", 
T.tvm_stack_make_array(w_global_vtcm.data, T.tvm_stack_make_shape(1, 128, 
dtype="handle"), 0, 2, "uint8", 0, dtype="handle"), 
T.tvm_stack_make_array(w_buffer.data, T.tvm_stack_make_shape(1, 128, 
dtype="handle"), 0, 2, "uint8", 0, dtype="handle"), T.Cast("int32", 128), 
dtype="int32"))
       for n, index_0 in T.grid(1024, 1):
           with T.block("c_buffer"):
               vn_index, vi_index = T.axis.remap("SR", [n, index_0])
               T.reads(a_global_vtcm[vn_index, 0 : 128], 
w_global_vtcm[vi_index, 0 : 128], c_global_vtcm[vn_index, 0 : 32])
               T.writes(c_global_vtcm[vn_index, 0 : 32])
               with T.init():
                   for x in T.serial(32):
                       c_global_vtcm[vn_index, x] = 0
               c_global_vtcm[vn_index, 0:32] = c_global_vtcm[vn_index, 0:32] + 
T.call_llvm_intrin(3885, T.uint32(2), T.reinterpret(a_global_vtcm[vn_index, 
0:128], dtype="int32x32"), T.reinterpret(w_global_vtcm[vi_index, 0:128], 
dtype="int32x32"), dtype="int32x32")
       T.evaluate(T.tvm_call_packed("device_api.hexagon.mem_copy_DLTensor", 
T.tvm_stack_make_array(c_buffer.data, T.tvm_stack_make_shape(1024, 128, 
dtype="handle"), 0, 2, "int32", 0, dtype="handle"), 
T.tvm_stack_make_array(c_global_vtcm.data, T.tvm_stack_make_shape(1024, 128, 
dtype="handle"), 0, 2, "int32", 0, dtype="handle"), T.Cast("int32", 131072), 
dtype="int32"))
   ``` 
   
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to