This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1bdedfb466 [OpStrategy] Support MetaSchedule Layout (#11848)
1bdedfb466 is described below

commit 1bdedfb466a6aa5c0176d92123709733e0d25f97
Author: Hongyi Jin <[email protected]>
AuthorDate: Fri Jun 24 01:50:29 2022 +0800

    [OpStrategy] Support MetaSchedule Layout (#11848)
---
 python/tvm/auto_scheduler/relay_integration.py |   3 -
 python/tvm/meta_schedule/__init__.py           |   4 +-
 python/tvm/meta_schedule/relay_integration.py  |  14 +++
 python/tvm/relay/op/strategy/arm_cpu.py        |  13 ++-
 python/tvm/relay/op/strategy/cuda.py           |  32 ++++++-
 python/tvm/relay/op/strategy/generic.py        |  53 ++++++++++--
 python/tvm/relay/op/strategy/mali.py           | 113 +++++++++++++++++--------
 python/tvm/relay/op/strategy/x86.py            | 110 ++++++++++++++++++------
 8 files changed, 260 insertions(+), 82 deletions(-)

diff --git a/python/tvm/auto_scheduler/relay_integration.py 
b/python/tvm/auto_scheduler/relay_integration.py
index ee166e8679..9541232a6a 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -483,7 +483,4 @@ def is_auto_scheduler_enabled():
     return PassContext.current().config.get(
         "relay.backend.use_auto_scheduler",
         False,
-    ) or PassContext.current().config.get(
-        "relay.backend.use_meta_schedule",
-        False,
     )
diff --git a/python/tvm/meta_schedule/__init__.py 
b/python/tvm/meta_schedule/__init__.py
index 26cf446b10..eb40b32e7c 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -30,10 +30,10 @@ from . import (
     search_strategy,
     space_generator,
 )
-from .profiler import Profiler
 from .apply_history_best import ApplyHistoryBest
 from .extracted_task import ExtractedTask
-from .relay_integration import extract_task_from_relay
+from .profiler import Profiler
+from .relay_integration import extract_task_from_relay, 
is_meta_schedule_enabled
 from .search_strategy import MeasureCandidate
 from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, 
tune_tir
 from .tune_context import TuneContext
diff --git a/python/tvm/meta_schedule/relay_integration.py 
b/python/tvm/meta_schedule/relay_integration.py
index 833f100a0d..84a6c55956 100644
--- a/python/tvm/meta_schedule/relay_integration.py
+++ b/python/tvm/meta_schedule/relay_integration.py
@@ -103,3 +103,17 @@ def extract_task_from_relay(
         disabled_pass=disabled_pass,
     ):
         return list(extract_task_func(mod, target, relay_params, 
te_filter_func))
+
+
+def is_meta_schedule_enabled() -> bool:
+    """Return whether the meta-schedule is enabled.
+
+    Returns
+    -------
+    enabled: bool
+        Whether the meta schedule is enabled
+    """
+    return transform.PassContext.current().config.get(
+        "relay.backend.use_meta_schedule",
+        False,
+    )
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 6ccb449d0e..4c5af610d7 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -15,16 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 """Definition of ARM CPU operator strategy."""
+import logging
+
 # pylint: 
disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 import re
-import logging
 
 from tvm import relay, topi
+
+from ....auto_scheduler import is_auto_scheduler_enabled
+from ....meta_schedule import is_meta_schedule_enabled
 from ....target import arm_isa
 from ....topi.generic import conv2d as conv2d_generic
-from ....auto_scheduler import is_auto_scheduler_enabled
-from .generic import *
 from .. import op as _op
+from .generic import *
 
 logger = logging.getLogger("strategy")
 
@@ -477,7 +480,9 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
         logger.warning("dense is not optimized for arm cpu.")
         strategy.add_implementation(
             wrap_compute_dense(
-                topi.nn.dense, 
need_auto_scheduler_layout=is_auto_scheduler_enabled()
+                topi.nn.dense,
+                need_auto_scheduler_layout=is_auto_scheduler_enabled(),
+                need_meta_schedule_layout=is_meta_schedule_enabled(),
             ),
             wrap_topi_schedule(topi.generic.schedule_dense),
             name="dense.generic",
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 4a7cff5f3f..072b958da2 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -20,11 +20,12 @@ from tvm import topi
 from tvm.auto_scheduler import is_auto_scheduler_enabled
 from tvm.contrib import nvcc
 from tvm.contrib.thrust import can_use_thrust
+from tvm.meta_schedule import is_meta_schedule_enabled
 from tvm.te import SpecializedCondition
 
-from .. import op as _op
 from ....target import Target
 from ....tir import IntImm
+from .. import op as _op
 from .generic import *
 
 
@@ -251,7 +252,17 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                 )
 
             # register auto-scheduler implementations
-            if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
+            if (
+                is_auto_scheduler_enabled() or is_meta_schedule_enabled()
+            ) and judge_winograd_auto_scheduler:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
+                    naive_schedule,  # this implementation should never be 
picked by autotvm
+                    name="conv2d_nhwc.winograd",
+                    plevel=15,
+                )
+            # register meta-schedule implementations
+            if is_meta_schedule_enabled() and judge_winograd_auto_scheduler:
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
                     naive_schedule,  # this implementation should never be 
picked by autotvm
@@ -534,7 +545,14 @@ def 
conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
                 
name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
             )
 
