krishnaraj36 commented on code in PR #15786: URL: https://github.com/apache/tvm/pull/15786#discussion_r1382990507
########## tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py: ########## @@ -0,0 +1,315 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare, build_run_compare_vm +import pytest + + +executor_type = tvm.testing.parameter("ge", "vm") +dtype = tvm.testing.parameter("float32") + + [email protected]_opencl [email protected]_targets("opencl -device=adreno") +def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): + # Conv2d transpose test cases lists + trials = [ + [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], + [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False)], + [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 100, 100), (True, True)], + [4, 4, (1, 1), (1, 1), (1, 1), 512, (16, 100, 100), (False, False)], + [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False)], + [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True)], + ] + ge_texture_scopes = [ + ["", "global.texture", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-nhwc", "", ""], + ] + vm_texture_scopes = [ + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-nhwc + """, + ] + + for i, ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + ) in enumerate(trials): + shape = (1, *shape) + has_bias = composite[0] + has_activation = composite[1] + input_shape = shape + filter_shape = (shape[1], out_channels, kernel_w, kernel_h) + x = relay.var("data", shape=input_shape, dtype=dtype) + w = relay.var("weight", shape=filter_shape, dtype=dtype) + inputs = [x, w] + y = relay.nn.conv2d_transpose( + x, + w, + channels=out_channels, + kernel_size=(kernel_w, kernel_h), + strides=stride, + padding=pad, + kernel_layout="IOHW", + data_layout="NCHW", + dilation=dilation, + ) + + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + initializer("weight", filter_data) + params1 = { + "weight": tvm.nd.array(filter_data), + } + + if has_bias: + b = relay.var("bias", shape=(out_channels,), dtype=dtype) + y = relay.nn.bias_add(y, b, axis=1) + inputs.append(b) + bias_data = np.zeros((out_channels,)).astype(dtype) + initializer("bias", bias_data) + params1["bias"] = tvm.nd.array(bias_data) + if has_activation: + y = relay.nn.relu(y) + + mod = relay.Function(inputs, y) + if executor_type == "ge": + build_run_compare( + remote, + mod, + params1, + {"data": input_shape}, + {"data": dtype}, + target, + ge_texture_scopes[i], + gpu_preprocess, + ) + else: + build_run_compare_vm( + remote, + mod, + params1, + {"data": input_shape}, + {"data": dtype}, + target, + vm_texture_scopes[i], + gpu_preprocess, Review Comment: Enabled few test case without explicit preprocessing. Thanks for review ########## python/tvm/relay/op/strategy/adreno.py: ########## @@ -215,6 +215,58 @@ def conv2d_winograd_without_weight_transform_strategy_adreno(attrs, inputs, out_ return strategy +@conv2d_transpose_strategy.register("adreno") +def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): + """conv2d_transpose adreno strategy""" + strategy = _op.OpStrategy() + _, kernel = inputs + dilation = attrs.get_int_tuple("dilation") + groups = attrs.groups + data_layout = attrs.data_layout + kernel_layout = attrs.kernel_layout + assert dilation == (1, 1), "not support dilate now" + + if (groups == 1) and ( + (data_layout == "NCHW" and kernel_layout == "IOHW") + or (data_layout == "NCHW4c" and kernel_layout == "IOHW4o") + or (data_layout == "NCHW" and kernel_layout == "IOHW4o") + ): + if len(kernel.shape) == 4: + oc, _, _, _ = get_const_tuple(kernel.shape) + else: + oc, _, _, _, _ = get_const_tuple(kernel.shape) + # We cannot use textures for case than number of channels is less than 4. + # So, we use compute functions from cuda. + if len(kernel.shape) == 4 and oc < 4: Review Comment: Enabled the test cases for oc < 4 with the fix. Thanks for review -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
