ekalda commented on a change in pull request #9547:
URL: https://github.com/apache/tvm/pull/9547#discussion_r757435585
##########
File path: python/tvm/relay/backend/contrib/ethosu/codegen.py
##########
@@ -22,6 +22,109 @@
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
from tvm.relay.backend.contrib.ethosu import util
+from tvm.relay.expr_functor import ExprMutator
+from tvm.ir.transform import Pass
+
+# pylint: disable=unused-import
+from tvm.relay.backend.contrib.ethosu.op import op_attrs
+from tvm.relay.backend.contrib.ethosu import op
+
+
+class OptimizeLUTs(ExprMutator):
+ """A pass to merge an identity operator with a LUT based activation
function with
+ a preceding operator provided that operator can do a table lookup for the
activation
+ in the hardware"""
+
+ def __init__(self):
+ super().__init__()
+ self.lut_ops = {
+ "contrib.ethosu.conv2d": op.ethosu_conv2d,
+ "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
+ "contrib.ethosu.pooling": op.ethosu_pooling,
+ }
+
+ def create_op_with_lut(self, call):
+ """Extract the parameters and attributes from the NPU operator and
create
+ a new operator with LUT.
+ ----------
+ call : tvm.relay.expr.Call
+ The current call node being visited.
+ Returns
+ -------
+ tvm.relay.expr.Call
+ The new operator with LUT.
+ """
+ identity = call
+ ethosu_op = call.args[0]
+ lut = identity.args[1]
+ activation = identity.attrs.activation
+
+ new_attrs = dict(ethosu_op.attrs)
+ new_attrs["activation"] = activation
+
+ # Assume that LUT is always the last argument
+ new_args = [ethosu_op.args[n] for n in range(len(ethosu_op.args) - 1)]
+ new_args.append(lut)
+ assert ethosu_op.op.name in self.lut_ops.keys()
+
+ return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs)
+
+ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
+ """Recursively visit call nodes in the input graph and if an
ethosu.identity
+ operator with LUT is found and the preceding operator has a LUT
attribute, create
+ a new NPU operator.
+ Parameters
+ ----------
+ call : tvm.relay.expr.Call
+ The current call node being visited.
+ Returns
+ -------
+ tvm.relay.expr.Call
+ The input call node in the case the current call node does
+ not refer to an Op. Else, a new call node with a new operator.
+ """
+ new_call = call
+ lut_activations = ["TANH", "LUT"]
+
+ if (
+ call.op.name == "contrib.ethosu.identity"
+ and call.attrs.activation in lut_activations
+ and isinstance(call.args[0], tvm.relay.expr.Call)
+ ):
+ producer_op = call.args[0]
+ # Check if the producer can do a LUT operation
+ if producer_op.op.name in self.lut_ops.keys():
+ # Check the producer doesn't already have a LUT
+ has_lut = producer_op.attrs.activation in lut_activations
+ if not has_lut:
+ new_call = self.create_op_with_lut(call)
+
+ new_call = super().visit_call(new_call)
+
+ return new_call
+
+
[email protected]_pass(opt_level=1, name="LutOptimizer")
+class LUTsOptimizer(Pass):
+ """Register LutOptimizer as a relay pass."""
+
+ def transform_function(
+ self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
+ ) -> tvm.IRModule:
+ """Visit relay nodes in the given module.
+ Parameters
+ ----------
+ func : tvm.relay.function.Function
+ The function to apply the layout optimization pass to.
Review comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]