This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b77d659c9a [microNPU][ETHOSU] Fix concatenation with reused buffers 
(#15428)
b77d659c9a is described below

commit b77d659c9afcfe3b80ea9b10a748cb15f4fe6539
Author: Aleksei-grovety <[email protected]>
AuthorDate: Tue Aug 1 19:00:46 2023 +0300

    [microNPU][ETHOSU] Fix concatenation with reused buffers (#15428)
    
    Add a pass to copy concatenation arguments which are used more than once in 
concatenation operation to prevent a situation where an argument used in 
multiple concatenations will be written to only one resulting buffer.
---
 python/tvm/relay/backend/contrib/ethosu/codegen.py | 101 +++++++++++++++++++++
 tests/python/contrib/test_ethosu/test_codegen.py   |  18 ++++
 2 files changed, 119 insertions(+)

diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py 
b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index f4cea5df13..b2fc5f0af2 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -444,6 +444,106 @@ def replicate_pads(mod):
     return mod
 
 
+class AnalyzeConcatArgs(ExprVisitor):
+    """Traverses the graph to determine which arguments were passed into the
+    concatenation operation and how many times they are used. The result is
+    maintained in `args_usage` and is a dictionary where the key is the 
concatenation argument and
+    the value is the number of uses of this argument.
+
+    Attributes
+    ----------
+    args_usage : Dict[tvm.relay.expr.Call, int]
+        Mapping from concatenation arguments to count their usage as 
concatenate arguments.
+    """
+
+    def __init__(self):
+        self.args_usage = defaultdict(int)
+        super().__init__()
+
+    def visit_call(self, call: relay.Call):
+        args = []
+
+        # Expand tuples
+        for arg in call.args:
+            if isinstance(arg, relay.Tuple):
+                args.extend(arg.fields)
+            else:
+                args.append(arg)
+
+        if isinstance(call.op, tvm.ir.Op) and call.op.name == "concatenate":
+            for arg in args:
+                if isinstance(arg, relay.Call):
+                    self.args_usage[arg] += 1
+
+        super().visit_call(call)
+
+
+class ConcatArgsCopier(ExprMutator):
+    """A pass for copying concatenation arguments that are used in multiple 
concatenation
+    operations. For a concatenation argument that is used n times, n - 1 copy 
operations
+    will be created.
+
+    Attributes
+    ----------
+    args_usage : Dict[tvm.relay.expr.Call, int]
+        Mapping from concatenation arguments to count their usage as 
concatenate arguments.
+    """
+
+    def __init__(self, args_usage):
+        super().__init__()
+        self.args_usage = args_usage
+
+    def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
+        if isinstance(call.op, tvm.ir.Op) and call.op == 
relay.op.get("concatenate"):
+            args = []
+
+            # Expand tuples
+            for arg in call.args:
+                if isinstance(arg, relay.Tuple):
+                    args.extend(arg.fields)
+                else:
+                    args.append(arg)
+            new_args = []
+            for arg in args:
+                visited = self.visit(arg)
+                if self.args_usage[arg] > 1:
+                    # Add copy operation
+                    lut = relay.const([], "int8")
+                    new_op = op.ethosu_identity(visited, lut)
+                    new_args.append(new_op)
+                    self.args_usage[arg] -= 1
+                else:
+                    new_args.append(visited)
+
+            new_args = [relay.Tuple(new_args)]
+        else:
+            new_args = [self.visit(arg) for arg in call.args]
+        new_op = self.visit(call.op)
+        new_call = _expr.CallWithFields(
+            call, new_op, new_args, call.attrs, call.type_args, None, call.span
+        )
+        return new_call
+
+
[email protected]_npu_function_pass(opt_level=1)
+class CopyReusedConcatBuffers:
+    """Register CopyReusedConcatBuffers as a Relay pass."""
+
+    def transform_npu_function(self, _, func: relay.Function) -> 
relay.Function:
+        """A pass to copy concatenation arguments which are used more than 
once in
+        concatenation operation. This is the preparation for the next 
RemoveConcatenates
+        pass to prevent a situation where an argument used in multiple 
concatenations
+        will be written to only one resulting buffer."""
+
+        analyze = AnalyzeConcatArgs()
+        analyze.visit(func)
+
+        return ConcatArgsCopier(analyze.args_usage).visit(func)
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
 def IdentityOptimizer():  # pylint: disable=invalid-name
     """Pass that removes redundant identities
 
@@ -585,6 +685,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
     """
     mod = OutlineCompilerFunctions("ethos-u")(mod)
     mod = LegalizeEthosU()(mod)
+    mod = CopyReusedConcatBuffers()(mod)
     mod = LUTsOptimizer()(mod)
     mod = relay.transform.InferType()(mod)
     mod = IdentityOptimizer()(mod)
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index d56b8b6ec9..e094bb74b2 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1170,6 +1170,24 @@ def test_tflite_concat(shapes, axis, accel_type):
     infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, 
enable_cascader=False)
 
 
+def test_tflite_concat_with_reused_args():
+    np.random.seed(0)
+    shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)]
+    axis = 2
+    accel_type = "ethos-u55-256"
+
+    @tf.function
+    def concat_func(*inputs):
+        op = tf.add(inputs[0], inputs[1])
+        op2 = tf.concat((inputs[0], inputs[2], op), axis)
+        op = tf.concat((inputs[0], inputs[3], op), axis)
+        op = tf.nn.max_pool2d(op, (1, 1), (1, 2), "SAME")
+        op = tf.add(op, op2)
+        return op
+
+    infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, 
enable_cascader=False)
+
+
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 def test_tflite_sigmoid(accel_type):
     np.random.seed(0)

Reply via email to