ekalda commented on a change in pull request #9457:
URL: https://github.com/apache/tvm/pull/9457#discussion_r750115026
##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -123,6 +123,108 @@ def __call__(self, *args, **kwargs):
pass
+class StridedSliceRewriter(DFPatternCallback):
+ """This pass brings the strided slice out of the partitioned function"""
+
+ def __init__(self):
+ super().__init__(require_type=True, rewrite_once=True)
+ self.pattern = (wildcard().has_attr({"Composite":
"ethosu.strided_slice"}))(wildcard())
+
+ def callback(
+ self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map
+ ) -> tvm.relay.Expr:
+ input = post.args[0]
+ attrs = post.op.body.attrs
+ begin = attrs.begin
+ end = attrs.end
+ strides = attrs.strides
+ axes = attrs.axes
+ slice_mode = attrs.slice_mode
+ strided_slice = relay.op.strided_slice(
+ input, begin, end, strides=strides, axes=axes,
slice_mode=slice_mode
+ )
+ return strided_slice
+
+
[email protected]_pass(opt_level=1)
+class LegalizeStridedSlice:
+ """This is the pass that wraps StridedSliceRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(StridedSliceRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class ReshapeRewriter(DFPatternCallback):
+ """This pass brings the reshape out of the partitioned function"""
+
+ def __init__(self):
+ super().__init__(require_type=True, rewrite_once=True)
+ self.pattern = (wildcard().has_attr({"Composite":
"ethosu.reshape"}))(wildcard())
+
+ def callback(
+ self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map
+ ) -> tvm.relay.Expr:
+ reshape_input = post.args[0]
+ new_shape = post.op.body.attrs.newshape
+ reshape = relay.op.reshape(reshape_input, newshape=new_shape)
+ return reshape
+
+
[email protected]_pass(opt_level=1)
+class LegalizeReshape:
+ """This is the pass that wraps ReshapeRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(ReshapeRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class NoOpRewriter(DFPatternCallback):
+ """This pass adds and idenity operator to reshape and strided slice to
avoid a no op without a consumer"""
Review comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]