This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b29ab5c6ba [Hexagon] Add test to show scheduling of resnet50 with 
async dma pipe… (#13352)
b29ab5c6ba is described below

commit b29ab5c6baea8884a5b9c0702c1c93da6a830866
Author: Noah Verke <[email protected]>
AuthorDate: Thu Nov 17 19:21:14 2022 -0800

    [Hexagon] Add test to show scheduling of resnet50 with async dma pipe… 
(#13352)
    
    * [Hexagon] Add test to show scheduling of resnet50 with async dma 
pipelines using metaschedule
    
    * lint
---
 python/tvm/contrib/hexagon/meta_schedule.py        |  23 ++-
 python/tvm/tir/tensor_intrin/hexagon.py            | 178 ++++++++---------
 .../metaschedule_e2e/test_resnet50_int8.py         | 218 ++++++++++++++++++---
 3 files changed, 299 insertions(+), 120 deletions(-)

diff --git a/python/tvm/contrib/hexagon/meta_schedule.py 
b/python/tvm/contrib/hexagon/meta_schedule.py
index aaf3f8c7f8..dcc7d232d8 100644
--- a/python/tvm/contrib/hexagon/meta_schedule.py
+++ b/python/tvm/contrib/hexagon/meta_schedule.py
@@ -17,7 +17,14 @@
 """Meta schedule tuning utilities for Hexagon."""
 import os
 import tempfile
-from typing import Callable, List, Optional
+from typing import Callable, Dict, List, Optional
+import tvm
+
+from tvm.ir.module import IRModule
+from tvm.runtime import Module, NDArray
+from tvm.target import Target
+from tvm.driver import build as tvm_build
+from tvm.tir.transform import RemoveWeightLayoutRewriteBlock
 from tvm.contrib.popen_pool import PopenPoolExecutor
 from tvm.meta_schedule.utils import cpu_count, derived_object
 from tvm.meta_schedule.builder import LocalBuilder
@@ -121,14 +128,24 @@ def _worker_func(hexagon_launcher, evaluator_config, 
alloc_repeat, artifact_path
     return costs
 
 
-def get_hexagon_local_builder():
+def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None):
     """Return Hexagon-compatible Builder for meta schedule."""
 
     def export_func(mod):
         binary_path = export_module(mod, tempfile.mkdtemp())
         return str(binary_path)
 
-    return LocalBuilder(f_export=export_func)
+    def default_build_with_context(
+        mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]
+    ) -> Module:
+        with pass_context:
+            mod = 
RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod)
+            return tvm_build(mod, target=target)
+
+    if pass_context is not None:
+        return LocalBuilder(f_build=default_build_with_context, 
f_export=export_func)
+    else:
+        return LocalBuilder(f_export=export_func)
 
 
 def get_hexagon_rpc_runner(
diff --git a/python/tvm/tir/tensor_intrin/hexagon.py 
b/python/tvm/tir/tensor_intrin/hexagon.py
index 306c8cd2e1..49c12c3e9d 100644
--- a/python/tvm/tir/tensor_intrin/hexagon.py
+++ b/python/tvm/tir/tensor_intrin/hexagon.py
@@ -20,98 +20,100 @@ from tvm.script import tir as T
 from .. import TensorIntrin
 
 
[email protected]_func
-def dot_product_32x4_u8u8i32_desc(
-    A: T.Buffer((4,), "uint8", offset_factor=1),
-    B: T.Buffer((32, 4), "uint8", offset_factor=1),
-    C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0:32], A[0:4], B[0:32, 0:4])
-        T.writes(C[0:32])
-        for i in T.serial(0, 32):
-            for k in T.serial(0, 4):
-                with T.block("update"):
-                    vi, vk = T.axis.remap("SR", [i, k])
-                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
-
-
[email protected]_func
-def dot_product_32x4_u8u8i32_vrmpy(
-    A: T.Buffer((4,), "uint8", offset_factor=1),
-    B: T.Buffer((32, 4), "uint8", offset_factor=1),
-    C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0:32], A[0:4], B[0:32, 0:4])
-        T.writes(C[0:32])
-
-        A_u8x4 = A.vload([0], "uint8x4")
-        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
-
-        B_i8x128 = B.vload([0, 0], dtype="uint8x128")
-        B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
-
-        C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
-            T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
-            T.uint32(3),
-            C[T.ramp(T.int32(0), 1, 32)],
-            B_i32x32,
-            A_i32,
-            dtype="int32x32",
-        )
-
-
[email protected]_func
-def dot_product_32x4_u8i8i32_desc(
-    A: T.Buffer((4,), "uint8", offset_factor=1),
-    B: T.Buffer((32, 4), "int8", offset_factor=1),
-    C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0:32], A[0:4], B[0:32, 0:4])
-        T.writes(C[0:32])
-        for i in T.serial(0, 32):
-            for k in T.serial(0, 4):
-                with T.block("update"):
-                    vi, vk = T.axis.remap("SR", [i, k])
-                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
-
-
[email protected]_func
-def dot_product_32x4_u8i8i32_vrmpy(
-    A: T.Buffer((4,), "uint8", offset_factor=1),
-    B: T.Buffer((32, 4), "int8", offset_factor=1),
-    C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0:32], A[0:4], B[0:32, 0:4])
-        T.writes(C[0:32])
-
-        A_u8x4 = A.vload([0], "uint8x4")
-        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
-
-        B_i8x128 = B.vload([0, 0], dtype="int8x128")
-        B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
-
-        C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
-            T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
-            T.uint32(3),
-            C[T.ramp(T.int32(0), 1, 32)],
-            T.broadcast(A_i32, 32),
-            B_i32x32,
-            dtype="int32x32",
-        )
+def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
+    @T.prim_func
+    def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) 
-> None:
+        A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+        B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, 
scope=mem_scope)
+        C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+        with T.block("root"):
+            T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+            T.writes(C[0:32])
+            for i in T.serial(0, 32):
+                for k in T.serial(0, 4):
+                    with T.block("update"):
+                        vi, vk = T.axis.remap("SR", [i, k])
+                        C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, 
vk], "int32")
+
+    @T.prim_func
+    def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) 
-> None:
+        A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+        B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, 
scope=mem_scope)
+        C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+        with T.block("root"):
+            T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+            T.writes(C[0:32])
+
+            A_u8x4 = A.vload([0], "uint8x4")
+            A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+            B_i8x128 = B.vload([0, 0], dtype="uint8x128")
+            B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
+
+            C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
+                T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
+                T.uint32(3),
+                C[T.ramp(T.int32(0), 1, 32)],
+                B_i32x32,
+                A_i32,
+                dtype="int32x32",
+            )
+
+    return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy
+
+
+def generate_dot_product_32x4_u8i8i32(mem_scope="global"):
+    @T.prim_func
+    def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) 
-> None:
+        A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+        B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, 
scope=mem_scope)
+        C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+        with T.block("root"):
+            T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+            T.writes(C[0:32])
+            for i in T.serial(0, 32):
+                for k in T.serial(0, 4):
+                    with T.block("update"):
+                        vi, vk = T.axis.remap("SR", [i, k])
+                        C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, 
vk], "int32")
+
+    @T.prim_func
+    def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) 
-> None:
+        A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+        B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, 
scope=mem_scope)
+        C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+        with T.block("root"):
+            T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+            T.writes(C[0:32])
+
+            A_u8x4 = A.vload([0], "uint8x4")
+            A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+            B_i8x128 = B.vload([0, 0], dtype="int8x128")
+            B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
+
+            C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
+                
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
+                T.uint32(3),
+                C[T.ramp(T.int32(0), 1, 32)],
+                T.broadcast(A_i32, 32),
+                B_i32x32,
+                dtype="int32x32",
+            )
+
+    return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy
 
 
 VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy"
 
