This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e814f798ed [Adreno] Change compute/schedule for ToMixedPrecision pass 
(#12537)
e814f798ed is described below

commit e814f798edc5bf6977a4f4f74ec8d1d7e363c608
Author: Andrey Malyshev <[email protected]>
AuthorDate: Thu Sep 1 19:33:15 2022 +0300

    [Adreno] Change compute/schedule for ToMixedPrecision pass (#12537)
    
    * [Adreno] Change compute/schedule for ToMixedPrecision pass
    
    * Address CI fails
    
    * address PR comments
    
    * Fix AutoTVM flow
---
 python/tvm/relay/op/strategy/adreno.py           | 142 +++++++----------------
 python/tvm/topi/adreno/conv2d_alter_op.py        |  48 +++++---
 python/tvm/topi/adreno/conv2d_nchw.py            | 117 +++++++++----------
 python/tvm/topi/adreno/conv2d_nchw_winograd.py   |  45 +------
 python/tvm/topi/adreno/conv2d_nhwc.py            | 111 +++++++++---------
 python/tvm/topi/adreno/conv2d_nhwc_winograd.py   |  45 +------
 python/tvm/topi/adreno/conv2d_winograd_common.py |  19 ++-
 python/tvm/topi/adreno/depthwise_conv2d_nchw.py  |  42 ++-----
 python/tvm/topi/adreno/depthwise_conv2d_nhwc.py  |  38 +-----
 tests/python/relay/test_conv2d_nchw_texture.py   |   4 +-
 tests/python/relay/test_conv2d_nhwc_texture.py   |   2 +-
 11 files changed, 218 insertions(+), 395 deletions(-)

diff --git a/python/tvm/relay/op/strategy/adreno.py 
b/python/tvm/relay/op/strategy/adreno.py
index a537fa1e7b..9429fd71e1 100644
--- a/python/tvm/relay/op/strategy/adreno.py
+++ b/python/tvm/relay/op/strategy/adreno.py
@@ -36,8 +36,10 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target):
         raise ValueError("dilation should be positive value")
 
     if groups == 1:
-        if (data_layout == "NCHW" and kernel_layout == "OIHW") or (
-            data_layout == "NCHW4c" and kernel_layout == "OIHW4o"
+        if (
+            (data_layout == "NCHW" and kernel_layout == "OIHW")
+            or (data_layout == "NCHW4c" and kernel_layout == "OIHW4o")
+            or (data_layout == "NCHW" and kernel_layout == "OIHW4o")
         ):
             if len(kernel.shape) == 4:
                 _, _, kh, kw = get_const_tuple(kernel.shape)
@@ -47,35 +49,24 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target):
                 (2 < kh < 8 and 2 < kw < 8 and kh == kw)
                 and (stride_h == 1 and stride_w == 1)
                 and (dilation_h == 1 and dilation_w == 1)
+                and not (data_layout == "NCHW" and kernel_layout == "OIHW4o")
             ):
-                if out_type.dtype == "float16":
-                    strategy.add_implementation(
-                        wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd),
-                        
wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd),
-                        name="conv2d_nchw_winograd.image2d",
-                        plevel=5,
-                    )
                 strategy.add_implementation(
-                    
wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_acc32),
-                    
wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd_acc32),
-                    name="conv2d_nchw_winograd_acc32.image2d",
-                    plevel=7,
-                )
-            if out_type.dtype == "float16":
-                strategy.add_implementation(
-                    wrap_compute_conv2d(topi.adreno.conv2d_nchwc),
-                    wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc),
-                    name="conv2d_nchwc.image2d",
-                    plevel=10,
+                    wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd),
+                    
wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd),
+                    name="conv2d_nchw_winograd.image2d",
+                    plevel=5,
                 )
             strategy.add_implementation(
-                wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32),
-                wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32),
-                name="conv2d_nchwc_acc32.image2d",
-                plevel=20,
+                wrap_compute_conv2d(topi.adreno.conv2d_nchwc),
+                wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc),
+                name="conv2d_nchwc.image2d",
+                plevel=10,
             )
-        elif (data_layout == "NHWC" and kernel_layout == "HWIO") or (
-            data_layout == "NHWC4c" and kernel_layout == "HWIO4o"
+        elif (
+            (data_layout == "NHWC" and kernel_layout == "HWIO")
+            or (data_layout == "NHWC4c" and kernel_layout == "HWIO4o")
+            or (data_layout == "NHWC" and kernel_layout == "HWIO4o")
         ):
             if len(kernel.shape) == 4:
                 kh, kw, _, _ = get_const_tuple(kernel.shape)
@@ -85,32 +76,19 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target):
                 (2 < kh < 8 and 2 < kw < 8 and kh == kw)
                 and (stride_h == 1 and stride_w == 1)
                 and (dilation_h == 1 and dilation_w == 1)
+                and not (data_layout == "NHWC" and kernel_layout == "HWIO4o")
             ):
-                if out_type.dtype == "float16":
-                    strategy.add_implementation(
-                        wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd),
-                        
wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd),
-                        name="conv2d_nhwc_winograd.image2d",
-                        plevel=5,
-                    )
                 strategy.add_implementation(
-                    
wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_acc32),
-                    
wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd_acc32),
-                    name="conv2d_nhwc_winograd_acc32.image2d",
-                    plevel=7,
-                )
-            if out_type.dtype == "float16":
-                strategy.add_implementation(
-                    wrap_compute_conv2d(topi.adreno.conv2d_nhwc),
-                    wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc),
-                    name="conv2d_nhwc.image2d",
-                    plevel=10,
+                    wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd),
+                    
wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd),
+                    name="conv2d_nhwc_winograd.image2d",
+                    plevel=5,
                 )
             strategy.add_implementation(
-                wrap_compute_conv2d(topi.adreno.conv2d_nhwc_acc32),
-                wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_acc32),
-                name="conv2d_nhwc_acc32.image2d",
-                plevel=20,
+                wrap_compute_conv2d(topi.adreno.conv2d_nhwc),
+                wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc),
+                name="conv2d_nhwc.image2d",
+                plevel=10,
             )
         else:
             raise RuntimeError(
@@ -149,35 +127,21 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, 
target):
             if (data_layout == "NCHW" and kernel_layout == "OIHW") or (
                 data_layout == "NCHW4c" and kernel_layout == "OIHW4o"
             ):
