ekalda commented on a change in pull request #9530:
URL: https://github.com/apache/tvm/pull/9530#discussion_r754430628
##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -741,6 +741,96 @@ def __call__(self, *args, **kwargs):
pass
+class UnaryElementwiseRewriter(DFPatternCallback):
+ """
+ Convert ethosu unary elementwise composite function to
+ ethosu_unary_elementwise operators
+ """
+
+ def __init__(self, params_class, pattern):
+ super().__init__(require_type=True)
+ self.params_class = params_class
+ self.pattern = pattern
+
+ def callback(
+ self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map
+ ) -> tvm.relay.Expr:
+ params = self.params_class(post.op.body)
+ params.ifm.tensor = post.args[0]
+
+ if str(params.ofm.layout) != "NHWC":
+ raise UnsupportedLayout(str(params.ofm.layout))
+
+ activation_map = {"clip": "CLIP"}
+ if params.activation:
+ activation = activation_map[params.activation.op.name]
+ clip_min = int(params.activation.attrs.a_min)
+ clip_max = int(params.activation.attrs.a_max)
+ else:
+ activation = "NONE"
+ clip_min = 0
+ clip_max = 0
+
+ # We don't yet support activation functions that use LUT.
+ lut = relay.const([], dtype="int8")
+
+ unary_input_shape = params.ifm.shape
+ # If the input tensor is not 4D, enter reshapes before and after the
unary operator
+ if len(params.ifm.shape) == 4:
+ unary_input = params.ifm.tensor
+ else:
+ while len(unary_input_shape) < 4:
+ unary_input_shape = [1] + unary_input_shape
Review comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]