This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aa67a6a01c [Hexagon] Add USMP tests (#11279)
aa67a6a01c is described below

commit aa67a6a01cdc241e63816dd0621474531d3725f5
Author: Mehrdad Hessar <[email protected]>
AuthorDate: Fri May 13 15:38:20 2022 -0700

    [Hexagon] Add USMP tests (#11279)
    
    * Add USMP tests
    
    * Address Chris comments
    
    * Address Chris comment on assert
    
    * trigger
---
 python/tvm/testing/usmp.py                         |  39 +++++++
 tests/python/contrib/test_hexagon/conftest.py      |  15 +--
 tests/python/contrib/test_hexagon/test_launcher.py |  13 +--
 tests/python/contrib/test_hexagon/test_models.py   |  26 ++---
 .../contrib/test_hexagon/test_thread_pool.py       |   5 +-
 tests/python/contrib/test_hexagon/test_usmp.py     | 112 +++++++++++++++++++++
 .../contrib/test_hexagon/topi/test_batch_matmul.py |   5 +-
 .../test_hexagon/topi/test_cache_read_write.py     |   7 +-
 .../contrib/test_hexagon/topi/test_conv2d_nchw.py  |   3 +-
 .../contrib/test_hexagon/topi/test_conv2d_nhwc.py  |   3 +-
 .../test_hexagon/topi/test_conv2d_transpose.py     |   3 +-
 .../python/contrib/test_hexagon/topi/test_dense.py |  10 +-
 .../test_hexagon/topi/test_depthwise_conv2d.py     |   3 +-
 .../contrib/test_hexagon/topi/test_pooling.py      |   3 +-
 .../contrib/test_hexagon/topi/test_reduce.py       |   5 +-
 .../contrib/test_hexagon/topi/test_softmax.py      |   3 +-
 tests/python/relay/aot/test_crt_aot_usmp.py        |  33 +++---
 17 files changed, 229 insertions(+), 59 deletions(-)

diff --git a/python/tvm/testing/usmp.py b/python/tvm/testing/usmp.py
new file mode 100644
index 0000000000..c35ac255c3
--- /dev/null
+++ b/python/tvm/testing/usmp.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" This file contains USMP tests harnesses."""
+
+import tvm
+
+
+def is_tvm_backendallocworkspace_calls(mod: tvm.runtime.module) -> bool:
+    """TVMBackendAllocWorkspace call check.
+
+    This checker checks whether any c-source produced has 
TVMBackendAllocWorkspace calls.
+    If USMP is invoked, none of them should have TVMBAW calls
+    """
+    dso_modules = mod._collect_dso_modules()
+    for dso_mod in dso_modules:
+        if dso_mod.type_key not in ["c", "llvm"]:
+            assert (
+                False
+            ), 'Current AoT codegen flow should only produce type "c" or 
"llvm" runtime modules'
+
+        source = dso_mod.get_source()
+        if source.count("TVMBackendAllocWorkspace") != 0:
+            return True
+
+    return False
diff --git a/tests/python/contrib/test_hexagon/conftest.py 
b/tests/python/contrib/test_hexagon/conftest.py
index e09329b76b..f76181e06d 100644
--- a/tests/python/contrib/test_hexagon/conftest.py
+++ b/tests/python/contrib/test_hexagon/conftest.py
@@ -21,13 +21,14 @@
 import os
 import random
 import socket
-from typing import Optional
+from typing import Optional, Union
 
 import pytest
 
 import tvm
 import tvm.rpc.tracker
-from tvm.contrib.hexagon.build import HexagonLauncher
+from tvm.contrib.hexagon.build import HexagonLauncher, HexagonLauncherRPC
+from tvm.contrib.hexagon.session import Session
 
 HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
 TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
@@ -84,7 +85,7 @@ listen_port_max = 9000  # Below the search range end 
(port_end=9199) of RPC serv
 previous_port = None
 
 
-def get_free_port():
+def get_free_port() -> int:
 
     global previous_port
     if previous_port is None:
@@ -100,7 +101,7 @@ def get_free_port():
 
 
 @pytest.fixture(scope="session")
-def _tracker_info() -> (str, int):
+def _tracker_info() -> Union[str, int]:
     env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="")
     env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="")
 
