This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f9171f1 [MetaSchedule] Schedule Rule: Add RFactor (#9975)
f9171f1 is described below
commit f9171f16e657d30fde3a366388b4fef837e5187f
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Jan 20 03:48:02 2022 +0800
[MetaSchedule] Schedule Rule: Add RFactor (#9975)
* add rfactor
* format
* fix ci
---
include/tvm/meta_schedule/schedule_rule.h | 10 +
python/tvm/meta_schedule/schedule_rule/__init__.py | 1 +
.../tvm/meta_schedule/schedule_rule/add_rfactor.py | 49 ++
python/tvm/meta_schedule/testing/schedule_rule.py | 8 +
python/tvm/meta_schedule/testing/te_workload.py | 877 +++++++++++++++++++++
src/meta_schedule/schedule_rule/add_rfactor.cc | 122 +++
src/meta_schedule/utils.h | 20 +
src/target/target_kind.cc | 1 +
src/tir/schedule/analysis.h | 38 +
src/tir/schedule/analysis/analysis.cc | 186 +++++
src/tir/schedule/utils.h | 54 ++
...test_meta_schedule_schedule_rule_add_rfactor.py | 80 ++
12 files changed, 1446 insertions(+)
diff --git a/include/tvm/meta_schedule/schedule_rule.h
b/include/tvm/meta_schedule/schedule_rule.h
index 6ee3947..95fce13 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -153,6 +153,16 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Map<String,
ObjectRef>> reuse_read, //
Optional<Map<String,
ObjectRef>> reuse_write);
/*!
+ * \brief Create a rule: add-rfactor to some blocks if needed
+ * \param max_jobs_per_core The maximum number of jobs to be launched per
CPU core. It sets the
+ * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1
to disable
+ * parallelism.
+ * \param max_innermost_factor The maximum size of the innermost factor.
NullOpt means no limit
+ * \return The schedule rule created
+ */
+ TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
+ Optional<Integer>
max_innermost_factor);
+ /*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
*/
diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py
b/python/tvm/meta_schedule/schedule_rule/__init__.py
index 9ad3c06..475c43a 100644
--- a/python/tvm/meta_schedule/schedule_rule/__init__.py
+++ b/python/tvm/meta_schedule/schedule_rule/__init__.py
@@ -16,6 +16,7 @@ The tvm.meta_schedule.schedule_rule package.
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
+from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .schedule_rule import PyScheduleRule, ScheduleRule
from .random_compute_location import RandomComputeLocation
diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py
b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py
new file mode 100644
index 0000000..72f9fc9
--- /dev/null
+++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add-rfactor Rule that add-rfactor to some blocks if needed"""
+from typing import Optional
+
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .schedule_rule import ScheduleRule
+
+
+@register_object("meta_schedule.AddRFactor")
+class AddRFactor(ScheduleRule):
+ """Rules for add-rfactor to some blocks if needed.
+
+ Parameters
+ ----------
+ max_jobs_per_core: int
+ The maximum number of jobs to be launched per CPU core. It sets the
uplimit of CPU
+ parallelism, i.e. `num_cores * max_jobs_per_core`.
+ Use -1 to disable parallelism.
+ max_innermost_factor: Optional[int] = None
+ The maximum size of the innermost factor. None means no limit.
+ """
+
+ def __init__(
+ self,
+ max_jobs_per_core: int = 16,
+ max_innermost_factor: Optional[int] = None,
+ ) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint:
disable=no-member
+ max_jobs_per_core,
+ max_innermost_factor,
+ )
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py
b/python/tvm/meta_schedule/testing/schedule_rule.py
index e69be13..020869d 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -16,6 +16,7 @@
# under the License.
"""Default schedule rules"""
from tvm.meta_schedule.schedule_rule import (
+ AddRFactor,
AutoInline,
ScheduleRule,
)
@@ -45,3 +46,10 @@ def auto_inline(target: Target) -> ScheduleRule:
disallow_op=None,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
+
+
+def add_rfactor(target: Target) -> ScheduleRule:
+ """Default schedule rules for with add_rfactor"""
+ if target.kind.name == "llvm":
+ return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
+ raise NotImplementedError(f"{target.kind.name} is not supported")
diff --git a/python/tvm/meta_schedule/testing/te_workload.py
b/python/tvm/meta_schedule/testing/te_workload.py
new file mode 100644
index 0000000..49a60a2
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/te_workload.py
@@ -0,0 +1,877 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Workloads in TE"""
+# pylint: disable=missing-docstring
+from typing import Tuple
+
+from tvm import te, tir, topi
+
+
+def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring
+ B: int,
+ N: int,
+ M: int,
+ K: int,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ x = te.placeholder((B, N, K), name="X")
+ y = te.placeholder((B, K, M), name="Y")
+ k = te.reduce_axis((0, K), name="k")
+ z = te.compute( # pylint: disable=invalid-name
+ (B, N, M),
+ lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]),
+ name="Z",
+ )
+ return (x, y, z)
+
+
+def conv1d_nlc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ L: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ inputs = te.placeholder((N, L, CI), name="inputs")
+ weight = te.placeholder((kernel_size, CI // groups, CO), name="weight")
+
+ batch_size, in_len, _ = inputs.shape
+ k_len, channel_per_group, out_channel = weight.shape
+ out_channel_per_group = out_channel // groups
+ out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1
+ rc = te.reduce_axis((0, channel_per_group), name="rc")
+ rl = te.reduce_axis((0, k_len), name="rl")
+
+ padded = topi.nn.pad(inputs, [0, padding, 0])
+ output = te.compute(
+ (batch_size, out_len, out_channel),
+ lambda n, l, co: te.sum(
+ (
+ padded[
+ n,
+ l * stride + rl * dilation,
+ co // out_channel_per_group * channel_per_group + rc,
+ ]
+ * weight[rl, rc, co]
+ ),
+ axis=[rl, rc],
+ ),
+ name="conv1d_nlc",
+ )
+ return (inputs, weight, output)
+
+
+def conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ inputs = te.placeholder((N, H, W, CI), name="inputs")
+ weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO),
name="weight")
+ batch_size, in_h, in_w, _ = inputs.shape
+ k_h, k_w, channel_per_group, out_channel = weight.shape
+ out_channel_per_group = out_channel // groups
+
+ out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+ out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+ rh = te.reduce_axis((0, k_h), name="rh")
+ rw = te.reduce_axis((0, k_w), name="rw")
+ rc = te.reduce_axis((0, channel_per_group), name="rc")
+
+ padded = topi.nn.pad(inputs, [0, padding, padding, 0])
+ output = te.compute(
+ (batch_size, out_h, out_w, out_channel),
+ lambda n, h, w, co: te.sum(
+ (
+ padded[
+ n,
+ h * stride + rh * dilation,
+ w * stride + rw * dilation,
+ co // out_channel_per_group * channel_per_group + rc,
+ ]
+ * weight[rh, rw, rc, co]
+ ),
+ axis=[rh, rw, rc],
+ ),
+ name="conv2d_nhwc",
+ )
+ return (inputs, weight, output)
+
+
+def conv3d_ndhwc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ D: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ inputs = te.placeholder((N, D, H, W, CI), name="inputs")
+ weight = te.placeholder(
+ (kernel_size, kernel_size, kernel_size, CI // groups, CO),
name="weight"
+ )
+ batch_size, in_d, in_h, in_w, _ = inputs.shape
+ k_d, k_h, k_w, channel_per_group, out_channel = weight.shape
+ out_channel_per_group = out_channel // groups
+
+ out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1
+ out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+ out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+ rd = te.reduce_axis((0, k_d), name="rd")
+ rh = te.reduce_axis((0, k_h), name="rh")
+ rw = te.reduce_axis((0, k_w), name="rw")
+ rc = te.reduce_axis((0, channel_per_group), name="rc")
+
+ padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0])
+ output = te.compute(
+ (batch_size, out_d, out_h, out_w, out_channel),
+ lambda n, d, h, w, co: te.sum(
+ (
+ padded[
+ n,
+ d * stride + rd * dilation,
+ h * stride + rh * dilation,
+ w * stride + rw * dilation,
+ co // out_channel_per_group * channel_per_group + rc,
+ ]
+ * weight[rd, rh, rw, rc, co]
+ ),
+ axis=[rd, rh, rw, rc],
+ ),
+ name="conv3d_ndhwc",
+ )
+ return (inputs, weight, output)
+
+
+def depthwise_conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ C: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ factor: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ inputs = te.placeholder((N, H, W, C))
+ weight = te.placeholder((factor, kernel_size, kernel_size, C))
+ batch_size, in_h, in_w, in_channel = inputs.shape
+ factor, k_h, k_w, in_channel = weight.shape
+ out_channel = in_channel * factor
+ assert int(factor) == 1, "Not optimized for factor != 1"
+ out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+ out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+ rh = te.reduce_axis((0, k_h), name="rh")
+ rw = te.reduce_axis((0, k_w), name="rw")
+ padded = topi.nn.pad(inputs, [0, padding, padding, 0])
+ output = te.compute(
+ (batch_size, out_h, out_w, out_channel),
+ lambda n, h, w, c: te.sum(
+ (
+ padded[
+ n,
+ h * stride + rh * dilation,
+ w * stride + rw * dilation,
+ c // factor,
+ ]
+ * weight[c % factor, rh, rw, c // factor]
+ ),
+ axis=[rh, rw],
+ ),
+ name="depth_conv2d_nhwc",
+ )
+ return (inputs, weight, output)
+
+
+def conv2d_transpose_nhwc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ inputs = te.placeholder((N, H, W, CI), name="inputs")
+ weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight")
+
+ batch, in_h, in_w, in_c = inputs.shape
+ filter_h, filter_w, in_c, out_c = weight.shape
+ stride_h, stride_w = (stride, stride)
+
+ # compute padding
+ fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(
+ padding, (filter_h, filter_w)
+ )
+ bpad_top = filter_h - 1 - fpad_top
+ bpad_bottom = filter_h - 1 - fpad_bottom
+ bpad_left = filter_w - 1 - fpad_left
+ bpad_right = filter_w - 1 - fpad_right
+
+ # padding stage
+ padded = topi.nn.pad(
+ inputs,
+ [
+ 0,
+ (bpad_top + stride_h - 1) // stride_h,
+ (bpad_left + stride_w - 1) // stride_w,
+ 0,
+ ],
+ [
+ 0,
+ (bpad_bottom + stride_h - 1) // stride_h,
+ (bpad_right + stride_w - 1) // stride_w,
+ 0,
+ ],
+ )
+
+ # remove extra padding introduced by dilatation
+ idx_div = te.indexdiv
+ idx_mod = te.indexmod
+ border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h)
+ border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w)
+
+ # dilation stage
+ strides = [1, stride_h, stride_w, 1]
+ n = len(padded.shape)
+
+ # We should embed this dilation directly into te.compute rather than
creating a new te.compute.
+ # Only in this way can we use unroll to eliminate the multiplication of
zeros.
+ def _dilate(*indices):
+ not_zero = []
+ index_tuple = []
+ for i in range(n):
+ if not strides[i] == 1:
+ index_tuple.append(idx_div(indices[i], strides[i]))
+ not_zero.append(idx_mod(indices[i], strides[i]).equal(0))
+ else:
+ index_tuple.append(indices[i])
+ if not_zero:
+ not_zero = te.all(*not_zero)
+ return te.if_then_else(not_zero, padded(*index_tuple),
tir.const(0.0, padded.dtype))
+ return padded(*index_tuple)
+
+ # convolution stage
+ out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
+ out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
+ rc = te.reduce_axis((0, in_c), name="rc")
+ rh = te.reduce_axis((0, filter_h), name="rh")
+ rw = te.reduce_axis((0, filter_w), name="rw")
+
+ output = te.compute(
+ (batch, out_h, out_w, out_c),
+ lambda n, h, w, co: te.sum(
+ _dilate(n, h + rh + border_h, w + rw + border_w, rc)
+ * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co],
+ axis=[rh, rw, rc],
+ ),
+ name="conv2d_transpose_nhwc",
+ )
+ return (inputs, weight, output)
+
+
+def conv2d_capsule_nhwijc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ capsule_size: int = 4,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI),
name="inputs")
+ weight = te.placeholder(
+ (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO),
name="weight"
+ )
+ batch_size, in_h, in_w, _, _, in_channel = inputs.shape
+ k_h, k_w, _, _, _, out_channel = weight.shape
+
+ out_h = (in_h + 2 * padding - kernel_size) // stride + 1
+ out_w = (in_w + 2 * padding - kernel_size) // stride + 1
+
+ rh = te.reduce_axis((0, k_h), name="rh")
+ rw = te.reduce_axis((0, k_w), name="rw")
+ cap_k = te.reduce_axis((0, capsule_size), name="cap_k")
+ rc = te.reduce_axis((0, in_channel), name="rc")
+
+ padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0])
+ output = te.compute(
+ (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel),
+ lambda n, h, w, cap_i, cap_j, co: te.sum(
+ (
+ padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc]
+ * weight[rh, rw, cap_k, cap_j, rc, co]
+ ),
+ axis=[rh, rw, cap_k, rc],
+ ),
+ name="conv2d_capsule_nhwijc",
+ )
+ return (inputs, weight, output)
+
+
+def norm_bmn( # pylint: disable=invalid-name,missing-docstring
+ B: int,
+ M: int,
+ N: int,
+) -> Tuple[te.Tensor, te.Tensor]:
+ a = te.placeholder((B, M, N), name="A")
+ i = te.reduce_axis((0, M), name="i")
+ j = te.reduce_axis((0, N), name="j")
+ c = te.compute(
+ (B,),
+ lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]),
+ name="C",
+ )
+ d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D")
+ return (a, d)
+
+
+def conv2d_nhwc_without_layout_rewrite( # pylint: disable=invalid-name
+ Input: int,
+ Filter: int,
+ stride: int,
+ padding: int,
+ dilation: int,
+ out_dtype="float32",
+):
+ """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute.
+ We use this in single op and subgraph evaluation
+ because we don't want to introduce graph level optimization.
+ """
+ assert isinstance(stride, int) or len(stride) == 2
+ assert isinstance(dilation, int) or len(dilation) == 2
+
+ if isinstance(stride, int):
+ stride_h = stride_w = stride
+ else:
+ stride_h, stride_w = stride
+
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ batch, in_height, in_width, in_channel = Input.shape # type: ignore
+ kernel_h, kernel_w, _channel, num_filter = Filter.shape # type: ignore
+
+ # compute the output shape
+ dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+ dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+ pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple(
+ padding, (dilated_kernel_h, dilated_kernel_w)
+ )
+ out_channel = num_filter
+ out_height = topi.utils.simplify(
+ (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1
+ )
+ out_width = topi.utils.simplify(
+ (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1
+ )
+ pad_before = [0, pad_top, pad_left, 0]
+ pad_after = [0, pad_down, pad_right, 0]
+ PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput")
+ rc = te.reduce_axis((0, in_channel), name="rc")
+ ry = te.reduce_axis((0, kernel_h), name="ry")
+ rx = te.reduce_axis((0, kernel_w), name="rx")
+ Output = te.compute(
+ (batch, out_height, out_width, out_channel),
+ lambda nn, yy, xx, ff: te.sum(
+ PaddedInput[
+ nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx *
dilation_w, rc
+ ].astype(out_dtype)
+ * Filter[ry, rx, rc, ff].astype(out_dtype), # type: ignore
+ axis=[ry, rx, rc],
+ ),
+ name="Conv2dOutput",
+ tag="conv2d_nhwc",
+ )
+ return Output
+
+
+def conv2d_nhwc_bn_relu( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ strides: int,
+ padding: int,
+ dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]:
+ data = te.placeholder((N, H, W, CI), name="data")
+ kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel")
+ bias = te.placeholder((CO,), name="bias")
+ bn_scale = te.placeholder((CO,), name="bn_scale")
+ bn_offset = te.placeholder((CO,), name="bn_offset")
+ OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
+ OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
+ conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding,
dilation)
+ conv = te.compute(
+ (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l],
name="bias_add"
+ )
+ conv = te.compute(
+ (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l],
name="bn_mul"
+ )
+ conv = te.compute(
+ (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l],
name="bn_add"
+ )
+ out = topi.nn.relu(conv)
+ return (data, kernel, bias, bn_offset, bn_scale, out)
+
+
+def transpose_batch_matmul( # pylint: disable=invalid-name,missing-docstring
+ batch: int,
+ seq_len: int,
+ n_head: int,
+ n_dim: int,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ query = te.placeholder((batch, seq_len, n_head, n_dim), name="query")
+ value = te.placeholder((batch, seq_len, n_head, n_dim), name="value")
+ query_T = te.compute(
+ (batch, n_head, seq_len, n_dim),
+ lambda b, h, l, d: query[b, l, h, d],
+ name="query_T",
+ )
+ value_T = te.compute(
+ (batch, n_head, n_dim, seq_len),
+ lambda b, h, d, l: value[b, l, h, d],
+ name="value_T",
+ )
+ k = te.reduce_axis((0, n_dim), name="k")
+ out = te.compute(
+ (batch, n_head, seq_len, seq_len),
+ lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j],
axis=[k]),
+ name="C",
+ )
+ return (query, value, out)
+
+
+def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring
+ N: int,
+ H: int,
+ W: int,
+ CI: int,
+ CO: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ tile_size = 4 # _infer_tile_size(data, kernel)
+ inputs = te.placeholder((N, H, W, CI), name="inputs")
+ N, H, W, CI = topi.utils.get_const_tuple(inputs.shape)
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation"
+
+ KH = KW = kernel_size
+ HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW))
+ HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride
+ assert HSTR == 1 and WSTR == 1 and KH == KW
+
+ data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0),
name="data_pad")
+
+ r = KW
+ m = tile_size
+ alpha = m + r - 1
+ A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r,
"float32")
+
+ H = (H + 2 * HPAD - KH) // HSTR + 1
+ W = (W + 2 * WPAD - KW) // WSTR + 1
+ nH, nW = (H + m - 1) // m, (W + m - 1) // m
+ P = N * nH * nW
+ _rkh = te.reduce_axis((0, KH), name="r_kh")
+ _rkw = te.reduce_axis((0, KW), name="r_kw")
+ kshape = (alpha, alpha, CI, CO)
+ kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")
+
+ idxdiv = te.indexdiv
+ idxmod = te.indexmod
+ # pack input tile
+ input_tile = te.compute(
+ (alpha, alpha, P, CI),
+ lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p,
nW), nH) * m + eps][
+ idxmod(p, nW) * m + nu
+ ][ci],
+ name="input_tile",
+ )
+
+ # transform data
+ r_a = te.reduce_axis((0, alpha), "r_a")
+ r_b = te.reduce_axis((0, alpha), "r_b")
+ data_pack = te.compute(
+ (alpha, alpha, P, CI),
+ lambda eps, nu, p, ci: te.sum(
+ input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a,
r_b]
+ ),
+ name="data_pack",
+ attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu",
"r_a", "r_b"]},
+ )
+
+ # do batch gemm
+ ci = te.reduce_axis((0, CI), name="ci")
+ bgemm = te.compute(
+ (alpha, alpha, P, CO),
+ lambda eps, nu, p, co: te.sum(
+ data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci]
+ ),
+ name="bgemm",
+ )
+
+ # inverse transform
+ r_a = te.reduce_axis((0, alpha), "r_a")
+ r_b = te.reduce_axis((0, alpha), "r_b")
+ inverse = te.compute(
+ (m, m, P, CO),
+ lambda vh, vw, p, co: te.sum(
+ bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
+ ),
+ name="inverse",
+ attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw",
"r_a", "r_b"]},
+ )
+
+ # output
+ output = te.compute(
+ (N, H, W, CO),
+ lambda n, h, w, co: inverse[
+ idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW +
idxdiv(w, m), co
+ ],
+ name="conv2d_winograd",
+ )
+
+ return (inputs, kernel_pack, output)
+
+
+def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ a = te.placeholder((n, k), name="A")
+ b = te.placeholder((k, m), name="B")
+ k = te.reduce_axis((0, k), name="k")
+ c = te.compute(
+ (n, m),
+ lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]),
+ name="C",
+ )
+ return (a, b, c)
+
+
+def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor,
te.Tensor]:
+ a = te.placeholder((n, k), name="A", dtype="float16")
+ b = te.placeholder((k, m), name="B", dtype="float16")
+ k = te.reduce_axis((0, k), name="k")
+
+ def f_compute(i, j):
+ v_a = tir.Cast(dtype="float32", value=a[i, k])
+ v_b = tir.Cast(dtype="float32", value=b[k, j])
+ return te.sum(v_a * v_b, axis=[k])
+
+ c = te.compute((n, m), f_compute, name="C")
+ return (a, b, c)
+
+
+def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor,
te.Tensor]:
+ a = te.placeholder((n, k), name="A")
+ b = te.placeholder((m, k), name="B")
+ k = te.reduce_axis((0, k), name="k")
+ c = te.compute(
+ (n, m),
+ lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]),
+ name="C",
+ )
+ d = topi.nn.relu(c) # pylint: disable=invalid-name
+ return (a, b, d)
+
+
+def matmul_relu_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor,
te.Tensor]:
+ a = te.placeholder((n, k), name="A", dtype="float16")
+ b = te.placeholder((k, m), name="B", dtype="float16")
+ k = te.reduce_axis((0, k), name="k")
+
+ def f_compute(i, j):
+ v_a = tir.Cast(dtype="float32", value=a[i, k])
+ v_b = tir.Cast(dtype="float32", value=b[k, j])
+ return te.sum(v_a * v_b, axis=[k])
+
+ c = te.compute((n, m), f_compute, name="C")
+ d = topi.nn.relu(c) # pylint: disable=invalid-name
+ return (a, b, d)
+
+
+def conv2d_nchw( # pylint: disable=invalid-name
+ n: int,
+ h: int,
+ w: int,
+ ci: int,
+ co: int,
+ kh: int,
+ kw: int,
+ stride: int,
+ padding: int,
+ dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+ x = te.placeholder((n, ci, h, w), name="X")
+ w = te.placeholder((co, ci, kh, kw), name="W")
+ y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding,
dilation=dilation)
+ return (x, w, y)
+
+
+def conv2d_nchw_bias_bn_relu( # pylint: disable=invalid-name
+ n: int,
+ h: int,
+ w: int,
+ ci: int,
+ co: int,
+ kh: int,
+ kw: int,
+ stride: int,
+ padding: int,
+ dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]:
+ oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1 # pylint:
disable=invalid-name
+ ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1 # pylint:
disable=invalid-name
+ x = te.placeholder((n, ci, h, w), name="X")
+ w = te.placeholder((co, ci, kh, kw), name="W")
+ b = te.placeholder((co, 1, 1), name="B")
+ bn_scale = te.placeholder((co, 1, 1), name="bn_scale")
+ bn_offset = te.placeholder((co, 1, 1), name="bn_offset")
+ y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding,
dilation=dilation)
+ y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0,
0], name="bias_add")
+ y = te.compute(
+ (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0],
name="bn_mul"
+ )
+ y = te.compute(
+ (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0,
0], name="bn_add"
+ )
+ y = topi.nn.relu(y)
+ return (x, w, b, bn_scale, bn_offset, y)
+
+
+def max_pool2d_nchw( # pylint: disable=invalid-name
+ n: int,
+ h: int,
+ w: int,
+ ci: int,
+ padding: int,
+) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name
+ x = te.placeholder((n, ci, h, w), name="X")
+ y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding,
padding], "max")
+ return (x, y)
+
+
+def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint:
disable=invalid-name
+ a = te.placeholder((m, n), name="A")
+ b = topi.nn.softmax(a, axis=1)
+
+ return (a, b)
+
+
+def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
+ workload_func, params = CONFIGS[name]
+ return te.create_prim_func(workload_func(*params[idx])) # type: ignore
+
+
+CONFIGS = {
+ "C1D": (
+ conv1d_nlc,
+ [
+ # derived from conv2d_shapes
+ (1, 256, 64, 128, 3, 2, 1),
+ # (1, 256, 64, 128, 1, 2, 0),
+ # (1, 256, 64, 64, 1, 1, 0),
+ # (1, 128, 128, 256, 3, 2, 1),
+ (1, 128, 128, 256, 1, 2, 0),
+ # (1, 128, 128, 128, 3, 1, 1),
+ # (1, 64, 256, 512, 3, 2, 1),
+ # (1, 64, 256, 512, 1, 2, 0),
+ (1, 64, 256, 256, 5, 1, 2),
+ (1, 32, 512, 512, 3, 1, 1),
+ ],
+ ),
+ "C2D": (
+ conv2d_nhwc,
+ [
+ # all conv2d layers in resnet-18
+ (1, 224, 224, 3, 64, 7, 2, 3),
+ # (1, 56, 56, 64, 128, 3, 2, 1),
+ # (1, 56, 56, 64, 128, 1, 2, 0),
+ # (1, 56, 56, 64, 64, 3, 1, 1),
+ (1, 56, 56, 64, 64, 1, 1, 0),
+ # (1, 28, 28, 128, 256, 3, 2, 1),
+ # (1, 28, 28, 128, 256, 1, 2, 0),
+ # (1, 28, 28, 128, 128, 3, 1, 1),
+ # (1, 14, 14, 256, 512, 3, 2, 1),
+ # (1, 14, 14, 256, 512, 1, 2, 0),
+ (1, 14, 14, 256, 256, 3, 1, 1),
+ (1, 7, 7, 512, 512, 3, 1, 1),
+ ],
+ ),
+ "C3D": (
+ conv3d_ndhwc,
+ [
+ # Derived from conv2d_shapes. Use depth=16 for all configurations
+ (1, 16, 224, 224, 3, 64, 7, 2, 3),
+ # (1, 16, 56, 56, 64, 128, 3, 2, 1),
+ # (1, 16, 56, 56, 64, 128, 1, 2, 0),
+ # (1, 16, 56, 56, 64, 64, 3, 1, 1),
+ (1, 16, 56, 56, 64, 64, 1, 1, 0),
+ # (1, 16, 28, 28, 128, 256, 3, 2, 1),
+ # (1, 16, 28, 28, 128, 256, 1, 2, 0),
+ # (1, 16, 28, 28, 128, 128, 3, 1, 1),
+ # (1, 16, 14, 14, 256, 512, 3, 2, 1),
+ # (1, 16, 14, 14, 256, 512, 1, 2, 0),
+ (1, 16, 14, 14, 256, 256, 3, 1, 1),
+ (1, 16, 7, 7, 512, 512, 3, 1, 1),
+ ],
+ ),
+ "GMM": (
+ batch_matmul_nkkm,
+ [
+ (1, 128, 128, 128),
+ (1, 512, 32, 512),
+ (1, 512, 512, 512),
+ (1, 1024, 1024, 1024),
+ ],
+ ),
+ "GRP": (
+ conv2d_nhwc,
+ [
+ # Derived from conv2d_shapes. Use group=4 for all configurations
+ (1, 56, 56, 64, 128, 3, 2, 1, 1, 4),
+ # (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4),
+ # (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4),
+ (1, 56, 56, 64, 64, 1, 1, 0, 1, 4),
+ # (1, 28, 28, 128, 256, 3, 2, 1, 1, 4),
+ # (1, 28, 28, 128, 256, 1, 2, 0, 1, 4),
+ # (1, 28, 28, 128, 128, 3, 1, 1, 1, 4),
+ # (1, 14, 14, 256, 512, 3, 2, 1, 1, 4),
+ # (1, 14, 14, 256, 512, 1, 2, 0, 1, 4),
+ (1, 14, 14, 256, 256, 3, 1, 1, 1, 4),
+ (1, 7, 7, 512, 512, 3, 1, 1, 1, 4),
+ ],
+ ),
+ "DIL": (
+ conv2d_nhwc,
+ [
+ # Derived from conv2d_shapes. Use dilation=2 for all configurations
+ (1, 224, 224, 3, 64, 7, 2, 3, 2),
+ # (1, 56, 56, 64, 128, 3, 2, 1 , 2),
+ # (1, 56, 56, 64, 128, 1, 2, 0 , 2),
+ # (1, 56, 56, 64, 64, 3, 1, 1 , 2),
+ (1, 56, 56, 64, 64, 1, 1, 0, 2),
+ # (1, 28, 28, 128, 256, 3, 2, 1, 2),
+ # (1, 28, 28, 128, 256, 1, 2, 0, 2),
+ # (1, 28, 28, 128, 128, 3, 1, 1, 2),
+ # (1, 14, 14, 256, 512, 3, 2, 1, 2),
+ # (1, 14, 14, 256, 512, 1, 2, 0, 2),
+ (1, 14, 14, 256, 256, 3, 1, 1, 2),
+ (1, 7, 7, 512, 512, 3, 1, 1, 2),
+ ],
+ ),
+ "DEP": (
+ depthwise_conv2d_nhwc,
+ [
+ # all depthwise conv2d layers in mobilenet
+ (1, 112, 112, 32, 3, 1, 1),
+ (1, 112, 112, 64, 3, 2, 1),
+ # (1, 56, 56, 128, 3, 1, 1),
+ # (1, 56, 56, 128, 3, 2, 1),
+ # (1, 28, 28, 256, 3, 1, 1),
+ # (1, 28, 28, 256, 3, 2, 1),
+ # (1, 14, 14, 512, 3, 1, 1),
+ (1, 14, 14, 512, 3, 2, 1),
+ (1, 7, 7, 1024, 3, 1, 1),
+ ],
+ ),
+ "T2D": (
+ conv2d_transpose_nhwc,
+ [
+ # all conv2d tranpose layers in DCGAN
+ (1, 4, 4, 512, 256, 4, 2, 1),
+ (1, 8, 8, 256, 128, 4, 2, 1),
+ (1, 16, 16, 128, 64, 4, 2, 1),
+ (1, 32, 32, 64, 3, 4, 2, 1),
+ ],
+ ),
+ "CAP": (
+ conv2d_capsule_nhwijc,
+ [
+ # all conv2d capsule layers in matrix capsules withemrouting (ICLR
2018)
+ (1, 16, 16, 32, 32, 3, 2, 1),
+ (1, 8, 8, 32, 32, 3, 1, 1),
+ (1, 16, 16, 8, 16, 3, 2, 1),
+ (1, 8, 8, 16, 16, 3, 1, 1),
+ ],
+ ),
+ "NRM": (
+ norm_bmn,
+ [
+ (1, 256, 256),
+ (1, 512, 512),
+ (1, 1024, 1024),
+ (1, 4096, 1024),
+ ],
+ ),
+ "SFM": (
+ softmax_mn,
+ [
+ (256, 256),
+ (512, 512),
+ (1024, 1024),
+ (2048, 2048),
+ ],
+ ),
+ "C2d-BN-RELU": (
+ conv2d_nhwc_bn_relu,
+ [
+ (1, 224, 224, 3, 64, 7, 2, 3),
+ (1, 56, 56, 64, 128, 3, 2, 1),
+ (1, 28, 28, 128, 256, 1, 2, 0),
+ (1, 7, 7, 512, 512, 3, 1, 1),
+ ],
+ ),
+ "TBG": (
+ transpose_batch_matmul,
+ [
+ (1, 128, 12, 64),
+ (1, 128, 16, 64),
+ (1, 64, 12, 128),
+ (1, 128, 12, 128),
+ ],
+ ),
+}
diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc
b/src/meta_schedule/schedule_rule/add_rfactor.cc
new file mode 100644
index 0000000..5ef2ac3
--- /dev/null
+++ b/src/meta_schedule/schedule_rule/add_rfactor.cc
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class AddRFactorNode : public ScheduleRuleNode {
+ public:
+ // Inherited from ScheduleRuleNode
+ void InitializeWithTuneContext(const TuneContext& context) final {
+ ICHECK(context->target.defined());
+ Target target = context->target.value();
+ this->max_parallel_basic_ = GetTargetNumCores(target);
+ if (this->max_jobs_per_core != -1) {
+ this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core;
+ }
+ }
+
+ // Inherited from ScheduleRuleNode
+ Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV&
block_rv);
+
+ public:
+ /*!
+ * \brief The maximum number of jobs to be launched per core.
+ * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`.
+ * Use -1 to disable parallelism.
+ */
+ int max_jobs_per_core;
+ /*! \brief The maximum size of the innermost factor */
+ int max_innermost_factor;
+ /*! \brief The number of uplimit of parallelism. */
+ int max_parallel_extent_;
+ /*! \brief The number of cores. */
+ int max_parallel_basic_;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("max_jobs_per_core", &max_jobs_per_core);
+ v->Visit("max_innermost_factor", &max_innermost_factor);
+ // `max_parallel_extent_` is not visited
+ // `max_parallel_basic_` is not visited
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.AddRFactor";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
+ Optional<Integer> max_innermost_factor) {
+ ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
+ n->max_jobs_per_core = max_jobs_per_core;
+ n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
+ n->max_parallel_extent_ = -1;
+ n->max_parallel_basic_ = -1;
+ return ScheduleRule(n);
+}
+
+Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const
tir::BlockRV& block_rv) {
+ tir::StmtSRef block_sref = sch->GetSRef(block_rv);
+ if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref,
max_parallel_extent_,
+ max_parallel_basic_)) {
+ return {sch};
+ }
+
+ // Make a copy of the original schedule.
+ tir::Schedule ori_sch = sch->Copy();
+ ori_sch->Seed(sch->ForkSeed());
+
+ // Reorder the loop axes if reduction loops are not innermost.
+ // After the reordering, fuse all the reduction loops.
+ size_t num_spatial_loops;
+ tir::LoopRV fused_reduce_loop;
+ ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop,
&num_spatial_loops);
+
+ // Split the fused reduction loop.
+ Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2,
max_innermost_factor);
+ const Array<tir::LoopRV>& split_loops =
+ sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
+
+ Array<tir::Schedule> res;
+ for (const tir::LoopRV& split_loop : split_loops) {
+ tir::Schedule sch_tmp = sch->Copy();
+ sch_tmp->Seed(sch->ForkSeed());
+ try {
+ const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop,
num_spatial_loops);
+ Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
+ ICHECK_GT(axes.size(), num_spatial_loops);
+
+ // Annotate that the rfactor block, which is now the producer of the
original block, needs to
+ // be considered by the rule Random-Compute-Location.
+ sch_tmp->Annotate(block_rv,
tir::attr::meta_schedule_random_compute_producer, Bool(true));
+ res.push_back(sch_tmp);
+ } catch (const tvm::runtime::Error& e) {
+ }
+ }
+
+ res.push_back(ori_sch);
+ return res;
+}
+
+TVM_REGISTER_NODE_TYPE(AddRFactorNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor")
+ .set_body_typed(ScheduleRule::AddRFactor);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index ef15f49..5b49769 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -318,6 +318,26 @@ struct ThreadedTraceApply {
Item* items_;
};
+/*!
+ * \brief Get the number of cores in CPU
+ * \param target The target
+ * \return The number of cores.
+ */
+inline int GetTargetNumCores(const Target& target) {
+ int num_cores = target->GetAttr<Integer>("num-cores").value_or(-1);
+ if (num_cores == -1) {
+ static const auto* f_cpu_count =
runtime::Registry::Get("meta_schedule.cpu_count");
+ ICHECK(f_cpu_count)
+ << "ValueError: Cannot find the packed function
\"meta_schedule._cpu_count\"";
+ num_cores = (*f_cpu_count)(false);
+ LOG(FATAL)
+ << "Target does not have attribute \"num-cores\", physical core number
must be "
+ "defined! For example, on the local machine, the target must be
\"llvm -num-cores "
+ << num_cores << "\"";
+ }
+ return num_cores;
+}
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index e4bf48b..c562c78 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -254,6 +254,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<String>("mabi")
.add_attr_option<Bool>("system-lib")
.add_attr_option<String>("runtime")
+ .add_attr_option<Integer>("num-cores")
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<Bool>("unpacked-api")
.add_attr_option<String>("interface-api")
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 9622e2d..636cc7d 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -520,6 +520,44 @@ std::tuple</*exists=*/bool,
/*no_shift_read=*/bool>
AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion&
write_region);
+/*!
+ * \brief Check if the block is a data parallel block, i.e. all the block vars
are data parallel
+ * \param block_sref The block to be checked
+ * \return A boolean flag indicating if the block is a data parallel block
+ */
+bool IsSpatial(const StmtSRef& block_sref);
+
+/*!
+ * \brief Check whether a block has a trivial binding, i.e. each block var is
bound to a outer loop,
+ * from outer to inner.
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean flag indicating if the block has a trivial binding
+ */
+bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
+
+/*!
+ * \brief Checks if the given block has data reuse opportunity and thus
multi-level tiling is
+ * beneficial.
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean indicating whether the block has data reuse opportunity
+ */
+bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef&
block_sref);
+
+/*!
+ * \brief Checks if the rfactor or cross thread reduction is beneficial to the
given block.
+ * \param self The schedule state.
+ * \param block_sref The block to be checked.
+ * \param max_parallel_extent The maximum parallel jobs on the target.
+ * \param max_parallel_basic The maximum cores on the target.
+ * \return A boolean indicating whether the operation is beneficial.
+ */
+bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
+ const tir::StmtSRef& block_sref, //
+ int64_t max_parallel_extent, //
+ int64_t max_parallel_basic);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 0520973..2053f8d 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1661,5 +1661,191 @@ void CheckStorageScope(const ScheduleState& self,
String storage_scope) {
}
}
+bool IsSpatial(const StmtSRef& block_sref) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ for (const IterVar& iter_var : block->iter_vars) {
+ if (iter_var->iter_type != IterVarType::kDataPar) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ Array<StmtSRef> loops = GetLoops(block_sref);
+ Array<PrimExpr> binds = GetBlockRealize(self, block_sref)->iter_values;
+ if (loops.size() != binds.size()) {
+ return false;
+ }
+ for (int i = 0, n = loops.size(); i < n; ++i) {
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]);
+ if (binds[i].get() != loop->loop_var.get()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef&
block_sref) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ if (block->writes.size() != 1 || block->reads.empty() ||
IsSpatial(block_sref) ||
+ !IsTrivialBinding(self, block_sref)) {
+ return false;
+ }
+ const BufferNode* write_buffer = block->writes[0]->buffer.get();
+ // Step 1. Sort out spatial block variables
+ std::vector<const VarNode*> spatial_block_vars;
+ spatial_block_vars.reserve(block->iter_vars.size());
+ for (const IterVar& block_var : block->iter_vars) {
+ if (block_var->iter_type == IterVarType::kDataPar) {
+ spatial_block_vars.push_back(block_var->var.get());
+ }
+ }
+ // Step 2. Enumerate each read region, check the number of block vars that
are not used
+ // to index the read region
+ int total_unused_block_vars = 0;
+ std::unordered_set<const BufferNode*> read_buffers;
+ read_buffers.reserve(block->reads.size());
+ for (const BufferRegion& buffer_region : block->reads) {
+ const BufferNode* buffer = buffer_region->buffer.get();
+ const Array<Range>& regions = buffer_region->region;
+ // Step 2.1. Duplication of read buffers are not allowed
+ if (read_buffers.insert(buffer).second == false) {
+ return false;
+ }
+ // Step 2.2. Skip the reduction buffer
+ if (buffer == write_buffer) {
+ continue;
+ }
+ // Step 2.3. Collect the block vars that are used to index the read region
+ std::unordered_set<const VarNode*> vars;
+ for (const Range& range : regions) {
+ if (as_const_int(range->extent) == nullptr) {
+ return false;
+ }
+ for (const Var& var : UndefinedVars(range->min)) {
+ vars.insert(var.get());
+ }
+ }
+ // Step 2.4. Check if the block vars are not used to index the read region
+ int n_unused_block_vars = 0;
+ for (const VarNode* block_var : spatial_block_vars) {
+ if (vars.count(block_var) == 0) {
+ ++n_unused_block_vars;
+ }
+ }
+ total_unused_block_vars += n_unused_block_vars;
+ }
+ return total_unused_block_vars >= 1;
+}
+
+std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const
tir::ScheduleState& self,
+ const
tir::StmtSRef& block_sref) {
+ Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
+ int64_t cum_space_len = 1, cum_reduce_len = 1;
+ /*
+ * Return (-1, -1) if
+ * 1. there is some loop with type other than kDataPar and kCommReduce;
+ * 2. there is some loop which is dynamic.
+ */
+ for (const tir::StmtSRef& loop_sref : loops) {
+ tir::IterVarType type = GetLoopIterType(loop_sref);
+ if (type == tir::kDataPar) {
+ const int64_t* extent = GetLoopIntExtent(loop_sref);
+ if (*extent != -1) {
+ cum_space_len *= *extent;
+ } else {
+ return std::make_pair(-1, -1);
+ }
+ } else if (type == tir::kCommReduce) {
+ const int64_t* extent = GetLoopIntExtent(loop_sref);
+ if (*extent != -1) {
+ cum_reduce_len *= *extent;
+ } else {
+ return std::make_pair(-1, -1);
+ }
+ } else {
+ return std::make_pair(-1, -1);
+ }
+ }
+ return std::make_pair(cum_space_len, cum_reduce_len);
+}
+
+bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
+ const tir::StmtSRef& block_sref, //
+ int64_t max_parallel_extent, //
+ int64_t max_parallel_basic) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
+
+ // Cond 1. The block has only one write buffer
+ if (block->writes.size() != 1) {
+ return false;
+ }
+
+ // Cond 2. The block is a reduction block and has trivial binding.
+ const StmtSRef& scope_sref = GetScopeRoot(self, block_sref,
//
+ /*require_stage_pipeline=*/false,
//
+
/*require_subtree_compact_dataflow=*/false);
+ if (!(IsReductionBlock(self, block_sref, scope_sref) && //
+ IsTrivialBinding(self, block_sref))) {
+ return false;
+ }
+
+ // Cond 3. Every the loop axis must be either spatial axis or reduction axis.
+ for (const tir::StmtSRef& loop_sref : loops) {
+ const tir::IterVarType& type = GetLoopIterType(loop_sref);
+ if (type != tir::kDataPar && type != tir::kCommReduce) {
+ return false;
+ }
+ }
+
+ // Cond 4. Whether there is at least one reduction loop.
+ // Cond 5. The loops are continuous, and the body of the innermost loop is
exactly the block.
+ bool has_reduction_loop = false;
+ for (size_t i = 0; i < loops.size(); ++i) {
+ // Cond 4.
+ if (GetLoopIterType(loops[i]) == tir::kCommReduce) {
+ has_reduction_loop = true;
+ }
+
+ // Cond 5.
+ const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
+ if (i < loops.size() - 1) {
+ const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
+ if (loop_i->body.get() != loop_i1) {
+ return false;
+ }
+ } else {
+ const auto* block_realize = loop_i->body.as<tir::BlockRealizeNode>();
+ if (!block_realize || block_realize->block.get() != block) {
+ return false;
+ }
+ }
+ }
+ if (!has_reduction_loop) {
+ return false;
+ }
+
+ // Cond 6. Can successfully calculating the cumulative loop length.
+ int64_t cum_space_len, cum_reduce_len;
+ std::tie(cum_space_len, cum_reduce_len) =
GetCumulativeSpaceAndReductionLength(self, block_sref);
+ if (cum_space_len == -1 || cum_reduce_len == -1) {
+ return false;
+ }
+
+ // Cond 7.
+ if (NeedsMultiLevelTiling(self, block_sref)) {
+ // Do not use rfactor/cross-thread-reduction if we have enough parallelism
on spatial loops.
+ return !(cum_space_len >= cum_reduce_len || cum_space_len >
max_parallel_extent);
+ } else if (cum_reduce_len > 1) {
+ // Always try rfactor/cross-thread-reduction for other reduction blocks.
+ return cum_reduce_len > max_parallel_basic;
+ } else {
+ return false;
+ }
+}
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index 2acab38..be6d5a1 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -320,6 +320,60 @@ inline bool HasAnn(const StmtSRef& sref, const String&
ann_key, bool ann_val) {
return result.defined() && result.value()->value == ann_val;
}
+/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction
**********/
+
+/*!
+ * \brief Reorder the reduction loops to innermost positions if needed.
+ * \param sch The schedule
+ * \param block_rv The block where to apply the reorder
+ * \param fused_reduce_loop The fusion-generated loop to return.
+ * \param num_spatial_loops The number of spatial loops to return.
+ * \note Before invoking this helper function, make sure that the block has
only spatial and
+ * reduction loop axes.
+ */
+inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const
tir::BlockRV& block_rv,
+ tir::LoopRV* fused_reduce_loop,
+ size_t* num_spatial_loops) {
+ Array<tir::LoopRV> loops = sch->GetLoops(block_rv);
+ Array<tir::StmtSRef> loop_srefs;
+ for (const tir::LoopRV& loop_rv : loops) {
+ loop_srefs.push_back(sch->GetSRef(loop_rv));
+ }
+
+ Array<tir::LoopRV> new_order;
+ // Step 1. Add spatial loops.
+ *num_spatial_loops = 0;
+ for (size_t i = 0; i < loops.size(); ++i) {
+ if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) {
+ new_order.push_back(loops[i]);
+ (*num_spatial_loops)++;
+ }
+ }
+ // Step 2. Add reduction loops.
+ Array<tir::LoopRV> reduction_loops;
+ for (size_t i = 0; i < loops.size(); ++i) {
+ if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) {
+ new_order.push_back(loops[i]);
+ reduction_loops.push_back(loops[i]);
+ }
+ }
+ // Step 3. Apply reordering if new_order differs from the original order.
+ ICHECK_EQ(new_order.size(), loops.size());
+ for (size_t i = 0; i < loops.size(); ++i) {
+ if (!new_order[i].same_as(loops[i])) {
+ sch->Reorder(new_order);
+ break;
+ }
+ }
+ // Step 4. Fuse all the reduction loops if there are multiple reduction
loops.
+ CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one
reduction loop";
+ if (reduction_loops.size() > 1) {
+ *fused_reduce_loop = sch->Fuse(reduction_loops);
+ } else {
+ *fused_reduce_loop = reduction_loops[0];
+ }
+}
+
} // namespace tir
} // namespace tvm
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
new file mode 100644
index 0000000..5a80312
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+
+from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm.meta_schedule.testing import te_workload
+from tvm.meta_schedule.testing.schedule_rule import add_rfactor
+from tvm.meta_schedule.testing.space_generation import check_trace
+from tvm.meta_schedule.tune_context import TuneContext
+from tvm.target import Target
+from tvm.te.operation import create_prim_func
+
+
+def _create_context(mod, target, rule) -> TuneContext:
+ ctx = TuneContext(
+ mod=mod,
+ target=target,
+ space_generator=PostOrderApply(),
+ sch_rules=[rule],
+ task_name="test",
+ )
+ ctx.space_generator.initialize_with_tune_context(ctx)
+ for sch_rule in ctx.sch_rules:
+ sch_rule.initialize_with_tune_context(ctx)
+ return ctx
+
+
+def test_cpu_matmul():
+ expected = [
+ [],
+ [
+ 'b0 = sch.get_block(name="C", func_name="main")',
+ "l1, l2, l3 = sch.get_loops(block=b0)",
+ "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2,
max_innermost_factor=64)",
+ "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+ "b8 = sch.rfactor(loop=l7, factor_axis=2)",
+ 'sch.annotate(block_or_loop=b0,
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
+ ],
+ [
+ 'b0 = sch.get_block(name="C", func_name="main")',
+ "l1, l2, l3 = sch.get_loops(block=b0)",
+ "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2,
max_innermost_factor=64)",
+ "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+ "b8 = sch.rfactor(loop=l6, factor_axis=2)",
+ 'sch.annotate(block_or_loop=b0,
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
+ ],
+ ]
+ target = Target("llvm --num-cores=32")
+ ctx = _create_context(
+ create_prim_func(
+ te_workload.matmul(
+ n=4,
+ m=4,
+ k=512,
+ )
+ ),
+ target=target,
+ rule=add_rfactor(target=target),
+ )
+ spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+ assert len(spaces) == 3
+ check_trace(spaces, expected)
+
+
+if __name__ == "__main__":
+ test_cpu_matmul()