-                if out_type.dtype == "float16":
-                    strategy.add_implementation(
-                        
wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc),
-                        
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc),
-                        name="depthwise_conv2d_nchwc.image2d",
-                        plevel=10,
-                    )
                 strategy.add_implementation(
-                    
wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc_acc32),
-                    
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc_acc32),
-                    name="depthwise_conv2d_nchwc_acc32.image2d",
-                    plevel=20,
+                    wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc),
+                    
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc),
+                    name="depthwise_conv2d_nchwc.image2d",
+                    plevel=10,
                 )
             elif (data_layout == "NHWC" and kernel_layout == "HWOI") or (
                 data_layout == "NHWC4c" and kernel_layout == "HWOI4o"
             ):
                 if data.shape[-1] >= 4:
-                    if out_type.dtype == "float16":
-                        strategy.add_implementation(
-                            
wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nhwc),
-                            
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nhwc),
-                            name="depthwise_conv2d_nhwc.image2d",
-                            plevel=10,
-                        )
                     strategy.add_implementation(
-                        
wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nhwc_acc32),
-                        
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nhwc_acc32),
-                        name="depthwise_conv2d_nhwc_acc32.image2d",
-                        plevel=20,
+                        wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nhwc),
+                        
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nhwc),
+                        name="depthwise_conv2d_nhwc.image2d",
+                        plevel=10,
                     )
                 else:
                     strategy.add_implementation(
@@ -208,40 +172,18 @@ def 
conv2d_winograd_without_weight_transfrom_strategy_adreno(attrs, inputs, out_
     assert groups == 1, "Do not supoort arbitrary group number"
     strategy = _op.OpStrategy()
     if layout in ("NCHW", "NCHW4c"):
-        if out_type.dtype == "float16":
-            strategy.add_implementation(
-                
wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform),
-                wrap_topi_schedule(
-                    
topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform
-                ),
-                name="conv2d_nchw_winograd_without_weight_transform.image2d",
-                plevel=5,
-            )
         strategy.add_implementation(
-            
wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform_acc32),
-            wrap_topi_schedule(
-                
topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform_acc32
-            ),
-            name="conv2d_nchw_winograd_without_weight_transform_acc32.image2d",
-            plevel=7,
+            
wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform),
+            
wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform),
+            name="conv2d_nchw_winograd_without_weight_transform.image2d",
+            plevel=5,
         )
     elif layout in ("NHWC", "NHWC4c"):