@@ -156,7 +157,9 @@ def adb_server_socket() -> str:
 
 
 @tvm.testing.fixture
-def hexagon_launcher(request, android_serial_number, rpc_server_port, 
adb_server_socket):
+def hexagon_launcher(
+    request, android_serial_number, rpc_server_port, adb_server_socket
+) -> HexagonLauncherRPC:
     if android_serial_number is None:
         yield None
     else:
@@ -181,7 +184,7 @@ def hexagon_launcher(request, android_serial_number, 
rpc_server_port, adb_server
 
 
 @tvm.testing.fixture
-def hexagon_session(hexagon_launcher):
+def hexagon_session(hexagon_launcher) -> Session:
     if hexagon_launcher is None:
         yield None
     else:
diff --git a/tests/python/contrib/test_hexagon/test_launcher.py 
b/tests/python/contrib/test_hexagon/test_launcher.py
index 861ad4f15b..7dadc8f2f4 100644
--- a/tests/python/contrib/test_hexagon/test_launcher.py
+++ b/tests/python/contrib/test_hexagon/test_launcher.py
@@ -23,12 +23,13 @@ import tvm.testing
 from tvm import te
 from tvm import relay
 from tvm.relay.backend import Executor, Runtime
+from tvm.contrib.hexagon.session import Session
 
 from .conftest import requires_hexagon_toolchain
 
 
 @requires_hexagon_toolchain
-def test_add(hexagon_session):
+def test_add(hexagon_session: Session):
     dtype = "int8"
     A = tvm.te.placeholder((2,), dtype=dtype)
     B = tvm.te.placeholder((1,), dtype=dtype)
@@ -53,7 +54,7 @@ def test_add(hexagon_session):
 
 
 @requires_hexagon_toolchain
-def test_add_vtcm(hexagon_session):
+def test_add_vtcm(hexagon_session: Session):
     dtype = "int8"
     A = tvm.te.placeholder((2,), dtype=dtype)
     B = tvm.te.placeholder((1,), dtype=dtype)
@@ -122,7 +123,7 @@ class TestMatMul:
 
 
 @requires_hexagon_toolchain
-def test_graph_executor(hexagon_session):
+def test_graph_executor(hexagon_session: Session):
     dtype = "float32"
     data = relay.var("data", relay.TensorType((1, 64, 64, 3), dtype))
     weight = relay.var("weight", relay.TensorType((5, 5, 3, 8), dtype))
@@ -178,7 +179,7 @@ def test_graph_executor(hexagon_session):
 
 
 @requires_hexagon_toolchain
-def test_graph_executor_multiple_conv2d(hexagon_session):
+def test_graph_executor_multiple_conv2d(hexagon_session: Session):
     dtype = "float32"
     input_shape = (1, 8, 8, 3)
     w1_shape = (5, 5, 3, 1)
@@ -255,7 +256,7 @@ def test_graph_executor_multiple_conv2d(hexagon_session):
 
 
 @requires_hexagon_toolchain
-def test_aot_executor(hexagon_session, aot_host_target, aot_target):
+def test_aot_executor(hexagon_session: Session, aot_host_target, aot_target):
     dtype = "float32"
     input_shape = (1, 128, 128, 3)
     w_shape = (5, 5, 3, 8)
@@ -314,7 +315,7 @@ def test_aot_executor(hexagon_session, aot_host_target, 
aot_target):
 
 
 @requires_hexagon_toolchain
-def test_aot_executor_multiple_conv2d(hexagon_session, aot_host_target, 
aot_target):
+def test_aot_executor_multiple_conv2d(hexagon_session: Session, 
aot_host_target, aot_target):
     dtype = "float32"
     input_shape = (1, 8, 8, 3)
     w1_shape = (5, 5, 3, 1)
diff --git a/tests/python/contrib/test_hexagon/test_models.py 
b/tests/python/contrib/test_hexagon/test_models.py
index 0ce66a455e..649cc5b3f4 100644
--- a/tests/python/contrib/test_hexagon/test_models.py
+++ b/tests/python/contrib/test_hexagon/test_models.py
@@ -15,20 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
 import sys
 import pytest
 import numpy as np
 
 import tvm.testing
-from tvm import te
 from tvm import relay
 from tvm.relay.backend import Executor, Runtime
+from tvm.contrib.hexagon.session import Session
 
 from .conftest import requires_hexagon_toolchain
 
-MOBILENET_MODEL = ""
-
 
 def get_mobilenet():
     """Download and import mobilenet model with ONNX"""
@@ -42,7 +39,7 @@ def get_mobilenet():
 
 
 @requires_hexagon_toolchain
-def test_mobilenet(hexagon_session):
+def test_mobilenet(hexagon_session: Session):
     dtype = "float32"
     onnx_model = get_mobilenet()
 
@@ -88,8 +85,11 @@ def test_mobilenet(hexagon_session):
     tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, 
atol=1e-5)
 
 
+enable_usmp = tvm.testing.parameter(False, True)
+
+
 @requires_hexagon_toolchain
