This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b29ab5c6ba [Hexagon] Add test to show scheduling of resnet50 with
async dma pipe… (#13352)
b29ab5c6ba is described below
commit b29ab5c6baea8884a5b9c0702c1c93da6a830866
Author: Noah Verke <[email protected]>
AuthorDate: Thu Nov 17 19:21:14 2022 -0800
[Hexagon] Add test to show scheduling of resnet50 with async dma pipe…
(#13352)
* [Hexagon] Add test to show scheduling of resnet50 with async dma
pipelines using metaschedule
* lint
---
python/tvm/contrib/hexagon/meta_schedule.py | 23 ++-
python/tvm/tir/tensor_intrin/hexagon.py | 178 ++++++++---------
.../metaschedule_e2e/test_resnet50_int8.py | 218 ++++++++++++++++++---
3 files changed, 299 insertions(+), 120 deletions(-)
diff --git a/python/tvm/contrib/hexagon/meta_schedule.py
b/python/tvm/contrib/hexagon/meta_schedule.py
index aaf3f8c7f8..dcc7d232d8 100644
--- a/python/tvm/contrib/hexagon/meta_schedule.py
+++ b/python/tvm/contrib/hexagon/meta_schedule.py
@@ -17,7 +17,14 @@
"""Meta schedule tuning utilities for Hexagon."""
import os
import tempfile
-from typing import Callable, List, Optional
+from typing import Callable, Dict, List, Optional
+import tvm
+
+from tvm.ir.module import IRModule
+from tvm.runtime import Module, NDArray
+from tvm.target import Target
+from tvm.driver import build as tvm_build
+from tvm.tir.transform import RemoveWeightLayoutRewriteBlock
from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.meta_schedule.utils import cpu_count, derived_object
from tvm.meta_schedule.builder import LocalBuilder
@@ -121,14 +128,24 @@ def _worker_func(hexagon_launcher, evaluator_config,
alloc_repeat, artifact_path
return costs
-def get_hexagon_local_builder():
+def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None):
"""Return Hexagon-compatible Builder for meta schedule."""
def export_func(mod):
binary_path = export_module(mod, tempfile.mkdtemp())
return str(binary_path)
- return LocalBuilder(f_export=export_func)
+ def default_build_with_context(
+ mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]
+ ) -> Module:
+ with pass_context:
+ mod =
RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod)
+ return tvm_build(mod, target=target)
+
+ if pass_context is not None:
+ return LocalBuilder(f_build=default_build_with_context,
f_export=export_func)
+ else:
+ return LocalBuilder(f_export=export_func)
def get_hexagon_rpc_runner(
diff --git a/python/tvm/tir/tensor_intrin/hexagon.py
b/python/tvm/tir/tensor_intrin/hexagon.py
index 306c8cd2e1..49c12c3e9d 100644
--- a/python/tvm/tir/tensor_intrin/hexagon.py
+++ b/python/tvm/tir/tensor_intrin/hexagon.py
@@ -20,98 +20,100 @@ from tvm.script import tir as T
from .. import TensorIntrin
[email protected]_func
-def dot_product_32x4_u8u8i32_desc(
- A: T.Buffer((4,), "uint8", offset_factor=1),
- B: T.Buffer((32, 4), "uint8", offset_factor=1),
- C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
- with T.block("root"):
- T.reads(C[0:32], A[0:4], B[0:32, 0:4])
- T.writes(C[0:32])
- for i in T.serial(0, 32):
- for k in T.serial(0, 4):
- with T.block("update"):
- vi, vk = T.axis.remap("SR", [i, k])
- C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
-
-
[email protected]_func
-def dot_product_32x4_u8u8i32_vrmpy(
- A: T.Buffer((4,), "uint8", offset_factor=1),
- B: T.Buffer((32, 4), "uint8", offset_factor=1),
- C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
- with T.block("root"):
- T.reads(C[0:32], A[0:4], B[0:32, 0:4])
- T.writes(C[0:32])
-
- A_u8x4 = A.vload([0], "uint8x4")
- A_i32 = T.reinterpret(A_u8x4, dtype="int32")
-
- B_i8x128 = B.vload([0, 0], dtype="uint8x128")
- B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
-
- C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
- T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
- T.uint32(3),
- C[T.ramp(T.int32(0), 1, 32)],
- B_i32x32,
- A_i32,
- dtype="int32x32",
- )
-
-
[email protected]_func
-def dot_product_32x4_u8i8i32_desc(
- A: T.Buffer((4,), "uint8", offset_factor=1),
- B: T.Buffer((32, 4), "int8", offset_factor=1),
- C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
- with T.block("root"):
- T.reads(C[0:32], A[0:4], B[0:32, 0:4])
- T.writes(C[0:32])
- for i in T.serial(0, 32):
- for k in T.serial(0, 4):
- with T.block("update"):
- vi, vk = T.axis.remap("SR", [i, k])
- C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
-
-
[email protected]_func
-def dot_product_32x4_u8i8i32_vrmpy(
- A: T.Buffer((4,), "uint8", offset_factor=1),
- B: T.Buffer((32, 4), "int8", offset_factor=1),
- C: T.Buffer((32,), "int32", offset_factor=1),
-) -> None:
- with T.block("root"):
- T.reads(C[0:32], A[0:4], B[0:32, 0:4])
- T.writes(C[0:32])
-
- A_u8x4 = A.vload([0], "uint8x4")
- A_i32 = T.reinterpret(A_u8x4, dtype="int32")
-
- B_i8x128 = B.vload([0, 0], dtype="int8x128")
- B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
-
- C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
- T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
- T.uint32(3),
- C[T.ramp(T.int32(0), 1, 32)],
- T.broadcast(A_i32, 32),
- B_i32x32,
- dtype="int32x32",
- )
+def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
+ @T.prim_func
+ def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle)
-> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+ B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1,
scope=mem_scope)
+ C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+ with T.block("root"):
+ T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+ T.writes(C[0:32])
+ for i in T.serial(0, 32):
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi,
vk], "int32")
+
+ @T.prim_func
+ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle)
-> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+ B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1,
scope=mem_scope)
+ C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+ with T.block("root"):
+ T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+ T.writes(C[0:32])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x128 = B.vload([0, 0], dtype="uint8x128")
+ B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
+
+ C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
+ T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
+ T.uint32(3),
+ C[T.ramp(T.int32(0), 1, 32)],
+ B_i32x32,
+ A_i32,
+ dtype="int32x32",
+ )
+
+ return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy
+
+
+def generate_dot_product_32x4_u8i8i32(mem_scope="global"):
+ @T.prim_func
+ def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle)
-> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+ B = T.match_buffer(b, (32, 4), "int8", offset_factor=1,
scope=mem_scope)
+ C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+ with T.block("root"):
+ T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+ T.writes(C[0:32])
+ for i in T.serial(0, 32):
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi,
vk], "int32")
+
+ @T.prim_func
+ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle)
-> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
+ B = T.match_buffer(b, (32, 4), "int8", offset_factor=1,
scope=mem_scope)
+ C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
+ with T.block("root"):
+ T.reads(C[0:32], A[0:4], B[0:32, 0:4])
+ T.writes(C[0:32])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x128 = B.vload([0, 0], dtype="int8x128")
+ B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
+
+ C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
+
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
+ T.uint32(3),
+ C[T.ramp(T.int32(0), 1, 32)],
+ T.broadcast(A_i32, 32),
+ B_i32x32,
+ dtype="int32x32",
+ )
+
+ return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy
VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy"
-TensorIntrin.register(
- VRMPY_u8u8i32_INTRIN, dot_product_32x4_u8u8i32_desc,
dot_product_32x4_u8u8i32_vrmpy
-)
+TensorIntrin.register(VRMPY_u8u8i32_INTRIN,
*generate_dot_product_32x4_u8u8i32())
VRMPY_u8i8i32_INTRIN = "dot_32x4_u8i8i32_vrmpy"
-TensorIntrin.register(
- VRMPY_u8i8i32_INTRIN, dot_product_32x4_u8i8i32_desc,
dot_product_32x4_u8i8i32_vrmpy
-)
+TensorIntrin.register(VRMPY_u8i8i32_INTRIN,
*generate_dot_product_32x4_u8i8i32())
+
+VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy"
+TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN,
*generate_dot_product_32x4_u8u8i32("global.vtcm"))
+
+VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy"
+TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN,
*generate_dot_product_32x4_u8i8i32("global.vtcm"))
diff --git
a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index 91eb67bbf4..e15b0a4e7d 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -18,7 +18,8 @@
import os
import tempfile
-from typing import Optional
+from types import MappingProxyType
+from typing import Any, Mapping, Optional
import numpy as np
import pytest
@@ -34,7 +35,11 @@ from tvm.contrib.hexagon.meta_schedule import (
from tvm.meta_schedule import postproc, schedule_rule
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.analysis import has_block
-from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN,
VRMPY_u8u8i32_INTRIN
+from tvm.tir.tensor_intrin.hexagon import (
+ VRMPY_u8i8i32_INTRIN,
+ VRMPY_u8u8i32_INTRIN,
+ VRMPY_u8i8i32_VTCM_INTRIN,
+)
from ..infrastructure import get_hexagon_target
@@ -133,7 +138,6 @@ def tune_vrmpy_auto_tensorize(mod, params,
hexagon_launcher):
# from 36 to 23, with negligible performance difference.
module_equality="anchor-block",
)
-
return ms.relay_integration.compile_relay(
database=database,
mod=mod,
@@ -142,10 +146,13 @@ def tune_vrmpy_auto_tensorize(mod, params,
hexagon_launcher):
)
[email protected]("End-to-end tuning is skipped on CI.")
@tvm.testing.requires_hexagon
def test_resnet50(hexagon_launcher):
"""Test Resnet50."""
+
+ if tvm.testing.utils.IS_IN_CI:
+ pytest.skip("Skipping test since it takes too long in CI.")
+
if not os.path.exists(MODEL_JSON):
pytest.skip(msg="Run python export_models.py first.")
@@ -200,6 +207,44 @@ def test_resnet50(hexagon_launcher):
print(debug_ex.profile(input_name=inp.copy()))
+def evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name,
inp, benchmark=False):
+ """Evaluate the Modules against llvm version."""
+ with hexagon_launcher.create_session() as session:
+ graph_mod = session.get_executor_from_factory(hexagon_lowered)
+ graph_mod.set_input(input_name, inp.copy())
+ graph_mod.run()
+ output = graph_mod.get_output(0).numpy()
+
+ llvm_graph_mod =
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
+ llvm_graph_mod.set_input(input_name, inp.copy())
+ llvm_graph_mod.run()
+ ref_result = llvm_graph_mod.get_output(0).numpy()
+
+ if benchmark:
+ time_ms = graph_mod.benchmark(session.device, number=1,
repeat=1).mean * 1e3
+ print("hexagon time elapsed: ", time_ms)
+ debug_ex = session.get_graph_debug_executor(
+ hexagon_lowered.get_graph_json(), hexagon_lowered.lib
+ )
+ print(debug_ex.profile(input_name=inp.copy()))
+
+ np.testing.assert_allclose(ref_result, output, atol=1e-4, rtol=1e-5)
+
+
+def load_model():
+ """Load renset50 model."""
+ if not os.path.exists(MODEL_JSON):
+ pytest.skip(msg="Run python export_models.py first.")
+
+ with open(MODEL_JSON, "r") as file:
+ mod = tvm.ir.load_json(file.read())
+
+ with open(MODEL_PARAMS, "rb") as file:
+ params = relay.load_param_dict(file.read())
+
+ return mod, params
+
+
def _schedule_packed_8x8x32_conv2d():
"""Manually schedule a conv2d block, created from TE compute op via
CreatePrimFunc,
using 8x8x32 packed layout.
@@ -268,22 +313,39 @@ def _schedule_packed_8x8x32_conv2d():
return schedule_fn
-def tune_packed_8x8x32_template(mod, params, hexagon_launcher):
+def tune_conv2d_template(
+ mod,
+ scheduler,
+ schedule_tag,
+ params,
+ hexagon_launcher,
+ pass_config: Mapping[str, Any] = MappingProxyType({}),
+):
"""Generate packed 8*8*32 template."""
- def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block:
BlockRV):
- _schedule_packed_8x8x32_conv2d()(sch, conv2d_block)
+ def schedule_rule_conv2d(sch: Schedule, conv2d_block: BlockRV):
+ scheduler()(sch, conv2d_block)
return [sch]
- register_func("meta_schedule.conv2d_NCHWc_int8.hexagon",
schedule_rule_conv2d_packed_8x8x32)
+ register_func(
+ "meta_schedule.conv2d_NCHWc_int8.{}.hexagon".format(schedule_tag),
schedule_rule_conv2d
+ )
def schedule_conv2d_for_tune(sch: Schedule):
- _schedule_packed_8x8x32_conv2d()(sch)
+ scheduler()(sch)
# This line is necessary for link-params to take effect during
# task extraction and relay.build(...).
mod = mod.with_attr("executor", EXECUTOR)
+ pass_context = None
+ if len(pass_config.items()) > 0:
+ pass_context = (
+ tvm.transform.PassContext(opt_level=3, config=pass_config)
+ if pass_config is not None
+ else None
+ )
+
with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
@@ -294,8 +356,8 @@ def tune_packed_8x8x32_template(mod, params,
hexagon_launcher):
max_trials_per_task=1,
num_trials_per_iter=1,
strategy="replay-trace",
- builder=get_hexagon_local_builder(),
- runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
+ builder=get_hexagon_local_builder(pass_context),
+ runner=get_hexagon_rpc_runner(hexagon_launcher, number=1),
# Apply MS auto scheduling rules for all blocks, but utilize
# the custom block scheduling strategy registered above for
# blocks annotated as
`schedule_rule:meta_schedule.conv2d_NCHWc_int8`
@@ -318,33 +380,37 @@ def tune_packed_8x8x32_template(mod, params,
hexagon_launcher):
# are treated as distinct tuning tasks.
module_equality="ignore-ndarray",
)
+
+ # Add default options so that it still uses the base config.
+ pass_config["relay.backend.use_meta_schedule"] = True
+ pass_config["relay.backend.tir_converter"] = "default"
return ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=TARGET_HEXAGON,
params=params,
+ pass_config=pass_config,
)
[email protected]("End-to-end tuning is skipped on CI.")
@tvm.testing.requires_hexagon
def test_packed_8x8x32_resnet50(hexagon_launcher):
"""Test packed 8*8*32 Resnet50"""
- if not os.path.exists(MODEL_JSON):
- pytest.skip(msg="Run python export_models.py first.")
- with open(MODEL_JSON, "r") as file:
- mod = tvm.ir.load_json(file.read())
+ if tvm.testing.utils.IS_IN_CI:
+ pytest.skip("Skipping test since it takes too long in CI.")
+
+ mod, params = load_model()
- with open(MODEL_PARAMS, "rb") as file:
- params = relay.load_param_dict(file.read())
inp = np.random.randn(1, 3, 224, 224).astype("float32")
input_name = "image"
do_tune = True
if do_tune:
- hexagon_lowered = tune_packed_8x8x32_template(mod, params,
hexagon_launcher)
+ hexagon_lowered = tune_conv2d_template(
+ mod, _schedule_packed_8x8x32_conv2d, "packed_8x8x32", params,
hexagon_launcher
+ )
else:
with tvm.transform.PassContext(opt_level=3):
hexagon_lowered = relay.build(
@@ -361,18 +427,112 @@ def test_packed_8x8x32_resnet50(hexagon_launcher):
params=params,
)
- with hexagon_launcher.start_session() as session:
- graph_mod = session.get_executor_from_factory(hexagon_lowered)
- graph_mod.set_input(input_name, inp.copy())
- graph_mod.run()
- hexagon_output = graph_mod.get_output(0).numpy()
+ evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name,
inp)
- llvm_graph_mod =
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
- llvm_graph_mod.set_input(input_name, inp.copy())
- llvm_graph_mod.run()
- ref_result = llvm_graph_mod.get_output(0).numpy()
- np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4,
rtol=1e-5)
+def _schedule_async_dma_conv2d():
+ """Manually schedule a conv2d block, created from TE compute op via
CreatePrimFunc,
+ using 8x8x32 packed layout.
+ """
+
+ def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
+ if conv2d_block is None:
+ if has_block(sch, "conv2d_NCHWc_int8"):
+ conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+ else:
+ return False
+
+ assert "conv2d_NCHWc_int8" in
sch.get(conv2d_block).annotations["schedule_rule"]
+
+ # Apply scheduling
+
+ post_blocks = sch.get_consumers(conv2d_block)
+ if len(post_blocks) > 0:
+ # Fuse all intermediate post ops into the last op.
+ # This is equivalent to the traverse_inline function used in TE
schedules.
+ while True:
+ next_post_blocks = []
+ for post_block in post_blocks:
+ next_consumers = sch.get_consumers(post_block)
+ if len(next_consumers) > 0:
+ sch.compute_inline(post_block)
+ next_post_blocks += next_consumers
+ if len(next_post_blocks) == 0:
+ assert len(post_blocks) == 1
+ outer_block = post_blocks[0]
+ break
+ post_blocks = next_post_blocks
+ else:
+ outer_block = conv2d_block
+
+ # Move the conv2d mma into the injective post mma compute block
+ if outer_block != conv2d_block:
+ loops = sch.get_loops(outer_block)
+ # Compute at the second loop for pipelining.
+ sch.compute_at(conv2d_block, loops[1], preserve_unit_loops=True)
+
+ # Add cache for input and output for copying data to vtcm.
+ input_a_cache = sch.cache_read(conv2d_block, 0, "global.vtcm")
+ sch.compute_at(input_a_cache, sch.get_loops(conv2d_block)[1])
+ sch.fuse(*sch.get_loops(input_a_cache)[2:])
+
+ input_b_cache = sch.cache_read(conv2d_block, 1, "global.vtcm")
+ sch.compute_at(input_b_cache, sch.get_loops(conv2d_block)[1])
+ sch.fuse(*sch.get_loops(input_b_cache)[2:])
+
+ output_cache_write = sch.cache_write(conv2d_block, 0, "global.vtcm")
+ sch.fuse(*sch.get_loops(output_cache_write)[2:])
+
+ conv2d_loops = sch.get_loops(block=conv2d_block)
+ o_c, k_h, k_w, x_0, x_1, i_c = conv2d_loops[-6:]
+ ic_o, ic_i = sch.split(loop=i_c, factors=[None, 4],
preserve_unit_iters=True)
+ oc_o, oc_i = sch.split(loop=o_c, factors=[None, 32],
preserve_unit_iters=True)
+ sch.reorder(oc_o, k_h, k_w, x_0, x_1, ic_o, oc_i, ic_i)
+ new_loops = sch.get_loops(block=conv2d_block)
+ sch.parallel(new_loops[4])
+ sch.unroll(new_loops[5])
+ # TODO(nverke): Add compute optimizations here.
+ sch.blockize(loop=oc_i)
+
+ sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN)
+
+ pipeline_loop = conv2d_loops[1]
+ sch.annotate(pipeline_loop, "software_pipeline_stage", [0, 0, 1, 2, 3])
+ sch.annotate(pipeline_loop, "software_pipeline_order", [0, 1, 2, 3, 4])
+ sch.annotate(pipeline_loop, "software_pipeline_async_stages", [0, 2])
+
+ return True
+
+ return schedule_fn
+
+
[email protected]_hexagon
+def test_async_dma_resnet50(hexagon_launcher):
+ """Test async dma Resnet50"""
+
+ if tvm.testing.utils.IS_IN_CI:
+ pytest.skip("Skipping test since it takes too long in CI.")
+
+ mod, params = load_model()
+
+ inp = np.random.randn(1, 3, 224, 224).astype("float32")
+ input_name = "image"
+
+ pass_config = {
+ "tir.use_async_copy": 1,
+ "tir.merge_async_commit_queue_scope": False,
+ "relay.backend.use_meta_schedule": True,
+ "relay.backend.tir_converter": "default",
+ }
+
+ hexagon_lowered = tune_conv2d_template(
+ mod, _schedule_async_dma_conv2d, "async_dma", params,
hexagon_launcher, pass_config
+ )
+ with tvm.transform.PassContext(opt_level=3):
+ llvm_lowered = tvm.relay.build(
+ mod, tvm.target.Target(TARGET_LLVM, host=TARGET_LLVM),
params=params
+ )
+ evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name,
inp, True)
if __name__ == "__main__":