-TensorIntrin.register(
-    VRMPY_u8u8i32_INTRIN, dot_product_32x4_u8u8i32_desc, 
dot_product_32x4_u8u8i32_vrmpy
-)
+TensorIntrin.register(VRMPY_u8u8i32_INTRIN, 
*generate_dot_product_32x4_u8u8i32())
 
 VRMPY_u8i8i32_INTRIN = "dot_32x4_u8i8i32_vrmpy"
 
-TensorIntrin.register(
-    VRMPY_u8i8i32_INTRIN, dot_product_32x4_u8i8i32_desc, 
dot_product_32x4_u8i8i32_vrmpy
-)
+TensorIntrin.register(VRMPY_u8i8i32_INTRIN, 
*generate_dot_product_32x4_u8i8i32())
+
+VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy"
+TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, 
*generate_dot_product_32x4_u8u8i32("global.vtcm"))
+
+VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy"
+TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, 
*generate_dot_product_32x4_u8i8i32("global.vtcm"))
diff --git 
a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py 
b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index 91eb67bbf4..e15b0a4e7d 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -18,7 +18,8 @@
 
 import os
 import tempfile
-from typing import Optional
+from types import MappingProxyType
+from typing import Any, Mapping, Optional
 
 import numpy as np
 import pytest
@@ -34,7 +35,11 @@ from tvm.contrib.hexagon.meta_schedule import (
 from tvm.meta_schedule import postproc, schedule_rule
 from tvm.tir.schedule import BlockRV, Schedule
 from tvm.tir.schedule.analysis import has_block
-from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, 
VRMPY_u8u8i32_INTRIN
+from tvm.tir.tensor_intrin.hexagon import (
+    VRMPY_u8i8i32_INTRIN,
+    VRMPY_u8u8i32_INTRIN,
+    VRMPY_u8i8i32_VTCM_INTRIN,
+)
 
 from ..infrastructure import get_hexagon_target
 
@@ -133,7 +138,6 @@ def tune_vrmpy_auto_tensorize(mod, params, 
hexagon_launcher):
             # from 36 to 23, with negligible performance difference.
             module_equality="anchor-block",
         )