-def test_mobilenet_aot(hexagon_session, aot_host_target, aot_target):
+def test_mobilenet_aot(hexagon_session: Session, aot_host_target, aot_target, 
enable_usmp):
     if hexagon_session._launcher._serial_number == "simulator":
         pytest.skip(msg="Skip on simulator due to long runtime.")
 
@@ -104,7 +104,8 @@ def test_mobilenet_aot(hexagon_session, aot_host_target, 
aot_target):
     inputs = {input_name: data_in}
 
     target_llvm = tvm.target.Target("llvm")
-    with tvm.transform.PassContext(opt_level=3):
+    config = {"tir.usmp.enable": enable_usmp}
+    with tvm.transform.PassContext(opt_level=3, config=config):
         hexagon_lowered = tvm.relay.build(
             relay_mod,
             tvm.target.Target(aot_target, host=aot_host_target),
@@ -113,6 +114,12 @@ def test_mobilenet_aot(hexagon_session, aot_host_target, 
aot_target):
             params=params,
         )
 
+    aot_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
+    aot_mod.set_input(**inputs)
+    aot_mod.run()
+    hexagon_output = aot_mod.get_output(0).numpy()
+
+    with tvm.transform.PassContext(opt_level=3):
         llvm_lowered = tvm.relay.build(
             relay_mod,
             tvm.target.Target(target_llvm, host=target_llvm),
@@ -121,11 +128,6 @@ def test_mobilenet_aot(hexagon_session, aot_host_target, 
aot_target):
             params=params,
         )
 
-    aot_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
-    aot_mod.set_input(**inputs)
-    aot_mod.run()
-    hexagon_output = aot_mod.get_output(0).numpy()
-
     llvm_graph_mod = 
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
     llvm_graph_mod.set_input(**inputs)
     llvm_graph_mod.run()
diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py 
b/tests/python/contrib/test_hexagon/test_thread_pool.py
index a054049146..8a35bff7e7 100644
--- a/tests/python/contrib/test_hexagon/test_thread_pool.py
+++ b/tests/python/contrib/test_hexagon/test_thread_pool.py
@@ -20,6 +20,7 @@ import pytest
 
 import tvm
 import tvm.contrib.hexagon
+from tvm.contrib.hexagon.session import Session
 import tvm.script
 import tvm.testing
 from tvm import te
