comaniac commented on a change in pull request #6987:
URL: https://github.com/apache/tvm/pull/6987#discussion_r534367774
##########
File path: python/tvm/auto_scheduler/relay_integration.py
##########
@@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
- # todo(merrymercy, minminsun): port layout rewrite
- raise NotImplementedError
+ # in prepare_layout_rewrite mode
+ if enable_layout_rewrite and has_layout_free:
+ dispatch_ctx = DispatchContext.current
+ state = dispatch_ctx.query(target, key, has_complex_op, dag)
+ if state is None:
+ return te.create_schedule([x.op for x in outs])
+
+ # rewrite the layout and update the context for the new dag
+ dag = ComputeDAG(outs)
+ new_dag = dag.rewrite_layout_from_state(state)
+ new_key = json.dumps((new_dag.hash_key(),))
+ if new_key != key:
+ dispatch_ctx.update(target, new_key, state)
+ return te.create_schedule([x.op for x in outs])
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)
return schedule
+
+
+def tensor_no_check_call(self, *indices):
+ """An indexing function without any check.
+ This is the same as `tvm.te.Tensor::__call__` except that the safety
+ check is removed.
+ """
+ indices = convert_to_object(indices)
+ args = []
+ for x in indices:
+ if isinstance(x, _expr.PrimExpr):
+ args.append(x)
+ elif isinstance(x, _expr.IterVar):
+ args.append(x.var)
+ else:
+ raise ValueError("The indices must be expression")
+
+ return _expr.ProducerLoad(self, args)
+
+
+def remove_index_check(tensor):
+ """Remove the safety check in the indexing function for a tensor.
+ This is done by monkey patching its indexing function.
+ After removing the check, we are allowed to create a
+ temporary wrong IR and fix it later in other places.
+
+ Parameters
+ ----------
+ tensor: Tensor
+ The tensor to remove index check.
+ """
+ # Monkey patch the indexing function
+ tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)
+
+
+def rewrite_compute_body(compute_tensor, new_layout):
+ """Rewrite the body of a ComputeOp according to a new layout of a
placeholder"""
+ op = compute_tensor.op
+
+ # Get layout free placeholders
+ layout_free_placeholders = op.attrs["layout_free_placeholders"]
+ assert len(layout_free_placeholders) == 1
Review comment:
Add some messages?
##########
File path: tests/python/relay/test_auto_scheduler_layout_rewrite.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test layout rewrite support for whole neural networks"""
+import tempfile
+
+import numpy as np
+
+import tvm
+from tvm import relay, auto_scheduler
+from tvm.contrib import graph_runtime
+import tvm.testing
+
+
+def get_np_array(var, dtype):
+ return np.random.randn(*[int(x) for x in
var.type_annotation.shape]).astype(dtype)
+
+
+def get_relay_conv2d(
+ outc=128,
+ inc=64,
+ height=14,
+ width=14,
+ kh=3,
+ kw=3,
+ batch=1,
+ pad=0,
+ stride=1,
+ dilation=1,
+ layout="NHWC",
+):
+ dtype = "float32"
+ if layout == "NHWC":
+ kernel_layout = "HWIO"
+ d = relay.var("data", shape=(batch, height, width, inc), dtype=dtype)
+ w = relay.var("weight", shape=(kh, kw, inc, outc), dtype=dtype)
+ elif layout == "NCHW":
+ kernel_layout = "OIHW"
+ d = relay.var("data", shape=(batch, inc, height, width), dtype=dtype)
+ w = relay.var("weight", shape=(outc, inc, kh, kw), dtype=dtype)
+
+ y = relay.nn.conv2d(
+ d,
+ w,
+ padding=pad,
+ kernel_size=(kh, kw),
+ strides=(stride, stride),
+ dilation=(dilation, dilation),
+ channels=outc,
+ groups=1,
+ data_layout=layout,
+ kernel_layout=kernel_layout,
+ )
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([d, w], y)
+ data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+ return mod, data, weight
+
+
+def tune_and_check(mod, data, weight):
Review comment:
If we only want to test layout rewrite, we should be able to hard code a
log without runnning a tuning on the fly?
##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -162,6 +162,23 @@ def infer_bound_from_state(self, state):
updated_state.stage_id_map[k] = v
return updated_state
+ def rewrite_layout_from_state(self, state):
Review comment:
This might be confusing with `apply_steps_from_state(state,
layout_rewrite=True)`. IIUC, this function retuerns a DAG while
`apply_steps_from_state` returns a schedule. To differentiate, maybe we need a
better name such as `rewrite_dag_layout_from_state`, or improve the docstring.
##########
File path: python/tvm/auto_scheduler/relay_integration.py
##########
@@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
- # todo(merrymercy, minminsun): port layout rewrite
- raise NotImplementedError
+ # in prepare_layout_rewrite mode
+ if enable_layout_rewrite and has_layout_free:
+ dispatch_ctx = DispatchContext.current
+ state = dispatch_ctx.query(target, key, has_complex_op, dag)
+ if state is None:
+ return te.create_schedule([x.op for x in outs])
Review comment:
Should we return `None` in this case as the final build mode?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]