-        if out_type.dtype == "float16":
-            strategy.add_implementation(
-                
wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform),
-                wrap_topi_schedule(
-                    
topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform
-                ),
-                name="conv2d_nhwc_winograd_without_weight_transform.image2d",
-                plevel=5,
-            )
         strategy.add_implementation(
-            
wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform_acc32),
-            wrap_topi_schedule(
-                
topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform_acc32
-            ),
-            name="conv2d_nhwc_winograd_without_weight_transform_acc32.image2d",
-            plevel=7,
+            
wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform),
+            
wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform),
+            name="conv2d_nhwc_winograd_without_weight_transform.image2d",
+            plevel=5,
         )
     else:
         raise RuntimeError(
diff --git a/python/tvm/topi/adreno/conv2d_alter_op.py 
b/python/tvm/topi/adreno/conv2d_alter_op.py
index 16573991e0..6cf749a62b 100644
--- a/python/tvm/topi/adreno/conv2d_alter_op.py
+++ b/python/tvm/topi/adreno/conv2d_alter_op.py
@@ -304,7 +304,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
                 num_filter_block = 4
 
             # no support yet for tensors that cannot be divisible by factor 4
-            if in_channel_block != 4 or num_filter_block != 4:
+            if num_filter_block != 4:
                 return None
 
             batch_size, in_channel, height, width = 
get_const_tuple(data_tensor.shape)
@@ -312,16 +312,22 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
 
             # update new attrs
             new_attrs["channels"] = out_channel
-            new_attrs["data_layout"] = "NCHW%dc" % in_channel_block
+            if in_channel_block == 4:
+                new_attrs["data_layout"] = "NCHW%dc" % in_channel_block
+            else:
+                new_attrs["data_layout"] = "NCHW"
             # (oc, ic, h, w) -> (OC, ic, h, w, oc)
             new_attrs["kernel_layout"] = "OIHW%do" % num_filter_block
             new_attrs["out_layout"] = "NCHW%dc" % num_filter_block
 
             # Store altered operator's config for applying of tuned AutoTVM 
statistics
-            new_data = te.placeholder(
-                (batch_size, in_channel // in_channel_block, height, width, 
in_channel_block),
-                dtype=data_dtype,
-            )
+            if in_channel_block == 4:
+                new_data = te.placeholder(
+                    (batch_size, in_channel // in_channel_block, height, 
width, in_channel_block),
+                    dtype=data_dtype,
+                )
+            else:
+                new_data = data_tensor
             new_kernel = te.placeholder(
                 (out_channel // num_filter_block, in_filter_channel, kh, kw, 
num_filter_block),
                 dtype=kernel_tensor.dtype,
@@ -361,12 +367,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
                 num_filter_block = 4
 
             # no support yet for tensors cannot be divisible by factor 4
-            if in_channel_block != 4 or num_filter_block != 4:
+            if num_filter_block != 4:
                 return None
 
             # update new attrs
             new_attrs["channels"] = out_channles
-            new_attrs["data_layout"] = "NHWC%dc" % in_channel_block
+            if in_channel_block == 4:
+                new_attrs["data_layout"] = "NHWC%dc" % in_channel_block
+            else:
+                new_attrs["data_layout"] = "NHWC"
             # (h, w, ic, oc) -> (h, w, ic, OC, oc)
             if kernel_layout == "HWIO":
                 new_attrs["kernel_layout"] = "HWIO%do" % num_filter_block
@@ -375,16 +384,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
             new_attrs["out_layout"] = "NHWC%dc" % num_filter_block
 
             # Store altered operator's config for applying of tuned AutoTVM 
statistics
-            new_data = te.placeholder(
-                (
-                    batch_size,
-                    in_height,
-                    in_width,
-                    in_channels // in_channel_block,
-                    in_channel_block,
-                ),
-                dtype=data_dtype,
-            )
+            if in_channel_block == 4:
+                new_data = te.placeholder(
+                    (
+                        batch_size,
+                        in_height,
+                        in_width,
+                        in_channels // in_channel_block,
+                        in_channel_block,
+                    ),
+                    dtype=data_dtype,
+                )
+            else:
+                new_data = data_tensor
             if kernel_layout == "HWIO":
                 new_kernel = te.placeholder(
                     (
diff --git a/python/tvm/topi/adreno/conv2d_nchw.py 
b/python/tvm/topi/adreno/conv2d_nchw.py
index 65cd8e0150..082f71364a 100644
--- a/python/tvm/topi/adreno/conv2d_nchw.py
+++ b/python/tvm/topi/adreno/conv2d_nchw.py
@@ -33,48 +33,22 @@ from .utils import (
 )
 
 
[email protected]_topi_compute("conv2d_nchwc.image2d")
-def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float16"}
-    return compute_conv2d_NCHWc_KCRSk(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
[email protected]_topi_compute("conv2d_nchwc_acc32.image2d")
-def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float32"}
-    return compute_conv2d_NCHWc_KCRSk(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
 @autotvm.register_topi_schedule("conv2d_nchwc.image2d")
 def schedule_conv2d_nchwc(cfg, outs):
-    return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("conv2d_nchwc_acc32.image2d")
-def schedule_conv2d_nchwc_acc32(cfg, outs):
-    return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32")
-
-
-def schedule_conv2d_nchwc_impl(cfg, outs, tag):
     """Create the schedule for conv2d_nchw"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == tag:
+        if op.tag == "adreno_conv2d_latest_op":
             schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, 
out_dtype, args):
[email protected]_topi_compute("conv2d_nchwc.image2d")
+def conv2d_nchwc(cfg, Input, Filter, stride, padding, dilation, out_dtype):
     """
     Convolution operator in NCHWc layout.
     Algo:
@@ -109,18 +83,12 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, 
padding, dilation, out_dty
     convert_from4d = False
     if len(Input.shape) == 4:
         batch, in_channels, in_height, in_width = Input.shape
-        out_channles, in_filter_channels, kernel_h, kernel_w = Filter.shape
-
         in_channel_chunks, in_channel_block, in_channel_tail = 
split_to_chunks(in_channels, 4)
-        out_channel_chunks, out_channel_block, out_channel_tail = 
split_to_chunks(out_channles, 4)
 
         if autotvm.GLOBAL_SCOPE.in_tuning:
             dshape = (batch, in_channel_chunks, in_height, in_width, 
in_channel_block)
             Input = tvm.te.placeholder(dshape, Input.dtype, 
name="data_placeholder")
-            kshape = (out_channel_chunks, in_filter_channels, kernel_h, 
kernel_w, out_channel_block)
-            Filter = tvm.te.placeholder(kshape, Filter.dtype, 
name="kernel_placeholder")
         else:
-            convert_from4d = True
             Input = pack_input(
                 Input,
                 "NCHW",
@@ -131,6 +99,18 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, 
padding, dilation, out_dty
                 in_height,
                 in_width,
             )
+    else:
+        batch, in_channel_chunks, in_height, in_width, in_channel_block = 
Input.shape
+
+    if len(Filter.shape) == 4:
+        out_channles, in_filter_channels, kernel_h, kernel_w = Filter.shape
+        out_channel_chunks, out_channel_block, out_channel_tail = 
split_to_chunks(out_channles, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kshape = (out_channel_chunks, in_filter_channels, kernel_h, 
kernel_w, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, 
name="kernel_placeholder")
+        else:
+            convert_from4d = True
             Filter = pack_filter(
                 Filter,
                 "OIHW",
@@ -144,9 +124,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, 
padding, dilation, out_dty
                 kernel_h,
                 kernel_w,
             )
-
     else:
-        batch, in_channel_chunks, in_height, in_width, in_channel_block = 
Input.shape
         out_channel_chunks, in_filter_channels, kernel_h, kernel_w, 
out_channel_block = Filter.shape
 
     out_height_orig, out_height, out_width_orig, out_width = 
expand_spatial_dimensions(
@@ -178,7 +156,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, 
padding, dilation, out_dty
             (
                 temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + 
rx * dilation_w, rcb]
                 * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]
-            ).astype(args["accumulator"]),
+            ).astype(out_dtype),
             axis=[rcc, rcb, ry, rx],
         ),
         tag="conv2d_nchwc",
@@ -193,13 +171,13 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, 
padding, dilation, out_dty
         return te.compute(
             (batch, out_channles, out_height_orig, out_width_orig),
             lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % 
out_channel_block],
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_conv2d_latest_op",
         )
     else:
         return te.compute(
             (batch, out_channel_chunks, out_height_orig, out_width_orig, 
out_channel_block),
             lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, 
ffb].astype(out_dtype),
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_conv2d_latest_op",
         )
 
 
@@ -234,6 +212,20 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
         conv = output.op.input_tensors[0]
         latest_blocked = latest
 
+    pad_data, kernel = s[conv].op.input_tensors
+    filter_pack_rt = bool(
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
+    )
+
+    if "pad_temp" in pad_data.op.name:
+        input_pad_temp = pad_data.op.input_tensors[0]
+    else:
+        input_pad_temp = pad_data
+
+    input_pack_rt = bool(
+        isinstance(input_pad_temp.op, tvm.te.ComputeOp) and "input_pack" in 
input_pad_temp.op.tag
+    )
+
     ##### space definition begin #####
     n, fc, y, x, fb = s[conv].op.axis
     rcc, rcb, ry, rx = s[conv].op.reduce_axis
@@ -274,37 +266,40 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
-    if (
-        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
-    ):  # len(latest.op.axis) == 4:
-        # manage scheduling of datacopy
-        pad_data, kernel = s[conv].op.input_tensors
-        if "pad_temp" in pad_data.op.name:
-            pack_data = pad_data.op.input_tensors[0]
-            bind_data_copy(s[pack_data])
+    # There are several conditions that have to be handled:
+    # 1. If we are in the tuning, we always add cache read for data to main 
conv kernel
+    #    to get texture in tuning opencl kernel
+    # 2. If we are repacking input in runtime, we should always explicit 
schedule this one more
+    #    stage of data copy from 4d to 5d (referred as pack_data).
+    # 3. If we have pad (independently if we have runtime repack or not) we 
should inline it in the
+    #    cache_read("texture")
+    if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt:
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            if "pad_temp" in pad_data.op.name:
+                s[pad_data].compute_inline()
         else:
-            bind_data_copy(s[pad_data])
-        bind_data_copy(s[kernel])
-
-    pad_data, kernel = s[conv].op.input_tensors
+            if "pad_temp" in pad_data.op.name:
+                pack_data = pad_data.op.input_tensors[0]
+                bind_data_copy(s[pack_data])
+                s[pad_data].compute_inline()
+            else:
+                pack_data = pad_data
+                bind_data_copy(s[pack_data])
 
-    if (
-        autotvm.GLOBAL_SCOPE.in_tuning
-        or isinstance(kernel.op, tvm.te.ComputeOp)
-        and "filter_pack" in kernel.op.tag
-    ):
-        if "pad_temp" in pad_data.op.name:
-            s[pad_data].compute_inline()
         AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
         bind_data_copy(s[AT])
-        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
-        bind_data_copy(s[WT])
     elif "pad_temp" in pad_data.op.name:
         s[pad_data].compute_inline()
         # create cache stage
         AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
         bind_data_copy(s[AT])
 
+    if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
+        if not autotvm.GLOBAL_SCOPE.in_tuning:
+            bind_data_copy(s[kernel])
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+        bind_data_copy(s[WT])
+
     s[conv].set_scope("local")
     if latest_blocked == latest and output != latest:
         s[output].compute_inline()
diff --git a/python/tvm/topi/adreno/conv2d_nchw_winograd.py 
b/python/tvm/topi/adreno/conv2d_nchw_winograd.py
index 16f7cb8b19..0ddc0e7f2c 100644
--- a/python/tvm/topi/adreno/conv2d_nchw_winograd.py
+++ b/python/tvm/topi/adreno/conv2d_nchw_winograd.py
@@ -27,62 +27,32 @@ logger = logging.getLogger("conv2d_nchw_winograd")
 
 @autotvm.register_topi_compute("conv2d_nchw_winograd.image2d")
 def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
-    args = {"shared": False, "accumulator": "float16"}
     return conv2d_nchw_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=False
-    )
-
-
[email protected]_topi_compute("conv2d_nchw_winograd_acc32.image2d")
-def conv2d_nchw_winograd_acc32(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
-    args = {"shared": False, "accumulator": "float32"}
-    return conv2d_nchw_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=False
+        cfg, data, kernel, strides, padding, dilation, out_dtype, 
pre_computed=False
     )
 
 
 @autotvm.register_topi_schedule("conv2d_nchw_winograd.image2d")
 def schedule_conv2d_nchw_winograd(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("conv2d_nchw_winograd_acc32.image2d")
-def schedule_conv2d_nchw_winograd_acc32(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32")
+    return schedule_conv2d_winograd_impl(cfg, outs, tag="dummy_compute_at")
 
 
 
@autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform.image2d")
 def conv2d_nchw_winograd_without_weight_transform(
     cfg, data, kernel, strides, padding, dilation, out_dtype
 ):
-    args = {"shared": False, "accumulator": "float16"}
     return conv2d_nchw_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=True
-    )
-
-
[email protected]_topi_compute("conv2d_nchw_winograd_without_weight_transform_acc32.image2d")
-def conv2d_nchw_winograd_without_weight_transform_acc32(
-    cfg, data, kernel, strides, padding, dilation, out_dtype
-):
-    args = {"shared": False, "accumulator": "float32"}
-    return conv2d_nchw_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=True
+        cfg, data, kernel, strides, padding, dilation, out_dtype, 
pre_computed=True
     )
 
 
 
@autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform.image2d")
 def schedule_conv2d_nchw_winograd_without_weight_transform(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", 
pre_computed=True)
-
-
[email protected]_topi_schedule("conv2d_nchw_winograd_without_weight_transform_acc32.image2d")
-def schedule_conv2d_nchw_winograd_without_weight_transform_acc32(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", 
pre_computed=True)
+    return schedule_conv2d_winograd_impl(cfg, outs, tag="dummy_compute_at", 
pre_computed=True)
 
 
 def conv2d_nchw_winograd_comp(
-    cfg, data, kernel, strides, padding, dilation, out_dtype, args, 
pre_computed
+    cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed
 ):
     """Compute declaration for winograd
 
@@ -111,9 +81,6 @@ def conv2d_nchw_winograd_comp(
     out_dtype: str
         The output type. This is used for mixed precision.
 
-    args: dict
-        Dictionary with additional arguments, e.g. accumulator type
-
     pre_computed: bool
         Flag if weights were pre computed if true or the weights should be
         computed in runtime
@@ -124,5 +91,5 @@ def conv2d_nchw_winograd_comp(
         4-D or 5-D with shape NCHW or NCHW4c
     """
     return conv2d_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args, 
pre_computed, "NCHW"
+        cfg, data, kernel, strides, padding, dilation, out_dtype, 
pre_computed, "NCHW"
     )
