This is an automated email from the ASF dual-hosted git repository.
mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3cb4597ed4 [CMSIS-NN] Fixed error in finding input's dtype in maxpool
(#11701)
3cb4597ed4 is described below
commit 3cb4597ed48360e3f3d80161d1c03f833072d28e
Author: Ashutosh Parkhi <[email protected]>
AuthorDate: Wed Jun 15 16:47:26 2022 +0100
[CMSIS-NN] Fixed error in finding input's dtype in maxpool (#11701)
---
python/tvm/relay/op/contrib/cmsisnn.py | 19 +++----
tests/python/contrib/test_cmsisnn/test_pooling.py | 65 ++++++++++++++++++++++-
2 files changed, 71 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relay/op/contrib/cmsisnn.py
b/python/tvm/relay/op/contrib/cmsisnn.py
index 1a06867e54..09831929e5 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -31,12 +31,6 @@ def enabled():
return "cmsis-nn" in Target.list_kinds()
-def _find_last(pattern):
- if hasattr(pattern, "args"):
- return _find_last(pattern.args[0])
- return pattern
-
-
def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts):
"""Partition the graph greedily offloading supported
operators on Cortex-M using CMSIS-NN
@@ -206,17 +200,17 @@ def pattern_table():
def check_qnn_avg_pool2d(pattern):
"""Check if avg pool2d is supported by CMSIS-NN."""
output = pattern
- input_var = _find_last(pattern)
if str(pattern.op.name) == "clip":
pooling = pattern.args[0].args[0]
else:
pooling = pattern.args[0]
+ input_op = pooling.args[0].args[0]
return (
pooling.attrs.layout == "NHWC"
- and bool(input_var.checked_type.shape[0] == 1)
- and input_var.checked_type.dtype == "int8"
+ and int(input_op.checked_type.shape[0]) == 1
+ and input_op.checked_type.dtype == "int8"
and output.checked_type.dtype == "int8"
)
@@ -229,17 +223,18 @@ def pattern_table():
def check_qnn_max_pool2d(pattern):
"""Check if max pool2d is supported by CMSIS-NN."""
output = pattern
- input_var = _find_last(pattern)
+ input_op = None
if str(pattern.op.name) == "clip":
pooling = pattern.args[0]
else:
pooling = pattern
+ input_op = pooling.args[0]
return (
pooling.attrs.layout == "NHWC"
- and bool(input_var.checked_type.shape[0] == 1)
- and input_var.checked_type.dtype == "int8"
+ and int(input_op.checked_type.shape[0]) == 1
+ and input_op.checked_type.dtype == "int8"
and output.checked_type.dtype == "int8"
)
diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py
b/tests/python/contrib/test_cmsisnn/test_pooling.py
index 6b719cdc99..a59dba0f78 100644
--- a/tests/python/contrib/test_cmsisnn/test_pooling.py
+++ b/tests/python/contrib/test_cmsisnn/test_pooling.py
@@ -45,11 +45,15 @@ def make_model(
zero_point=-33,
relu_type="RELU",
layout="NHWC",
+ input_op=None,
):
"""Return a model and any parameters it may have,
all parameters are defaulted to known good values
"""
- op = relay.var("input", shape=shape, dtype=dtype)
+ if input_op:
+ op = input_op
+ else:
+ op = relay.var("input", shape=shape, dtype=dtype)
pad_ = (0, 0, 0, 0)
if padding == "SAME":
dilation = (1, 1)
@@ -135,6 +139,65 @@ def test_op_int8(
)
[email protected]_cmsisnn
[email protected](
+ "pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1),
"VALID")]
+)
[email protected]("relu_type", ["NONE", "RELU"])
+def test_int8_pool_with_float32_input(
+ pool_size,
+ strides,
+ padding,
+ relu_type,
+):
+ """Tests QNN maxpool partitions with float32 input"""
+ interface_api = "c"
+ use_unpacked_api = True
+ test_runner = AOT_USMP_CORSTONE300_RUNNER
+
+ in_shape = (1, 28, 28, 12)
+ zero_point, scale = (-34, 0.0256)
+
+ input_ = relay.var("input", shape=in_shape, dtype="float32")
+ op = relay.op.add(input_, input_)
+ op = relay.qnn.op.quantize(op, relay.const(scale),
relay.const(zero_point), -1, "int8")
+
+ model = make_model(
+ pool_op=relay.nn.max_pool2d,
+ shape=in_shape,
+ pool_size=pool_size,
+ strides=strides,
+ padding=padding,
+ scale=scale,
+ zero_point=zero_point,
+ relu_type=relu_type,
+ input_op=op,
+ )
+ orig_mod = make_module(model)
+
+ cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+ # validate pattern matching
+ assert_partitioned_function(orig_mod, cmsisnn_mod)
+
+ # validate the output
+ np.random.seed(0)
+ inputs = {"input": np.random.uniform(0, 1, in_shape).astype("float32")}
+ output_list = generate_ref_data(orig_mod["main"], inputs)
+ compile_and_run(
+ AOTTestModel(
+ module=cmsisnn_mod,
+ inputs=inputs,
+ outputs=output_list,
+ params=None,
+ output_tolerance=1,
+ ),
+ test_runner,
+ interface_api,
+ use_unpacked_api,
+ )
+
+
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d])
def test_invalid_datatype(op):