masahi commented on a change in pull request #4497: [Relay] Add a PyTorch to Relay Parser URL: https://github.com/apache/incubator-tvm/pull/4497#discussion_r379886998
########## File path: python/tvm/relay/frontend/pytorch.py ########## @@ -0,0 +1,1026 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks +# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except +# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension +"""PT: PyTorch frontend.""" +import numpy as np + +import tvm +from tvm.ir import module as _module + +from .. import analysis as _analysis +from .. import expr as _expr +from .. import op as _op +from .common import get_relay_op +from .common import infer_shape as _infer_shape + +__all__ = ["from_pytorch"] + +# operator implementation +def _elemwise(name): + def _impl(inputs, input_types): + # TODO: Figure out a better way to get typing to work for tensor + scalar + type0 = input_types[0] + if isinstance(inputs[1], _expr.Expr): + type0 = input_types[1] + + type1 = input_types[1] + if isinstance(inputs[0], _expr.Expr): + type1 = input_types[0] + + data0 = _convert_elemwise_input(inputs[0], type0) + data1 = _convert_elemwise_input(inputs[1], type1) + + return get_relay_op(name)(data0, data1) + return _impl + +def _unsqueeze(): + def _impl(inputs, input_types): + data = inputs[0] + axis = inputs[1] + + return _op.transform.expand_dims(data, int(axis), 1) + return _impl + +def _concatenate(): + def _impl(inputs, input_types): + data = inputs[0] + axis = inputs[1] + + if isinstance(data, _expr.Expr): + data = [data] + + return _op.tensor.concatenate(data, int(axis)) + return _impl + +def _slice(): + def _impl(inputs, input_types): + data = inputs[0] + strides = [] + + if isinstance(data, _expr.Expr): + inferred_shape = _infer_shape(data) + end = [] + for infer in inferred_shape: + end.append(int(infer)) + if isinstance(data, _expr.Var): + end = inferred_shape + end = list(end) + else: + end = data.shape + + begin = [0]*len(end) + dim = int(inputs[1]) + begin[dim] = int(inputs[2]) + + if isinstance(inputs[3], str) and inputs[3].isdigit(): + end[dim] = min(end[dim], int(inputs[3])) + else: + end[dim] = inputs[3] + + strides.append(int(inputs[4])) + return _op.transform.strided_slice(data, begin, end, strides) + return _impl + +def _select(): + def _impl(inputs, input_types): + data = inputs[0] + dim = int(inputs[1]) + index = int(inputs[2]) + + return _op.transform.take(data, _expr.const(index, dtype="int32"), axis=dim) + return _impl + +def _ones(): + def _impl(inputs, input_types): + data = inputs[0] + + import torch + if isinstance(data, _expr.Expr): + shape = _infer_shape(data) + elif isinstance(data, list): + shape = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + shape = data.shape + else: + assert "data type {} could not be parsed in ones op" % (type(data)) + + return _op.full(_expr.const(1), shape, dtype=_convert_data_type(input_types[0])) + return _impl + +def _zeros(): + def _impl(inputs, input_types): + data = inputs[0] + + import torch + if isinstance(data, _expr.Expr): + shape = _infer_shape(data) + elif isinstance(data, list): + shape = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + shape = data.shape + else: + assert "data type {} could not be parsed in zeros op" % (type(data)) + + return _op.full(_expr.const(0), shape, dtype=_convert_data_type(input_types[0])) + return _impl + +def _relu(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.nn.relu(data) + return _impl + +def _adaptive_avg_2d(): + def _impl(inputs, input_types): + data = inputs[0] + output_size = _infer_shape(inputs[1]) + + return _op.contrib.contrib.adaptive_avg_pool2d( + data, + output_size=output_size) + return _impl + +def _adaptive_max_2d(): + def _impl(inputs, input_types): + data = inputs[0] + output_size = _infer_shape(inputs[1]) + + return _op.contrib.contrib.adaptive_max_pool2d( + data, + output_size=output_size) + return _impl + +def _maxpool_2d(): + def _impl(inputs, input_types): + data = inputs[0] + + pool_size = _infer_shape(inputs[1]) + strides = _infer_shape(inputs[2]) + padding = _infer_shape(inputs[3]) + + ceil_mode = int(inputs[5]) + + return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) + return _impl + +def _hardtanh(): + def _impl(inputs, input_types): + a = inputs[0] + tanh_min = float(inputs[1]) + tanh_max = float(inputs[2]) + return _op.tensor.clip(a, tanh_min, tanh_max) + return _impl + +def _convolution(): + def _impl(inputs, input_types): + # Use transpose or normal + use_transpose = True if inputs[6] == "1" else False + + data = inputs[0] + weight = inputs[1] + bias = inputs[2] + strides = inputs[3] + padding = inputs[4] + dilation = inputs[5] + + if isinstance(weight, _expr.Expr): + inferred_shape = _infer_shape(weight) + weight_shape = [] + for infer in inferred_shape: + weight_shape.append(infer) + else: + assert "data type {} could not be parsed in conv op" % (type(weight)) + + channels = weight_shape[0] Review comment: Yes, it does make sense now that I did some research on this topic. I now understand why we need reshape when channel_multipler > 1. Let me explain my findings: First, a background on torch: * According to their [API doc](https://pytorch.org/docs/stable/nn.html#conv2d), it says ``` When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also termed in literature as depthwise convolution. ``` * In torch, weight_shape[1] is always in_channel // groups ([source](https://github.com/pytorch/pytorch/blob/00c6b903278d574d94171ebbf01986d698e22716/torch/nn/modules/conv.py#L46-L47)). So for depth wise conv, weight_shape[1] is always 1. However in Relay, the meaning of `groups` seems different and depth wise conv occurs when the **output** channel is equal to `groups`. See https://github.com/apache/incubator-tvm/blob/2c0c18494e00dd3beb71527b3f2ccde5df678440/src/relay/op/nn/convolution.h#L164-L168 So in the current implementation, if channel multiplier > 1, this condition is always false and depth wise convolution is never used even though it can. But the output is still correct, because we are doing **group** convolution and comparing its output with torch's depth wise convolution output. https://github.com/apache/incubator-tvm/blob/a5661611472c8e92b20bbe4d074333b8183f2878/python/tvm/relay/op/nn/_nn.py#L224 For depth wise convolution to happen correctly, I agree with @icemelon9 that we should add reshape when channel multiplier > 1. We also need to multiply `groups` by channel multiplier to account for the difference in the meaning of `groups` parameter. This way, `wshape` at https://github.com/apache/incubator-tvm/blob/2c0c18494e00dd3beb71527b3f2ccde5df678440/src/relay/op/nn/convolution.h#L166-L167 has the correct value expected by relay and topi. In summary, the necessary change is to add ```Python channels = weight_shape[0] if groups > 1 and channels % groups == 0: # in torch, groups == in_channels for depth wise conv channel_multiplier = channels // groups weight = _op.transform.reshape(inputs[1], (groups, channel_multiplier, 3, 3)) groups *= channel_multiplier ``` @icemelon9 Does this sound about right? I also wonder if other frontend are handling depth wise conv + mulitplier > 1 case correctly. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