-        if is_auto_scheduler_enabled():
+        if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
+            strategy.add_implementation(
+                
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
+                naive_schedule,  # this implementation should never be picked 
by autotvm
+                name="conv2d_nhwc_winograd_without_weight_transform",
+                plevel=15,
+            )
+        if is_meta_schedule_enabled():
             strategy.add_implementation(
                 
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
                 naive_schedule,  # this implementation should never be picked 
by autotvm
@@ -805,7 +823,13 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target):
     """Matmul cuda strategy."""
     strategy = _op.OpStrategy()
 
-    if is_auto_scheduler_enabled():
+    if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul),
+            naive_schedule,
+            name="matmul.cuda",
+        )
+    elif is_meta_schedule_enabled():
         strategy.add_implementation(
             wrap_compute_matmul(topi.nn.matmul),
             naive_schedule,
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index 2bb009dbc8..4ff7490b89 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -21,7 +21,12 @@ import re
 
 from tvm import _ffi, ir, te, topi
 from tvm.target import generic_func, override_native_generic_func
-from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, 
get_float_tuple
+from tvm.topi.utils import (
+    get_const_float,
+    get_const_int,
+    get_const_tuple,
+    get_float_tuple,
+)
 
 from .. import op as _op
 
@@ -211,6 +216,9 @@ def schedule_bitpack(attrs, outs, target):
 get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
     "relay.attrs.get_auto_scheduler_rewritten_layout"
 )
+get_meta_schedule_original_shape = _ffi.get_global_func(
+    "relay.attrs.get_meta_schedule_original_shape"
+)
 
 # conv2d
 def wrap_compute_conv2d(
@@ -219,6 +227,7 @@ def wrap_compute_conv2d(
     need_out_layout=False,
     has_groups=False,
     need_auto_scheduler_layout=False,
+    need_meta_schedule_layout=False,
 ):
     """Wrap conv2d topi compute"""
 
@@ -240,6 +249,9 @@ def wrap_compute_conv2d(
         args.append(out_dtype)
         if need_auto_scheduler_layout:
             args.append(get_auto_scheduler_rewritten_layout(attrs))
+        elif need_meta_schedule_layout:
+            args.append("")
+            args.append(get_meta_schedule_original_shape(attrs))
         return [topi_compute(*args)]
 
     return _compute_conv2d
@@ -530,7 +542,12 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, 
target):
 
 
 # conv3d
-def wrap_compute_conv3d(topi_compute, need_layout=False, 
need_auto_scheduler_layout=False):
+def wrap_compute_conv3d(
+    topi_compute,
+    need_layout=False,
+    need_auto_scheduler_layout=False,
+    need_meta_schedule_layout=False,
+):
     """wrap conv3d topi compute"""
 
     def _compute_conv3d(attrs, inputs, out_type):
@@ -552,6 +569,9 @@ def wrap_compute_conv3d(topi_compute, need_layout=False, 
need_auto_scheduler_lay
         args.append(out_dtype)
         if need_auto_scheduler_layout:
             args.append(get_auto_scheduler_rewritten_layout(attrs))
+        elif need_meta_schedule_layout:
+            args.append("")
+            args.append(get_meta_schedule_original_shape(attrs))
         return [topi_compute(*args)]
 
     return _compute_conv3d
@@ -782,7 +802,11 @@ def copy_if_identical(tensor_a, tensor_b):
 
 
 # matmul
-def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
+def wrap_compute_matmul(
+    topi_compute,
+    need_auto_scheduler_layout=False,
+    need_meta_schedule_layout=False,
+):
     """wrap matmul topi compute"""
 
     def _compute_matmul(attrs, inputs, out_type):
@@ -799,6 +823,9 @@ def wrap_compute_matmul(topi_compute, 
need_auto_scheduler_layout=False):
         ]
         if need_auto_scheduler_layout:
             args.append(get_auto_scheduler_rewritten_layout(attrs))
+        elif need_meta_schedule_layout:
+            args.append("")
+            args.append(get_meta_schedule_original_shape(attrs))
         args[1] = copy_if_identical(inputs[0], inputs[1])
         return [topi_compute(*args)]
 
@@ -819,7 +846,11 @@ def matmul_strategy(attrs, inputs, out_type, target):
 
 
 # dense
-def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
+def wrap_compute_dense(
+    topi_compute,
+    need_auto_scheduler_layout=False,
+    need_meta_schedule_layout=False,
+):
     """wrap dense topi compute"""
 
     def _compute_dense(attrs, inputs, out_type):
@@ -829,6 +860,9 @@ def wrap_compute_dense(topi_compute, 
need_auto_scheduler_layout=False):
         args = [inputs[0], inputs[1], None, out_dtype]
         if need_auto_scheduler_layout:
             args.append(get_auto_scheduler_rewritten_layout(attrs))
+        elif need_meta_schedule_layout:
+            args.append("")
+            args.append(get_meta_schedule_original_shape(attrs))
         args[1] = copy_if_identical(inputs[0], inputs[1])
         return [topi_compute(*args)]
 
@@ -862,7 +896,13 @@ def dense_pack_strategy(attrs, inputs, out_type, target):
 
 
 # batch_matmul
-def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, 
need_out_dtype=False):
+def wrap_compute_batch_matmul(
+    topi_compute,
+    *,
+    need_auto_scheduler_layout=False,
+    need_meta_schedule_layout=False,
+    need_out_dtype=False,
+):
     """wrap batch_matmul topi compute"""
 
     def _compute_batch_matmul(attrs, inputs, out_type):
