This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3cb4597ed4 [CMSIS-NN] Fixed error in finding input's dtype in maxpool 
(#11701)
3cb4597ed4 is described below

commit 3cb4597ed48360e3f3d80161d1c03f833072d28e
Author: Ashutosh Parkhi <[email protected]>
AuthorDate: Wed Jun 15 16:47:26 2022 +0100

    [CMSIS-NN] Fixed error in finding input's dtype in maxpool (#11701)
---
 python/tvm/relay/op/contrib/cmsisnn.py            | 19 +++----
 tests/python/contrib/test_cmsisnn/test_pooling.py | 65 ++++++++++++++++++++++-
 2 files changed, 71 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py 
b/python/tvm/relay/op/contrib/cmsisnn.py
index 1a06867e54..09831929e5 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -31,12 +31,6 @@ def enabled():
     return "cmsis-nn" in Target.list_kinds()
 
 
-def _find_last(pattern):
-    if hasattr(pattern, "args"):
-        return _find_last(pattern.args[0])
-    return pattern
-
-
 def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts):
     """Partition the graph greedily offloading supported
     operators on Cortex-M using CMSIS-NN
@@ -206,17 +200,17 @@ def pattern_table():
     def check_qnn_avg_pool2d(pattern):
         """Check if avg pool2d is supported by CMSIS-NN."""
         output = pattern
-        input_var = _find_last(pattern)
 
         if str(pattern.op.name) == "clip":
             pooling = pattern.args[0].args[0]
         else:
             pooling = pattern.args[0]
+        input_op = pooling.args[0].args[0]
 
         return (
             pooling.attrs.layout == "NHWC"
-            and bool(input_var.checked_type.shape[0] == 1)
-            and input_var.checked_type.dtype == "int8"
+            and int(input_op.checked_type.shape[0]) == 1
+            and input_op.checked_type.dtype == "int8"
             and output.checked_type.dtype == "int8"
         )
 
@@ -229,17 +223,18 @@ def pattern_table():
     def check_qnn_max_pool2d(pattern):
         """Check if max pool2d is supported by CMSIS-NN."""
         output = pattern
-        input_var = _find_last(pattern)
+        input_op = None
 
         if str(pattern.op.name) == "clip":
             pooling = pattern.args[0]
         else:
             pooling = pattern
+        input_op = pooling.args[0]
 
         return (
             pooling.attrs.layout == "NHWC"
-            and bool(input_var.checked_type.shape[0] == 1)
-            and input_var.checked_type.dtype == "int8"
+            and int(input_op.checked_type.shape[0]) == 1
+            and input_op.checked_type.dtype == "int8"
             and output.checked_type.dtype == "int8"
         )
 
diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py 
b/tests/python/contrib/test_cmsisnn/test_pooling.py
index 6b719cdc99..a59dba0f78 100644
--- a/tests/python/contrib/test_cmsisnn/test_pooling.py
+++ b/tests/python/contrib/test_cmsisnn/test_pooling.py
@@ -45,11 +45,15 @@ def make_model(
     zero_point=-33,
     relu_type="RELU",
     layout="NHWC",
+    input_op=None,
 ):
     """Return a model and any parameters it may have,
     all parameters are defaulted to known good values
     """
-    op = relay.var("input", shape=shape, dtype=dtype)
+    if input_op:
+        op = input_op
+    else:
+        op = relay.var("input", shape=shape, dtype=dtype)
     pad_ = (0, 0, 0, 0)
     if padding == "SAME":
         dilation = (1, 1)
@@ -135,6 +139,65 @@ def test_op_int8(
     )
 
 
[email protected]_cmsisnn
[email protected](
+    "pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), 
"VALID")]
+)
[email protected]("relu_type", ["NONE", "RELU"])
+def test_int8_pool_with_float32_input(
+    pool_size,
+    strides,
+    padding,
+    relu_type,
+):
+    """Tests QNN maxpool partitions with float32 input"""
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_USMP_CORSTONE300_RUNNER
+
+    in_shape = (1, 28, 28, 12)
+    zero_point, scale = (-34, 0.0256)
+
+    input_ = relay.var("input", shape=in_shape, dtype="float32")
+    op = relay.op.add(input_, input_)
+    op = relay.qnn.op.quantize(op, relay.const(scale), 
relay.const(zero_point), -1, "int8")
+
+    model = make_model(
+        pool_op=relay.nn.max_pool2d,
+        shape=in_shape,
+        pool_size=pool_size,
+        strides=strides,
+        padding=padding,
+        scale=scale,
+        zero_point=zero_point,
+        relu_type=relu_type,
+        input_op=op,
+    )
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    assert_partitioned_function(orig_mod, cmsisnn_mod)
+
+    # validate the output
+    np.random.seed(0)
+    inputs = {"input": np.random.uniform(0, 1, in_shape).astype("float32")}
+    output_list = generate_ref_data(orig_mod["main"], inputs)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            params=None,
+            output_tolerance=1,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
 @tvm.testing.requires_cmsisnn
 @pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d])
 def test_invalid_datatype(op):

Reply via email to