This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b77d659c9a [microNPU][ETHOSU] Fix concatenation with reused buffers
(#15428)
b77d659c9a is described below
commit b77d659c9afcfe3b80ea9b10a748cb15f4fe6539
Author: Aleksei-grovety <[email protected]>
AuthorDate: Tue Aug 1 19:00:46 2023 +0300
[microNPU][ETHOSU] Fix concatenation with reused buffers (#15428)
Add a pass to copy concatenation arguments which are used more than once in
concatenation operation to prevent a situation where an argument used in
multiple concatenations will be written to only one resulting buffer.
---
python/tvm/relay/backend/contrib/ethosu/codegen.py | 101 +++++++++++++++++++++
tests/python/contrib/test_ethosu/test_codegen.py | 18 ++++
2 files changed, 119 insertions(+)
diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py
b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index f4cea5df13..b2fc5f0af2 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -444,6 +444,106 @@ def replicate_pads(mod):
return mod
+class AnalyzeConcatArgs(ExprVisitor):
+ """Traverses the graph to determine which arguments were passed into the
+ concatenation operation and how many times they are used. The result is
+ maintained in `args_usage` and is a dictionary where the key is the
concatenation argument and
+ the value is the number of uses of this argument.
+
+ Attributes
+ ----------
+ args_usage : Dict[tvm.relay.expr.Call, int]
+ Mapping from concatenation arguments to count their usage as
concatenate arguments.
+ """
+
+ def __init__(self):
+ self.args_usage = defaultdict(int)
+ super().__init__()
+
+ def visit_call(self, call: relay.Call):
+ args = []
+
+ # Expand tuples
+ for arg in call.args:
+ if isinstance(arg, relay.Tuple):
+ args.extend(arg.fields)
+ else:
+ args.append(arg)
+
+ if isinstance(call.op, tvm.ir.Op) and call.op.name == "concatenate":
+ for arg in args:
+ if isinstance(arg, relay.Call):
+ self.args_usage[arg] += 1
+
+ super().visit_call(call)
+
+
+class ConcatArgsCopier(ExprMutator):
+ """A pass for copying concatenation arguments that are used in multiple
concatenation
+ operations. For a concatenation argument that is used n times, n - 1 copy
operations
+ will be created.
+
+ Attributes
+ ----------
+ args_usage : Dict[tvm.relay.expr.Call, int]
+ Mapping from concatenation arguments to count their usage as
concatenate arguments.
+ """
+
+ def __init__(self, args_usage):
+ super().__init__()
+ self.args_usage = args_usage
+
+ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
+ if isinstance(call.op, tvm.ir.Op) and call.op ==
relay.op.get("concatenate"):
+ args = []
+
+ # Expand tuples
+ for arg in call.args:
+ if isinstance(arg, relay.Tuple):
+ args.extend(arg.fields)
+ else:
+ args.append(arg)
+ new_args = []
+ for arg in args:
+ visited = self.visit(arg)
+ if self.args_usage[arg] > 1:
+ # Add copy operation
+ lut = relay.const([], "int8")
+ new_op = op.ethosu_identity(visited, lut)
+ new_args.append(new_op)
+ self.args_usage[arg] -= 1
+ else:
+ new_args.append(visited)
+
+ new_args = [relay.Tuple(new_args)]
+ else:
+ new_args = [self.visit(arg) for arg in call.args]
+ new_op = self.visit(call.op)
+ new_call = _expr.CallWithFields(
+ call, new_op, new_args, call.attrs, call.type_args, None, call.span
+ )
+ return new_call
+
+
[email protected]_npu_function_pass(opt_level=1)
+class CopyReusedConcatBuffers:
+ """Register CopyReusedConcatBuffers as a Relay pass."""
+
+ def transform_npu_function(self, _, func: relay.Function) ->
relay.Function:
+ """A pass to copy concatenation arguments which are used more than
once in
+ concatenation operation. This is the preparation for the next
RemoveConcatenates
+ pass to prevent a situation where an argument used in multiple
concatenations
+ will be written to only one resulting buffer."""
+
+ analyze = AnalyzeConcatArgs()
+ analyze.visit(func)
+
+ return ConcatArgsCopier(analyze.args_usage).visit(func)
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
def IdentityOptimizer(): # pylint: disable=invalid-name
"""Pass that removes redundant identities
@@ -585,6 +685,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
"""
mod = OutlineCompilerFunctions("ethos-u")(mod)
mod = LegalizeEthosU()(mod)
+ mod = CopyReusedConcatBuffers()(mod)
mod = LUTsOptimizer()(mod)
mod = relay.transform.InferType()(mod)
mod = IdentityOptimizer()(mod)
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py
b/tests/python/contrib/test_ethosu/test_codegen.py
index d56b8b6ec9..e094bb74b2 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1170,6 +1170,24 @@ def test_tflite_concat(shapes, axis, accel_type):
infra.compare_tvm_with_tflite(concat_func, shapes, accel_type,
enable_cascader=False)
+def test_tflite_concat_with_reused_args():
+ np.random.seed(0)
+ shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)]
+ axis = 2
+ accel_type = "ethos-u55-256"
+
+ @tf.function
+ def concat_func(*inputs):
+ op = tf.add(inputs[0], inputs[1])
+ op2 = tf.concat((inputs[0], inputs[2], op), axis)
+ op = tf.concat((inputs[0], inputs[3], op), axis)
+ op = tf.nn.max_pool2d(op, (1, 1), (1, 2), "SAME")
+ op = tf.add(op, op2)
+ return op
+
+ infra.compare_tvm_with_tflite(concat_func, shapes, accel_type,
enable_cascader=False)
+
+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_tflite_sigmoid(accel_type):
np.random.seed(0)