diff --git a/python/tvm/topi/adreno/conv2d_nhwc.py 
b/python/tvm/topi/adreno/conv2d_nhwc.py
index b377169ca8..993b632525 100644
--- a/python/tvm/topi/adreno/conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/conv2d_nhwc.py
@@ -33,44 +33,22 @@ from .utils import (
 )
 
 
[email protected]_topi_compute("conv2d_nhwc.image2d")
-def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float16"}
-    return compute_conv2d_NHWC_HWIO(data, kernel, strides, padding, dilation, 
out_dtype, args=args)
-
-
[email protected]_topi_compute("conv2d_nhwc_acc32.image2d")
-def conv2d_nhwc_acc32(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float32"}
-    return compute_conv2d_NHWC_HWIO(data, kernel, strides, padding, dilation, 
out_dtype, args=args)
-
-
 @autotvm.register_topi_schedule("conv2d_nhwc.image2d")
 def schedule_conv2d_nhwc(cfg, outs):
-    return schedule_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("conv2d_nhwc_acc32.image2d")
-def schedule_conv2d_nhwc_acc32(cfg, outs):
-    return schedule_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc32")
-
-
-def schedule_conv2d_nhwc_impl(cfg, outs, tag):
     """Create the schedule for conv2d_nhwc"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == tag:
+        if op.tag == "adreno_conv2d_latest_op":
             schedule_conv2d_NHWC(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-def compute_conv2d_NHWC_HWIO(Input, Filter, stride, padding, dilation, 
out_dtype, args):
[email protected]_topi_compute("conv2d_nhwc.image2d")
+def conv2d_nhwc(cfg, Input, Filter, stride, padding, dilation, out_dtype):
     """
     Convolution operator in NHWC layout.
     Algo:
@@ -105,18 +83,12 @@ def compute_conv2d_NHWC_HWIO(Input, Filter, stride, 
padding, dilation, out_dtype
     convert_from4d = False
     if len(Input.shape) == 4:
         batch, in_height, in_width, in_channels = Input.shape
-        kernel_h, kernel_w, in_filter_channels, out_channles = Filter.shape
-
         in_channel_chunks, in_channel_block, in_channel_tail = 
split_to_chunks(in_channels, 4)
-        out_channel_chunks, out_channel_block, out_channel_tail = 
split_to_chunks(out_channles, 4)
 
         if autotvm.GLOBAL_SCOPE.in_tuning:
             dshape = (batch, in_height, in_width, in_channel_chunks, 
in_channel_block)
             Input = tvm.te.placeholder(dshape, Input.dtype, 
name="data_placeholder")
-            kshape = (kernel_h, kernel_w, in_filter_channels, 
out_channel_chunks, out_channel_block)
-            Filter = tvm.te.placeholder(kshape, Filter.dtype, 
name="kernel_placeholder")
         else:
-            convert_from4d = True
             Input = pack_input(
                 Input,
                 "NHWC",
@@ -127,6 +99,17 @@ def compute_conv2d_NHWC_HWIO(Input, Filter, stride, 
padding, dilation, out_dtype
                 in_height,
                 in_width,
             )
+    else:
+        batch, in_height, in_width, in_channel_chunks, in_channel_block = 
Input.shape
+
+    if len(Filter.shape) == 4:
+        kernel_h, kernel_w, in_filter_channels, out_channles = Filter.shape
+        out_channel_chunks, out_channel_block, out_channel_tail = 
split_to_chunks(out_channles, 4)
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kshape = (kernel_h, kernel_w, in_filter_channels, 
out_channel_chunks, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, 
name="kernel_placeholder")
+        else:
+            convert_from4d = True
             Filter = pack_filter(
                 Filter,
                 "HWIO",
@@ -140,9 +123,7 @@ def compute_conv2d_NHWC_HWIO(Input, Filter, stride, 
padding, dilation, out_dtype
                 kernel_h,
                 kernel_w,
             )
-
     else:
-        batch, in_height, in_width, in_channel_chunks, in_channel_block = 
Input.shape
         kernel_h, kernel_w, in_filter_channels, out_channel_chunks, 
out_channel_block = Filter.shape
 
     out_height_orig, out_height, out_width_orig, out_width = 
expand_spatial_dimensions(
@@ -173,7 +154,7 @@ def compute_conv2d_NHWC_HWIO(Input, Filter, stride, 
padding, dilation, out_dtype
             (
                 temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * 
dilation_w, rcc, rcb]
                 * Filter[ry, rx, rcc * in_channel_block + rcb, fc, fb]
-            ).astype(args["accumulator"]),
+            ).astype(out_dtype),
             axis=[ry, rx, rcc, rcb],
         ),
         tag="conv2d_nhwc",
@@ -188,13 +169,13 @@ def compute_conv2d_NHWC_HWIO(Input, Filter, stride, 
padding, dilation, out_dtype
         return te.compute(
             (batch, out_height_orig, out_width_orig, out_channles),
             lambda n, y, x, c: dummy_cast[n, y, x, c // out_channel_block, c % 
out_channel_block],
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_conv2d_latest_op",
         )
     else:
         return te.compute(
             (batch, out_height_orig, out_width_orig, out_channel_chunks, 
out_channel_block),
             lambda n, y, x, ffc, ffb: conv[n, y, x, ffc, 
ffb].astype(out_dtype),
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_conv2d_latest_op",
         )
 
 
@@ -229,6 +210,19 @@ def schedule_conv2d_NHWC(cfg, s, output):
         conv = output.op.input_tensors[0]
         latest_blocked = latest
 
+    pad_data, kernel = s[conv].op.input_tensors
+    filter_pack_rt = bool(
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
+    )
+
+    if "pad_temp" in pad_data.op.name:
+        input_pad_temp = pad_data.op.input_tensors[0]
+    else:
+        input_pad_temp = pad_data
+
+    input_pack_rt = bool(
+        isinstance(input_pad_temp.op, tvm.te.ComputeOp) and "input_pack" in 
input_pad_temp.op.tag
+    )
     ##### space definition begin #####
     n, y, x, fc, fb = s[conv].op.axis
     ry, rx, rcc, rcb = s[conv].op.reduce_axis
@@ -270,37 +264,40 @@ def schedule_conv2d_NHWC(cfg, s, output):
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
-    if (
-        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
-    ):  # len(latest.op.axis) == 4:
-        # manage scheduling of datacopy
-        pad_data, kernel = s[conv].op.input_tensors
-        if "pad_temp" in pad_data.op.name:
-            pack_data = pad_data.op.input_tensors[0]
-            bind_data_copy(s[pack_data])
+    # There are several conditions that have to be handled:
+    # 1. If we are in the tuning, we always add cache read for data to main 
conv kernel
+    #    to get texture in tuning opencl kernel
+    # 2. If we are repacking input in runtime, we should always explicit 
schedule this one more
+    #    stage of data copy from 4d to 5d (referred as pack_data).
+    # 3. If we have pad (independently if we have runtime repack or not) we 
should inline it in the
+    #    cache_read("texture")
+    if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt:
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            if "pad_temp" in pad_data.op.name:
+                s[pad_data].compute_inline()
         else:
-            bind_data_copy(s[pad_data])
-        bind_data_copy(s[kernel])
-
-    pad_data, kernel = s[conv].op.input_tensors
+            if "pad_temp" in pad_data.op.name:
+                s[pad_data].compute_inline()
+                pack_data = pad_data.op.input_tensors[0]
+                bind_data_copy(s[pack_data])
+            else:
+                pack_data = pad_data
+                bind_data_copy(s[pack_data])
 
-    if (
-        autotvm.GLOBAL_SCOPE.in_tuning
-        or isinstance(kernel.op, tvm.te.ComputeOp)
-        and "filter_pack" in kernel.op.tag
-    ):
-        if "pad_temp" in pad_data.op.name:
-            s[pad_data].compute_inline()
         AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
         bind_data_copy(s[AT])
-        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
-        bind_data_copy(s[WT])
     elif "pad_temp" in pad_data.op.name:
         s[pad_data].compute_inline()
         # create cache stage
         AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
         bind_data_copy(s[AT])
 
+    if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
+        if not autotvm.GLOBAL_SCOPE.in_tuning:
+            bind_data_copy(s[kernel])
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+        bind_data_copy(s[WT])
+
     s[conv].set_scope("local")
     if latest_blocked == latest and output != latest:
         s[output].compute_inline()
diff --git a/python/tvm/topi/adreno/conv2d_nhwc_winograd.py 
b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py
index bfe385f210..b055b388e1 100644
--- a/python/tvm/topi/adreno/conv2d_nhwc_winograd.py
+++ b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py
@@ -27,62 +27,32 @@ logger = logging.getLogger("conv2d_nhwc_winograd")
 
 @autotvm.register_topi_compute("conv2d_nhwc_winograd.image2d")
 def conv2d_nhwc_winograd(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
-    args = {"shared": False, "accumulator": "float16"}
     return conv2d_nhwc_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=False
-    )
-
-
[email protected]_topi_compute("conv2d_nhwc_winograd_acc32.image2d")
-def conv2d_nhwc_winograd_acc32(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
-    args = {"shared": False, "accumulator": "float32"}
-    return conv2d_nhwc_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=False
+        cfg, data, kernel, strides, padding, dilation, out_dtype, 
pre_computed=False
     )
 
 
 @autotvm.register_topi_schedule("conv2d_nhwc_winograd.image2d")
 def schedule_conv2d_nhwc_winograd(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("conv2d_nhwc_winograd_acc32.image2d")
-def schedule_conv2d_nhwc_winograd_acc32(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32")
+    return schedule_conv2d_winograd_impl(cfg, outs, tag="dummy_compute_at")
 
 
 
@autotvm.register_topi_compute("conv2d_nhwc_winograd_without_weight_transform.image2d")
 def conv2d_nhwc_winograd_without_weight_transform(
     cfg, data, kernel, strides, padding, dilation, out_dtype
 ):
-    args = {"shared": False, "accumulator": "float16"}
     return conv2d_nhwc_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=True
-    )
-
-
[email protected]_topi_compute("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d")
-def conv2d_nhwc_winograd_without_weight_transform_acc32(
-    cfg, data, kernel, strides, padding, dilation, out_dtype
-):
-    args = {"shared": False, "accumulator": "float32"}
-    return conv2d_nhwc_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, 
pre_computed=True
+        cfg, data, kernel, strides, padding, dilation, out_dtype, 
pre_computed=True
     )
 
 
 
@autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform.image2d")
 def schedule_conv2d_nhwc_winograd_without_weight_transform(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", 
pre_computed=True)
-
-
[email protected]_topi_schedule("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d")
-def schedule_conv2d_nhwc_winograd_without_weight_transform_acc32(cfg, outs):
-    return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", 
pre_computed=True)
+    return schedule_conv2d_winograd_impl(cfg, outs, tag="dummy_compute_at", 
pre_computed=True)
 
 
 def conv2d_nhwc_winograd_comp(
-    cfg, data, kernel, strides, padding, dilation, out_dtype, args, 
pre_computed
+    cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed
 ):
     """Compute declaration for winograd
 
@@ -111,9 +81,6 @@ def conv2d_nhwc_winograd_comp(
     out_dtype: str
         The output type. This is used for mixed precision.
 
-    args: dict
-        Dictionary with additional arguments, e.g. accumulator type
-
     pre_computed: bool
         Flag if weights were pre computed if true or the weights should be
         computed in runtime
@@ -124,5 +91,5 @@ def conv2d_nhwc_winograd_comp(
         4-D or 5-D with shape NCHW or NCHW4c
     """
     return conv2d_winograd_comp(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, args, 
pre_computed, "NHWC"
+        cfg, data, kernel, strides, padding, dilation, out_dtype, 
pre_computed, "NHWC"
     )
diff --git a/python/tvm/topi/adreno/conv2d_winograd_common.py 
b/python/tvm/topi/adreno/conv2d_winograd_common.py
index b0cec0f702..501773ad46 100644
--- a/python/tvm/topi/adreno/conv2d_winograd_common.py
+++ b/python/tvm/topi/adreno/conv2d_winograd_common.py
@@ -35,7 +35,7 @@ from .utils import (
 
 
 def conv2d_winograd_comp(
-    cfg, data, kernel, strides, padding, dilation, out_dtype, args, 
pre_computed, layout
+    cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed, 
layout
 ):
     """Compute declaration for winograd
 
@@ -64,9 +64,6 @@ def conv2d_winograd_comp(
     out_dtype: str
         The output type. This is used for mixed precision.
 
-    args: dict
-        Dictionary with additional arguments, e.g. accumulator type
-
     pre_computed: bool
         Flag if weights were pre computed if true or the weights should be
         computed in runtime
@@ -186,7 +183,7 @@ def conv2d_winograd_comp(
 
     r = KW
     m = tile_size
-    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+    A, B, G = winograd_transform_matrices(m, r, data.dtype)
 
     H = (H + pt + pb - KH) // HSTR + 1
     W = (W + pl + pr - KW) // WSTR + 1
@@ -268,7 +265,7 @@ def conv2d_winograd_comp(
         lambda eps, nu, co, p, cob: te.sum(
             (
                 kernel_pack[eps][nu][ci * CB + cb][co][cob] * 
data_pack_trans[eps][nu][ci][p][cb]
-            ).astype(args["accumulator"]),
+            ).astype(out_dtype),
             axis=[ci, cb],
         ),
         name="bgemm",
@@ -280,7 +277,7 @@ def conv2d_winograd_comp(
     inverse = te.compute(
         (CO, P, m, m, COB),
         lambda co, p, vh, vw, cob: te.sum(
-            bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * 
A[r_b][vw]).astype(args["accumulator"]),
+            bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * 
A[r_b][vw]).astype(out_dtype),
             axis=[r_a, r_b],
         ),
         name="inverse",
@@ -295,7 +292,7 @@ def conv2d_winograd_comp(
                     idxmod(h, m)
                 ][idxmod(w, m)][c % CB].astype(out_dtype),
                 name="output",
-                tag="cast_from_acc" + args["accumulator"][-2:],
+                tag="dummy_compute_at",
             )
         else:
             output = te.compute(
@@ -304,7 +301,7 @@ def conv2d_winograd_comp(
                     n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)
                 ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype),
                 name="output",
-                tag="cast_from_acc" + args["accumulator"][-2:],
+                tag="dummy_compute_at",
             )
     else:
         if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False:
@@ -314,7 +311,7 @@ def conv2d_winograd_comp(
                     idxmod(h, m)
                 ][idxmod(w, m)][c % CB].astype(out_dtype),
                 name="output",
-                tag="cast_from_acc" + args["accumulator"][-2:],
+                tag="dummy_compute_at",
             )
         else:
             output = te.compute(
@@ -323,7 +320,7 @@ def conv2d_winograd_comp(
                     n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)
                 ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype),
                 name="output",
-                tag="cast_from_acc" + args["accumulator"][-2:],
+                tag="dummy_compute_at",
             )
 
     if isinstance(N, int):
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py 
b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
index 37713b4584..eb998bdbcd 100644
--- a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
@@ -33,50 +33,22 @@ from .utils import (
 )
 
 
[email protected]_topi_compute("depthwise_conv2d_nchwc.image2d")
-def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute depthwise_conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float16"}
-    return compute_depthwise_conv2d_NCHWc_KCRSk(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
[email protected]_topi_compute("depthwise_conv2d_nchwc_acc32.image2d")
-def depthwise_conv2d_nchwc_acc32(
-    cfg, data, kernel, strides, padding, dilation, out_dtype="float16"
-):
-    """Compute depthwise_conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float32"}
-    return compute_depthwise_conv2d_NCHWc_KCRSk(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
 @autotvm.register_topi_schedule("depthwise_conv2d_nchwc.image2d")
 def schedule_depthwise_conv2d_nchwc(cfg, outs):
-    return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, 
tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("depthwise_conv2d_nchwc_acc32.image2d")
-def schedule_depthwise_conv2d_nchwc_acc32(cfg, outs):
-    return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, 
tag="cast_from_acc32")
-
-
-def schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag):
     """Create the schedule for depthwise conv2d_nchw4c_ohwi4o"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == tag:
+        if op.tag == "adreno_dw_conv2d_latest_op":
             schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, 
dilation, out_dtype, args):
[email protected]_topi_compute("depthwise_conv2d_nchwc.image2d")
+def depthwise_conv2d_nchwc(cfg, Input, Filter, stride, padding, dilation, 
out_dtype):
     """
     Depthwise convolution operator in NCHWc layout.
     Algo:
@@ -183,10 +155,10 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, 
stride, padding, dilatio
                     ffb,
                 ]
                 * Filter[ffc // in_filter_channels, ffc % in_filter_channels, 
ry, rx, ffb]
-            ).astype(args["accumulator"]),
+            ).astype(out_dtype),
             axis=[ry, rx],
         ),
