This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6c63e0db53 [ETHOSU][MicroNPU][Pass] Add a pass to replicate pads
(#14909)
6c63e0db53 is described below
commit 6c63e0db53a9fee57b22b53941a8e634c1c90227
Author: sergio-grovety <[email protected]>
AuthorDate: Mon Jul 17 18:37:52 2023 +0300
[ETHOSU][MicroNPU][Pass] Add a pass to replicate pads (#14909)
Added a pass to to handle the situation when nn.pad operator has more than
one qnn.conv2d consumer.
pad
/ \
Conv2D Conv2D
In this case, because of the peculiarities of pattern parsing, conv2d does
not get into the composite for the NPU. Therefore, pads are added so that each
has only one consumer.
---------
Co-authored-by: Sergey Smirnov
<[email protected]>
Co-authored-by: Arina <[email protected]>
Co-authored-by: arina.naumova <[email protected]>
---
python/tvm/relay/backend/contrib/ethosu/codegen.py | 89 +++++++++++++-
python/tvm/relay/op/contrib/ethosu.py | 4 +-
tests/python/contrib/test_ethosu/test_codegen.py | 63 ++++++++++
tests/python/contrib/test_ethosu/test_legalize.py | 132 ++++++++++++++++++++-
4 files changed, 285 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py
b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index e40053d49a..f4cea5df13 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -32,7 +32,8 @@ from tvm.contrib.ethosu.cascader import (
)
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util,
vela_api
-from tvm.relay.expr_functor import ExprMutator, ExprVisitor
+from tvm.relay.expr_functor import ExprMutator, ExprVisitor, Call
+from tvm.relay import expr as _expr
# pylint: disable=unused-import
from tvm.relay.backend.contrib.ethosu.op import op_attrs
@@ -357,6 +358,92 @@ class LayoutOptimizer:
pass
+class PadsWithMultipleConsumersReplicator(ExprMutator):
+ """A pass to to handle the situation when nn.pad operator has
+ more than one qnn.conv2d consumer.
+
+ pad
+ / \
+ Conv2D Conv2D
+
+ In this case, because of the peculiarities of pattern parsing,
+ conv2d does not get into the composite for the NPU.
+ Therefore, pads are added so that each has only one consumer.
+ """
+
+ def __init__(self):
+ super().__init__()
+ # a set to record hashes of an pads which already have one qnn.conv2d
consumer
+ self.hashes = set()
+
+ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
+ if (
+ isinstance(call.op, tvm.ir.Op)
+ and isinstance(call.args[0], Call)
+ and isinstance(call.args[0].op, tvm.ir.Op)
+ and call.op == relay.op.get("qnn.conv2d")
+ and call.args[0].op == relay.op.get("nn.pad")
+ ):
+ if tvm.ir.structural_hash(call.args[0]) not in self.hashes:
+ # add the hash of nn.pad to set
+ self.hashes.add(tvm.ir.structural_hash(call.args[0]))
+ else:
+ # if this pad already has a conv2d consumer, duplicate the pad
+ # and make it an input for current conv2d
+ used_pad = self.visit(call.args[0])
+ used_pad_args = [self.visit(arg) for arg in used_pad.args]
+ new_pad = Call(
+ used_pad.op, used_pad_args, used_pad.attrs,
used_pad.type_args, used_pad.span
+ )
+ new_conv2d_args = []
+ for i, arg in enumerate(call.args):
+ if i == 0:
+ new_conv2d_args.append(self.visit(new_pad))
+ else:
+ new_conv2d_args.append(self.visit(arg))
+ new_conv2d_op = self.visit(call.op)
+ expr__ = _expr.CallWithFields(
+ call,
+ new_conv2d_op,
+ new_conv2d_args,
+ call.attrs,
+ call.type_args,
+ None,
+ call.span,
+ )
+ return expr__
+
+ new_args = [self.visit(arg) for arg in call.args]
+ new_op = self.visit(call.op)
+ expr__ = _expr.CallWithFields(
+ call, new_op, new_args, call.attrs, call.type_args, None, call.span
+ )
+ return expr__
+
+
+def replicate_pads(mod):
+ """Traverses the Relay graph to replicate nn.pad operators if thay have
+ multiple qnn.conv2d consumers. That making remove the situation when
+ e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped
+ because several conv2d use the same pad operation.
+
+ Parameters
+ ----------
+ tvm.ir.IRModule
+ The IRModule that gets generated from a relay frontend.
+
+ Returns
+ -------
+ tvm.ir.IRModule
+ The IRModule without nn.pad operators with multiple consumers.
+ """
+ replicator = PadsWithMultipleConsumersReplicator()
+ for global_var, func in mod.functions.items():
+ func = replicator.visit(func)
+ mod.update_func(global_var, func)
+ return mod
+
+
def IdentityOptimizer(): # pylint: disable=invalid-name
"""Pass that removes redundant identities
diff --git a/python/tvm/relay/op/contrib/ethosu.py
b/python/tvm/relay/op/contrib/ethosu.py
index 0796ccf62a..386ef9038e 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -2341,13 +2341,15 @@ def partition_for_ethosu(
mod : IRModule
The partitioned IRModule with external global functions
"""
- from tvm.relay.backend.contrib.ethosu import preprocess
+ from tvm.relay.backend.contrib.ethosu import preprocess, codegen
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
pattern = relay.op.contrib.get_pattern_table("ethos-u")
mod = relay.transform.InferType()(mod)
+ mod = codegen.replicate_pads(mod)
+ mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py
b/tests/python/contrib/test_ethosu/test_codegen.py
index cb1592c041..d56b8b6ec9 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -157,6 +157,69 @@ def test_ethosu_conv2d_double(
infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+ "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d",
"depthwise")]
+)
+def test_tflite_shared_pad(
+ accel_type,
+ op_pairs,
+):
+ np.random.seed(0)
+
+ ifm_shape = (1, 55, 32, 3)
+ kernel_shape = (3, 3)
+ strides = (3, 2)
+ dilation = (1, 1)
+ activation_function = "RELU"
+ op_padding = "SAME"
+ sep_padding = (0, 0, 1, 1)
+
+ @tf.function
+ def tf_function(x):
+ def make_depthwise_or_conv2d(pair_idx, x):
+ # The input strides to the TensorFlow API needs to be of shape 1x4
+ tf_strides = [1, strides[0], strides[1], 1]
+ if op_pairs[pair_idx] == "depthwise":
+ weight_shape = [kernel_shape[0], kernel_shape[1],
ifm_shape[3], 1]
+ weight = tf.constant(np.random.uniform(size=weight_shape),
dtype=tf.float32)
+ op = tf.nn.depthwise_conv2d(
+ x, weight, strides=tf_strides, padding=op_padding,
dilations=dilation
+ )
+ else:
+ weight_shape = [kernel_shape[0], kernel_shape[1],
ifm_shape[3], 3]
+ weight = tf.constant(np.random.uniform(size=weight_shape),
dtype=tf.float32)
+ op = tf.nn.conv2d(
+ x,
+ weight,
+ strides=tf_strides,
+ padding=op_padding,
+ dilations=dilation,
+ )
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
+
+ x = tf.pad(
+ x,
+ [
+ [0, 0],
+ [sep_padding[0], sep_padding[2]],
+ [sep_padding[1], sep_padding[3]],
+ [0, 0],
+ ],
+ "CONSTANT",
+ )
+
+ x1 = make_depthwise_or_conv2d(0, x)
+ x2 = make_depthwise_or_conv2d(1, x)
+
+ x3 = tf.math.add(x1, x2)
+ return x3
+
+ infra.compare_tvm_with_tflite(tf_function, [ifm_shape], accel_type)
+
+
@pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10,
1e10)])
def test_out_of_range_scaling(weight_min, weight_max):
np.random.seed(0)
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py
b/tests/python/contrib/test_ethosu/test_legalize.py
index c952a13c52..6dd533c730 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -31,7 +31,7 @@ from tvm import relay
from tvm.relay.backend.contrib.ethosu import legalize, preprocess
from tvm.relay import dataflow_pattern
from tvm.relay.op.contrib import ethosu
-from tvm.relay.backend.contrib.ethosu import util
+from tvm.relay.backend.contrib.ethosu import util, codegen
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.frontend.tflite import get_pad_value
from tvm.relay.expr_functor import ExprVisitor
@@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table):
want to add the operator's pattern to the pattern table so that the
compiler
wouldn't attempt to offload an operator without full stack support."""
mod = relay.transform.InferType()(mod)
+ mod = mod = codegen.replicate_pads(mod)
+ mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern_table)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
@@ -3676,5 +3678,133 @@ def test_tflite_softmax():
verify(mod["tvmgen_default_ethos_u_main_0"])
[email protected]("ifm_shape", [(1, 55, 55, 3)])
[email protected]("kernel_shape", [(3, 3)])
[email protected]("strides, dilation", [((1, 1), (1, 1))])
[email protected]("op_padding", ["SAME", "VALID"])
[email protected]("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)])
[email protected](
+ "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d",
"depthwise")]
+)
+def test_tflite_shared_pad_legalize(
+ ifm_shape,
+ kernel_shape,
+ strides,
+ dilation,
+ op_padding,
+ sep_padding,
+ op_pairs,
+):
+ dtype = "int8"
+
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, x):
+ def make_depthwise_or_conv2d(pair_idx):
+ if op_pairs[pair_idx] == "depthwise":
+ weight_shape = [kernel_shape[0], kernel_shape[1],
ifm_shape[3], 1]
+ weight =
tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+ return tf.nn.depthwise_conv2d(
+ x, weight, strides=tf_strides, padding=op_padding,
dilations=dilation
+ )
+ weight_shape = [kernel_shape[0], kernel_shape[1],
ifm_shape[3], 3]
+ weight = tf.constant(np.random.uniform(size=weight_shape),
dtype=tf.float32)
+ return tf.nn.conv2d(
+ x,
+ weight,
+ strides=tf_strides,
+ padding=op_padding,
+ dilations=dilation,
+ )
+
+ x = tf.pad(
+ x,
+ [
+ [0, 0],
+ [sep_padding[0], sep_padding[2]],
+ [sep_padding[1], sep_padding[3]],
+ [0, 0],
+ ],
+ "CONSTANT",
+ )
+
+ # The input strides to the TensorFlow API needs to be of shape
1x4
+ tf_strides = [1, strides[0], strides[1], 1]
+
+ x1 = make_depthwise_or_conv2d(0)
+ x2 = make_depthwise_or_conv2d(1)
+
+ x3 = tf.math.add(x1, x2)
+ return x3
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ tf.TensorSpec(ifm_shape, dtype=tf.float32)
+ )
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple(ifm_shape))
+ yield [data.astype(np.float32)]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+ return tflite_model
+
+ conv2d_pattern_table = [
+ (
+ ethosu.QnnConv2DParams.composite_name,
+ ethosu.qnn_conv2d_pattern(),
+ lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+ ),
+ (
+ ethosu.QnnDepthwiseConv2DParams.composite_name,
+ ethosu.qnn_depthwise_conv2d_pattern(),
+ lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
+ ),
+ ]
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ mod, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"input": ifm_shape},
+ dtype_dict={"input": dtype},
+ )
+
+ mod["main"] = bind_params_by_name(mod["main"], params)
+ mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ [legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
+ mod["tvmgen_default_ethos_u_main_0"],
+ )
+ mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite(
+ [legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
+ mod["tvmgen_default_ethos_u_main_1"],
+ )
+
+ if op_pairs[0] == "depthwise":
+ assert (
+ mod["tvmgen_default_ethos_u_main_0"].body.op.name ==
"contrib.ethosu.depthwise_conv2d"
+ )
+ else:
+ assert mod["tvmgen_default_ethos_u_main_0"].body.op.name ==
"contrib.ethosu.conv2d"
+
+ if op_pairs[1] == "depthwise":
+ assert (
+ mod["tvmgen_default_ethos_u_main_1"].body.op.name ==
"contrib.ethosu.depthwise_conv2d"
+ )
+ else:
+ assert mod["tvmgen_default_ethos_u_main_1"].body.op.name ==
"contrib.ethosu.conv2d"
+
+
if __name__ == "__main__":
tvm.testing.main()