This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7c6d71b  [UnitTests][CMSISNN] Mark CMSISNN with skipif they are 
missing libraries (#9179)
7c6d71b is described below

commit 7c6d71bd50dc62b360d11ac358ad608403dda57a
Author: Lunderberg <[email protected]>
AuthorDate: Mon Oct 4 21:09:32 2021 -0500

    [UnitTests][CMSISNN] Mark CMSISNN with skipif they are missing libraries 
(#9179)
    
    * [UnitTests][CMSISNN] Mark CMSISNN with skipif they are missing libraries
    
    Show test as skipped, rather than failing test.
    
    * Added tvm.testing.requires_cmsisnn
---
 python/tvm/relay/op/contrib/cmsisnn.py             |  4 ++++
 python/tvm/testing/utils.py                        | 14 ++++++++++++++
 tests/python/contrib/test_cmsisnn/test_networks.py |  5 ++++-
 tests/python/contrib/test_cmsisnn/test_softmax.py  |  3 +++
 4 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py 
b/python/tvm/relay/op/contrib/cmsisnn.py
index db584fb..cf0e915 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -24,6 +24,10 @@ from ...dataflow_pattern import is_constant, is_op, wildcard
 from .register import register_pattern_table
 
 
+def enabled():
+    return bool(tvm.get_global_func("relay.ext.cmsisnn", True))
+
+
 def partition_for_cmsisnn(mod, params=None, **opts):
     """Partition the graph greedily offloading supported
     operators on Cortex-M using CMSIS-NN
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 39c759c..4188fea 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -82,6 +82,7 @@ import tvm._ffi
 from tvm.contrib import nvcc, cudnn
 from tvm.error import TVMError
 from tvm.relay.op.contrib.ethosn import ethosn_available
+from tvm.relay.op.contrib import cmsisnn
 
 
 def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
@@ -838,6 +839,19 @@ def requires_hexagon(*args):
     return _compose(args, _requires_hexagon)
 
 
+def requires_cmsisnn(*args):
+    """Mark a test as requiring the CMSIS NN library.
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+
+    requirements = [pytest.mark.skipif(not cmsisnn.enabled(), reason="CMSIS NN 
not enabled")]
+    return _compose(args, requirements)
+
+
 def requires_package(*packages):
     """Mark a test as requiring python packages to run.
 
diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py 
b/tests/python/contrib/test_cmsisnn/test_networks.py
index b14a15c..e78b06c 100644
--- a/tests/python/contrib/test_cmsisnn/test_networks.py
+++ b/tests/python/contrib/test_cmsisnn/test_networks.py
@@ -19,9 +19,10 @@
 
 import sys
 
-import numpy as np
 import pytest
+import numpy as np
 
+import tvm.testing
 from tvm import relay
 from tvm.contrib.download import download_testdata
 from tvm.relay.op.contrib import cmsisnn
@@ -74,6 +75,8 @@ def convert_to_relay(
 
 
 @skip_if_no_reference_system
[email protected]_package("tflite")
[email protected]_cmsisnn
 def test_cnn_small():
     # download the model
     base_url = 
"https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8";
diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py 
b/tests/python/contrib/test_cmsisnn/test_softmax.py
index 12e11c3..b030437 100644
--- a/tests/python/contrib/test_cmsisnn/test_softmax.py
+++ b/tests/python/contrib/test_cmsisnn/test_softmax.py
@@ -23,6 +23,7 @@ import itertools
 import numpy as np
 import pytest
 
+import tvm.testing
 from tvm import relay
 from tvm.relay.op.contrib import cmsisnn
 
@@ -62,6 +63,7 @@ def make_model(
 
 @skip_if_no_reference_system
 @pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]])
[email protected]_cmsisnn
 def test_softmax_int8(zero_point, scale):
     interface_api = "c"
     use_unpacked_api = True
@@ -132,6 +134,7 @@ def parameterize_for_invalid_model(test):
 
 
 @parameterize_for_invalid_model
[email protected]_cmsisnn
 def test_invalid_softmax(in_dtype, out_dtype, zero_point, scale, 
out_zero_point, out_scale):
     model = make_model(
         [1, 16, 16, 3], in_dtype, out_dtype, zero_point, scale, 
out_zero_point, out_scale

Reply via email to