This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch revert-15311-pool-2d-sched
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 32ddd274abb1e4913c30a8c69cc6a7d83c1f504e
Author: Junru Shao <[email protected]>
AuthorDate: Thu Jul 20 10:29:41 2023 -0700

    Revert "[topi] Add `arm_cpu` specific pooling schedules (#15311)"
    
    This reverts commit 0a3ad644e55f7b07852c52a573580baac95b110e.
---
 python/tvm/relay/op/strategy/arm_cpu.py          | 17 +++--
 python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py |  2 -
 python/tvm/topi/arm_cpu/mprofile/dsp/pool.py     | 30 +++-----
 python/tvm/topi/arm_cpu/pooling.py               | 93 +-----------------------
 tests/python/topi/python/test_topi_pooling.py    |  1 -
 5 files changed, 25 insertions(+), 118 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 3a9f7e1c11..dc3b16aa82 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -26,7 +26,6 @@ from tvm import relay, topi, tir
 from ....auto_scheduler import is_auto_scheduler_enabled
 from ....meta_schedule import is_meta_schedule_enabled
 from ....topi.generic import conv2d as conv2d_generic
-from ....topi.arm_cpu.mprofile import dsp
 from .. import op as _op
 from .generic import *
 
@@ -64,11 +63,19 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, 
target):
 def schedule_pool_arm_cpu(attrs, outs, target):
     """schedule pooling ops arm cpu"""
     layout = attrs.layout
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
     with target:
-        if target.features.has_dsp:
-            is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
-            return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
-        return topi.arm_cpu.schedule_pool(outs, layout)
+        if (
+            avg_pool
+            and target.features.has_dsp
+            and layout in ("NCW", "NCHW")
+            or not avg_pool
+            and target.features.has_dsp
+            and layout in ("NWC", "NHWC")
+        ):
+            return topi.arm_cpu.schedule_pool(outs, layout)
+        logger.warning("pool is not optimized for arm cpu.")
+        return topi.generic.schedule_pool(outs, layout)
 
 
 def _get_padding_width(padding):
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
index 35e3f35a10..13a83393a9 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
@@ -14,5 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Schedule for arm_cpu targets supporting DSP"""
-from .pool import schedule_pool
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
index 90da8072bc..4416831124 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
@@ -23,9 +23,15 @@ import tvm
 from tvm import te
 from tvm.topi.utils import traverse_inline
 
-from .micro_kernel.max_pool import intrin_max, max_impl
-from .micro_kernel.avg_pool import intrin_sum, sum_impl
-from .... import generic
+from .micro_kernel.max_pool import (
+    intrin_max,
+    max_impl,
+)
+
+from .micro_kernel.avg_pool import (
+    intrin_sum,
+    sum_impl,
+)
 
 logger = logging.getLogger("topi")
 
@@ -94,24 +100,8 @@ def schedule_avgpool_2d_nchw(s, op):
     s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))
 
 
-def schedule_pool(outs, layout, is_avg_pool):
+def pool_dsp_schedule(outs, layout):
     """Schedule function for v7e-m DSP instructions of pooling."""
-
-    if is_avg_pool and layout not in ["NCW", "NCHW"]:
-        logger.warning(
-            "avg pool not support for NCW or NCHW layouts on DSP"
-            "enabled targets, falling back on generic pool"
-            "implementation"
-        )
-        return generic.schedule_pool(outs, layout)
-    elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
-        logger.warning(
-            "max pool not support for NWC or NHWC layouts on DSP"
-            "enabled targets, falling back on generic pool"
-            "implementation"
-        )
-        return generic.schedule_pool(outs, layout)
-
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
diff --git a/python/tvm/topi/arm_cpu/pooling.py 
b/python/tvm/topi/arm_cpu/pooling.py
index fee6983cc4..f09f008934 100644
--- a/python/tvm/topi/arm_cpu/pooling.py
+++ b/python/tvm/topi/arm_cpu/pooling.py
@@ -17,96 +17,9 @@
 # pylint: disable=invalid-name, unused-variable
 """Schedule for pooling operators"""
 
-import logging
-from tvm import te
-from tvm.target import Target
-
-from .. import tag
-from .. import generic
+from .mprofile.dsp.pool import pool_dsp_schedule
 
 
 def schedule_pool(outs, layout):
-    """Create schedule for avgpool/maxpool"""
-
-    if layout != "NHWC":
-        logger = logging.getLogger("topi")
-        logger.warning(
-            """We currently only support NHWC target specific pools on arm_cpu,
-               falling back on generic pool scheduling"""
-        )
-        return generic.schedule_pool(outs, layout)
-
-    return schedule_pool_2d(outs)
-
-
-def schedule_pool_2d(outs):
-    """Create arm_cpu specific 2D schedule for avgpool/maxpool"""
-
-    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
-    schedule_ops = [x.op for x in outs]
-    schedule = te.create_schedule(schedule_ops)
-    scheduled_ops = []
-
-    def traverse(op):
-        # Recursively inline any injective operation that isn't the pooling
-        # operation or hasn't already been scheduled.
-        if tag.is_injective(op.tag):
-            if op not in schedule.outputs:
-                schedule[op].compute_inline()
-            for tensor in op.input_tensors:
-                if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op 
not in scheduled_ops:
-                    traverse(tensor.op)
-        # schedule the actual pooling operation
-        elif op.tag.startswith("pool"):
-            n, height, width, channel = schedule[op].op.axis
-            # Average pool consists of two parts; a sum then a division.
-            # We can schedule the division loop to parallelize across height 
and
-            # vectorize across width.
-            enable_explicit_vectorization = not 
Target.current(allow_none=False).features.has_sve
-            if op != outs[0].op:
-                output = outs[0]
-                output_fused = schedule[output].fuse(output.op.axis[1], 
output.op.axis[2])
-                schedule[output].parallel(output_fused)
-                vectorization_factor = (
-                    8 if enable_explicit_vectorization else 
output.op.axis[3].dom.extent
-                )
-                _, inner = schedule[output].split(output.op.axis[3], 
vectorization_factor)
-                schedule[output].vectorize(inner)
-
-            padded_input = op.input_tensors[0]
-            if isinstance(padded_input.op, te.tensor.ComputeOp):
-                schedule[padded_input].compute_inline()
-
-            # For targets without SVE try explicitly vectorizing the channel
-            # loop, For SVE targets leave the loop in place for LLVM to convert
-            # into a scalable vector loop.
-            vectorization_factor = 8 if enable_explicit_vectorization else 
channel.dom.extent
-            channel_outer, channel_inner = schedule[op].split(channel, 
vectorization_factor)
-            schedule[op].vectorize(channel_inner)
-            schedule[op].parallel(height)
-            if len(schedule[op].op.reduce_axis) > 0:
-                filter_height, filter_width = schedule[op].op.reduce_axis
-                # We consider any filter of area < 10 to be small enough to
-                # unroll; 3x3 filters have shown better performance when
-                # unrolled.
-                if filter_height.dom.extent * filter_width.dom.extent <= 9:
-                    # For small filters, unrolling the filter loops allows us 
to
-                    # vectorize over channels without reordering anything.
-                    schedule[op].unroll(filter_width)
-                    schedule[op].unroll(filter_height)
-                else:
-                    # Reordering so that channels is the fastest moving axis 
allows
-                    # LLVM to vectorize across contiguous memory in the NHWC
-                    # ordering.
-                    schedule[op].reorder(
-                        n, height, width, filter_height, filter_width, 
channel_outer, channel_inner
-                    )
-            else:
-                schedule[op].reorder(n, height, width, channel_outer, 
channel_inner)
-        else:
-            raise RuntimeError("Unsupported operator: %s" % op.tag)
-
-        scheduled_ops.append(op)
-
-    traverse(outs[0].op)
-    return schedule
+    """Create schedule for avgpool/maxpool with dsp"""
+    return pool_dsp_schedule(outs, layout)
diff --git a/tests/python/topi/python/test_topi_pooling.py 
b/tests/python/topi/python/test_topi_pooling.py
index 0d0ee65ad4..5f8aebabc2 100644
--- a/tests/python/topi/python/test_topi_pooling.py
+++ b/tests/python/topi/python/test_topi_pooling.py
@@ -28,7 +28,6 @@ from tvm.topi.utils import get_const_tuple
 
 _pool_schedule = {
     "generic": topi.generic.schedule_pool,
-    "arm_cpu": topi.arm_cpu.schedule_pool,
     "cpu": topi.x86.schedule_pool,
     "gpu": topi.cuda.schedule_pool,
     "hls": topi.hls.schedule_pool,

Reply via email to