@@ -53,7 +54,7 @@ class ElemwiseSumIRModule:
                 C[vi] = A[vi] + B[vi]
 
 
-def generate_add_test_data(hexagon_session, n=128 * 1024):
+def generate_add_test_data(hexagon_session: Session, n=128 * 1024):
     a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), 
hexagon_session.device)
     b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), 
hexagon_session.device)
     c = tvm.nd.array(np.zeros(n, dtype="float32"), hexagon_session.device)
@@ -85,7 +86,7 @@ def test_speedup(hexagon_session, capsys):
 
 
 @requires_hexagon_toolchain
-def test_elemwise_sum_parallel(hexagon_session):
+def test_elemwise_sum_parallel(hexagon_session: Session):
     if hexagon_session is None:
         pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not 
set.")
 
diff --git a/tests/python/contrib/test_hexagon/test_usmp.py 
b/tests/python/contrib/test_hexagon/test_usmp.py
new file mode 100644
index 0000000000..116ecb4154
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_usmp.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import pytest
+import numpy as np
+
+import tvm.testing
+from tvm import te
+from tvm import relay
+from tvm.relay.backend import Executor, Runtime
+from tvm.contrib.hexagon.session import Session
+from tvm.testing.usmp import is_tvm_backendallocworkspace_calls
+
+from .conftest import requires_hexagon_toolchain
+
+usmp_enabled = tvm.testing.parameter(False, True)
+
+
+@requires_hexagon_toolchain
+def test_conv2d(hexagon_session: Session, aot_host_target, aot_target, 
usmp_enabled):
+    dtype = "float32"
+    input_shape = (1, 8, 8, 3)
+    w1_shape = (5, 5, 3, 1)
+    w2_shape = (5, 5, 1, 3)
+    data = relay.var("data", relay.TensorType(input_shape, dtype))
+    weight1 = relay.var("weight1", relay.TensorType(w1_shape, dtype))
+    weight2 = relay.var("weight2", relay.TensorType(w2_shape, dtype))
+    y1 = relay.nn.conv2d(
+        data,
+        weight1,
+        padding=(2, 2),
+        kernel_size=(5, 5),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="float32",
+    )
+    y2 = relay.nn.conv2d(
+        y1,
+        weight2,
+        padding=(2, 2),
+        kernel_size=(5, 5),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="float32",
+    )
+    f = relay.Function([data, weight1, weight2], y2)
+    relay_mod = tvm.IRModule.from_expr(f)
+    relay_mod = relay.transform.InferType()(relay_mod)
+
+    weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], 
w1_shape[3]).astype(
+        dtype=dtype
+    )
+    weight2_data = np.random.rand(w2_shape[0], w2_shape[1], w2_shape[2], 
w2_shape[3]).astype(
+        dtype=dtype
+    )
+    input_data = np.random.rand(
+        input_shape[0], input_shape[1], input_shape[2], input_shape[3]
+    ).astype(dtype=dtype)
+
+    params = {"weight1": weight1_data, "weight2": weight2_data}
+    inputs = {"data": input_data}
+
+    with tvm.transform.PassContext(opt_level=3, config={"tir.usmp.enable": 
usmp_enabled}):
+        lowered = tvm.relay.build(
+            relay_mod,
+            params=params,
+            target=tvm.target.Target(aot_target, host=aot_host_target),
+            runtime=Runtime("cpp"),
+            executor=Executor("aot", {"unpacked-api": False, "interface-api": 
"packed"}),
+        )
+
+    assert is_tvm_backendallocworkspace_calls(lowered.lib) != usmp_enabled
+
+    aot_mod = hexagon_session.get_executor_from_factory(lowered)
+    aot_mod.set_input(**inputs)
+    aot_mod.run()
+    hexagon_output = aot_mod.get_output(0).numpy()
+
+    target_llvm = tvm.target.Target("llvm")
+    with tvm.transform.PassContext(opt_level=3):
+        llvm_lowered = tvm.relay.build(
+            relay_mod,
+            tvm.target.Target(target_llvm, host=target_llvm),
+            runtime=Runtime("cpp"),
+            executor=Executor("graph"),
+        )
+
+    llvm_graph_mod = 
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
+    llvm_graph_mod.set_input(**params)
+    llvm_graph_mod.run(**inputs)
+    expected_output = llvm_graph_mod.get_output(0).numpy()
+
+    tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, 
atol=1e-5)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py 
b/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py
index d73ab46424..2816322b6d 100644
--- a/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py
+++ b/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 from tvm.topi.utils import get_const_tuple
 
