This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc7125b31e [Hexagon] Propagate QNN Concat Quantization Params to 
Inputs (#15258)
dc7125b31e is described below

commit dc7125b31ee456f06d6cf215a40c7ffec83a7532
Author: arangasa <[email protected]>
AuthorDate: Mon Jul 10 19:21:15 2023 +0530

    [Hexagon] Propagate QNN Concat Quantization Params to Inputs (#15258)
    
    * [Hexagon] Propagate qnn.concat quantization params to inputs, eliminating 
redundant requantization when possible, and make it concat
    
    * Fix pylint issue
    
    * Add relay IR snippet before and after transformation
    
    * Better test file description comment
---
 python/tvm/contrib/hexagon/transform.py            | 105 ++++++++++++++++++++-
 .../test_hexagon/test_relay_simplify_qnn_concat.py | 101 ++++++++++++++++++++
 2 files changed, 203 insertions(+), 3 deletions(-)

diff --git a/python/tvm/contrib/hexagon/transform.py 
b/python/tvm/contrib/hexagon/transform.py
index 2e5e84342b..664739dea5 100644
--- a/python/tvm/contrib/hexagon/transform.py
+++ b/python/tvm/contrib/hexagon/transform.py
@@ -21,8 +21,16 @@ import functools as ft
 
 import tvm
 from tvm import relay
-from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard
-from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple
+from tvm.relay.dataflow_pattern import (
+    DFPatternCallback,
+    is_constant,
+    is_op,
+    is_tuple,
+    rewrite,
+    wildcard,
+)
+from tvm.relay.expr import Call
+
 from ..._ffi.registry import register_func
 
 ### VTCM
@@ -43,7 +51,6 @@ def mem_info_vtcm():
 
 
 def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx):  # pylint: 
disable=unused-argument
-
     """Generic VTCM allocation
 
     Parameters
@@ -311,3 +318,95 @@ def remove_empty_pad(mod):
     """Remove the empty pad operator."""
     mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"])
     return mod
+
+
+class simplify_qnn_concat_in_func(DFPatternCallback):
+
+    """
+    Propagate qnn.concat's quantization params to its inputs,
+    and try to avoid redundant requantization while doing so.
+
+    Replace
+    def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
+        %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), 
uint8]) {
+        %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], 
layout="NHWC");
+        %1 = qnn.requantize(%q2, 0.000109401f, 0, 0.00345f, 0, axis=1, 
out_dtype="uint8");
+        %2 = (%0, %1, %q3);
+        %3 = (0.0425042f, 0.00345f, 0.0486874f);
+        %4 = (0, 0, 0);
+        qnn.concatenate(%2, %3, %4, 0.0486874f, 0, axis=1)
+    }
+
+    with
+
+    def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
+        %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), 
uint8]) {
+        %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], 
layout="NHWC");
+        %1 = qnn.requantize(%0, 0.0425042f, 0, 0.0486874f, 0, axis=1, 
out_dtype="uint8");
+        %2 = qnn.requantize(%q2, 0.000109401f, 0, 0.0486874f, 0, axis=1, 
out_dtype="uint8");
+        %3 = (%1, %2, %q3);
+        concatenate(%3, axis=1)
+    }
+    """
+
+    def __init__(self):
+        super(simplify_qnn_concat_in_func, self).__init__()
+        self.qvals = wildcard()
+        self.scales = wildcard()
+        self.zps = wildcard()
+        self.out_scale = wildcard()
+        self.out_zp = wildcard()
+        self.pattern = is_op("qnn.concatenate")(
+            self.qvals, self.scales, self.zps, self.out_scale, self.out_zp
+        )
+
+    def callback(self, pre, post, node_map):
+        in_qvals = node_map[self.qvals][0]
+        in_scales = node_map[self.scales][0]
+        in_zps = node_map[self.zps][0]
+        new_qvals = []
+        for i in range(len(in_qvals)):
+            new_requant_args = []
+            # TODO Generalize for all qnn ops
+            if isinstance(in_qvals[i], Call) and (in_qvals[i].op.name == 
"qnn.requantize"):
+                # propagate scale/zp of qnn.concat to this requantize op
+                for j in range(3):
+                    new_requant_args.append(in_qvals[i].args[j])
+                new_requant_args += [node_map[self.out_scale][0], 
node_map[self.out_zp][0]]
+                new_qvals.append(relay.qnn.op.requantize(*new_requant_args, 
**(in_qvals[i].attrs)))
+            else:
+                # simply create a new requantize op if there is a change in 
quantization params
+                # if not, just retain the old qval
+                if (in_scales[i] == node_map[self.out_scale][0]) and (
+                    in_zps[i] == node_map[self.out_zp][0]
+                ):
+                    new_qvals.append(in_qvals[i])
+                else:
+                    new_requant_args += [
+                        in_qvals[i],
+                        in_scales[i],
+                        in_zps[i],
+                        node_map[self.out_scale][0],
+                        node_map[self.out_zp][0],
+                    ]
+                    new_qvals.append(
+                        relay.qnn.op.requantize(
+                            *new_requant_args,
+                            axis=post.attrs["axis"],
+                            out_dtype=post.checked_type.dtype,
+                        )
+                    )
+
+        new_op = relay.op.concatenate(
+            new_qvals,
+            node_map[self.pattern][0].attrs["axis"],
+        )
+        return new_op
+
+
+# Right now context is ignored
[email protected]_pass(opt_level=1)
+def simplify_qnn_concat(mod, _=None):
+    for global_var in mod.functions.keys():
+        mod[global_var] = rewrite(simplify_qnn_concat_in_func(), 
mod[global_var])
+    return mod
diff --git 
a/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py 
b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py
new file mode 100644
index 0000000000..ad1d7592fc
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-wildcard-import, invalid-name
+
+"""
+Test hexagon relay transform - qnn.concat optimization
+"""
+import tvm
+from tvm import relay, testing
+from tvm.contrib.hexagon.transform import simplify_qnn_concat
+
+
+def get_test_module():
+    """Creates a test relay module and returns it."""
+    q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8")
+    q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8")
+    q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8")
+    s2 = relay.const(0.000109401, dtype="float32")
+    s3 = relay.const(0.0486874, dtype="float32")
+    s4 = relay.const(0.0425042, dtype="float32")
+    s5 = relay.const(0.00345, dtype="float32")
+    z1 = relay.const(0, dtype="int32")
+    r1 = relay.op.nn.max_pool2d(
+        q1,
+        pool_size=[3, 3],
+        strides=[1, 1],
+        padding=[1, 1],
+        dilation=[1, 1],
+        ceil_mode=False,
+        layout="NHWC",
+    )
+    r2 = relay.qnn.op.requantize(q2, s2, z1, s5, z1, axis=1, out_dtype="uint8")
+    q_tuple = relay.expr.Tuple([r1, r2, q3])
+    s_tuple = relay.expr.Tuple([s4, s5, s3])
+    z_tuple = relay.expr.Tuple([z1, z1, z1])
+    graph = relay.qnn.op.concatenate(q_tuple, s_tuple, z_tuple, s3, z1, axis=1)
+
+    func = relay.Function(relay.analysis.free_vars(graph), graph)
+    mod = tvm.IRModule.from_expr(func)
+    return mod
+
+
+def get_expected_output_module():
+    """Returns manually created expected output module."""
+    out_q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8")
+    out_q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8")
+    out_q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8")
+    out_s2 = relay.const(0.000109401, dtype="float32")
+    out_s3 = relay.const(0.0486874, dtype="float32")
+    out_s4 = relay.const(0.0425042, dtype="float32")
+    out_z1 = relay.const(0, dtype="int32")
+    nn_max_pool = relay.op.nn.max_pool2d(
+        out_q1,
+        pool_size=[3, 3],
+        strides=[1, 1],
+        padding=[1, 1],
+        dilation=[1, 1],
+        ceil_mode=False,
+        layout="NHWC",
+    )
+    out_r1 = relay.qnn.op.requantize(
+        nn_max_pool, out_s4, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8"
+    )
+    out_r2 = relay.qnn.op.requantize(
+        out_q2, out_s2, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8"
+    )
+    out_q_tuple = relay.expr.Tuple([out_r1, out_r2, out_q3])
+    out_graph = relay.op.concatenate(out_q_tuple, axis=1)
+
+    out_func = relay.Function(relay.analysis.free_vars(out_graph), out_graph)
+    out_mod = tvm.IRModule.from_expr(out_func)
+    return out_mod
+
+
+def test_simplify_qnn_concat():
+    mod = get_test_module()
+    mod = tvm.relay.transform.InferType()(mod)
+    mod = simplify_qnn_concat(mod)
+
+    out_mod = get_expected_output_module()
+    out_mod = tvm.relay.transform.InferType()(out_mod)
+
+    assert tvm.ir.structural_equal(mod["main"], out_mod["main"])
+
+
+if __name__ == "__main__":
+    testing.main()

Reply via email to