merrymercy commented on a change in pull request #6987:
URL: https://github.com/apache/tvm/pull/6987#discussion_r534440217



##########
File path: python/tvm/auto_scheduler/relay_integration.py
##########
@@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
             env.add_workload_key(key, ccache_key)
         schedule = te.create_schedule([x.op for x in outs])
     elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
-        # todo(merrymercy, minminsun): port layout rewrite
-        raise NotImplementedError
+        # in prepare_layout_rewrite mode
+        if enable_layout_rewrite and has_layout_free:
+            dispatch_ctx = DispatchContext.current
+            state = dispatch_ctx.query(target, key, has_complex_op, dag)
+            if state is None:
+                return te.create_schedule([x.op for x in outs])
+
+            # rewrite the layout and update the context for the new dag
+            dag = ComputeDAG(outs)
+            new_dag = dag.rewrite_layout_from_state(state)
+            new_key = json.dumps((new_dag.hash_key(),))
+            if new_key != key:
+                dispatch_ctx.update(target, new_key, state)
+        return te.create_schedule([x.op for x in outs])
     else:
         raise ValueError("Invalid tracing mode: " + env.tracing_mode)
 
     return schedule
+
+
+def tensor_no_check_call(self, *indices):
+    """An indexing function without any check.
+    This is the same as `tvm.te.Tensor::__call__` except that the safety
+    check is removed.
+    """
+    indices = convert_to_object(indices)
+    args = []
+    for x in indices:
+        if isinstance(x, _expr.PrimExpr):
+            args.append(x)
+        elif isinstance(x, _expr.IterVar):
+            args.append(x.var)
+        else:
+            raise ValueError("The indices must be expression")
+
+    return _expr.ProducerLoad(self, args)
+
+
+def remove_index_check(tensor):
+    """Remove the safety check in the indexing function for a tensor.
+    This is done by monkey patching its indexing function.
+    After removing the check, we are allowed to create a
+    temporary wrong IR and fix it later in other places.
+
+    Parameters
+    ----------
+    tensor: Tensor
+      The tensor to remove index check.
+    """
+    # Monkey patch the indexing function
+    tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)
+
+
+def rewrite_compute_body(compute_tensor, new_layout):
+    """Rewrite the body of a ComputeOp according to a new layout of a 
placeholder"""
+    op = compute_tensor.op
+
+    # Get layout free placeholders
+    layout_free_placeholders = op.attrs["layout_free_placeholders"]
+    assert len(layout_free_placeholders) == 1

Review comment:
       ```suggestion
       assert len(layout_free_placeholders) == 1, "Only support one layout free 
placeholder"
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to