This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc7125b31e [Hexagon] Propagate QNN Concat Quantization Params to
Inputs (#15258)
dc7125b31e is described below
commit dc7125b31ee456f06d6cf215a40c7ffec83a7532
Author: arangasa <[email protected]>
AuthorDate: Mon Jul 10 19:21:15 2023 +0530
[Hexagon] Propagate QNN Concat Quantization Params to Inputs (#15258)
* [Hexagon] Propagate qnn.concat quantization params to inputs, eliminating
redundant requantization when possible, and make it concat
* Fix pylint issue
* Add relay IR snippet before and after transformation
* Better test file description comment
---
python/tvm/contrib/hexagon/transform.py | 105 ++++++++++++++++++++-
.../test_hexagon/test_relay_simplify_qnn_concat.py | 101 ++++++++++++++++++++
2 files changed, 203 insertions(+), 3 deletions(-)
diff --git a/python/tvm/contrib/hexagon/transform.py
b/python/tvm/contrib/hexagon/transform.py
index 2e5e84342b..664739dea5 100644
--- a/python/tvm/contrib/hexagon/transform.py
+++ b/python/tvm/contrib/hexagon/transform.py
@@ -21,8 +21,16 @@ import functools as ft
import tvm
from tvm import relay
-from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard
-from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple
+from tvm.relay.dataflow_pattern import (
+ DFPatternCallback,
+ is_constant,
+ is_op,
+ is_tuple,
+ rewrite,
+ wildcard,
+)
+from tvm.relay.expr import Call
+
from ..._ffi.registry import register_func
### VTCM
@@ -43,7 +51,6 @@ def mem_info_vtcm():
def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx): # pylint:
disable=unused-argument
-
"""Generic VTCM allocation
Parameters
@@ -311,3 +318,95 @@ def remove_empty_pad(mod):
"""Remove the empty pad operator."""
mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"])
return mod
+
+
+class simplify_qnn_concat_in_func(DFPatternCallback):
+
+ """
+ Propagate qnn.concat's quantization params to its inputs,
+ and try to avoid redundant requantization while doing so.
+
+ Replace
+ def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
+ %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35),
uint8]) {
+ %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1],
layout="NHWC");
+ %1 = qnn.requantize(%q2, 0.000109401f, 0, 0.00345f, 0, axis=1,
out_dtype="uint8");
+ %2 = (%0, %1, %q3);
+ %3 = (0.0425042f, 0.00345f, 0.0486874f);
+ %4 = (0, 0, 0);
+ qnn.concatenate(%2, %3, %4, 0.0486874f, 0, axis=1)
+ }
+
+ with
+
+ def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
+ %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35),
uint8]) {
+ %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1],
layout="NHWC");
+ %1 = qnn.requantize(%0, 0.0425042f, 0, 0.0486874f, 0, axis=1,
out_dtype="uint8");
+ %2 = qnn.requantize(%q2, 0.000109401f, 0, 0.0486874f, 0, axis=1,
out_dtype="uint8");
+ %3 = (%1, %2, %q3);
+ concatenate(%3, axis=1)
+ }
+ """
+
+ def __init__(self):
+ super(simplify_qnn_concat_in_func, self).__init__()
+ self.qvals = wildcard()
+ self.scales = wildcard()
+ self.zps = wildcard()
+ self.out_scale = wildcard()
+ self.out_zp = wildcard()
+ self.pattern = is_op("qnn.concatenate")(
+ self.qvals, self.scales, self.zps, self.out_scale, self.out_zp
+ )
+
+ def callback(self, pre, post, node_map):
+ in_qvals = node_map[self.qvals][0]
+ in_scales = node_map[self.scales][0]
+ in_zps = node_map[self.zps][0]
+ new_qvals = []
+ for i in range(len(in_qvals)):
+ new_requant_args = []
+ # TODO Generalize for all qnn ops
+ if isinstance(in_qvals[i], Call) and (in_qvals[i].op.name ==
"qnn.requantize"):
+ # propagate scale/zp of qnn.concat to this requantize op
+ for j in range(3):
+ new_requant_args.append(in_qvals[i].args[j])
+ new_requant_args += [node_map[self.out_scale][0],
node_map[self.out_zp][0]]
+ new_qvals.append(relay.qnn.op.requantize(*new_requant_args,
**(in_qvals[i].attrs)))
+ else:
+ # simply create a new requantize op if there is a change in
quantization params
+ # if not, just retain the old qval
+ if (in_scales[i] == node_map[self.out_scale][0]) and (
+ in_zps[i] == node_map[self.out_zp][0]
+ ):
+ new_qvals.append(in_qvals[i])
+ else:
+ new_requant_args += [
+ in_qvals[i],
+ in_scales[i],
+ in_zps[i],
+ node_map[self.out_scale][0],
+ node_map[self.out_zp][0],
+ ]
+ new_qvals.append(
+ relay.qnn.op.requantize(
+ *new_requant_args,
+ axis=post.attrs["axis"],
+ out_dtype=post.checked_type.dtype,
+ )
+ )
+
+ new_op = relay.op.concatenate(
+ new_qvals,
+ node_map[self.pattern][0].attrs["axis"],
+ )
+ return new_op
+
+
+# Right now context is ignored
[email protected]_pass(opt_level=1)
+def simplify_qnn_concat(mod, _=None):
+ for global_var in mod.functions.keys():
+ mod[global_var] = rewrite(simplify_qnn_concat_in_func(),
mod[global_var])
+ return mod
diff --git
a/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py
b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py
new file mode 100644
index 0000000000..ad1d7592fc
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-wildcard-import, invalid-name
+
+"""
+Test hexagon relay transform - qnn.concat optimization
+"""
+import tvm
+from tvm import relay, testing
+from tvm.contrib.hexagon.transform import simplify_qnn_concat
+
+
+def get_test_module():
+ """Creates a test relay module and returns it."""
+ q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8")
+ q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8")
+ q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8")
+ s2 = relay.const(0.000109401, dtype="float32")
+ s3 = relay.const(0.0486874, dtype="float32")
+ s4 = relay.const(0.0425042, dtype="float32")
+ s5 = relay.const(0.00345, dtype="float32")
+ z1 = relay.const(0, dtype="int32")
+ r1 = relay.op.nn.max_pool2d(
+ q1,
+ pool_size=[3, 3],
+ strides=[1, 1],
+ padding=[1, 1],
+ dilation=[1, 1],
+ ceil_mode=False,
+ layout="NHWC",
+ )
+ r2 = relay.qnn.op.requantize(q2, s2, z1, s5, z1, axis=1, out_dtype="uint8")
+ q_tuple = relay.expr.Tuple([r1, r2, q3])
+ s_tuple = relay.expr.Tuple([s4, s5, s3])
+ z_tuple = relay.expr.Tuple([z1, z1, z1])
+ graph = relay.qnn.op.concatenate(q_tuple, s_tuple, z_tuple, s3, z1, axis=1)
+
+ func = relay.Function(relay.analysis.free_vars(graph), graph)
+ mod = tvm.IRModule.from_expr(func)
+ return mod
+
+
+def get_expected_output_module():
+ """Returns manually created expected output module."""
+ out_q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8")
+ out_q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8")
+ out_q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8")
+ out_s2 = relay.const(0.000109401, dtype="float32")
+ out_s3 = relay.const(0.0486874, dtype="float32")
+ out_s4 = relay.const(0.0425042, dtype="float32")
+ out_z1 = relay.const(0, dtype="int32")
+ nn_max_pool = relay.op.nn.max_pool2d(
+ out_q1,
+ pool_size=[3, 3],
+ strides=[1, 1],
+ padding=[1, 1],
+ dilation=[1, 1],
+ ceil_mode=False,
+ layout="NHWC",
+ )
+ out_r1 = relay.qnn.op.requantize(
+ nn_max_pool, out_s4, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8"
+ )
+ out_r2 = relay.qnn.op.requantize(
+ out_q2, out_s2, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8"
+ )
+ out_q_tuple = relay.expr.Tuple([out_r1, out_r2, out_q3])
+ out_graph = relay.op.concatenate(out_q_tuple, axis=1)
+
+ out_func = relay.Function(relay.analysis.free_vars(out_graph), out_graph)
+ out_mod = tvm.IRModule.from_expr(out_func)
+ return out_mod
+
+
+def test_simplify_qnn_concat():
+ mod = get_test_module()
+ mod = tvm.relay.transform.InferType()(mod)
+ mod = simplify_qnn_concat(mod)
+
+ out_mod = get_expected_output_module()
+ out_mod = tvm.relay.transform.InferType()(out_mod)
+
+ assert tvm.ir.structural_equal(mod["main"], out_mod["main"])
+
+
+if __name__ == "__main__":
+ testing.main()