csullivan commented on code in PR #12537:
URL: https://github.com/apache/tvm/pull/12537#discussion_r952839198


##########
python/tvm/topi/adreno/conv2d_nhwc.py:
##########
@@ -270,37 +264,27 @@ def schedule_conv2d_NHWC(cfg, s, output):
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
-    if (
-        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
-    ):  # len(latest.op.axis) == 4:
-        # manage scheduling of datacopy
-        pad_data, kernel = s[conv].op.input_tensors
+    if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt:
         if "pad_temp" in pad_data.op.name:
+            s[pad_data].compute_inline()
             pack_data = pad_data.op.input_tensors[0]
             bind_data_copy(s[pack_data])
         else:
             bind_data_copy(s[pad_data])
-        bind_data_copy(s[kernel])
-
-    pad_data, kernel = s[conv].op.input_tensors
 
-    if (
-        autotvm.GLOBAL_SCOPE.in_tuning
-        or isinstance(kernel.op, tvm.te.ComputeOp)
-        and "filter_pack" in kernel.op.tag
-    ):
-        if "pad_temp" in pad_data.op.name:
-            s[pad_data].compute_inline()
         AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
         bind_data_copy(s[AT])
-        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
-        bind_data_copy(s[WT])
     elif "pad_temp" in pad_data.op.name:
         s[pad_data].compute_inline()
         # create cache stage
         AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
         bind_data_copy(s[AT])
 
+    if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
+        bind_data_copy(s[kernel])
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])

Review Comment:
   Why do we want a cache_read when using filter packing?



##########
python/tvm/topi/adreno/conv2d_nchw.py:
##########
@@ -33,48 +33,22 @@
 )
 
 
[email protected]_topi_compute("conv2d_nchwc.image2d")
-def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float16"}
-    return compute_conv2d_NCHWc_KCRSk(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
[email protected]_topi_compute("conv2d_nchwc_acc32.image2d")
-def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float16"):
-    """Compute conv2d with NCHWc layout"""
-    args = {"shared": False, "accumulator": "float32"}
-    return compute_conv2d_NCHWc_KCRSk(
-        data, kernel, strides, padding, dilation, out_dtype, args=args
-    )
-
-
 @autotvm.register_topi_schedule("conv2d_nchwc.image2d")
 def schedule_conv2d_nchwc(cfg, outs):
-    return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16")
-
-
[email protected]_topi_schedule("conv2d_nchwc_acc32.image2d")
-def schedule_conv2d_nchwc_acc32(cfg, outs):
-    return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32")
-
-
-def schedule_conv2d_nchwc_impl(cfg, outs, tag):
     """Create the schedule for conv2d_nchw"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if op.tag == tag:
+        if op.tag == "dummy_compute_at":

Review Comment:
   nit: A better op.tag than `dummy_compute_at` is probably desirable here and 
throughout.



##########
python/tvm/topi/adreno/conv2d_nhwc.py:
##########
@@ -270,37 +264,27 @@ def schedule_conv2d_NHWC(cfg, s, output):
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
-    if (
-        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
-    ):  # len(latest.op.axis) == 4:
-        # manage scheduling of datacopy
-        pad_data, kernel = s[conv].op.input_tensors
+    if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt:
         if "pad_temp" in pad_data.op.name:
+            s[pad_data].compute_inline()

Review Comment:
   I read this as: If there is input packing we inline the padding into the 
packing step. Is this correct?



##########
tests/python/relay/test_conv2d_nchw_texture.py:
##########
@@ -437,7 +437,7 @@ def test_conv2d_vgg16_winograd_4d():
     stat_file = temp.relpath("stat.log")
     with open(stat_file, "w") as f:
         f.write(
-            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 
512, 28, 28], "float16"], ["TENSOR", [512, 512, 3, 3], "float16"], [1, 1], [1, 
1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, 
"entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], 
["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": 
[[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, 
"tvm_version": "0.8.dev0"}\n'
+            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno 
-max_num_threads=256", "conv2d_nchw_winograd.image2d", [["TENSOR", [1, 512, 28, 
28], "float16"], ["TENSOR", [512, 512, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], 
[1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": 
[["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", 
"sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 
7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": 
"0.8.dev0"}\n'

Review Comment:
   🎉 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to