This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a7bf979  [AutoScheduler] Support layout rewrite for whole networks 
(#6987)
a7bf979 is described below

commit a7bf97936ce51e5ece2f66d7fa9602b35a3aa1fa
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Dec 2 18:05:35 2020 -0800

    [AutoScheduler] Support layout rewrite for whole networks (#6987)
    
    * [AutoScheduler] Add layout rewrite pass in relay
    
    * fix
    
    * fix lint
    
    * fix attrs
    
    * trigger CI
    
    * Apply suggestions from code review
    
    * trigger CI
    
    * Update python/tvm/auto_scheduler/relay_integration.py
    
    * Update python/tvm/auto_scheduler/relay_integration.py
    
    * Update python/tvm/auto_scheduler/compute_dag.py
    
    * Trigger CI
    
    * Apply suggestions from code review
---
 include/tvm/ir/transform.h                         |   7 +
 include/tvm/relay/attrs/nn.h                       |   1 +
 include/tvm/relay/attrs/transform.h                |  14 ++
 include/tvm/relay/transform.h                      |  14 ++
 include/tvm/topi/transform.h                       |  68 +++++++++
 python/tvm/auto_scheduler/__init__.py              |   2 +-
 python/tvm/auto_scheduler/compute_dag.py           |  17 +++
 python/tvm/auto_scheduler/measure.py               |   4 +-
 python/tvm/auto_scheduler/relay_integration.py     | 103 ++++++++++++-
 python/tvm/relay/op/_transform.py                  |   2 +
 python/tvm/relay/op/strategy/generic.py            |  15 +-
 python/tvm/relay/op/strategy/x86.py                |   3 +-
 python/tvm/te/tensor.py                            |   2 +-
 python/tvm/topi/nn/conv2d.py                       |  41 +++++-
 src/auto_scheduler/compute_dag.cc                  |  20 ++-
 src/ir/transform.cc                                |  54 +++----
 src/relay/backend/build_module.cc                  |  17 +++
 src/relay/backend/compile_engine.cc                |  26 ++--
 src/relay/backend/compile_engine.h                 |   9 ++
 src/relay/backend/utils.h                          |   9 ++
 src/relay/op/make_op.h                             |   2 +
 src/relay/op/nn/convolution.h                      |  12 +-
 src/relay/op/tensor/transform.cc                   |  50 ++++++-
 .../transforms/auto_scheduler_layout_rewrite.cc    | 160 +++++++++++++++++++++
 .../transforms/auto_scheduler_layout_rewrite.h     |  49 +++++++
 .../relay/test_auto_scheduler_layout_rewrite.py    | 121 ++++++++++++++++
 26 files changed, 751 insertions(+), 71 deletions(-)

diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index d293112..56905de 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -198,6 +198,13 @@ class PassContext : public ObjectRef {
   TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool 
is_before) const;
 
   /*!
+   * \brief Check whether a pass is enabled.
+   * \param info The pass information.
+   * \return true if the pass is enabled. Otherwise, false.
+   */
+  TVM_DLL bool PassEnabled(const PassInfo& info) const;
+
+  /*!
    * \brief Register a valid configuration option and its ValueType for 
validation.
    *
    * \param key The configuration key.
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index e697ac4..f8aa1fc 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
   tvm::String data_layout;
   tvm::String kernel_layout;
   tvm::String out_layout;
+  std::string auto_scheduler_rewritten_layout;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index 3ed6b83..cbe989f 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -358,6 +358,20 @@ struct LayoutTransformAttrs : public 
tvm::AttrsNode<LayoutTransformAttrs> {
   }
 };
 
+/*! \brief Attributes for AutoSchedulerLayoutTransform operator */
+struct AutoSchedulerLayoutTransformAttrs
+    : public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> {
+  std::string src_layout;
+  std::string dst_layout;
+
+  TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs,
+                    "relay.attrs.AutoSchedulerLayoutTransformAttrs") {
+    TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. 
(e.g. 1N32C112H112W)");
+    TVM_ATTR_FIELD(dst_layout)
+        .describe("The destination layout of the tensor. (e.g. 
1N2C112H112W16c)");
+  }
+};
+
 /*! \brief Attributes for ShapeOf operator */
 struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
   DataType dtype;
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index a9a45b5..e4b39da 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -107,6 +107,14 @@ TVM_DLL Pass FoldConstant();
 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
 
 /*!
+ * \brief The inverse operation of FuseOps. It transforms a fused program 
returned by
+ * FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x)))
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass DefuseOps();
+
+/*!
  * \brief Rewrite the annotated program.
  *
  * \param fallback_device The fallback device which is the default device for
@@ -316,6 +324,12 @@ TVM_DLL Pass CanonicalizeOps();
 TVM_DLL Pass AlterOpLayout();
 
 /*!
+ * \brief Do layout rewrite according to the tile structure created by 
auto-scheduler.
+ * \return The pass
+ */
+TVM_DLL Pass AutoSchedulerLayoutRewrite();
+
+/*!
  * \brief Given a dest layout, this pass transforms the expr such that most of 
the ops input data
  * layout is changed to the dest layout. In ideal situation, there are only 2 
layout transforms, one
  * at the start and one at the end.
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index c866dfb..c2a4843 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const 
std::string& src_layout,
       name, tag);
 }
 
+/*! \brief Utility function for auto_scheduler_layout_transform */
+inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>* 
shape,
+                                        std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+/*!
+ * \brief Transform the auto-scheduler generated layout according to
+ *        \p src_layout and \p dst_layout
+ * \param src the source input.
+ * \param src_layout the source layout.
+ * \param dst_layout the destination layout.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return A tensor with shape in \p dst_layout
+ */
+inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& 
src_layout,
+                                              const String& dst_layout,
+                                              const String name = 
"T_auto_scheduler_layout_trans",
+                                              const String tag = kInjective) {
+  Array<PrimExpr> src_shape;
+  std::vector<std::string> src_axes;
+  Array<PrimExpr> dst_shape;
+  std::vector<std::string> dst_axes;
+
+  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
+  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
+  return compute(
+      dst_shape,
+      [&](const Array<Var>& dst_indices) {
+        Array<PrimExpr> dst_indices_expr(dst_indices.begin(), 
dst_indices.end());
+        Array<PrimExpr> src_indices;
+        for (const std::string& src_axis : src_axes) {
+          PrimExpr src_index = 0;
+          CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
+          for (size_t i = 0; i < dst_axes.size(); ++i) {
+            if (dst_axes[i] == src_axis) {
+              src_index = src_index * dst_shape[i] + dst_indices_expr[i];
+            }
+          }
+          src_indices.push_back(src_index);
+        }
+        return src(src_indices);
+      },
+      name, tag);
+}
+
 /*!
  * \brief Get the shape of input tensor.
  * \param src the input tensor.
diff --git a/python/tvm/auto_scheduler/__init__.py 
b/python/tvm/auto_scheduler/__init__.py
index f0d076e..5bf2335 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -44,7 +44,7 @@ from .measure import (
     LocalRPCMeasureContext,
 )
 from .measure_record import RecordToFile, RecordReader, load_best, 
load_records, save_records
-from .relay_integration import extract_tasks
+from .relay_integration import extract_tasks, remove_index_check, 
rewrite_compute_body
 from .search_task import SearchTask
 from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
 from .task_scheduler import TaskScheduler
diff --git a/python/tvm/auto_scheduler/compute_dag.py 
b/python/tvm/auto_scheduler/compute_dag.py
index 3427709..cba3600 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -162,6 +162,23 @@ class ComputeDAG(Object):
                 updated_state.stage_id_map[k] = v
         return updated_state
 
+    def rewrite_layout_from_state(self, state):
+        """
+        Rewrite the layout of the DAG according to the history transform steps 
of a state.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The state from which we get transform steps.
+
+        Returns
+        -------
+        updated_dag : ComputeDAG
+            The compute dag with rewritten layout.
+        """
+        state_obj = state if isinstance(state, StateObject) else 
state.state_object
+        return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)
+
     def hash_key(self):
         """Return the hash key of this compute DAG.
 
diff --git a/python/tvm/auto_scheduler/measure.py 
b/python/tvm/auto_scheduler/measure.py
index 117cd4f..b9d7148 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -544,7 +544,9 @@ def _timed_func(inp_serialized, build_func, verbose):
     args = []
 
     try:
-        sch, args = task.compute_dag.apply_steps_from_state(inp.state, 
layout_rewrite=True)
+        sch, args = task.compute_dag.apply_steps_from_state(
+            inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed
+        )
     # pylint: disable=broad-except
     except Exception:
         error_no = MeasureErrorNo.INSTANTIATION_ERROR
diff --git a/python/tvm/auto_scheduler/relay_integration.py 
b/python/tvm/auto_scheduler/relay_integration.py
index 6864bcc..25b8881 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -23,11 +23,15 @@ Integrate auto_scheduler into relay. It implements the 
following items:
 """
 
 import logging
+import json
 import threading
 
 import tvm
 from tvm import autotvm, te, transform
-from tvm.te.tensor import ComputeOp, PlaceholderOp
+from tvm.runtime import convert_to_object
+from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
+from tvm.tir import expr as _expr
+from . import _ffi_api
 from .compute_dag import ComputeDAG
 from .dispatcher import DispatchContext
 from .search_task import SearchTask
@@ -46,7 +50,11 @@ def call_all_topi_funcs(mod, params, target):
     old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
     autotvm.GLOBAL_SCOPE.silent = True
 
-    with transform.PassContext(opt_level=3, 
config={"relay.backend.use_auto_scheduler": True}):
+    with transform.PassContext(
+        opt_level=3,
+        config={"relay.backend.use_auto_scheduler": True},
+        disabled_pass={"AutoSchedulerLayoutRewrite"},
+    ):
         opt_mod, _ = relay.optimize(mod, target, params)
         grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
         grc.codegen(opt_mod["main"])
@@ -158,6 +166,20 @@ class TracingEnvironment:
         self.wkl_key_to_ccache_key[workload_key] = ccache_key
 
 
+@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
+def enter_layout_rewrite():
+    """Enter layout rewrite tracing environment"""
+    env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
+    env.__enter__()
+
+
+@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
+def exit_layout_rewrite():
+    """Exit layout rewrite tracing environment"""
+    env = TracingEnvironment.current
+    env.__exit__(None, None, None)
+
+
 def traverse_to_get_io_tensors(outs):
     """Traverse from a list of output tensors to get both input and output 
tensors
 
@@ -230,11 +252,13 @@ def auto_schedule_topi(outs, has_complex_op):
     key = register_workload_tensors(dag.hash_key(), io_tensors)
 
     # only enable layout rewrite for cpu backend
-    enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys
+    target = tvm.target.Target.current()
+    enable_layout_rewrite = "cpu" in target.keys
 
     env = TracingEnvironment.current
-    if env is None:  # in the final build mode
-        state = DispatchContext.current.query(tvm.target.Target.current(), 
key, has_complex_op, dag)
+    if env is None:
+        # in the final build mode
+        state = DispatchContext.current.query(target, key, has_complex_op, dag)
         if state is None:
             return None
 
@@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
             env.add_workload_key(key, ccache_key)
         schedule = te.create_schedule([x.op for x in outs])
     elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
-        # todo(merrymercy, minminsun): port layout rewrite
-        raise NotImplementedError
+        # in prepare_layout_rewrite mode
+        if enable_layout_rewrite and has_layout_free:
+            dispatch_ctx = DispatchContext.current
+            state = dispatch_ctx.query(target, key, has_complex_op, dag)
+            if state is None:
+                return None
+
+            # rewrite the layout and update the context for the new dag
+            dag = ComputeDAG(outs)
+            new_dag = dag.rewrite_layout_from_state(state)
+            new_key = json.dumps((new_dag.hash_key(),))
+            if new_key != key:
+                dispatch_ctx.update(target, new_key, state)
+        return te.create_schedule([x.op for x in outs])
     else:
         raise ValueError("Invalid tracing mode: " + env.tracing_mode)
 
     return schedule
+
+
+def tensor_no_check_call(self, *indices):
+    """An indexing function without any check.
+    This is the same as `tvm.te.Tensor::__call__` except that the safety
+    check is removed.
+    """
+    indices = convert_to_object(indices)
+    args = []
+    for x in indices:
+        if isinstance(x, _expr.PrimExpr):
+            args.append(x)
+        elif isinstance(x, _expr.IterVar):
+            args.append(x.var)
+        else:
+            raise ValueError("The indices must be expression")
+
+    return _expr.ProducerLoad(self, args)
+
+
+def remove_index_check(tensor):
+    """Remove the safety check in the indexing function for a tensor.
+    This is done by monkey patching its indexing function.
+    After removing the check, we are allowed to create a
+    temporary wrong IR and fix it later in other places.
+
+    Parameters
+    ----------
+    tensor: Tensor
+      The tensor to remove index check.
+    """
+    # Monkey patch the indexing function
+    tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)
+
+
+def rewrite_compute_body(compute_tensor, new_layout):
+    """Rewrite the body of a ComputeOp according to a new layout of a 
placeholder"""
+    op = compute_tensor.op
+
+    # Get layout free placeholders
+    layout_free_placeholders = op.attrs["layout_free_placeholders"]
+    assert len(layout_free_placeholders) == 1, "Only support one layout free 
placeholder"
+    placeholder_op = layout_free_placeholders[0].op
+
+    # Rewrite the index expression in body
+    body = []
+    for b in op.body:
+        body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op, 
new_layout, b))
+    op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis, 
body)
+
+    num = op_node.num_outputs
+    outputs = tuple(op_node.output(i) for i in range(num))
+    return outputs[0] if num == 1 else outputs
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index e1cb9e9..38d27e3 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -79,6 +79,8 @@ _reg.register_injective_schedule("strided_set")
 # layout_transform
 _reg.register_injective_schedule("layout_transform")
 _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
+_reg.register_injective_schedule("auto_scheduler_layout_transform")
+_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
 
 # argwhere
 @_reg.register_compute("argwhere")
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index ac9d3b1..a03c517 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -19,7 +19,7 @@
 import logging
 
 import re
-from tvm import topi
+from tvm import topi, _ffi
 from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, 
get_float_tuple
 from tvm.target import generic_func, override_native_generic_func
 from .. import op as _op
@@ -166,9 +166,17 @@ def schedule_bitpack(attrs, outs, target):
         return topi.generic.schedule_bitpack(outs)
 
 
+get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
+    "relay.attrs.get_auto_scheduler_rewritten_layout"
+)
+
 # conv2d
 def wrap_compute_conv2d(
-    topi_compute, need_data_layout=False, need_out_layout=False, 
has_groups=False
+    topi_compute,
+    need_data_layout=False,
+    need_out_layout=False,
+    has_groups=False,
+    need_auto_scheduler_layout=False,
 ):
     """Wrap conv2d topi compute"""
 
@@ -179,6 +187,7 @@ def wrap_compute_conv2d(
         data_layout = attrs.get_str("data_layout")
         out_layout = attrs.get_str("out_layout")
         out_dtype = attrs.out_dtype
+        auto_scheduler_rewritten_layout = 
get_auto_scheduler_rewritten_layout(attrs)
         out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
         args = [inputs[0], inputs[1], strides, padding, dilation]
         if has_groups:
@@ -188,6 +197,8 @@ def wrap_compute_conv2d(
         if need_out_layout:
             args.append(out_layout)
         args.append(out_dtype)
+        if need_auto_scheduler_layout:
+            args.append(auto_scheduler_rewritten_layout)
         return [topi_compute(*args)]
 
     return _compute_conv2d
diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index 3f129c4..98b56ef 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -117,9 +117,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
             return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
-            logger.warning("For x86 target, NCHW layout is recommended for 
conv2d.")
             strategy.add_implementation(
-                wrap_compute_conv2d(topi.nn.conv2d_nhwc),
+                wrap_compute_conv2d(topi.nn.conv2d_nhwc, 
need_auto_scheduler_layout=True),
                 wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
                 name="conv2d_nhwc.x86",
             )
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index 6294eab..bdf3954 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -40,7 +40,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp):
 
     def asobject(self):
         """Convert slice to object."""
-        return self.tensor(*self.indices)
+        return self.tensor.__call__(*self.indices)
 
     @property
     def dtype(self):
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 7c9cef6..8d591a2 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -20,7 +20,7 @@
 from __future__ import absolute_import as _abs
 from collections import namedtuple
 import tvm
-from tvm import te
+from tvm import te, auto_scheduler
 
 from .pad import pad
 from .utils import get_pad_tuple
@@ -331,7 +331,15 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, 
out_dtype=None):
     return Output
 
 
-def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"):
+def conv2d_nhwc(
+    Input,
+    Filter,
+    stride,
+    padding,
+    dilation,
+    out_dtype="float32",
+    auto_scheduler_rewritten_layout="",
+):
     """Convolution operator in NHWC layout.
 
     Parameters
@@ -371,8 +379,30 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, 
out_dtype="float32"):
     else:
         dilation_h, dilation_w = dilation
 
+    if auto_scheduler_rewritten_layout:
+        # Infer shape for the rewritten layout
+        # todo(merrymercy): wrap this with a more general interface.
+        if len(Filter.shape) >= 10:
+            # For cpu tile structure SSRSRS
+            base = len(Filter.shape) - 10
+            kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base]
+            kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base]
+            channel = Filter.shape[4 + base] * Filter.shape[8 + base]
+            num_filter = Filter.shape[5 + base] * Filter.shape[9 + base]
+            for i in range(base + 2):
+                num_filter *= Filter.shape[i]
+        elif len(Filter.shape) == 4:
+            num_filter, kernel_h, kernel_w, channel = Filter.shape
+        else:
+            raise ValueError(
+                "Don't know how to infer the layout for filter shape: %s. "
+                "Please add a new branch to handle this case." % str(Filter)
+            )
+        auto_scheduler.remove_index_check(Filter)
+    else:
+        kernel_h, kernel_w, channel, num_filter = Filter.shape
+
     batch, in_height, in_width, in_channel = Input.shape
-    kernel_h, kernel_w, channel, num_filter = Filter.shape
     # compute the output shape
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
@@ -399,7 +429,12 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, 
out_dtype="float32"):
         ),
         name="Conv2dOutput",
         tag="conv2d_nhwc",
+        attrs={"layout_free_placeholders": [Filter]},
     )
+
+    if auto_scheduler_rewritten_layout:
+        Output = auto_scheduler.rewrite_compute_body(Output, 
auto_scheduler_rewritten_layout)
+
     return Output
 
 
diff --git a/src/auto_scheduler/compute_dag.cc 
b/src/auto_scheduler/compute_dag.cc
index caaed6f..ca59979 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -42,6 +42,7 @@
 #include <vector>
 
 #include "../arith/pattern_match.h"
+#include "../relay/transforms/auto_scheduler_layout_rewrite.h"
 #include "search_policy/utils.h"
 #include "utils.h"
 
@@ -813,8 +814,7 @@ std::string GetOrigLayout(std::set<std::string>* 
placeholder_axis_names, const t
   ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
   std::string orig_layout = os.str();
   os.str("");
-  // TODO(minmin): uncomment this line for relay integration
-  // 
::tvm::relay::KernelLayoutTransformer::global_orig_layouts_queue.push_back(orig_layout);
+  
::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout);
   return orig_layout;
 }
 
@@ -878,8 +878,7 @@ std::string GetNewLayout(const State& state, const int 
stage_id, const Stage& st
   }
   std::string new_layout = os.str();
   os.str("");
-  // TODO(minmin): uncomment this line for relay integration
-  // 
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  
::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout);
   return new_layout;
 }
 
@@ -1440,5 +1439,18 @@ 
TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState")
       return dag.InferBound(state);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState")
+    .set_body_typed([](const ComputeDAG& dag, const State& state) {
+      Array<Step>* transform_steps = 
const_cast<Array<Step>*>(&state->transform_steps);
+      return dag.RewriteLayout(transform_steps, 
LayoutRewriteOption::RewriteForPreTransformed);
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
+    .set_body_typed([](const te::Operation& placeholder_op, const std::string& 
new_layout,
+                       const PrimExpr& body) {
+      IndexRewriter index_rewriter(placeholder_op, new_layout);
+      return index_rewriter.Rewrite(body);
+    });
+
 }  // namespace auto_scheduler
 }  // namespace tvm
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 3b77446..f4516d5 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -74,6 +74,26 @@ PassContext PassContext::Current() {
   }
 }
 
+// linearly scan the pass array to match pass_name
+bool PassArrayContains(const Array<runtime::String>& pass_array, const 
std::string& pass_name) {
+  for (auto x : pass_array) {
+    if (x == pass_name) return true;
+  }
+  return false;
+}
+
+bool PassContext::PassEnabled(const PassInfo& info) const {
+  if (PassArrayContains(operator->()->disabled_pass, info->name)) {
+    return false;
+  }
+
+  if (PassArrayContains(operator->()->required_pass, info->name)) {
+    return true;
+  }
+
+  return operator->()->opt_level >= info->opt_level;
+}
+
 class PassConfigManager {
  public:
   void Register(std::string key, uint32_t value_type_index) {
@@ -225,15 +245,6 @@ class SequentialNode : public PassNode {
   PassInfo Info() const override { return pass_info; }
 
   /*!
-   * \brief Check if a pass is enabled.
-   *
-   * \param info The pass information.
-   *
-   * \return true if the pass is enabled. Otherwise, false.
-   */
-  bool PassEnabled(const PassInfo& info) const;
-
-  /*!
    * \brief Resolve the pass dependency. It globs all required passes by
    *        a given pass and executes them.
    *
@@ -344,29 +355,6 @@ void SequentialNode::ResolveDependency(const IRModule& 
mod) {
              << "\n";
 }
 
-// linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<runtime::String>& pass_array,
-                              const std::string& pass_name) {
-  for (auto x : pass_array) {
-    if (x == pass_name) return true;
-  }
-  return false;
-}
-
-bool SequentialNode::PassEnabled(const PassInfo& info) const {
-  PassContext ctx = PassContext::Current();
-
-  if (PassArrayContains(ctx->disabled_pass, info->name)) {
-    return false;
-  }
-
-  if (PassArrayContains(ctx->required_pass, info->name)) {
-    return true;
-  }
-
-  return ctx->opt_level >= info->opt_level;
-}
-
 Pass GetPass(const String& pass_name) {
   using tvm::runtime::Registry;
   const runtime::PackedFunc* f = nullptr;
@@ -387,7 +375,7 @@ IRModule SequentialNode::operator()(IRModule mod, const 
PassContext& pass_ctx) c
   for (const Pass& pass : passes) {
     ICHECK(pass.defined()) << "Found undefined pass for optimization.";
     const PassInfo& pass_info = pass->Info();
-    if (!PassEnabled(pass_info)) continue;
+    if (!pass_ctx.PassEnabled(pass_info)) continue;
     // resolve dependencies
     for (const auto& it : pass_info->required) {
       mod = GetPass(it)(std::move(mod), pass_ctx);
diff --git a/src/relay/backend/build_module.cc 
b/src/relay/backend/build_module.cc
index 82ac1c5..a0828d1 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -338,7 +338,24 @@ class RelayBuildModule : public runtime::ModuleNode {
 
     // Fuse the operations if it is needed.
     relay_module = transform::FuseOps()(relay_module);
+
+    // Do layout rewrite for auto-scheduler.
+    if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
+      const auto& target = (*targets.begin()).second;
+      Pass major_pass = transform::AutoSchedulerLayoutRewrite();
+
+      if (target->kind->device_type == kDLCPU && 
pass_ctx.PassEnabled(major_pass->Info())) {
+        With<Target> tctx(target);
+        relay_module = major_pass(relay_module);
+        // Defuse ops to fold constants, then fuse them again
+        relay_module = transform::DefuseOps()(relay_module);
+        relay_module = transform::FoldConstant()(relay_module);
+        relay_module = transform::FuseOps()(relay_module);
+      }
+    }
+
     relay_module = transform::InferType()(relay_module);
+
     // Inline the functions that have been lifted by the module scope.
     //
     // TODO(@zhiics) Note that we need to be careful about the subgraphs with
diff --git a/src/relay/backend/compile_engine.cc 
b/src/relay/backend/compile_engine.cc
index 1559d7e..98d9136 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -101,9 +101,7 @@ class ScheduleGetter : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
   explicit ScheduleGetter(Target target)
       : target_(target), device_copy_op_(Op::Get("device_copy")) {
     // Whether to use auto_scheduler schedule.
-    use_auto_scheduler_ = transform::PassContext::Current()
-                              
->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
-                              .value();
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
   }
 
   CachedFunc Create(const Function& prim_func) {
@@ -322,6 +320,17 @@ class ScheduleGetter : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
   const Op& device_copy_op_;
 };
 
+/*!
+ * \brief Create schedule for target.
+ * \param source_func The primitive function to be lowered.
+ * \param target The target we want to create schedule for.
+ * \return Pair of schedule and cache.
+ *  The funcs field in cache is not yet populated.
+ */
+CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
+  return ScheduleGetter(target).Create(source_func);
+}
+
 // Creates shape function from functor.
 class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>> {
  public:
@@ -680,17 +689,6 @@ class CompileEngineImpl : public CompileEngineNode {
    */
   CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
 
-  /*!
-   * \brief Create schedule for target.
-   * \param source_func The primitive function to be lowered.
-   * \param target The target we want to create schedule for.
-   * \return Pair of schedule and cache.
-   *  The funcs field in cache is not yet populated.
-   */
-  CachedFunc CreateSchedule(const Function& source_func, const Target& target) 
{
-    return ScheduleGetter(target).Create(source_func);
-  }
-
  private:
   // implement lowered func
   CCacheValue LowerInternal(const CCacheKey& key) {
diff --git a/src/relay/backend/compile_engine.h 
b/src/relay/backend/compile_engine.h
index 5582291..d7628e7 100644
--- a/src/relay/backend/compile_engine.h
+++ b/src/relay/backend/compile_engine.h
@@ -242,6 +242,15 @@ class CompileEngine : public ObjectRef {
 };
 
 /*!
+ * \brief Create schedule for target.
+ * \param source_func The primitive function to be lowered.
+ * \param target The target we want to create schedule for.
+ * \return Pair of schedule and cache.
+ *  The funcs field in cache is not yet populated.
+ */
+CachedFunc CreateSchedule(const Function& source_func, const Target& target);
+
+/*!
  * \brief Check if the type is dynamic.
  * \param ty The type to be checked.
  * \return The result.
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index ccb8611..e167720 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -294,6 +294,15 @@ inline std::string GetExtSymbol(const Function& func) {
   return std::string(name_node.value());
 }
 
+/*!
+ * \brief Return whether the auto scheduler is enabled in the pass context.
+ */
+inline bool IsAutoSchedulerEnabled() {
+  return transform::PassContext::Current()
+      ->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
+      .value();
+}
+
 }  // namespace backend
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 34bff0f..d2fb6aa 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -52,6 +52,8 @@ Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType 
dtype);
 
 Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);
 
+Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String 
dst_layout);
+
 Expr MakeOnes(Array<Integer> shape, DataType dtype);
 
 Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, 
String pad_mode);
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index f011222..13e87a5 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -212,8 +212,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
     if (weight != nullptr) {
       weight_dtype = weight->dtype;
     }
-    // assign result to reporter
-    reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+
+    if (param->auto_scheduler_rewritten_layout.size() == 0) {
+      // Normal case: assign result to reporter
+      reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+    } else {
+      // If the layout is rewritten by auto-scheduler,
+      // we just forcly apply the layout provided by auto-scheduler and
+      // skip the normal inference logic.
+      {}  // do nothing
+    }
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 5a13e9a..a3a9280 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2806,7 +2806,55 @@ the input array by output[n, c, h, w, C] = data[n, 
C*16+c, h, w]
     .set_support_level(5)
     .set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
 
-/* relay._contrib_reverse_reshape */
+// relay.auto_scheduler_layout_transform
+TVM_REGISTER_NODE_TYPE(AutoSchedulerLayoutTransformAttrs);
+
+Array<te::Tensor> AutoSchedulerLayoutTransformCompute(const Attrs& attrs,
+                                                      const Array<te::Tensor>& 
inputs,
+                                                      const Type& out_type) {
+  const auto* param = attrs.as<AutoSchedulerLayoutTransformAttrs>();
+  CHECK(param != nullptr);
+  return Array<te::Tensor>{
+      topi::auto_scheduler_layout_transform(inputs[0], param->src_layout, 
param->dst_layout)};
+}
+
+bool AutoSchedulerLayoutTransformRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
+                                     const TypeReporter& reporter) {
+  const auto* data = types[0].as<TensorTypeNode>();
+  CHECK(data != nullptr);
+  const AutoSchedulerLayoutTransformAttrs* params = 
attrs.as<AutoSchedulerLayoutTransformAttrs>();
+
+  Array<IndexExpr> dst_shape;
+  std::vector<std::string> dst_axes;
+
+  topi::parse_auto_scheduler_layout(params->dst_layout, &dst_shape, &dst_axes);
+
+  reporter->Assign(types[1], TensorType(dst_shape, data->dtype));
+  return true;
+}
+
+Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String 
dst_layout) {
+  auto attrs = make_object<AutoSchedulerLayoutTransformAttrs>();
+  attrs->src_layout = std::move(src_layout);
+  attrs->dst_layout = std::move(dst_layout);
+  static const Op& op = Op::Get("auto_scheduler_layout_transform");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.auto_scheduler_layout_transform")
+    .set_body_typed(MakeAutoSchedulerLayoutTransform);
+
+RELAY_REGISTER_OP("auto_scheduler_layout_transform")
+    .describe(R"code(Transform the input kernel layout.
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<AutoSchedulerLayoutTransformAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_type_rel("auto_scheduler_layout_transform", 
AutoSchedulerLayoutTransformRel)
+    .set_support_level(5)
+    .set_attr<FTVMCompute>("FTVMCompute", AutoSchedulerLayoutTransformCompute);
+
+// relay._contrib_reverse_reshape
 Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
   auto attrs = make_object<ReshapeAttrs>();
   attrs->newshape = std::move(newshape);
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc 
b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
new file mode 100644
index 0000000..c9875ef
--- /dev/null
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_scheduler_layout_rewrite.h
+ * \brief Rewrite the layout of "layout free" tensors (e.g., the weight 
tensors in
+ * conv2d and dense layers) according to the tile structure generated by the 
auto-scheduler.
+ */
+
+#include "auto_scheduler_layout_rewrite.h"
+
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+#include <deque>
+#include <functional>
+#include <vector>
+
+#include "../backend/compile_engine.h"
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// Two global variables for receiving layout information from python
+std::deque<std::string> AutoSchedulerLayoutRewriter::global_ori_layouts_queue;
+std::deque<std::string> AutoSchedulerLayoutRewriter::global_new_layouts_queue;
+
+// Copy an Attrs but with a new auto_scheduler_rewritten_layout filed.
+template <typename T>
+Attrs CopyAttrsWithNewLayout(const T* ptr, const std::string& layout) {
+  auto n = make_object<T>(*ptr);
+  n->auto_scheduler_rewritten_layout = layout;
+  return Attrs(n);
+}
+
+// Mutate ops in a function
+class FuncMutator : public ExprMutator {
+ public:
+  FuncMutator(const std::deque<std::string>& ori_layouts_queue,
+              const std::deque<std::string>& new_layouts_queue)
+      : ExprMutator(),
+        ori_layouts_queue_(ori_layouts_queue),
+        new_layouts_queue_(new_layouts_queue) {}
+
+  Expr VisitExpr_(const CallNode* n) {
+    auto new_n = ExprMutator::VisitExpr_(n);
+
+    const auto* call = new_n.as<CallNode>();
+    if (call && call->op.as<OpNode>() &&
+        (std::find(target_ops_.begin(), target_ops_.end(), 
n->op.as<OpNode>()->name) !=
+         target_ops_.end()) &&
+        !ori_layouts_queue_.empty() && !new_layouts_queue_.empty()) {
+      // Pop a new layout from the queue
+      const std::string ori_layout = ori_layouts_queue_.front();
+      const std::string new_layout = new_layouts_queue_.front();
+      ori_layouts_queue_.pop_front();
+      new_layouts_queue_.pop_front();
+
+      // Insert a new op to do layout transform. (This will be simplified by 
FoldConstant later).
+      Expr updated_kernel = MakeAutoSchedulerLayoutTransform(call->args[1], 
ori_layout, new_layout);
+      Array<Expr> updated_args = {call->args[0], updated_kernel};
+
+      // Update the attrs
+      Attrs updated_attrs;
+      if (auto pattr = call->attrs.as<Conv2DAttrs>()) {
+        updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+      }
+      new_n = Call(call->op, updated_args, updated_attrs);
+    }
+    return new_n;
+  }
+
+ private:
+  std::deque<std::string> ori_layouts_queue_;
+  std::deque<std::string> new_layouts_queue_;
+
+  std::vector<std::string> target_ops_{"nn.conv2d"};
+};
+
+Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
+  auto new_n = ExprMutator::VisitExpr_(n);
+
+  if (const auto* call = new_n.as<CallNode>()) {
+    if (const auto* func = call->op.as<FunctionNode>()) {
+      global_ori_layouts_queue.clear();
+      global_new_layouts_queue.clear();
+
+      // Use ScheduleGetter to call python lower functions.
+      // This is used to get the layout transform information.
+      // The layout transformation will be recorded to global_ori_layout_queue
+      // and global_new_layouts_queue in ComputeDAG::RewriteLayout.
+      auto f = runtime::Registry::Get("auto_scheduler.enter_layout_rewrite");
+      CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite 
function.";
+      (*f)();
+
+      CreateSchedule(GetRef<Function>(func), Target::Current());
+
+      f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite");
+      CHECK(f) << "Could not find ansor.exit_layout_rewrite function.";
+      (*f)();
+
+      // Mutate the called function
+      if (!global_ori_layouts_queue.empty() && 
!global_new_layouts_queue.empty()) {
+        auto ret = FuncMutator(global_ori_layouts_queue, 
global_new_layouts_queue).VisitExpr(new_n);
+        return ret;
+      }
+    }
+  }
+
+  return new_n;
+}
+
+Expr AutoSchedulerLayoutRewrite(const Expr& expr) {
+  return AutoSchedulerLayoutRewriter().Mutate(expr);
+}
+
+namespace transform {
+
+Pass AutoSchedulerLayoutRewrite() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(relay::AutoSchedulerLayoutRewrite(f));
+      };
+  return CreateFunctionPass(pass_func, 3, "AutoSchedulerLayoutRewrite", 
{"InferType"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite")
+    .set_body_typed(AutoSchedulerLayoutRewrite);
+
+TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
+    .set_body_typed([](const Attrs& attrs) {
+      if (attrs->IsInstance<Conv2DAttrs>()) {
+        return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
+      }
+      return std::string();
+    });
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.h 
b/src/relay/transforms/auto_scheduler_layout_rewrite.h
new file mode 100644
index 0000000..d0d89db
--- /dev/null
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.h
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_scheduler_layout_rewrite.h
+ * \brief Rewrite the layout of "layout free" tensors (e.g., the weight 
tensors in
+ * conv2d and dense layers) according to the tile structure generated by the 
auto-scheduler.
+ */
+
+#ifndef TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_
+#define TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_
+
+#include <tvm/relay/expr_functor.h>
+
+#include <deque>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+class AutoSchedulerLayoutRewriter : public ExprMutator {
+ public:
+  Expr VisitExpr_(const CallNode* n) final;
+
+  // Two global variables for receiving layout information from python
+  static std::deque<std::string> global_ori_layouts_queue;
+  static std::deque<std::string> global_new_layouts_queue;
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py 
b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
new file mode 100644
index 0000000..299fcb8
--- /dev/null
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test layout rewrite support for whole neural networks"""
+import tempfile
+
+import numpy as np
+
+import tvm
+from tvm import relay, auto_scheduler
+from tvm.contrib import graph_runtime
+import tvm.testing
+
+
+def get_np_array(var, dtype):
+    return np.random.randn(*[int(x) for x in 
var.type_annotation.shape]).astype(dtype)
+
+
+def get_relay_conv2d(
+    outc=128,
+    inc=64,
+    height=14,
+    width=14,
+    kh=3,
+    kw=3,
+    batch=1,
+    pad=0,
+    stride=1,
+    dilation=1,
+    layout="NHWC",
+):
+    dtype = "float32"
+    if layout == "NHWC":
+        kernel_layout = "HWIO"
+        d = relay.var("data", shape=(batch, height, width, inc), dtype=dtype)
+        w = relay.var("weight", shape=(kh, kw, inc, outc), dtype=dtype)
+    elif layout == "NCHW":
+        kernel_layout = "OIHW"
+        d = relay.var("data", shape=(batch, inc, height, width), dtype=dtype)
+        w = relay.var("weight", shape=(outc, inc, kh, kw), dtype=dtype)
+
+    y = relay.nn.conv2d(
+        d,
+        w,
+        padding=pad,
+        kernel_size=(kh, kw),
+        strides=(stride, stride),
+        dilation=(dilation, dilation),
+        channels=outc,
+        groups=1,
+        data_layout=layout,
+        kernel_layout=kernel_layout,
+    )
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([d, w], y)
+    data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+    return mod, data, weight
+
+
+def tune_and_check(mod, data, weight):
+    # Extract tasks from a relay program
+    target = tvm.target.Target("llvm")
+    tasks, task_weights = auto_scheduler.extract_tasks(mod, target=target, 
params={})
+
+    with tempfile.NamedTemporaryFile() as fp:
+        log_file = fp.name
+
+        # Tune tasks
+        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+        tune_option = auto_scheduler.TuningOptions(
+            num_measure_trials=1,
+            num_measures_per_round=1,
+            builder=auto_scheduler.LocalBuilder(timeout=60),
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+        )
+        tuner.tune(tune_option, search_policy="sketch.random")
+
+        # Compile and run
+        def compile_and_run(disabled_pass={}):
+            with auto_scheduler.ApplyHistoryBest(log_file):
+                with tvm.transform.PassContext(
+                    opt_level=3,
+                    config={"relay.backend.use_auto_scheduler": True},
+                    disabled_pass=disabled_pass,
+                ):
+                    lib = relay.build(mod, target=target, params={"weight": 
weight})
+
+            ctx = tvm.cpu()
+            module = graph_runtime.GraphModule(lib["default"](ctx))
+            module.set_input("data", data)
+            module.run()
+
+            return module.get_output(0).asnumpy()
+
+        # Check correctness
+        actual_output = compile_and_run()
+        expected_output = 
compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"})
+
+        tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4)
+
+
+def test_conv2d():
+    mod, data, weight = get_relay_conv2d(kh=1, kw=1)
+    tune_and_check(mod, data, weight)
+
+
+if __name__ == "__main__":
+    test_conv2d()

Reply via email to