@@ -872,6 +912,9 @@ def wrap_compute_batch_matmul(topi_compute, 
need_auto_scheduler_layout=False, ne
         args.append(attrs.transpose_b)
         if need_auto_scheduler_layout:
             args.append(get_auto_scheduler_rewritten_layout(attrs))
+        elif need_meta_schedule_layout:
+            args.append("")
+            args.append(get_meta_schedule_original_shape(attrs))
         args[1] = copy_if_identical(inputs[0], inputs[1])
         return [topi_compute(*args)]
 
diff --git a/python/tvm/relay/op/strategy/mali.py 
b/python/tvm/relay/op/strategy/mali.py
index e5f4b4e585..dca684835b 100644
--- a/python/tvm/relay/op/strategy/mali.py
+++ b/python/tvm/relay/op/strategy/mali.py
@@ -17,10 +17,13 @@
 """Definition of mali operator strategy."""
 # pylint: 
disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 import re
+
 from tvm import topi
 from tvm.auto_scheduler import is_auto_scheduler_enabled
-from .generic import *
+from tvm.meta_schedule import is_meta_schedule_enabled
+
 from .. import op as _op
+from .generic import *
 
 
 @conv2d_strategy.register("mali")
@@ -72,15 +75,15 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
                 )
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
-            if not is_auto_scheduler_enabled():
-                strategy.add_implementation(
-                    wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack),
-                    
wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack),
-                    name="conv2d_nhwc_spatial_pack.mali",
-                )
-            else:
+            need_auto_scheduler_layout = is_auto_scheduler_enabled()
+            need_meta_schedule_layout = is_meta_schedule_enabled()
+            if need_auto_scheduler_layout or need_meta_schedule_layout:
                 strategy.add_implementation(
-                    wrap_compute_conv2d(topi.nn.conv2d_nhwc, 
need_auto_scheduler_layout=True),
+                    wrap_compute_conv2d(
+                        topi.nn.conv2d_nhwc,
+                        need_auto_scheduler_layout=need_auto_scheduler_layout,
+                        need_meta_schedule_layout=need_meta_schedule_layout,
+                    ),
                     naive_schedule,
                     name="conv2d_nhwc.mali",
                 )
@@ -98,14 +101,36 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
                         and dilation_w == 1
                     )
                 if is_winograd_applicable:
-                    strategy.add_implementation(
-                        wrap_compute_conv2d(
-                            topi.nn.conv2d_winograd_nhwc, 
need_auto_scheduler_layout=True
-                        ),
-                        naive_schedule,  # this implementation should never be 
picked by autotvm
-                        name="conv2d_nhwc.winograd",
-                        plevel=15,
-                    )
+                    if need_meta_schedule_layout:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(
+                                topi.nn.conv2d_winograd_nhwc,
+                                need_auto_scheduler_layout=False,
+                                need_meta_schedule_layout=True,
+                            ),
+                            naive_schedule,  # this implementation should 
never be picked by autotvm
+                            name="conv2d_nhwc.winograd",
+                            plevel=15,
+                        )
+                    elif need_auto_scheduler_layout:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(
+                                topi.nn.conv2d_winograd_nhwc,
+                                need_auto_scheduler_layout=True,
+                                need_meta_schedule_layout=False,
+                            ),
+                            naive_schedule,  # this implementation should 
never be picked by autotvm
+                            name="conv2d_nhwc.winograd",
+                            plevel=15,
+                        )
+                    else:
+                        raise RuntimeError("Both AutoScheduler and 
MetaSchedule are not enabled")
+            else:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack),
+                    
wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack),
+                    name="conv2d_nhwc_spatial_pack.mali",
+                )
 
         else:
             raise RuntimeError("Unsupported conv2d layout {} for 
mali".format(layout))
@@ -119,18 +144,24 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
             )
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
-            if not is_auto_scheduler_enabled():
+            if is_auto_scheduler_enabled():
                 strategy.add_implementation(
-                    wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc),
-                    
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc),
+                    wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                    naive_schedule,
                     name="depthwise_conv2d_nhwc.mali",
                 )
-            else:
+            elif is_meta_schedule_enabled():
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                     naive_schedule,
                     name="depthwise_conv2d_nhwc.mali",
                 )
+            else:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc),
+                    
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc),
+                    name="depthwise_conv2d_nhwc.mali",
+                )
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {} for 
mali".format(layout))
     else:  # group_conv2d
@@ -158,19 +189,23 @@ def 
conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty
             name="conv2d_nchw_winograd.mali",
         )
     elif layout == "NHWC":
-        if not is_auto_scheduler_enabled():
+        need_auto_scheduler_layout = is_auto_scheduler_enabled()
+        need_meta_schedule_layout = is_meta_schedule_enabled()
+        if need_auto_scheduler_layout or need_meta_schedule_layout:
+            strategy.add_implementation(
+                wrap_compute_conv2d(
+                    topi.nn.conv2d_winograd_nhwc_without_weight_transform,
+                    need_auto_scheduler_layout=need_auto_scheduler_layout,
+                    need_meta_schedule_layout=need_meta_schedule_layout,
+                ),
+                naive_schedule,  # this implementation should never be picked 
by autotvm
+                name="conv2d_nhwc_winograd_without_weight_transform",
+                plevel=15,
+            )
+        else:
             raise RuntimeError(
                 "Winograd conv2d NHWC is not enabled for mali without 
auto_scheduler."
             )
-        strategy.add_implementation(
-            wrap_compute_conv2d(
-                topi.nn.conv2d_winograd_nhwc_without_weight_transform,
-                need_auto_scheduler_layout=True,
-            ),
-            naive_schedule,  # this implementation should never be picked by 
autotvm
-            name="conv2d_nhwc_winograd_without_weight_transform",
-            plevel=15,
-        )
     else:
         raise RuntimeError(
             "Unsupported conv2d_winograd_without_weight_transfrom layout 
{}".format(layout)
@@ -182,16 +217,22 @@ def 
conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty
 def dense_strategy_mali(attrs, inputs, out_type, target):
     """dense mali strategy"""
     strategy = _op.OpStrategy()
-    if not is_auto_scheduler_enabled():
+    if is_auto_scheduler_enabled():
         strategy.add_implementation(
-            wrap_compute_dense(topi.mali.dense),
-            wrap_topi_schedule(topi.mali.schedule_dense),
+            wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True),
+            naive_schedule,
             name="dense.mali",
         )
-    else:
+    elif is_meta_schedule_enabled():
         strategy.add_implementation(
-            wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True),
+            wrap_compute_dense(topi.nn.dense, need_meta_schedule_layout=True),
             naive_schedule,
             name="dense.mali",
         )
+    else:
+        strategy.add_implementation(
+            wrap_compute_dense(topi.mali.dense),
+            wrap_topi_schedule(topi.mali.schedule_dense),
+            name="dense.mali",
+        )
     return strategy
diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index a032fd00bf..abbc9d9a4c 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -17,16 +17,18 @@
 """Definition of x86 operator strategy."""
 # pylint: 
disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 import logging
-
 import re
-from tvm import topi, tir
-from tvm.topi.x86.utils import target_has_vnni
+
+from tvm import tir, topi
 from tvm.auto_scheduler import is_auto_scheduler_enabled
-from tvm.te import SpecializedCondition
+from tvm.meta_schedule import is_meta_schedule_enabled
 from tvm.relay.ty import is_dynamic
 from tvm.target import Target
-from .generic import *
+from tvm.te import SpecializedCondition
+from tvm.topi.x86.utils import target_has_vnni
+
 from .. import op as _op
+from .generic import *
 
 logger = logging.getLogger("strategy")
 
@@ -111,6 +113,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
     if dilation_h < 1 or dilation_w < 1:
         raise ValueError("dilation should be positive value")
 
+    need_auto_scheduler_layout = is_auto_scheduler_enabled()
+    need_meta_schedule_layout = is_meta_schedule_enabled()
+
     if groups == 1:
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
@@ -137,7 +142,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
             return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
-            if not is_auto_scheduler_enabled():
+            if (not need_auto_scheduler_layout) and (not 
need_meta_schedule_layout):
                 logger.warning("conv2d NHWC layout is not optimized for x86 
with autotvm.")
             if "dnnl" in target.libs:
                 strategy.add_implementation(
@@ -147,7 +152,11 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                 )
             else:
                 strategy.add_implementation(
-                    wrap_compute_conv2d(topi.nn.conv2d_nhwc, 
need_auto_scheduler_layout=True),
+                    wrap_compute_conv2d(
+                        topi.nn.conv2d_nhwc,
+                        need_auto_scheduler_layout=need_auto_scheduler_layout,
+                        need_meta_schedule_layout=need_meta_schedule_layout,
+                    ),
                     wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
                     name="conv2d_nhwc.x86",
                 )
@@ -171,10 +180,14 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                 )
 
             # register auto-scheduler implementations
-            if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
+            if (
+                need_auto_scheduler_layout or need_meta_schedule_layout
+            ) and judge_winograd_auto_scheduler:
                 strategy.add_implementation(
                     wrap_compute_conv2d(
-                        topi.nn.conv2d_winograd_nhwc, 
need_auto_scheduler_layout=True
+                        topi.nn.conv2d_winograd_nhwc,
+                        need_auto_scheduler_layout=need_auto_scheduler_layout,
+                        need_meta_schedule_layout=need_meta_schedule_layout,
                     ),
                     naive_schedule,  # this implementation should never be 
picked by autotvm
                     name="conv2d_nhwc.winograd",
@@ -182,7 +195,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                 )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
-            if not is_auto_scheduler_enabled():
+            if (not need_auto_scheduler_layout) or (not 
need_meta_schedule_layout):
                 logger.warning("conv2d HWCN layout is not optimized for x86 
with autotvm.")
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_hwcn),
@@ -216,7 +229,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
             return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, 
out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
-            if not is_auto_scheduler_enabled():
+            if (not need_auto_scheduler_layout) and (not 
need_meta_schedule_layout):
                 logger.warning(
                     "depthwise_conv2d NHWC layout is not optimized for x86 
with autotvm."
                 )
@@ -237,7 +250,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
             )
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
-            if not is_auto_scheduler_enabled():
+            if (not need_auto_scheduler_layout) and (not 
need_meta_schedule_layout):
                 logger.warning("group_conv2d is not optimized for x86 with 
autotvm.")
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, 
has_groups=True),
@@ -328,7 +341,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target):
     """conv3d generic strategy"""
     strategy = _op.OpStrategy()
     layout = attrs.data_layout
-    if is_auto_scheduler_enabled():
+    need_auto_scheduler_layout = is_auto_scheduler_enabled()
+    need_meta_schedule_layout = is_meta_schedule_enabled()
+    if need_auto_scheduler_layout or need_meta_schedule_layout:
         # Use auto-scheduler. We should provide clear compute definition 
without autotvm templates
         # or packed layouts.
         if layout == "NCDHW":
@@ -339,7 +354,11 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target):
             )
         elif layout == "NDHWC":
             strategy.add_implementation(
-                wrap_compute_conv3d(topi.nn.conv3d_ndhwc, 
need_auto_scheduler_layout=True),
+                wrap_compute_conv3d(
+                    topi.nn.conv3d_ndhwc,
+                    need_auto_scheduler_layout=need_auto_scheduler_layout,
+                    need_meta_schedule_layout=need_meta_schedule_layout,
+                ),
                 naive_schedule,
                 name="conv3d_ndhwc.x86",
             )