@@ -46,7 +47,7 @@ class TestMatMulFloat:
 
     # TODO(mehrdadh): add dynamic testing
     @requires_hexagon_toolchain
-    def test_batch_matmul(self, hexagon_session, x_batch, y_batch, M, N, K, 
dtype):
+    def test_batch_matmul(self, hexagon_session: Session, x_batch, y_batch, M, 
N, K, dtype):
         if dtype == "float16":
             pytest.xfail("float16 is not supported.")
 
@@ -98,7 +99,7 @@ class TestMatMulInt8:
     )
 
     @requires_hexagon_toolchain
-    def test_batch_matmul_int8(self, hexagon_session, x_batch, y_batch, M, N, 
K):
+    def test_batch_matmul_int8(self, hexagon_session: Session, x_batch, 
y_batch, M, N, K):
         dtype = "int8"
         out_dtype = "int8"
         assert x_batch == y_batch or x_batch == 1 or y_batch == 1
diff --git a/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py 
b/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py
index 46e78f6683..bfb597f7b7 100644
--- a/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py
+++ b/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py
@@ -17,6 +17,7 @@
 
 import pytest
 import numpy as np
+from tvm.contrib.hexagon.session import Session
 
 import tvm.testing
 from tvm import te
@@ -70,7 +71,7 @@ def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
     return te.decl_tensor_intrin(dst.op, intrin_func, binds={src: src_buffer, 
dst: dst_buffer})
 
 
-def verify(hexagon_session, s, x, y, z, size):
+def verify(hexagon_session: Session, s, x, y, z, size):
     print(tvm.lower(s, [x, y, z]))
 
     target_hexagon = tvm.target.hexagon("v68", link_params=True)
@@ -98,7 +99,7 @@ def verify(hexagon_session, s, x, y, z, size):
 
 
 @requires_hexagon_toolchain
-def test_cache_read_write(hexagon_session):
+def test_cache_read_write(hexagon_session: Session):
     size = 128
     outer_shape = (size,)
     factor = 16
@@ -140,7 +141,7 @@ def layout_transform_2d(n):
 
 
 @requires_hexagon_toolchain
-def test_cache_read_write_2d(hexagon_session):
+def test_cache_read_write_2d(hexagon_session: Session):
     size = 128
     outer_shape = (size,)
     factor = 16
diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py 
b/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py
index 12417e80af..b3d6832ffa 100644
--- a/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py
+++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 from tvm.topi.utils import get_const_tuple
 from tvm.topi.nn.utils import get_pad_tuple