-
         return ms.relay_integration.compile_relay(
             database=database,
             mod=mod,
@@ -142,10 +146,13 @@ def tune_vrmpy_auto_tensorize(mod, params, 
hexagon_launcher):
         )
 
 
[email protected]("End-to-end tuning is skipped on CI.")
 @tvm.testing.requires_hexagon
 def test_resnet50(hexagon_launcher):
     """Test Resnet50."""
+
+    if tvm.testing.utils.IS_IN_CI:
+        pytest.skip("Skipping test since it takes too long in CI.")
+
     if not os.path.exists(MODEL_JSON):
         pytest.skip(msg="Run python export_models.py first.")
 
@@ -200,6 +207,44 @@ def test_resnet50(hexagon_launcher):
         print(debug_ex.profile(input_name=inp.copy()))
 
 
+def evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, 
inp, benchmark=False):
+    """Evaluate the Modules against llvm version."""
+    with hexagon_launcher.create_session() as session:
+        graph_mod = session.get_executor_from_factory(hexagon_lowered)
+        graph_mod.set_input(input_name, inp.copy())
+        graph_mod.run()
+        output = graph_mod.get_output(0).numpy()
+
+        llvm_graph_mod = 
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
+        llvm_graph_mod.set_input(input_name, inp.copy())
+        llvm_graph_mod.run()
+        ref_result = llvm_graph_mod.get_output(0).numpy()
+
+        if benchmark:
+            time_ms = graph_mod.benchmark(session.device, number=1, 
repeat=1).mean * 1e3
+            print("hexagon time elapsed: ", time_ms)
+            debug_ex = session.get_graph_debug_executor(
+                hexagon_lowered.get_graph_json(), hexagon_lowered.lib
+            )
+            print(debug_ex.profile(input_name=inp.copy()))
+
+        np.testing.assert_allclose(ref_result, output, atol=1e-4, rtol=1e-5)
+
+
+def load_model():
+    """Load renset50 model."""
+    if not os.path.exists(MODEL_JSON):
+        pytest.skip(msg="Run python export_models.py first.")
+
+    with open(MODEL_JSON, "r") as file:
+        mod = tvm.ir.load_json(file.read())
+
+    with open(MODEL_PARAMS, "rb") as file:
+        params = relay.load_param_dict(file.read())
+
+    return mod, params
+
+
 def _schedule_packed_8x8x32_conv2d():
     """Manually schedule a conv2d block, created from TE compute op via 
CreatePrimFunc,
     using 8x8x32 packed layout.
@@ -268,22 +313,39 @@ def _schedule_packed_8x8x32_conv2d():
     return schedule_fn
 
 
-def tune_packed_8x8x32_template(mod, params, hexagon_launcher):
+def tune_conv2d_template(
+    mod,
+    scheduler,
+    schedule_tag,
+    params,
+    hexagon_launcher,
+    pass_config: Mapping[str, Any] = MappingProxyType({}),
+):
     """Generate packed 8*8*32 template."""
 
-    def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: 
BlockRV):
-        _schedule_packed_8x8x32_conv2d()(sch, conv2d_block)
+    def schedule_rule_conv2d(sch: Schedule, conv2d_block: BlockRV):
+        scheduler()(sch, conv2d_block)
         return [sch]
 
-    register_func("meta_schedule.conv2d_NCHWc_int8.hexagon", 
schedule_rule_conv2d_packed_8x8x32)
+    register_func(
+        "meta_schedule.conv2d_NCHWc_int8.{}.hexagon".format(schedule_tag), 
schedule_rule_conv2d
+    )
 
     def schedule_conv2d_for_tune(sch: Schedule):
-        _schedule_packed_8x8x32_conv2d()(sch)
+        scheduler()(sch)
 
     # This line is necessary for link-params to take effect during
     # task extraction and relay.build(...).
     mod = mod.with_attr("executor", EXECUTOR)
 
+    pass_context = None
+    if len(pass_config.items()) > 0:
+        pass_context = (
+            tvm.transform.PassContext(opt_level=3, config=pass_config)
+            if pass_config is not None
+            else None
+        )
+
     with tempfile.TemporaryDirectory() as work_dir:
         database = ms.relay_integration.tune_relay(
             mod=mod,
@@ -294,8 +356,8 @@ def tune_packed_8x8x32_template(mod, params, 
hexagon_launcher):
             max_trials_per_task=1,
             num_trials_per_iter=1,
             strategy="replay-trace",
-            builder=get_hexagon_local_builder(),
-            runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
+            builder=get_hexagon_local_builder(pass_context),
+            runner=get_hexagon_rpc_runner(hexagon_launcher, number=1),
             # Apply MS auto scheduling rules for all blocks, but utilize
             # the custom block scheduling strategy registered above for
             # blocks annotated as 
`schedule_rule:meta_schedule.conv2d_NCHWc_int8`
@@ -318,33 +380,37 @@ def tune_packed_8x8x32_template(mod, params, 
hexagon_launcher):
             # are treated as distinct tuning tasks.
             module_equality="ignore-ndarray",
         )
+
+        # Add default options so that it still uses the base config.
+        pass_config["relay.backend.use_meta_schedule"] = True
+        pass_config["relay.backend.tir_converter"] = "default"
         return ms.relay_integration.compile_relay(
             database=database,
             mod=mod,
             target=TARGET_HEXAGON,
             params=params,
+            pass_config=pass_config,
         )
 
 
[email protected]("End-to-end tuning is skipped on CI.")
 @tvm.testing.requires_hexagon
 def test_packed_8x8x32_resnet50(hexagon_launcher):
     """Test packed 8*8*32 Resnet50"""
-    if not os.path.exists(MODEL_JSON):
-        pytest.skip(msg="Run python export_models.py first.")
 
-    with open(MODEL_JSON, "r") as file:
-        mod = tvm.ir.load_json(file.read())
+    if tvm.testing.utils.IS_IN_CI:
+        pytest.skip("Skipping test since it takes too long in CI.")
+
+    mod, params = load_model()
 
-    with open(MODEL_PARAMS, "rb") as file:
-        params = relay.load_param_dict(file.read())
     inp = np.random.randn(1, 3, 224, 224).astype("float32")
     input_name = "image"
 
     do_tune = True
 
     if do_tune:
-        hexagon_lowered = tune_packed_8x8x32_template(mod, params, 
hexagon_launcher)
+        hexagon_lowered = tune_conv2d_template(
+            mod, _schedule_packed_8x8x32_conv2d, "packed_8x8x32", params, 
hexagon_launcher
+        )
     else:
         with tvm.transform.PassContext(opt_level=3):
             hexagon_lowered = relay.build(
@@ -361,18 +427,112 @@ def test_packed_8x8x32_resnet50(hexagon_launcher):
             params=params,
         )
 
-    with hexagon_launcher.start_session() as session:
-        graph_mod = session.get_executor_from_factory(hexagon_lowered)
-        graph_mod.set_input(input_name, inp.copy())
-        graph_mod.run()
-        hexagon_output = graph_mod.get_output(0).numpy()
+    evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, 
inp)
 
-        llvm_graph_mod = 
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
-        llvm_graph_mod.set_input(input_name, inp.copy())
-        llvm_graph_mod.run()
-        ref_result = llvm_graph_mod.get_output(0).numpy()
 
-        np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4, 
rtol=1e-5)
+def _schedule_async_dma_conv2d():
+    """Manually schedule a conv2d block, created from TE compute op via 
CreatePrimFunc,
+    using 8x8x32 packed layout.
+    """
+
+    def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
+        if conv2d_block is None:
+            if has_block(sch, "conv2d_NCHWc_int8"):
+                conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+            else:
+                return False
+
+        assert "conv2d_NCHWc_int8" in 
sch.get(conv2d_block).annotations["schedule_rule"]
+
+        # Apply scheduling
+
+        post_blocks = sch.get_consumers(conv2d_block)
+        if len(post_blocks) > 0:
+            # Fuse all intermediate post ops into the last op.
+            # This is equivalent to the traverse_inline function used in TE 
schedules.
+            while True:
+                next_post_blocks = []
+                for post_block in post_blocks:
+                    next_consumers = sch.get_consumers(post_block)
+                    if len(next_consumers) > 0:
+                        sch.compute_inline(post_block)
+                    next_post_blocks += next_consumers
+                if len(next_post_blocks) == 0:
+                    assert len(post_blocks) == 1
+                    outer_block = post_blocks[0]
+                    break
+                post_blocks = next_post_blocks
+        else:
+            outer_block = conv2d_block
+
+        # Move the conv2d mma into the injective post mma compute block
+        if outer_block != conv2d_block:
+            loops = sch.get_loops(outer_block)
+            # Compute at the second loop for pipelining.
+            sch.compute_at(conv2d_block, loops[1], preserve_unit_loops=True)
+
+        # Add cache for input and output for copying data to vtcm.
+        input_a_cache = sch.cache_read(conv2d_block, 0, "global.vtcm")
+        sch.compute_at(input_a_cache, sch.get_loops(conv2d_block)[1])
+        sch.fuse(*sch.get_loops(input_a_cache)[2:])
+
+        input_b_cache = sch.cache_read(conv2d_block, 1, "global.vtcm")
+        sch.compute_at(input_b_cache, sch.get_loops(conv2d_block)[1])
+        sch.fuse(*sch.get_loops(input_b_cache)[2:])
+
+        output_cache_write = sch.cache_write(conv2d_block, 0, "global.vtcm")
+        sch.fuse(*sch.get_loops(output_cache_write)[2:])
+
+        conv2d_loops = sch.get_loops(block=conv2d_block)
+        o_c, k_h, k_w, x_0, x_1, i_c = conv2d_loops[-6:]
+        ic_o, ic_i = sch.split(loop=i_c, factors=[None, 4], 
preserve_unit_iters=True)
+        oc_o, oc_i = sch.split(loop=o_c, factors=[None, 32], 
preserve_unit_iters=True)
+        sch.reorder(oc_o, k_h, k_w, x_0, x_1, ic_o, oc_i, ic_i)
+        new_loops = sch.get_loops(block=conv2d_block)
+        sch.parallel(new_loops[4])
+        sch.unroll(new_loops[5])
+        # TODO(nverke): Add compute optimizations here.
+        sch.blockize(loop=oc_i)
+
+        sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN)
+
+        pipeline_loop = conv2d_loops[1]
+        sch.annotate(pipeline_loop, "software_pipeline_stage", [0, 0, 1, 2, 3])
+        sch.annotate(pipeline_loop, "software_pipeline_order", [0, 1, 2, 3, 4])
+        sch.annotate(pipeline_loop, "software_pipeline_async_stages", [0, 2])
+
+        return True
+
+    return schedule_fn
+
+
[email protected]_hexagon
+def test_async_dma_resnet50(hexagon_launcher):
+    """Test async dma Resnet50"""
+
+    if tvm.testing.utils.IS_IN_CI:
+        pytest.skip("Skipping test since it takes too long in CI.")
+
+    mod, params = load_model()
+
+    inp = np.random.randn(1, 3, 224, 224).astype("float32")
+    input_name = "image"
+
+    pass_config = {
+        "tir.use_async_copy": 1,
+        "tir.merge_async_commit_queue_scope": False,
+        "relay.backend.use_meta_schedule": True,
+        "relay.backend.tir_converter": "default",
+    }
+
+    hexagon_lowered = tune_conv2d_template(
+        mod, _schedule_async_dma_conv2d, "async_dma", params, 
hexagon_launcher, pass_config
+    )
+    with tvm.transform.PassContext(opt_level=3):
+        llvm_lowered = tvm.relay.build(
+            mod, tvm.target.Target(TARGET_LLVM, host=TARGET_LLVM), 
params=params
+        )
+    evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, 
inp, True)
 
 
 if __name__ == "__main__":

Reply via email to