This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a7bf979 [AutoScheduler] Support layout rewrite for whole networks
(#6987)
a7bf979 is described below
commit a7bf97936ce51e5ece2f66d7fa9602b35a3aa1fa
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Dec 2 18:05:35 2020 -0800
[AutoScheduler] Support layout rewrite for whole networks (#6987)
* [AutoScheduler] Add layout rewrite pass in relay
* fix
* fix lint
* fix attrs
* trigger CI
* Apply suggestions from code review
* trigger CI
* Update python/tvm/auto_scheduler/relay_integration.py
* Update python/tvm/auto_scheduler/relay_integration.py
* Update python/tvm/auto_scheduler/compute_dag.py
* Trigger CI
* Apply suggestions from code review
---
include/tvm/ir/transform.h | 7 +
include/tvm/relay/attrs/nn.h | 1 +
include/tvm/relay/attrs/transform.h | 14 ++
include/tvm/relay/transform.h | 14 ++
include/tvm/topi/transform.h | 68 +++++++++
python/tvm/auto_scheduler/__init__.py | 2 +-
python/tvm/auto_scheduler/compute_dag.py | 17 +++
python/tvm/auto_scheduler/measure.py | 4 +-
python/tvm/auto_scheduler/relay_integration.py | 103 ++++++++++++-
python/tvm/relay/op/_transform.py | 2 +
python/tvm/relay/op/strategy/generic.py | 15 +-
python/tvm/relay/op/strategy/x86.py | 3 +-
python/tvm/te/tensor.py | 2 +-
python/tvm/topi/nn/conv2d.py | 41 +++++-
src/auto_scheduler/compute_dag.cc | 20 ++-
src/ir/transform.cc | 54 +++----
src/relay/backend/build_module.cc | 17 +++
src/relay/backend/compile_engine.cc | 26 ++--
src/relay/backend/compile_engine.h | 9 ++
src/relay/backend/utils.h | 9 ++
src/relay/op/make_op.h | 2 +
src/relay/op/nn/convolution.h | 12 +-
src/relay/op/tensor/transform.cc | 50 ++++++-
.../transforms/auto_scheduler_layout_rewrite.cc | 160 +++++++++++++++++++++
.../transforms/auto_scheduler_layout_rewrite.h | 49 +++++++
.../relay/test_auto_scheduler_layout_rewrite.py | 121 ++++++++++++++++
26 files changed, 751 insertions(+), 71 deletions(-)
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index d293112..56905de 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -198,6 +198,13 @@ class PassContext : public ObjectRef {
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool
is_before) const;
/*!
+ * \brief Check whether a pass is enabled.
+ * \param info The pass information.
+ * \return true if the pass is enabled. Otherwise, false.
+ */
+ TVM_DLL bool PassEnabled(const PassInfo& info) const;
+
+ /*!
* \brief Register a valid configuration option and its ValueType for
validation.
*
* \param key The configuration key.
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index e697ac4..f8aa1fc 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
+ std::string auto_scheduler_rewritten_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index 3ed6b83..cbe989f 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -358,6 +358,20 @@ struct LayoutTransformAttrs : public
tvm::AttrsNode<LayoutTransformAttrs> {
}
};
+/*! \brief Attributes for AutoSchedulerLayoutTransform operator */
+struct AutoSchedulerLayoutTransformAttrs
+ : public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> {
+ std::string src_layout;
+ std::string dst_layout;
+
+ TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs,
+ "relay.attrs.AutoSchedulerLayoutTransformAttrs") {
+ TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor.
(e.g. 1N32C112H112W)");
+ TVM_ATTR_FIELD(dst_layout)
+ .describe("The destination layout of the tensor. (e.g.
1N2C112H112W16c)");
+ }
+};
+
/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
DataType dtype;
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index a9a45b5..e4b39da 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -107,6 +107,14 @@ TVM_DLL Pass FoldConstant();
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
/*!
+ * \brief The inverse operation of FuseOps. It transforms a fused program
returned by
+ * FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x)))
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass DefuseOps();
+
+/*!
* \brief Rewrite the annotated program.
*
* \param fallback_device The fallback device which is the default device for
@@ -316,6 +324,12 @@ TVM_DLL Pass CanonicalizeOps();
TVM_DLL Pass AlterOpLayout();
/*!
+ * \brief Do layout rewrite according to the tile structure created by
auto-scheduler.
+ * \return The pass
+ */
+TVM_DLL Pass AutoSchedulerLayoutRewrite();
+
+/*!
* \brief Given a dest layout, this pass transforms the expr such that most of
the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2
layout transforms, one
* at the start and one at the end.
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index c866dfb..c2a4843 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const
std::string& src_layout,
name, tag);
}
+/*! \brief Utility function for auto_scheduler_layout_transform */
+inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>*
shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+/*!
+ * \brief Transform the auto-scheduler generated layout according to
+ * \p src_layout and \p dst_layout
+ * \param src the source input.
+ * \param src_layout the source layout.
+ * \param dst_layout the destination layout.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return A tensor with shape in \p dst_layout
+ */
+inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String&
src_layout,
+ const String& dst_layout,
+ const String name =
"T_auto_scheduler_layout_trans",
+ const String tag = kInjective) {
+ Array<PrimExpr> src_shape;
+ std::vector<std::string> src_axes;
+ Array<PrimExpr> dst_shape;
+ std::vector<std::string> dst_axes;
+
+ parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
+ parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
+ return compute(
+ dst_shape,
+ [&](const Array<Var>& dst_indices) {
+ Array<PrimExpr> dst_indices_expr(dst_indices.begin(),
dst_indices.end());
+ Array<PrimExpr> src_indices;
+ for (const std::string& src_axis : src_axes) {
+ PrimExpr src_index = 0;
+ CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
+ for (size_t i = 0; i < dst_axes.size(); ++i) {
+ if (dst_axes[i] == src_axis) {
+ src_index = src_index * dst_shape[i] + dst_indices_expr[i];
+ }
+ }
+ src_indices.push_back(src_index);
+ }
+ return src(src_indices);
+ },
+ name, tag);
+}
+
/*!
* \brief Get the shape of input tensor.
* \param src the input tensor.
diff --git a/python/tvm/auto_scheduler/__init__.py
b/python/tvm/auto_scheduler/__init__.py
index f0d076e..5bf2335 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -44,7 +44,7 @@ from .measure import (
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best,
load_records, save_records
-from .relay_integration import extract_tasks
+from .relay_integration import extract_tasks, remove_index_check,
rewrite_compute_body
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
diff --git a/python/tvm/auto_scheduler/compute_dag.py
b/python/tvm/auto_scheduler/compute_dag.py
index 3427709..cba3600 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -162,6 +162,23 @@ class ComputeDAG(Object):
updated_state.stage_id_map[k] = v
return updated_state
+ def rewrite_layout_from_state(self, state):
+ """
+ Rewrite the layout of the DAG according to the history transform steps
of a state.
+
+ Parameters
+ ----------
+ state : Union[State, StateObject]
+ The state from which we get transform steps.
+
+ Returns
+ -------
+ updated_dag : ComputeDAG
+ The compute dag with rewritten layout.
+ """
+ state_obj = state if isinstance(state, StateObject) else
state.state_object
+ return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)
+
def hash_key(self):
"""Return the hash key of this compute DAG.
diff --git a/python/tvm/auto_scheduler/measure.py
b/python/tvm/auto_scheduler/measure.py
index 117cd4f..b9d7148 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -544,7 +544,9 @@ def _timed_func(inp_serialized, build_func, verbose):
args = []
try:
- sch, args = task.compute_dag.apply_steps_from_state(inp.state,
layout_rewrite=True)
+ sch, args = task.compute_dag.apply_steps_from_state(
+ inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed
+ )
# pylint: disable=broad-except
except Exception:
error_no = MeasureErrorNo.INSTANTIATION_ERROR
diff --git a/python/tvm/auto_scheduler/relay_integration.py
b/python/tvm/auto_scheduler/relay_integration.py
index 6864bcc..25b8881 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -23,11 +23,15 @@ Integrate auto_scheduler into relay. It implements the
following items:
"""
import logging
+import json
import threading
import tvm
from tvm import autotvm, te, transform
-from tvm.te.tensor import ComputeOp, PlaceholderOp
+from tvm.runtime import convert_to_object
+from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
+from tvm.tir import expr as _expr
+from . import _ffi_api
from .compute_dag import ComputeDAG
from .dispatcher import DispatchContext
from .search_task import SearchTask
@@ -46,7 +50,11 @@ def call_all_topi_funcs(mod, params, target):
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True
- with transform.PassContext(opt_level=3,
config={"relay.backend.use_auto_scheduler": True}):
+ with transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_auto_scheduler": True},
+ disabled_pass={"AutoSchedulerLayoutRewrite"},
+ ):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
@@ -158,6 +166,20 @@ class TracingEnvironment:
self.wkl_key_to_ccache_key[workload_key] = ccache_key
+@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
+def enter_layout_rewrite():
+ """Enter layout rewrite tracing environment"""
+ env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
+ env.__enter__()
+
+
+@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
+def exit_layout_rewrite():
+ """Exit layout rewrite tracing environment"""
+ env = TracingEnvironment.current
+ env.__exit__(None, None, None)
+
+
def traverse_to_get_io_tensors(outs):
"""Traverse from a list of output tensors to get both input and output
tensors
@@ -230,11 +252,13 @@ def auto_schedule_topi(outs, has_complex_op):
key = register_workload_tensors(dag.hash_key(), io_tensors)
# only enable layout rewrite for cpu backend
- enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys
+ target = tvm.target.Target.current()
+ enable_layout_rewrite = "cpu" in target.keys
env = TracingEnvironment.current
- if env is None: # in the final build mode
- state = DispatchContext.current.query(tvm.target.Target.current(),
key, has_complex_op, dag)
+ if env is None:
+ # in the final build mode
+ state = DispatchContext.current.query(target, key, has_complex_op, dag)
if state is None:
return None
@@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
- # todo(merrymercy, minminsun): port layout rewrite
- raise NotImplementedError
+ # in prepare_layout_rewrite mode
+ if enable_layout_rewrite and has_layout_free:
+ dispatch_ctx = DispatchContext.current
+ state = dispatch_ctx.query(target, key, has_complex_op, dag)
+ if state is None:
+ return None
+
+ # rewrite the layout and update the context for the new dag
+ dag = ComputeDAG(outs)
+ new_dag = dag.rewrite_layout_from_state(state)
+ new_key = json.dumps((new_dag.hash_key(),))
+ if new_key != key:
+ dispatch_ctx.update(target, new_key, state)
+ return te.create_schedule([x.op for x in outs])
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)
return schedule
+
+
+def tensor_no_check_call(self, *indices):
+ """An indexing function without any check.
+ This is the same as `tvm.te.Tensor::__call__` except that the safety
+ check is removed.
+ """
+ indices = convert_to_object(indices)
+ args = []
+ for x in indices:
+ if isinstance(x, _expr.PrimExpr):
+ args.append(x)
+ elif isinstance(x, _expr.IterVar):
+ args.append(x.var)
+ else:
+ raise ValueError("The indices must be expression")
+
+ return _expr.ProducerLoad(self, args)
+
+
+def remove_index_check(tensor):
+ """Remove the safety check in the indexing function for a tensor.
+ This is done by monkey patching its indexing function.
+ After removing the check, we are allowed to create a
+ temporary wrong IR and fix it later in other places.
+
+ Parameters
+ ----------
+ tensor: Tensor
+ The tensor to remove index check.
+ """
+ # Monkey patch the indexing function
+ tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)
+
+
+def rewrite_compute_body(compute_tensor, new_layout):
+ """Rewrite the body of a ComputeOp according to a new layout of a
placeholder"""
+ op = compute_tensor.op
+
+ # Get layout free placeholders
+ layout_free_placeholders = op.attrs["layout_free_placeholders"]
+ assert len(layout_free_placeholders) == 1, "Only support one layout free
placeholder"
+ placeholder_op = layout_free_placeholders[0].op
+
+ # Rewrite the index expression in body
+ body = []
+ for b in op.body:
+ body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op,
new_layout, b))
+ op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis,
body)
+
+ num = op_node.num_outputs
+ outputs = tuple(op_node.output(i) for i in range(num))
+ return outputs[0] if num == 1 else outputs
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index e1cb9e9..38d27e3 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -79,6 +79,8 @@ _reg.register_injective_schedule("strided_set")
# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
+_reg.register_injective_schedule("auto_scheduler_layout_transform")
+_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
# argwhere
@_reg.register_compute("argwhere")
diff --git a/python/tvm/relay/op/strategy/generic.py
b/python/tvm/relay/op/strategy/generic.py
index ac9d3b1..a03c517 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -19,7 +19,7 @@
import logging
import re
-from tvm import topi
+from tvm import topi, _ffi
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple,
get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
@@ -166,9 +166,17 @@ def schedule_bitpack(attrs, outs, target):
return topi.generic.schedule_bitpack(outs)
+get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
+ "relay.attrs.get_auto_scheduler_rewritten_layout"
+)
+
# conv2d
def wrap_compute_conv2d(
- topi_compute, need_data_layout=False, need_out_layout=False,
has_groups=False
+ topi_compute,
+ need_data_layout=False,
+ need_out_layout=False,
+ has_groups=False,
+ need_auto_scheduler_layout=False,
):
"""Wrap conv2d topi compute"""
@@ -179,6 +187,7 @@ def wrap_compute_conv2d(
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
+ auto_scheduler_rewritten_layout =
get_auto_scheduler_rewritten_layout(attrs)
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
@@ -188,6 +197,8 @@ def wrap_compute_conv2d(
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
+ if need_auto_scheduler_layout:
+ args.append(auto_scheduler_rewritten_layout)
return [topi_compute(*args)]
return _compute_conv2d
diff --git a/python/tvm/relay/op/strategy/x86.py
b/python/tvm/relay/op/strategy/x86.py
index 3f129c4..98b56ef 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -117,9 +117,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
- logger.warning("For x86 target, NCHW layout is recommended for
conv2d.")
strategy.add_implementation(
- wrap_compute_conv2d(topi.nn.conv2d_nhwc),
+ wrap_compute_conv2d(topi.nn.conv2d_nhwc,
need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index 6294eab..bdf3954 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -40,7 +40,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp):
def asobject(self):
"""Convert slice to object."""
- return self.tensor(*self.indices)
+ return self.tensor.__call__(*self.indices)
@property
def dtype(self):
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 7c9cef6..8d591a2 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -20,7 +20,7 @@
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
-from tvm import te
+from tvm import te, auto_scheduler
from .pad import pad
from .utils import get_pad_tuple
@@ -331,7 +331,15 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation,
out_dtype=None):
return Output
-def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"):
+def conv2d_nhwc(
+ Input,
+ Filter,
+ stride,
+ padding,
+ dilation,
+ out_dtype="float32",
+ auto_scheduler_rewritten_layout="",
+):
"""Convolution operator in NHWC layout.
Parameters
@@ -371,8 +379,30 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation,
out_dtype="float32"):
else:
dilation_h, dilation_w = dilation
+ if auto_scheduler_rewritten_layout:
+ # Infer shape for the rewritten layout
+ # todo(merrymercy): wrap this with a more general interface.
+ if len(Filter.shape) >= 10:
+ # For cpu tile structure SSRSRS
+ base = len(Filter.shape) - 10
+ kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base]
+ kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base]
+ channel = Filter.shape[4 + base] * Filter.shape[8 + base]
+ num_filter = Filter.shape[5 + base] * Filter.shape[9 + base]
+ for i in range(base + 2):
+ num_filter *= Filter.shape[i]
+ elif len(Filter.shape) == 4:
+ num_filter, kernel_h, kernel_w, channel = Filter.shape
+ else:
+ raise ValueError(
+ "Don't know how to infer the layout for filter shape: %s. "
+ "Please add a new branch to handle this case." % str(Filter)
+ )
+ auto_scheduler.remove_index_check(Filter)
+ else:
+ kernel_h, kernel_w, channel, num_filter = Filter.shape
+
batch, in_height, in_width, in_channel = Input.shape
- kernel_h, kernel_w, channel, num_filter = Filter.shape
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
@@ -399,7 +429,12 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation,
out_dtype="float32"):
),
name="Conv2dOutput",
tag="conv2d_nhwc",
+ attrs={"layout_free_placeholders": [Filter]},
)
+
+ if auto_scheduler_rewritten_layout:
+ Output = auto_scheduler.rewrite_compute_body(Output,
auto_scheduler_rewritten_layout)
+
return Output
diff --git a/src/auto_scheduler/compute_dag.cc
b/src/auto_scheduler/compute_dag.cc
index caaed6f..ca59979 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -42,6 +42,7 @@
#include <vector>
#include "../arith/pattern_match.h"
+#include "../relay/transforms/auto_scheduler_layout_rewrite.h"
#include "search_policy/utils.h"
#include "utils.h"
@@ -813,8 +814,7 @@ std::string GetOrigLayout(std::set<std::string>*
placeholder_axis_names, const t
ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
std::string orig_layout = os.str();
os.str("");
- // TODO(minmin): uncomment this line for relay integration
- //
::tvm::relay::KernelLayoutTransformer::global_orig_layouts_queue.push_back(orig_layout);
+
::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout);
return orig_layout;
}
@@ -878,8 +878,7 @@ std::string GetNewLayout(const State& state, const int
stage_id, const Stage& st
}
std::string new_layout = os.str();
os.str("");
- // TODO(minmin): uncomment this line for relay integration
- //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+
::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout);
return new_layout;
}
@@ -1440,5 +1439,18 @@
TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState")
return dag.InferBound(state);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState")
+ .set_body_typed([](const ComputeDAG& dag, const State& state) {
+ Array<Step>* transform_steps =
const_cast<Array<Step>*>(&state->transform_steps);
+ return dag.RewriteLayout(transform_steps,
LayoutRewriteOption::RewriteForPreTransformed);
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
+ .set_body_typed([](const te::Operation& placeholder_op, const std::string&
new_layout,
+ const PrimExpr& body) {
+ IndexRewriter index_rewriter(placeholder_op, new_layout);
+ return index_rewriter.Rewrite(body);
+ });
+
} // namespace auto_scheduler
} // namespace tvm
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 3b77446..f4516d5 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -74,6 +74,26 @@ PassContext PassContext::Current() {
}
}
+// linearly scan the pass array to match pass_name
+bool PassArrayContains(const Array<runtime::String>& pass_array, const
std::string& pass_name) {
+ for (auto x : pass_array) {
+ if (x == pass_name) return true;
+ }
+ return false;
+}
+
+bool PassContext::PassEnabled(const PassInfo& info) const {
+ if (PassArrayContains(operator->()->disabled_pass, info->name)) {
+ return false;
+ }
+
+ if (PassArrayContains(operator->()->required_pass, info->name)) {
+ return true;
+ }
+
+ return operator->()->opt_level >= info->opt_level;
+}
+
class PassConfigManager {
public:
void Register(std::string key, uint32_t value_type_index) {
@@ -225,15 +245,6 @@ class SequentialNode : public PassNode {
PassInfo Info() const override { return pass_info; }
/*!
- * \brief Check if a pass is enabled.
- *
- * \param info The pass information.
- *
- * \return true if the pass is enabled. Otherwise, false.
- */
- bool PassEnabled(const PassInfo& info) const;
-
- /*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
@@ -344,29 +355,6 @@ void SequentialNode::ResolveDependency(const IRModule&
mod) {
<< "\n";
}
-// linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<runtime::String>& pass_array,
- const std::string& pass_name) {
- for (auto x : pass_array) {
- if (x == pass_name) return true;
- }
- return false;
-}
-
-bool SequentialNode::PassEnabled(const PassInfo& info) const {
- PassContext ctx = PassContext::Current();
-
- if (PassArrayContains(ctx->disabled_pass, info->name)) {
- return false;
- }
-
- if (PassArrayContains(ctx->required_pass, info->name)) {
- return true;
- }
-
- return ctx->opt_level >= info->opt_level;
-}
-
Pass GetPass(const String& pass_name) {
using tvm::runtime::Registry;
const runtime::PackedFunc* f = nullptr;
@@ -387,7 +375,7 @@ IRModule SequentialNode::operator()(IRModule mod, const
PassContext& pass_ctx) c
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
- if (!PassEnabled(pass_info)) continue;
+ if (!pass_ctx.PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(std::move(mod), pass_ctx);
diff --git a/src/relay/backend/build_module.cc
b/src/relay/backend/build_module.cc
index 82ac1c5..a0828d1 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -338,7 +338,24 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
+
+ // Do layout rewrite for auto-scheduler.
+ if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
+ const auto& target = (*targets.begin()).second;
+ Pass major_pass = transform::AutoSchedulerLayoutRewrite();
+
+ if (target->kind->device_type == kDLCPU &&
pass_ctx.PassEnabled(major_pass->Info())) {
+ With<Target> tctx(target);
+ relay_module = major_pass(relay_module);
+ // Defuse ops to fold constants, then fuse them again
+ relay_module = transform::DefuseOps()(relay_module);
+ relay_module = transform::FoldConstant()(relay_module);
+ relay_module = transform::FuseOps()(relay_module);
+ }
+ }
+
relay_module = transform::InferType()(relay_module);
+
// Inline the functions that have been lifted by the module scope.
//
// TODO(@zhiics) Note that we need to be careful about the subgraphs with
diff --git a/src/relay/backend/compile_engine.cc
b/src/relay/backend/compile_engine.cc
index 1559d7e..98d9136 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -101,9 +101,7 @@ class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {
// Whether to use auto_scheduler schedule.
- use_auto_scheduler_ = transform::PassContext::Current()
-
->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
- .value();
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
}
CachedFunc Create(const Function& prim_func) {
@@ -322,6 +320,17 @@ class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
const Op& device_copy_op_;
};
+/*!
+ * \brief Create schedule for target.
+ * \param source_func The primitive function to be lowered.
+ * \param target The target we want to create schedule for.
+ * \return Pair of schedule and cache.
+ * The funcs field in cache is not yet populated.
+ */
+CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
+ return ScheduleGetter(target).Create(source_func);
+}
+
// Creates shape function from functor.
class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
@@ -680,17 +689,6 @@ class CompileEngineImpl : public CompileEngineNode {
*/
CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
- /*!
- * \brief Create schedule for target.
- * \param source_func The primitive function to be lowered.
- * \param target The target we want to create schedule for.
- * \return Pair of schedule and cache.
- * The funcs field in cache is not yet populated.
- */
- CachedFunc CreateSchedule(const Function& source_func, const Target& target)
{
- return ScheduleGetter(target).Create(source_func);
- }
-
private:
// implement lowered func
CCacheValue LowerInternal(const CCacheKey& key) {
diff --git a/src/relay/backend/compile_engine.h
b/src/relay/backend/compile_engine.h
index 5582291..d7628e7 100644
--- a/src/relay/backend/compile_engine.h
+++ b/src/relay/backend/compile_engine.h
@@ -242,6 +242,15 @@ class CompileEngine : public ObjectRef {
};
/*!
+ * \brief Create schedule for target.
+ * \param source_func The primitive function to be lowered.
+ * \param target The target we want to create schedule for.
+ * \return Pair of schedule and cache.
+ * The funcs field in cache is not yet populated.
+ */
+CachedFunc CreateSchedule(const Function& source_func, const Target& target);
+
+/*!
* \brief Check if the type is dynamic.
* \param ty The type to be checked.
* \return The result.
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index ccb8611..e167720 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -294,6 +294,15 @@ inline std::string GetExtSymbol(const Function& func) {
return std::string(name_node.value());
}
+/*!
+ * \brief Return whether the auto scheduler is enabled in the pass context.
+ */
+inline bool IsAutoSchedulerEnabled() {
+ return transform::PassContext::Current()
+ ->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
+ .value();
+}
+
} // namespace backend
} // namespace relay
} // namespace tvm
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 34bff0f..d2fb6aa 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -52,6 +52,8 @@ Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType
dtype);
Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);
+Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String
dst_layout);
+
Expr MakeOnes(Array<Integer> shape, DataType dtype);
Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value,
String pad_mode);
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index f011222..13e87a5 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -212,8 +212,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
if (weight != nullptr) {
weight_dtype = weight->dtype;
}
- // assign result to reporter
- reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+
+ if (param->auto_scheduler_rewritten_layout.size() == 0) {
+ // Normal case: assign result to reporter
+ reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+ } else {
+ // If the layout is rewritten by auto-scheduler,
+ // we just forcly apply the layout provided by auto-scheduler and
+ // skip the normal inference logic.
+ {} // do nothing
+ }
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 5a13e9a..a3a9280 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2806,7 +2806,55 @@ the input array by output[n, c, h, w, C] = data[n,
C*16+c, h, w]
.set_support_level(5)
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
-/* relay._contrib_reverse_reshape */
+// relay.auto_scheduler_layout_transform
+TVM_REGISTER_NODE_TYPE(AutoSchedulerLayoutTransformAttrs);
+
+Array<te::Tensor> AutoSchedulerLayoutTransformCompute(const Attrs& attrs,
+ const Array<te::Tensor>&
inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<AutoSchedulerLayoutTransformAttrs>();
+ CHECK(param != nullptr);
+ return Array<te::Tensor>{
+ topi::auto_scheduler_layout_transform(inputs[0], param->src_layout,
param->dst_layout)};
+}
+
+bool AutoSchedulerLayoutTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
+ const TypeReporter& reporter) {
+ const auto* data = types[0].as<TensorTypeNode>();
+ CHECK(data != nullptr);
+ const AutoSchedulerLayoutTransformAttrs* params =
attrs.as<AutoSchedulerLayoutTransformAttrs>();
+
+ Array<IndexExpr> dst_shape;
+ std::vector<std::string> dst_axes;
+
+ topi::parse_auto_scheduler_layout(params->dst_layout, &dst_shape, &dst_axes);
+
+ reporter->Assign(types[1], TensorType(dst_shape, data->dtype));
+ return true;
+}
+
+Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String
dst_layout) {
+ auto attrs = make_object<AutoSchedulerLayoutTransformAttrs>();
+ attrs->src_layout = std::move(src_layout);
+ attrs->dst_layout = std::move(dst_layout);
+ static const Op& op = Op::Get("auto_scheduler_layout_transform");
+ return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.auto_scheduler_layout_transform")
+ .set_body_typed(MakeAutoSchedulerLayoutTransform);
+
+RELAY_REGISTER_OP("auto_scheduler_layout_transform")
+ .describe(R"code(Transform the input kernel layout.
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<AutoSchedulerLayoutTransformAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_type_rel("auto_scheduler_layout_transform",
AutoSchedulerLayoutTransformRel)
+ .set_support_level(5)
+ .set_attr<FTVMCompute>("FTVMCompute", AutoSchedulerLayoutTransformCompute);
+
+// relay._contrib_reverse_reshape
Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
new file mode 100644
index 0000000..c9875ef
--- /dev/null
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_scheduler_layout_rewrite.h
+ * \brief Rewrite the layout of "layout free" tensors (e.g., the weight
tensors in
+ * conv2d and dense layers) according to the tile structure generated by the
auto-scheduler.
+ */
+
+#include "auto_scheduler_layout_rewrite.h"
+
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+#include <deque>
+#include <functional>
+#include <vector>
+
+#include "../backend/compile_engine.h"
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// Two global variables for receiving layout information from python
+std::deque<std::string> AutoSchedulerLayoutRewriter::global_ori_layouts_queue;
+std::deque<std::string> AutoSchedulerLayoutRewriter::global_new_layouts_queue;
+
+// Copy an Attrs but with a new auto_scheduler_rewritten_layout filed.
+template <typename T>
+Attrs CopyAttrsWithNewLayout(const T* ptr, const std::string& layout) {
+ auto n = make_object<T>(*ptr);
+ n->auto_scheduler_rewritten_layout = layout;
+ return Attrs(n);
+}
+
+// Mutate ops in a function
+class FuncMutator : public ExprMutator {
+ public:
+ FuncMutator(const std::deque<std::string>& ori_layouts_queue,
+ const std::deque<std::string>& new_layouts_queue)
+ : ExprMutator(),
+ ori_layouts_queue_(ori_layouts_queue),
+ new_layouts_queue_(new_layouts_queue) {}
+
+ Expr VisitExpr_(const CallNode* n) {
+ auto new_n = ExprMutator::VisitExpr_(n);
+
+ const auto* call = new_n.as<CallNode>();
+ if (call && call->op.as<OpNode>() &&
+ (std::find(target_ops_.begin(), target_ops_.end(),
n->op.as<OpNode>()->name) !=
+ target_ops_.end()) &&
+ !ori_layouts_queue_.empty() && !new_layouts_queue_.empty()) {
+ // Pop a new layout from the queue
+ const std::string ori_layout = ori_layouts_queue_.front();
+ const std::string new_layout = new_layouts_queue_.front();
+ ori_layouts_queue_.pop_front();
+ new_layouts_queue_.pop_front();
+
+ // Insert a new op to do layout transform. (This will be simplified by
FoldConstant later).
+ Expr updated_kernel = MakeAutoSchedulerLayoutTransform(call->args[1],
ori_layout, new_layout);
+ Array<Expr> updated_args = {call->args[0], updated_kernel};
+
+ // Update the attrs
+ Attrs updated_attrs;
+ if (auto pattr = call->attrs.as<Conv2DAttrs>()) {
+ updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+ }
+ new_n = Call(call->op, updated_args, updated_attrs);
+ }
+ return new_n;
+ }
+
+ private:
+ std::deque<std::string> ori_layouts_queue_;
+ std::deque<std::string> new_layouts_queue_;
+
+ std::vector<std::string> target_ops_{"nn.conv2d"};
+};
+
+Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
+ auto new_n = ExprMutator::VisitExpr_(n);
+
+ if (const auto* call = new_n.as<CallNode>()) {
+ if (const auto* func = call->op.as<FunctionNode>()) {
+ global_ori_layouts_queue.clear();
+ global_new_layouts_queue.clear();
+
+ // Use ScheduleGetter to call python lower functions.
+ // This is used to get the layout transform information.
+ // The layout transformation will be recorded to global_ori_layout_queue
+ // and global_new_layouts_queue in ComputeDAG::RewriteLayout.
+ auto f = runtime::Registry::Get("auto_scheduler.enter_layout_rewrite");
+ CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite
function.";
+ (*f)();
+
+ CreateSchedule(GetRef<Function>(func), Target::Current());
+
+ f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite");
+ CHECK(f) << "Could not find ansor.exit_layout_rewrite function.";
+ (*f)();
+
+ // Mutate the called function
+ if (!global_ori_layouts_queue.empty() &&
!global_new_layouts_queue.empty()) {
+ auto ret = FuncMutator(global_ori_layouts_queue,
global_new_layouts_queue).VisitExpr(new_n);
+ return ret;
+ }
+ }
+ }
+
+ return new_n;
+}
+
+Expr AutoSchedulerLayoutRewrite(const Expr& expr) {
+ return AutoSchedulerLayoutRewriter().Mutate(expr);
+}
+
+namespace transform {
+
+Pass AutoSchedulerLayoutRewrite() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(relay::AutoSchedulerLayoutRewrite(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "AutoSchedulerLayoutRewrite",
{"InferType"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite")
+ .set_body_typed(AutoSchedulerLayoutRewrite);
+
+TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
+ .set_body_typed([](const Attrs& attrs) {
+ if (attrs->IsInstance<Conv2DAttrs>()) {
+ return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
+ }
+ return std::string();
+ });
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.h
b/src/relay/transforms/auto_scheduler_layout_rewrite.h
new file mode 100644
index 0000000..d0d89db
--- /dev/null
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.h
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_scheduler_layout_rewrite.h
+ * \brief Rewrite the layout of "layout free" tensors (e.g., the weight
tensors in
+ * conv2d and dense layers) according to the tile structure generated by the
auto-scheduler.
+ */
+
+#ifndef TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_
+#define TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_
+
+#include <tvm/relay/expr_functor.h>
+
+#include <deque>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+class AutoSchedulerLayoutRewriter : public ExprMutator {
+ public:
+ Expr VisitExpr_(const CallNode* n) final;
+
+ // Two global variables for receiving layout information from python
+ static std::deque<std::string> global_ori_layouts_queue;
+ static std::deque<std::string> global_new_layouts_queue;
+};
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py
b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
new file mode 100644
index 0000000..299fcb8
--- /dev/null
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test layout rewrite support for whole neural networks"""
+import tempfile
+
+import numpy as np
+
+import tvm
+from tvm import relay, auto_scheduler
+from tvm.contrib import graph_runtime
+import tvm.testing
+
+
+def get_np_array(var, dtype):
+ return np.random.randn(*[int(x) for x in
var.type_annotation.shape]).astype(dtype)
+
+
+def get_relay_conv2d(
+ outc=128,
+ inc=64,
+ height=14,
+ width=14,
+ kh=3,
+ kw=3,
+ batch=1,
+ pad=0,
+ stride=1,
+ dilation=1,
+ layout="NHWC",
+):
+ dtype = "float32"
+ if layout == "NHWC":
+ kernel_layout = "HWIO"
+ d = relay.var("data", shape=(batch, height, width, inc), dtype=dtype)
+ w = relay.var("weight", shape=(kh, kw, inc, outc), dtype=dtype)
+ elif layout == "NCHW":
+ kernel_layout = "OIHW"
+ d = relay.var("data", shape=(batch, inc, height, width), dtype=dtype)
+ w = relay.var("weight", shape=(outc, inc, kh, kw), dtype=dtype)
+
+ y = relay.nn.conv2d(
+ d,
+ w,
+ padding=pad,
+ kernel_size=(kh, kw),
+ strides=(stride, stride),
+ dilation=(dilation, dilation),
+ channels=outc,
+ groups=1,
+ data_layout=layout,
+ kernel_layout=kernel_layout,
+ )
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([d, w], y)
+ data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+ return mod, data, weight
+
+
+def tune_and_check(mod, data, weight):
+ # Extract tasks from a relay program
+ target = tvm.target.Target("llvm")
+ tasks, task_weights = auto_scheduler.extract_tasks(mod, target=target,
params={})
+
+ with tempfile.NamedTemporaryFile() as fp:
+ log_file = fp.name
+
+ # Tune tasks
+ tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+ tune_option = auto_scheduler.TuningOptions(
+ num_measure_trials=1,
+ num_measures_per_round=1,
+ builder=auto_scheduler.LocalBuilder(timeout=60),
+ measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+ )
+ tuner.tune(tune_option, search_policy="sketch.random")
+
+ # Compile and run
+ def compile_and_run(disabled_pass={}):
+ with auto_scheduler.ApplyHistoryBest(log_file):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_auto_scheduler": True},
+ disabled_pass=disabled_pass,
+ ):
+ lib = relay.build(mod, target=target, params={"weight":
weight})
+
+ ctx = tvm.cpu()
+ module = graph_runtime.GraphModule(lib["default"](ctx))
+ module.set_input("data", data)
+ module.run()
+
+ return module.get_output(0).asnumpy()
+
+ # Check correctness
+ actual_output = compile_and_run()
+ expected_output =
compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"})
+
+ tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4)
+
+
+def test_conv2d():
+ mod, data, weight = get_relay_conv2d(kh=1, kw=1)
+ tune_and_check(mod, data, weight)
+
+
+if __name__ == "__main__":
+ test_conv2d()