-        tag="depthwise_conv2d_nchwc_kcrsk",
+        tag="depthwise_conv2d_nchwc",
     )
 
     if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
@@ -198,13 +170,13 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, 
stride, padding, dilatio
         return te.compute(
             (batch, out_channles, out_height_orig, out_width_orig),
             lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % 
out_channel_block],
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_dw_conv2d_latest_op",
         )
     else:
         return te.compute(
             (batch, out_channel_chunks, out_height_orig, out_width_orig, 
out_channel_block),
             lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, 
ffb].astype(out_dtype),
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_dw_conv2d_latest_op",
         )
 
 
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py 
b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
index 2b228b444f..c27f2a9eae 100644
--- a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
@@ -33,48 +33,22 @@ from .utils import (
 )
 
 
[email protected]_topi_compute("depthwise_conv2d_nhwc.image2d")
-def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute depthwise_conv2d with NHWC layout"""
-    args = {"shared": False, "accumulator": "float16"}
-    return compute_depthwise_conv2d_NHWC_HWOI(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
[email protected]_topi_compute("depthwise_conv2d_nhwc_acc32.image2d")
-def depthwise_conv2d_nhwc_acc32(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute depthwise_conv2d with NHWC layout"""
-    args = {"shared": False, "accumulator": "float32"}
-    return compute_depthwise_conv2d_NHWC_HWOI(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
 @autotvm.register_topi_schedule("depthwise_conv2d_nhwc.image2d")
 def schedule_depthwise_conv2d_nhwc(cfg, outs):
-    return schedule_depthwise_conv2d_nhwc_impl(cfg, outs, 
tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("depthwise_conv2d_nhwc_acc32.image2d")
-def schedule_depthwise_conv2d_nhwc_acc32(cfg, outs):
-    return schedule_depthwise_conv2d_nhwc_impl(cfg, outs, 
tag="cast_from_acc32")
-
-
-def schedule_depthwise_conv2d_nhwc_impl(cfg, outs, tag):
     """Create the schedule for depthwise conv2d_nchw4c_ohwi4o"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == tag:
+        if op.tag == "adreno_dw_conv2d_latest_op":
             schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-def compute_depthwise_conv2d_NHWC_HWOI(Input, Filter, stride, padding, 
dilation, out_dtype, args):
[email protected]_topi_compute("depthwise_conv2d_nhwc.image2d")
+def depthwise_conv2d_nhwc(cfg, Input, Filter, stride, padding, dilation, 
out_dtype):
     """
     Depthwise convolution operator in NCHWc layout.
     Algo:
@@ -175,7 +149,7 @@ def compute_depthwise_conv2d_NHWC_HWOI(Input, Filter, 
stride, padding, dilation,
             (
                 temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * 
dilation_w, ffc, ffb]
                 * Filter[ry, rx, ffc, 0, ffb]
-            ).astype(args["accumulator"]),
+            ).astype(out_dtype),
             axis=[ry, rx],
         ),
         tag="depthwise_conv2d_nhwc",
