Aleksei-grovety commented on code in PR #14909: URL: https://github.com/apache/tvm/pull/14909#discussion_r1227824435
########## python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py: ########## @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Adds pads so that each conv2d operator has only one consumer" Review Comment: Maybe "Adds pads to each conv2d operator so that each pad has only one consumer."? ########## tests/python/contrib/test_ethosu/test_legalize.py: ########## @@ -3646,5 +3648,148 @@ def _visit(stmt): verify(mod["tvmgen_default_ethos_u_main_0"]) [email protected]("ifm_shape", [(1, 55, 55, 3)]) [email protected]("kernel_shape", [(3, 3)]) [email protected]("strides, dilation", [((1, 1), (1, 1))]) [email protected]("op_padding", ["SAME", "VALID"]) [email protected]("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)]) [email protected]( + "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")] +) +def test_tflite_shared_pad_legalize( + ifm_shape, + kernel_shape, + strides, + dilation, + op_padding, + sep_padding, + op_pairs, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + + x = tf.pad( + x, + [ + [0, 0], + [sep_padding[0], sep_padding[2]], + [sep_padding[1], sep_padding[3]], + [0, 0], + ], + "CONSTANT", + ) + + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + + if op_pairs[0] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + x1 = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + else: + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + x1 = tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) Review Comment: Here the same code as on lines 3706 - 3721, it can be put into a separate function. ########## python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py: ########## @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Adds pads so that each conv2d operator has only one consumer" + +import tvm +from tvm import relay + +from ..expr_functor import ExprMutator, Call +from .. import expr as _expr + + +class PadsWithMultipleConsumersReplicator(ExprMutator): + """A pass to to handle the situation when nn.pad operator has + more than one qnn.conv2d consumer. + + pad + / \ + Conv2D Conv2D + + In this case, because of the peculiarities of pattern parsing, + conv2d does not get into the composite for the NPU. + Therefore, pads are added so that each has only one consumer. + """ + + def __init__(self): + ExprMutator.__init__(self) + self.hashes = set() + + def visit_call(self, call): + if ( + isinstance(call.op, tvm.ir.Op) + and isinstance(call.args[0], Call) + and isinstance(call.args[0].op, tvm.ir.Op) + and call.op == relay.op.get("qnn.conv2d") + and call.args[0].op == relay.op.get("nn.pad") + ): + if tvm.ir.structural_hash(call.args[0]) not in self.hashes: + self.hashes.add(tvm.ir.structural_hash(call.args[0])) + else: + used_pad = self.visit(call.args[0]) + used_pad_args = [self.visit(arg) for arg in used_pad.args] + new_pad = Call( + used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span + ) + new_pad = self.visit(new_pad) + new_conv2d_args = [] + for i, arg in enumerate(call.args): + if i == 0: + new_conv2d_args.append(self.visit(new_pad)) Review Comment: new_pad is visited twice the first time on line 59. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
