This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 25cf489b04 [TOPI] Layout Rewriting in TE (#11844)
25cf489b04 is described below
commit 25cf489b0410dc8cf4c938e9337a31b9c5ddd3b6
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Jun 23 11:43:45 2022 +0800
[TOPI] Layout Rewriting in TE (#11844)
---
python/tvm/auto_scheduler/relay_integration.py | 5 ++++
python/tvm/topi/cuda/conv2d_winograd.py | 1 +
python/tvm/topi/nn/batch_matmul.py | 14 +++++++++-
python/tvm/topi/nn/conv2d.py | 30 ++++++++++++++++++---
python/tvm/topi/nn/conv3d.py | 7 ++++-
python/tvm/topi/nn/dense.py | 36 +++++++++++++++++++++++---
src/auto_scheduler/compute_dag.cc | 10 +++++++
7 files changed, 95 insertions(+), 8 deletions(-)
diff --git a/python/tvm/auto_scheduler/relay_integration.py
b/python/tvm/auto_scheduler/relay_integration.py
index e9bf1ccfd7..ee166e8679 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -467,6 +467,11 @@ def rewrite_compute_body(compute_tensor, new_layout):
return outputs[0] if num == 1 else outputs
+def rewrite_tensor_shape(tensor, shape):
+ """Rewrite the tensor shape"""
+ _ffi_api.RewriteTensorShape(tensor, shape)
+
+
def is_auto_scheduler_enabled():
"""Return whether the auto-scheduler is enabled.
diff --git a/python/tvm/topi/cuda/conv2d_winograd.py
b/python/tvm/topi/cuda/conv2d_winograd.py
index d2b373ba87..89a21f5c02 100644
--- a/python/tvm/topi/cuda/conv2d_winograd.py
+++ b/python/tvm/topi/cuda/conv2d_winograd.py
@@ -379,6 +379,7 @@ def conv2d_winograd_nhwc_cuda(
out_dtype,
pre_computed=False,
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""Conv2D Winograd in NHWC layout.
This is a clean version to be used by the auto-scheduler for both CPU and
GPU.
diff --git a/python/tvm/topi/nn/batch_matmul.py
b/python/tvm/topi/nn/batch_matmul.py
index 26d45feb03..2156fe11ed 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -17,8 +17,10 @@
"""Batch matrix multiplication"""
# pylint: disable=invalid-name
import logging
+
import tvm
-from tvm import te, auto_scheduler
+from tvm import auto_scheduler, te
+
from ..utils import get_const_tuple
logger = logging.getLogger("topi")
@@ -32,6 +34,7 @@ def batch_matmul(
transpose_a=False,
transpose_b=True,
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
@@ -62,6 +65,9 @@ def batch_matmul(
auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the tensor
+
Returns
-------
output : tvm.te.Tensor
@@ -78,6 +84,12 @@ def batch_matmul(
auto_scheduler_rewritten_layout, ["b", "k", "j"]
)
auto_scheduler.remove_index_check(tensor_b)
+ elif meta_schedule_original_shape:
+ auto_scheduler.rewrite_tensor_shape(tensor_b,
meta_schedule_original_shape)
+ if transpose_b:
+ YB, YJ, YK = get_const_tuple(tensor_b.shape)
+ else:
+ YB, YK, YJ = get_const_tuple(tensor_b.shape)
else:
assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim"
if transpose_b:
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index b7ae9b3e1c..5db752f6d5 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -280,6 +280,7 @@ def conv2d_nhwc(
dilation,
out_dtype="float32",
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""Convolution operator in NHWC layout.
@@ -308,6 +309,9 @@ def conv2d_nhwc(
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
+
Returns
-------
output : tvm.te.Tensor
@@ -323,6 +327,7 @@ def conv2d_nhwc(
"NHWC",
out_dtype,
auto_scheduler_rewritten_layout,
+ meta_schedule_original_shape,
auto_scheduler_should_rewrite_layout=True,
)
@@ -716,6 +721,7 @@ def conv(
order: str,
out_dtype: Union[str, None] = None,
auto_scheduler_rewritten_layout: Optional[str] = None,
+ meta_schedule_original_shape=None,
auto_scheduler_should_rewrite_layout: bool = False,
):
"""Convolution operator in NCHW or NHWC layout.
@@ -755,14 +761,17 @@ def conv(
Elements are converted to this type before elementwise multiplication
and summation.
+ auto_scheduler_rewritten_layout: str
+ Layout from autoscheduler's layout rewritting.
+
+ meta_schedule_original_shape : Optional[List[PrimExpr]]
+ The original shape of the input tensor.
+
auto_scheduler_should_rewrite_layout : bool
Should auto scheduler be allowed to rewrite the layout of the filter
tensor. Defaults to false. This can cause errors if used with grouped
convs.
- auto_scheduler_rewritten_layout: str
- Layout from autoscheduler's layout rewritting.
-
Returns
-------
Output : tvm.te.Tensor
@@ -802,6 +811,8 @@ def conv(
permutation_to_kernel = [dim + 1, dim] + list(range(dim))
permutation_from_kernel = np.argsort(permutation_to_kernel)
+ if meta_schedule_original_shape:
+ auto_scheduler.rewrite_tensor_shape(filt, meta_schedule_original_shape)
batch, in_channel, *dimensions =
np.array(get_const_tuple(inp.shape))[permutation_to].tolist()
num_filter, _, *kernel_dimensions = np.array(get_const_tuple(filt.shape))[
permutation_to_kernel
@@ -959,6 +970,7 @@ def _conv2d_winograd_nhwc_impl(
tile_size,
pre_computed=False,
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""Conv2D Winograd implementation in NHWC layout.
This is a clean version to be used by the auto-scheduler for both CPU and
GPU.
@@ -983,6 +995,8 @@ def _conv2d_winograd_nhwc_impl(
Whether the kernel is precomputed
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
Returns
-------
@@ -994,6 +1008,8 @@ def _conv2d_winograd_nhwc_impl(
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
+ if meta_schedule_original_shape:
+ auto_scheduler.rewrite_tensor_shape(weight,
meta_schedule_original_shape)
assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation"
if not pre_computed:
@@ -1136,6 +1152,7 @@ def conv2d_winograd_nhwc(
out_dtype,
pre_computed=False,
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""Conv2D Winograd in NHWC layout.
This is a clean version to be used by the auto-scheduler for both CPU and
GPU.
@@ -1158,6 +1175,8 @@ def conv2d_winograd_nhwc(
Whether the kernel is precomputed
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
Returns
-------
@@ -1176,6 +1195,7 @@ def conv2d_winograd_nhwc(
tile_size,
pre_computed,
auto_scheduler_rewritten_layout,
+ meta_schedule_original_shape,
)
@@ -1187,6 +1207,7 @@ def conv2d_winograd_nhwc_without_weight_transform(
dilation,
out_dtype,
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""Conv2D Winograd without layout transform in NHWC layout.
This is a clean version to be used by the auto-scheduler for both CPU and
GPU.
@@ -1207,6 +1228,8 @@ def conv2d_winograd_nhwc_without_weight_transform(
Specifies the output data type.
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
Returns
-------
@@ -1223,4 +1246,5 @@ def conv2d_winograd_nhwc_without_weight_transform(
out_dtype,
pre_computed=True,
auto_scheduler_rewritten_layout=auto_scheduler_rewritten_layout,
+ meta_schedule_original_shape=meta_schedule_original_shape,
)
diff --git a/python/tvm/topi/nn/conv3d.py b/python/tvm/topi/nn/conv3d.py
index 2915b886a5..591c643a95 100644
--- a/python/tvm/topi/nn/conv3d.py
+++ b/python/tvm/topi/nn/conv3d.py
@@ -21,8 +21,8 @@ import tvm
from tvm import te
from ..utils import get_const_tuple
-from .winograd_util import winograd_transform_matrices
from .conv2d import conv
+from .winograd_util import winograd_transform_matrices
def conv3d_ncdhw(Input, Filter, stride, padding, dilation, groups,
out_dtype=None):
@@ -65,6 +65,7 @@ def conv3d_ndhwc(
groups,
out_dtype="float32",
auto_scheduler_rewritten_layout="",
+ meta_schedule_origin_shape=None,
):
"""Convolution operator in NDHWC layout.
@@ -94,6 +95,9 @@ def conv3d_ndhwc(
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_origin_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
+
Returns
-------
Output : tvm.te.Tensor
@@ -109,6 +113,7 @@ def conv3d_ndhwc(
"NDHWC",
out_dtype,
auto_scheduler_rewritten_layout,
+ meta_schedule_origin_shape,
)
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index 69fac92c7c..61f9c4e17c 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -17,7 +17,8 @@
# pylint: disable=invalid-name,unused-argument
"""TVM operator fully connected compute."""
import tvm
-from tvm import te, auto_scheduler
+from tvm import auto_scheduler, te
+
from .. import tag
@@ -29,6 +30,7 @@ def matmul(
transpose_a=False,
transpose_b=False,
auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
):
"""The default implementation of matmul in topi.
@@ -55,6 +57,9 @@ def matmul(
auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
+
Returns
-------
output : tvm.te.Tensor
@@ -77,6 +82,12 @@ def matmul(
auto_scheduler_rewritten_layout, ["j", "k"]
)
auto_scheduler.remove_index_check(tensor_b)
+ elif meta_schedule_original_shape:
+ auto_scheduler.rewrite_tensor_shape(tensor_b,
meta_schedule_original_shape)
+ if transpose_b:
+ out_dim, red_dim = tensor_b.shape
+ else:
+ red_dim, out_dim = tensor_b.shape
elif transpose_b:
out_dim, red_dim = tensor_b.shape
else:
@@ -156,7 +167,14 @@ def matmul_legalize(attrs, inputs, types):
return None
-def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layout=""):
+def dense(
+ data,
+ weight,
+ bias=None,
+ out_dtype=None,
+ auto_scheduler_rewritten_layout="",
+ meta_schedule_original_shape=None,
+):
"""The default implementation of dense in topi.
This is an alias of matmul_nt operator for data tensor in non-transposed
format and weight
tensor in transposed format.
@@ -178,12 +196,24 @@ def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layo
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
+ meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+ The original shape of the input tensor.
+
Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
- return matmul(data, weight, bias, out_dtype, False, True,
auto_scheduler_rewritten_layout)
+ return matmul(
+ data,
+ weight,
+ bias,
+ out_dtype,
+ False,
+ True,
+ auto_scheduler_rewritten_layout,
+ meta_schedule_original_shape,
+ )
@tvm.target.generic_func
diff --git a/src/auto_scheduler/compute_dag.cc
b/src/auto_scheduler/compute_dag.cc
index e82830fa4d..dad55db030 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -1517,6 +1517,16 @@
TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
return index_rewriter.Rewrite(body);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.RewriteTensorShape")
+ .set_body_typed([](te::Tensor tensor, Array<PrimExpr> new_shape) -> void {
+ ICHECK(tensor->op->IsInstance<te::PlaceholderOpNode>());
+ te::PlaceholderOpNode* op =
+
const_cast<te::PlaceholderOpNode*>(tensor->op.as<te::PlaceholderOpNode>());
+ te::TensorNode* t = const_cast<te::TensorNode*>(tensor.get());
+ op->shape = new_shape;
+ t->shape = new_shape;
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout")
.set_body_typed(GetShapeFromRewrittenLayout);