@@ -93,7 +94,7 @@ class BaseConv2DTests:
     @requires_hexagon_toolchain
     def test_conv2d_nchw(
         self,
-        hexagon_session,
+        hexagon_session: Session,
         batch,
         in_channel,
         in_size,
diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py 
b/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py
index 60b0b7ea6d..30b54d5134 100644
--- a/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py
+++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 from tvm.topi.utils import get_const_tuple
 from tvm.topi.nn.utils import get_pad_tuple
@@ -48,7 +49,7 @@ class BaseConv2DTests:
     @requires_hexagon_toolchain
     def test_conv2d_nhwc(
         self,
-        hexagon_session,
+        hexagon_session: Session,
         ref_data,
         batch,
         in_channel,
diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py 
b/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py
index 1dbac67aeb..0da740614f 100644
--- a/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py
+++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py
@@ -17,6 +17,7 @@
 """Test code for transposed convolution."""
 import numpy as np
 import tvm
+from tvm.contrib.hexagon.session import Session
 import tvm.testing
 from tvm import te
 from tvm import topi
@@ -70,7 +71,7 @@ class BaseConv2DTransposeTests:
     @requires_hexagon_toolchain
     def test_conv2d(
         self,
-        hexagon_session,
+        hexagon_session: Session,
         batch,
         in_channel,
         in_size,
diff --git a/tests/python/contrib/test_hexagon/topi/test_dense.py 
b/tests/python/contrib/test_hexagon/topi/test_dense.py
index 59a1573a6b..c63873a62d 100644
--- a/tests/python/contrib/test_hexagon/topi/test_dense.py
+++ b/tests/python/contrib/test_hexagon/topi/test_dense.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 from tvm.topi.utils import get_const_tuple
 
@@ -69,7 +70,14 @@ def dense_ref_data(random_seed, batch_size, in_dim, out_dim, 
use_bias, in_dtype,
 
 @requires_hexagon_toolchain
 def test_dense(
-    hexagon_session, batch_size, in_dim, out_dim, use_bias, in_dtype, 
out_dtype, dense_ref_data
+    hexagon_session: Session,
+    batch_size,
+    in_dim,
+    out_dim,
+    use_bias,
+    in_dtype,
+    out_dtype,
+    dense_ref_data,
 ):
     if in_dtype == "float16":
         pytest.xfail("float16 is not supported.")
diff --git a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py 
b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py
index 6343a10f1f..ab2ce36e1f 100644
--- a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py
+++ b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py
@@ -21,6 +21,7 @@ import numpy as np
 import pytest
 
 import tvm
+from tvm.contrib.hexagon.session import Session
 import tvm.testing
 import tvm.topi.testing
 
@@ -157,7 +158,7 @@ class BaseDepthwiseConv2D:
     @requires_hexagon_toolchain
     def test_conv2d(
         self,
-        hexagon_session,
+        hexagon_session: Session,
         in_dtype,
         out_dtype,
         layout,
diff --git a/tests/python/contrib/test_hexagon/topi/test_pooling.py 
b/tests/python/contrib/test_hexagon/topi/test_pooling.py
index f05611f2f5..38b7f387e5 100644
--- a/tests/python/contrib/test_hexagon/topi/test_pooling.py
+++ b/tests/python/contrib/test_hexagon/topi/test_pooling.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 from tvm.topi.utils import get_const_tuple
 
@@ -57,7 +58,7 @@ class TestAdaptivePool:
     )
 
     @requires_hexagon_toolchain
-    def test_adaptive_pool(self, hexagon_session, dshape, out_size, pool_type, 
layout):
+    def test_adaptive_pool(self, hexagon_session: Session, dshape, out_size, 
pool_type, layout):
         dtype = "float32"
         np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
         np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, 
layout)
diff --git a/tests/python/contrib/test_hexagon/topi/test_reduce.py 
b/tests/python/contrib/test_hexagon/topi/test_reduce.py
index 7978e3854f..beacb8cd18 100644
--- a/tests/python/contrib/test_hexagon/topi/test_reduce.py
+++ b/tests/python/contrib/test_hexagon/topi/test_reduce.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 
 from ..conftest import requires_hexagon_toolchain
@@ -101,7 +102,9 @@ def ref_data(in_shape, axis, keepdims, reduce_type, dtype):
 
 
 @requires_hexagon_toolchain
-def test_reduce_map(hexagon_session, ref_data, in_shape, axis, keepdims, 
reduce_type, dtype):
+def test_reduce_map(
+    hexagon_session: Session, ref_data, in_shape, axis, keepdims, reduce_type, 
dtype
+):
     in_npy, in_npy_map, out_npy = ref_data
 
     # Build the logic and compile the function
diff --git a/tests/python/contrib/test_hexagon/topi/test_softmax.py 
b/tests/python/contrib/test_hexagon/topi/test_softmax.py
index 4825d1e524..6857decabf 100644
--- a/tests/python/contrib/test_hexagon/topi/test_softmax.py
+++ b/tests/python/contrib/test_hexagon/topi/test_softmax.py
@@ -22,6 +22,7 @@ import sys
 import tvm
 from tvm import topi
 from tvm import te
+from tvm.contrib.hexagon.session import Session
 import tvm.topi.testing
 from tvm.topi.utils import get_const_tuple
 
@@ -54,7 +55,7 @@ softmax_operation, shape = tvm.testing.parameters(
 
 
 @requires_hexagon_toolchain
-def test_softmax(hexagon_session, shape, dtype, softmax_operation):
+def test_softmax(hexagon_session: Session, shape, dtype, softmax_operation):
     if dtype == "float16":
         pytest.xfail("float16 is not supported.")
     A = te.placeholder(shape, dtype=dtype, name="A")
diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py 
b/tests/python/relay/aot/test_crt_aot_usmp.py
index ab7fb4167c..650cb4526f 100644
--- a/tests/python/relay/aot/test_crt_aot_usmp.py
+++ b/tests/python/relay/aot/test_crt_aot_usmp.py
@@ -43,20 +43,13 @@ from aot_test_utils import (
     run_and_check,
     create_relay_module_and_inputs_from_tflite_file,
 )
+from tvm.testing.usmp import is_tvm_backendallocworkspace_calls
 
 
-def check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module):
-    """This checker checks whether any c-source produced has 
TVMBackendAllocWorkspace calls.
-    If USMP is invoked, none of them should have TVMBAW calls"""
-    dso_modules = mod._collect_dso_modules()
-    for dso_mod in dso_modules:
-        assert (
-            dso_mod.type_key == "c"
-        ), 'Current CRT AoT codegen flow should only produce type "c" runtime 
modules'
-        source = dso_mod.get_source()
-        assert (
-            source.count("TVMBackendAllocWorkspace") == 0
-        ), "This is failing because USMP was unable to plan for every 
tir.allocate node"
+def _check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module):
+    assert (
+        is_tvm_backendallocworkspace_calls(mod) == False
+    ), "This is failing because USMP was unable to plan for every tir.allocate 
node."
 
 
 @pytest.mark.parametrize(
@@ -138,7 +131,7 @@ def test_conv2d(interface_api, use_unpacked_api, 
test_runner, groups, weight_sha
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,
@@ -197,7 +190,7 @@ def test_byoc_microtvm(merge_compiler_regions):
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,
@@ -251,7 +244,7 @@ def test_tflite_model_u1_usecase(model_url, usmp_algo, 
workspace_size):
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     # Checking the workspace size reported in model library format
     mlf_memory_map = mlf._build_function_memory_map(
@@ -330,7 +323,7 @@ def 
test_tflite_model_u3_usecase_single_external_pool(model_url, usmp_algo):
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,
@@ -390,7 +383,7 @@ def 
test_tflite_model_u3_usecase_two_external_pools(model_url, usmp_algo):
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,
@@ -458,7 +451,7 @@ def 
test_tflite_model_u2_usecase_two_models_with_a_single_external_pool(model_ur
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,
@@ -526,7 +519,7 @@ def 
test_tflite_model_u4_usecase_single_external_pool(model_url, usmp_algo):
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,
@@ -602,7 +595,7 @@ def 
test_tflite_model_u4_usecase_two_external_pools(model_url, usmp_algo):
     )
 
     for compiled_model in compiled_test_mods:
-        
check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+        
_check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
 
     run_and_check(
         models=compiled_test_mods,

Reply via email to