@@ -190,13 +164,13 @@ def compute_depthwise_conv2d_NHWC_HWOI(Input, Filter, 
stride, padding, dilation,
         return te.compute(
             (batch, out_height_orig, out_width_orig, out_channles),
             lambda n, y, x, c: dummy_cast[n, y, x, c // out_channel_block, c % 
out_channel_block],
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_dw_conv2d_latest_op",
         )
     else:
         return te.compute(
             (batch, out_height_orig, out_width_orig, out_channel_chunks, 
out_channel_block),
             lambda n, y, x, ffc, ffb: conv[n, y, x, ffc, 
ffb].astype(out_dtype),
-            tag="cast_from_acc" + args["accumulator"][-2:],
+            tag="adreno_dw_conv2d_latest_op",
         )
 
 
diff --git a/tests/python/relay/test_conv2d_nchw_texture.py 
b/tests/python/relay/test_conv2d_nchw_texture.py
index 6eadd8fc1c..ab12e40b39 100644
--- a/tests/python/relay/test_conv2d_nchw_texture.py
+++ b/tests/python/relay/test_conv2d_nchw_texture.py
@@ -437,7 +437,7 @@ def test_conv2d_vgg16_winograd_4d():
     stat_file = temp.relpath("stat.log")
     with open(stat_file, "w") as f:
         f.write(
-            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 
512, 28, 28], "float16"], ["TENSOR", [512, 512, 3, 3], "float16"], [1, 1], [1, 
1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, 
"entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], 
["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": 
[[0.0037244], 0, 7.06374192237854, 165 [...]
+            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 512, 28, 
28], "float16"], ["TENSOR", [512, 512, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], 
[1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 
7.06374192237854, 165389862 [...]
         )
     graph = build_run_compare(
         mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file
@@ -486,7 +486,7 @@ def test_conv2d_winograd_conv():
     stat_file = temp.relpath("stat.log")
     with open(stat_file, "w") as f:
         f.write(
-            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 4, 
3, 3], "float16"], ["TENSOR", [8, 4, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], 
[1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 
7.06374192237854, 1653898629. [...]
+            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 4, 3, 
3], "float16"], ["TENSOR", [8, 4, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], [1, 
1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 
7.06374192237854, 1653898629.742793 [...]
         )
     graph = build_run_compare(
         mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file
diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py 
b/tests/python/relay/test_conv2d_nhwc_texture.py
index be5cefd460..cf8116c076 100644
--- a/tests/python/relay/test_conv2d_nhwc_texture.py
+++ b/tests/python/relay/test_conv2d_nhwc_texture.py
@@ -598,7 +598,7 @@ def test_conv2d_vgg16_winograd_4d():
     stat_file = temp.relpath("stat.log")
     with open(stat_file, "w") as f:
         f.write(
-            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nhwc_winograd_acc32.image2d", [["TENSOR", [1, 
28, 28, 512], "float16"], ["TENSOR", [3, 3, 512, 512], "float16"], [1, 1], [1, 
1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, 
"entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], 
["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": 
[[0.0037244], 0, 7.06374192237854, 165 [...]
+            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 28, 28, 
512], "float16"], ["TENSOR", [3, 3, 512, 512], "float16"], [1, 1], [1, 1, 1, 
1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, 
"entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], 
["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": 
[[0.0037244], 0, 7.06374192237854, 165389862 [...]
         )
     graph = build_run_compare(
         mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file

Reply via email to