@@ -456,9 +475,15 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target):
         if length_before == length_after:
             logger.warning("Currently dnnl only support the data type to be 
float32. Skip.")
 
-    if is_auto_scheduler_enabled():
+    need_auto_scheduler_layout = is_auto_scheduler_enabled()
+    need_meta_schedule_layout = is_meta_schedule_enabled()
+    if need_auto_scheduler_layout or need_meta_schedule_layout:
         strategy.add_implementation(
-            wrap_compute_matmul(topi.nn.matmul, 
need_auto_scheduler_layout=True),
+            wrap_compute_matmul(
+                topi.nn.matmul,
+                need_auto_scheduler_layout=need_auto_scheduler_layout,
+                need_meta_schedule_layout=need_meta_schedule_layout,
+            ),
             naive_schedule,
             name="matmul.generic",
             plevel=11,
@@ -499,9 +524,16 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
         plevel=10,
     )
 
-    if is_auto_scheduler_enabled():
+    need_auto_scheduler_layout = is_auto_scheduler_enabled()
+    need_meta_schedule_layout = is_meta_schedule_enabled()
+
+    if need_auto_scheduler_layout or need_meta_schedule_layout:
         strategy.add_implementation(
-            wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True),
+            wrap_compute_dense(
+                topi.nn.dense,
+                need_auto_scheduler_layout=need_auto_scheduler_layout,
+                need_meta_schedule_layout=need_meta_schedule_layout,
+            ),
             naive_schedule,
             name="dense.generic",
             plevel=11,
@@ -568,6 +600,9 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, 
target):
     strategy = _op.OpStrategy()
     mcpu = Target.current().mcpu
 
+    need_auto_scheduler_layout = is_auto_scheduler_enabled()
+    need_meta_schedule_layout = is_meta_schedule_enabled()
+
     if (
         not attrs.transpose_a
         and attrs.transpose_b
@@ -583,10 +618,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, 
target):
             name="batch_matmul_vnni.x86",
             plevel=10,
         )
-    elif is_dynamic(out_type) or is_auto_scheduler_enabled():
+    elif is_dynamic(out_type) or need_auto_scheduler_layout or 
need_meta_schedule_layout:
         strategy.add_implementation(
             wrap_compute_batch_matmul(
-                topi.nn.batch_matmul, need_auto_scheduler_layout=True, 
need_out_dtype=True
+                topi.nn.batch_matmul,
+                need_out_dtype=True,
+                need_auto_scheduler_layout=need_auto_scheduler_layout,
+                need_meta_schedule_layout=need_meta_schedule_layout,
             ),
             wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
             name="batch_matmul.generic",
@@ -733,15 +771,31 @@ def 
conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ
     assert strides == (1, 1), "Do not support strides now"
     assert groups == 1, "Do not supoort arbitrary group number"
     strategy = _op.OpStrategy()
+    need_auto_scheduler_layout = is_auto_scheduler_enabled()
+    need_meta_schedule_layout = is_meta_schedule_enabled()
     if layout == "NHWC":
-        strategy.add_implementation(
-            wrap_compute_conv2d(
-                topi.nn.conv2d_winograd_nhwc_without_weight_transform,
-                need_auto_scheduler_layout=True,
-            ),
-            naive_schedule,
-            name="ansor.winograd",
-        )
+        if need_meta_schedule_layout:
+            strategy.add_implementation(
+                wrap_compute_conv2d(
+                    topi.nn.conv2d_winograd_nhwc_without_weight_transform,
+                    need_auto_scheduler_layout=False,
+                    need_meta_schedule_layout=True,
+                ),
+                naive_schedule,
+                name="ansor.winograd",
+            )
+        elif need_auto_scheduler_layout:
+            strategy.add_implementation(
+                wrap_compute_conv2d(
+                    topi.nn.conv2d_winograd_nhwc_without_weight_transform,
+                    need_auto_scheduler_layout=True,
+                    need_meta_schedule_layout=False,
+                ),
+                naive_schedule,
+                name="ansor.winograd",
+            )
+        else:
+            raise RuntimeError("Both AutoScheduler and MetaSchedule are not 
enabled")
     else:
         raise RuntimeError(
             "Unsupported conv2d_winograd_without_weight_transfrom layout 
{}".format(layout)

Reply via email to