This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f9171f1  [MetaSchedule] Schedule Rule: Add RFactor (#9975)
f9171f1 is described below

commit f9171f16e657d30fde3a366388b4fef837e5187f
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Jan 20 03:48:02 2022 +0800

    [MetaSchedule] Schedule Rule: Add RFactor (#9975)
    
    * add rfactor
    
    * format
    
    * fix ci
---
 include/tvm/meta_schedule/schedule_rule.h          |  10 +
 python/tvm/meta_schedule/schedule_rule/__init__.py |   1 +
 .../tvm/meta_schedule/schedule_rule/add_rfactor.py |  49 ++
 python/tvm/meta_schedule/testing/schedule_rule.py  |   8 +
 python/tvm/meta_schedule/testing/te_workload.py    | 877 +++++++++++++++++++++
 src/meta_schedule/schedule_rule/add_rfactor.cc     | 122 +++
 src/meta_schedule/utils.h                          |  20 +
 src/target/target_kind.cc                          |   1 +
 src/tir/schedule/analysis.h                        |  38 +
 src/tir/schedule/analysis/analysis.cc              | 186 +++++
 src/tir/schedule/utils.h                           |  54 ++
 ...test_meta_schedule_schedule_rule_add_rfactor.py |  80 ++
 12 files changed, 1446 insertions(+)

diff --git a/include/tvm/meta_schedule/schedule_rule.h 
b/include/tvm/meta_schedule/schedule_rule.h
index 6ee3947..95fce13 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -153,6 +153,16 @@ class ScheduleRule : public runtime::ObjectRef {
                                                Optional<Map<String, 
ObjectRef>> reuse_read,  //
                                                Optional<Map<String, 
ObjectRef>> reuse_write);
   /*!
+   * \brief Create a rule: add-rfactor to some blocks if needed
+   * \param max_jobs_per_core The maximum number of jobs to be launched per 
CPU core. It sets the
+   * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 
to disable
+   * parallelism.
+   * \param max_innermost_factor The maximum size of the innermost factor. 
NullOpt means no limit
+   * \return The schedule rule created
+   */
+  TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core,  //
+                                         Optional<Integer> 
max_innermost_factor);
+  /*!
    * \brief A rule that randomly select a compute-at location for a free block
    * \return The rule created
    */
diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py 
b/python/tvm/meta_schedule/schedule_rule/__init__.py
index 9ad3c06..475c43a 100644
--- a/python/tvm/meta_schedule/schedule_rule/__init__.py
+++ b/python/tvm/meta_schedule/schedule_rule/__init__.py
@@ -16,6 +16,7 @@ The tvm.meta_schedule.schedule_rule package.
 Meta Schedule schedule rules are used for modification of
 blocks in a schedule. See also PostOrderApply.
 """
+from .add_rfactor import AddRFactor
 from .auto_inline import AutoInline
 from .schedule_rule import PyScheduleRule, ScheduleRule
 from .random_compute_location import RandomComputeLocation
diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py 
b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py
new file mode 100644
index 0000000..72f9fc9
--- /dev/null
+++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add-rfactor Rule that add-rfactor to some blocks if needed"""
+from typing import Optional
+
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .schedule_rule import ScheduleRule
+
+
+@register_object("meta_schedule.AddRFactor")
+class AddRFactor(ScheduleRule):
+    """Rules for add-rfactor to some blocks if needed.
+
+    Parameters
+    ----------
+    max_jobs_per_core: int
+        The maximum number of jobs to be launched per CPU core. It sets the 
uplimit of CPU
+        parallelism, i.e. `num_cores * max_jobs_per_core`.
+        Use -1 to disable parallelism.
+    max_innermost_factor: Optional[int] = None
+        The maximum size of the innermost factor. None means no limit.
+    """
+
+    def __init__(
+        self,
+        max_jobs_per_core: int = 16,
+        max_innermost_factor: Optional[int] = None,
+    ) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.ScheduleRuleAddRFactor,  # type: ignore # pylint: 
disable=no-member
+            max_jobs_per_core,
+            max_innermost_factor,
+        )
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py 
b/python/tvm/meta_schedule/testing/schedule_rule.py
index e69be13..020869d 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -16,6 +16,7 @@
 # under the License.
 """Default schedule rules"""
 from tvm.meta_schedule.schedule_rule import (
+    AddRFactor,
     AutoInline,
     ScheduleRule,
 )
@@ -45,3 +46,10 @@ def auto_inline(target: Target) -> ScheduleRule:
             disallow_op=None,
         )
     raise NotImplementedError(f"{target.kind.name} is not supported")
+
+
+def add_rfactor(target: Target) -> ScheduleRule:
+    """Default schedule rules for with add_rfactor"""
+    if target.kind.name == "llvm":
+        return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
+    raise NotImplementedError(f"{target.kind.name} is not supported")
diff --git a/python/tvm/meta_schedule/testing/te_workload.py 
b/python/tvm/meta_schedule/testing/te_workload.py
new file mode 100644
index 0000000..49a60a2
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/te_workload.py
@@ -0,0 +1,877 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Workloads in TE"""
+# pylint: disable=missing-docstring
+from typing import Tuple
+
+from tvm import te, tir, topi
+
+
+def batch_matmul_nkkm(  # pylint: disable=invalid-name,missing-docstring
+    B: int,
+    N: int,
+    M: int,
+    K: int,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    x = te.placeholder((B, N, K), name="X")
+    y = te.placeholder((B, K, M), name="Y")
+    k = te.reduce_axis((0, K), name="k")
+    z = te.compute(  # pylint: disable=invalid-name
+        (B, N, M),
+        lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]),
+        name="Z",
+    )
+    return (x, y, z)
+
+
+def conv1d_nlc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    L: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    dilation: int = 1,
+    groups: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    inputs = te.placeholder((N, L, CI), name="inputs")
+    weight = te.placeholder((kernel_size, CI // groups, CO), name="weight")
+
+    batch_size, in_len, _ = inputs.shape
+    k_len, channel_per_group, out_channel = weight.shape
+    out_channel_per_group = out_channel // groups
+    out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1
+    rc = te.reduce_axis((0, channel_per_group), name="rc")
+    rl = te.reduce_axis((0, k_len), name="rl")
+
+    padded = topi.nn.pad(inputs, [0, padding, 0])
+    output = te.compute(
+        (batch_size, out_len, out_channel),
+        lambda n, l, co: te.sum(
+            (
+                padded[
+                    n,
+                    l * stride + rl * dilation,
+                    co // out_channel_per_group * channel_per_group + rc,
+                ]
+                * weight[rl, rc, co]
+            ),
+            axis=[rl, rc],
+        ),
+        name="conv1d_nlc",
+    )
+    return (inputs, weight, output)
+
+
+def conv2d_nhwc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    dilation: int = 1,
+    groups: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    inputs = te.placeholder((N, H, W, CI), name="inputs")
+    weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO), 
name="weight")
+    batch_size, in_h, in_w, _ = inputs.shape
+    k_h, k_w, channel_per_group, out_channel = weight.shape
+    out_channel_per_group = out_channel // groups
+
+    out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+    out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+    rh = te.reduce_axis((0, k_h), name="rh")
+    rw = te.reduce_axis((0, k_w), name="rw")
+    rc = te.reduce_axis((0, channel_per_group), name="rc")
+
+    padded = topi.nn.pad(inputs, [0, padding, padding, 0])
+    output = te.compute(
+        (batch_size, out_h, out_w, out_channel),
+        lambda n, h, w, co: te.sum(
+            (
+                padded[
+                    n,
+                    h * stride + rh * dilation,
+                    w * stride + rw * dilation,
+                    co // out_channel_per_group * channel_per_group + rc,
+                ]
+                * weight[rh, rw, rc, co]
+            ),
+            axis=[rh, rw, rc],
+        ),
+        name="conv2d_nhwc",
+    )
+    return (inputs, weight, output)
+
+
+def conv3d_ndhwc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    D: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    dilation: int = 1,
+    groups: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    inputs = te.placeholder((N, D, H, W, CI), name="inputs")
+    weight = te.placeholder(
+        (kernel_size, kernel_size, kernel_size, CI // groups, CO), 
name="weight"
+    )
+    batch_size, in_d, in_h, in_w, _ = inputs.shape
+    k_d, k_h, k_w, channel_per_group, out_channel = weight.shape
+    out_channel_per_group = out_channel // groups
+
+    out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1
+    out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+    out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+    rd = te.reduce_axis((0, k_d), name="rd")
+    rh = te.reduce_axis((0, k_h), name="rh")
+    rw = te.reduce_axis((0, k_w), name="rw")
+    rc = te.reduce_axis((0, channel_per_group), name="rc")
+
+    padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0])
+    output = te.compute(
+        (batch_size, out_d, out_h, out_w, out_channel),
+        lambda n, d, h, w, co: te.sum(
+            (
+                padded[
+                    n,
+                    d * stride + rd * dilation,
+                    h * stride + rh * dilation,
+                    w * stride + rw * dilation,
+                    co // out_channel_per_group * channel_per_group + rc,
+                ]
+                * weight[rd, rh, rw, rc, co]
+            ),
+            axis=[rd, rh, rw, rc],
+        ),
+        name="conv3d_ndhwc",
+    )
+    return (inputs, weight, output)
+
+
+def depthwise_conv2d_nhwc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    C: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    dilation: int = 1,
+    factor: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    inputs = te.placeholder((N, H, W, C))
+    weight = te.placeholder((factor, kernel_size, kernel_size, C))
+    batch_size, in_h, in_w, in_channel = inputs.shape
+    factor, k_h, k_w, in_channel = weight.shape
+    out_channel = in_channel * factor
+    assert int(factor) == 1, "Not optimized for factor != 1"
+    out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
+    out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
+    rh = te.reduce_axis((0, k_h), name="rh")
+    rw = te.reduce_axis((0, k_w), name="rw")
+    padded = topi.nn.pad(inputs, [0, padding, padding, 0])
+    output = te.compute(
+        (batch_size, out_h, out_w, out_channel),
+        lambda n, h, w, c: te.sum(
+            (
+                padded[
+                    n,
+                    h * stride + rh * dilation,
+                    w * stride + rw * dilation,
+                    c // factor,
+                ]
+                * weight[c % factor, rh, rw, c // factor]
+            ),
+            axis=[rh, rw],
+        ),
+        name="depth_conv2d_nhwc",
+    )
+    return (inputs, weight, output)
+
+
+def conv2d_transpose_nhwc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    inputs = te.placeholder((N, H, W, CI), name="inputs")
+    weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight")
+
+    batch, in_h, in_w, in_c = inputs.shape
+    filter_h, filter_w, in_c, out_c = weight.shape
+    stride_h, stride_w = (stride, stride)
+
+    # compute padding
+    fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(
+        padding, (filter_h, filter_w)
+    )
+    bpad_top = filter_h - 1 - fpad_top
+    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_left = filter_w - 1 - fpad_left
+    bpad_right = filter_w - 1 - fpad_right
+
+    # padding stage
+    padded = topi.nn.pad(
+        inputs,
+        [
+            0,
+            (bpad_top + stride_h - 1) // stride_h,
+            (bpad_left + stride_w - 1) // stride_w,
+            0,
+        ],
+        [
+            0,
+            (bpad_bottom + stride_h - 1) // stride_h,
+            (bpad_right + stride_w - 1) // stride_w,
+            0,
+        ],
+    )
+
+    # remove extra padding introduced by dilatation
+    idx_div = te.indexdiv
+    idx_mod = te.indexmod
+    border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h)
+    border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w)
+
+    # dilation stage
+    strides = [1, stride_h, stride_w, 1]
+    n = len(padded.shape)
+
+    # We should embed this dilation directly into te.compute rather than 
creating a new te.compute.
+    # Only in this way can we use unroll to eliminate the multiplication of 
zeros.
+    def _dilate(*indices):
+        not_zero = []
+        index_tuple = []
+        for i in range(n):
+            if not strides[i] == 1:
+                index_tuple.append(idx_div(indices[i], strides[i]))
+                not_zero.append(idx_mod(indices[i], strides[i]).equal(0))
+            else:
+                index_tuple.append(indices[i])
+        if not_zero:
+            not_zero = te.all(*not_zero)
+            return te.if_then_else(not_zero, padded(*index_tuple), 
tir.const(0.0, padded.dtype))
+        return padded(*index_tuple)
+
+    # convolution stage
+    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
+    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
+    rc = te.reduce_axis((0, in_c), name="rc")
+    rh = te.reduce_axis((0, filter_h), name="rh")
+    rw = te.reduce_axis((0, filter_w), name="rw")
+
+    output = te.compute(
+        (batch, out_h, out_w, out_c),
+        lambda n, h, w, co: te.sum(
+            _dilate(n, h + rh + border_h, w + rw + border_w, rc)
+            * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co],
+            axis=[rh, rw, rc],
+        ),
+        name="conv2d_transpose_nhwc",
+    )
+    return (inputs, weight, output)
+
+
+def conv2d_capsule_nhwijc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    capsule_size: int = 4,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), 
name="inputs")
+    weight = te.placeholder(
+        (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), 
name="weight"
+    )
+    batch_size, in_h, in_w, _, _, in_channel = inputs.shape
+    k_h, k_w, _, _, _, out_channel = weight.shape
+
+    out_h = (in_h + 2 * padding - kernel_size) // stride + 1
+    out_w = (in_w + 2 * padding - kernel_size) // stride + 1
+
+    rh = te.reduce_axis((0, k_h), name="rh")
+    rw = te.reduce_axis((0, k_w), name="rw")
+    cap_k = te.reduce_axis((0, capsule_size), name="cap_k")
+    rc = te.reduce_axis((0, in_channel), name="rc")
+
+    padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0])
+    output = te.compute(
+        (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel),
+        lambda n, h, w, cap_i, cap_j, co: te.sum(
+            (
+                padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc]
+                * weight[rh, rw, cap_k, cap_j, rc, co]
+            ),
+            axis=[rh, rw, cap_k, rc],
+        ),
+        name="conv2d_capsule_nhwijc",
+    )
+    return (inputs, weight, output)
+
+
+def norm_bmn(  # pylint: disable=invalid-name,missing-docstring
+    B: int,
+    M: int,
+    N: int,
+) -> Tuple[te.Tensor, te.Tensor]:
+    a = te.placeholder((B, M, N), name="A")
+    i = te.reduce_axis((0, M), name="i")
+    j = te.reduce_axis((0, N), name="j")
+    c = te.compute(
+        (B,),
+        lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]),
+        name="C",
+    )
+    d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D")
+    return (a, d)
+
+
+def conv2d_nhwc_without_layout_rewrite(  # pylint: disable=invalid-name
+    Input: int,
+    Filter: int,
+    stride: int,
+    padding: int,
+    dilation: int,
+    out_dtype="float32",
+):
+    """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute.
+    We use this in single op and subgraph evaluation
+    because we don't want to introduce graph level optimization.
+    """
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    batch, in_height, in_width, in_channel = Input.shape  # type: ignore
+    kernel_h, kernel_w, _channel, num_filter = Filter.shape  # type: ignore
+
+    # compute the output shape
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
+    out_channel = num_filter
+    out_height = topi.utils.simplify(
+        (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1
+    )
+    out_width = topi.utils.simplify(
+        (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1
+    )
+    pad_before = [0, pad_top, pad_left, 0]
+    pad_after = [0, pad_down, pad_right, 0]
+    PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput")
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+    Output = te.compute(
+        (batch, out_height, out_width, out_channel),
+        lambda nn, yy, xx, ff: te.sum(
+            PaddedInput[
+                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * 
dilation_w, rc
+            ].astype(out_dtype)
+            * Filter[ry, rx, rc, ff].astype(out_dtype),  # type: ignore
+            axis=[ry, rx, rc],
+        ),
+        name="Conv2dOutput",
+        tag="conv2d_nhwc",
+    )
+    return Output
+
+
+def conv2d_nhwc_bn_relu(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    strides: int,
+    padding: int,
+    dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]:
+    data = te.placeholder((N, H, W, CI), name="data")
+    kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel")
+    bias = te.placeholder((CO,), name="bias")
+    bn_scale = te.placeholder((CO,), name="bn_scale")
+    bn_offset = te.placeholder((CO,), name="bn_offset")
+    OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
+    OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1
+    conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, 
dilation)
+    conv = te.compute(
+        (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l], 
name="bias_add"
+    )
+    conv = te.compute(
+        (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], 
name="bn_mul"
+    )
+    conv = te.compute(
+        (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], 
name="bn_add"
+    )
+    out = topi.nn.relu(conv)
+    return (data, kernel, bias, bn_offset, bn_scale, out)
+
+
+def transpose_batch_matmul(  # pylint: disable=invalid-name,missing-docstring
+    batch: int,
+    seq_len: int,
+    n_head: int,
+    n_dim: int,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    query = te.placeholder((batch, seq_len, n_head, n_dim), name="query")
+    value = te.placeholder((batch, seq_len, n_head, n_dim), name="value")
+    query_T = te.compute(
+        (batch, n_head, seq_len, n_dim),
+        lambda b, h, l, d: query[b, l, h, d],
+        name="query_T",
+    )
+    value_T = te.compute(
+        (batch, n_head, n_dim, seq_len),
+        lambda b, h, d, l: value[b, l, h, d],
+        name="value_T",
+    )
+    k = te.reduce_axis((0, n_dim), name="k")
+    out = te.compute(
+        (batch, n_head, seq_len, seq_len),
+        lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j], 
axis=[k]),
+        name="C",
+    )
+    return (query, value, out)
+
+
+def conv2d_winograd_nhwc(  # pylint: disable=invalid-name,missing-docstring
+    N: int,
+    H: int,
+    W: int,
+    CI: int,
+    CO: int,
+    kernel_size: int,
+    stride: int = 1,
+    padding: int = 0,
+    dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    tile_size = 4  # _infer_tile_size(data, kernel)
+    inputs = te.placeholder((N, H, W, CI), name="inputs")
+    N, H, W, CI = topi.utils.get_const_tuple(inputs.shape)
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation"
+
+    KH = KW = kernel_size
+    HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW))
+    HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride
+    assert HSTR == 1 and WSTR == 1 and KH == KW
+
+    data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), 
name="data_pad")
+
+    r = KW
+    m = tile_size
+    alpha = m + r - 1
+    A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, 
"float32")
+
+    H = (H + 2 * HPAD - KH) // HSTR + 1
+    W = (W + 2 * WPAD - KW) // WSTR + 1
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
+    P = N * nH * nW
+    _rkh = te.reduce_axis((0, KH), name="r_kh")
+    _rkw = te.reduce_axis((0, KW), name="r_kw")
+    kshape = (alpha, alpha, CI, CO)
+    kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")
+
+    idxdiv = te.indexdiv
+    idxmod = te.indexmod
+    # pack input tile
+    input_tile = te.compute(
+        (alpha, alpha, P, CI),
+        lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, 
nW), nH) * m + eps][
+            idxmod(p, nW) * m + nu
+        ][ci],
+        name="input_tile",
+    )
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    data_pack = te.compute(
+        (alpha, alpha, P, CI),
+        lambda eps, nu, p, ci: te.sum(
+            input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, 
r_b]
+        ),
+        name="data_pack",
+        attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", 
"r_a", "r_b"]},
+    )
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name="ci")
+    bgemm = te.compute(
+        (alpha, alpha, P, CO),
+        lambda eps, nu, p, co: te.sum(
+            data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci]
+        ),
+        name="bgemm",
+    )
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), "r_a")
+    r_b = te.reduce_axis((0, alpha), "r_b")
+    inverse = te.compute(
+        (m, m, P, CO),
+        lambda vh, vw, p, co: te.sum(
+            bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
+        ),
+        name="inverse",
+        attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", 
"r_a", "r_b"]},
+    )
+
+    # output
+    output = te.compute(
+        (N, H, W, CO),
+        lambda n, h, w, co: inverse[
+            idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + 
idxdiv(w, m), co
+        ],
+        name="conv2d_winograd",
+    )
+
+    return (inputs, kernel_pack, output)
+
+
+def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    a = te.placeholder((n, k), name="A")
+    b = te.placeholder((k, m), name="B")
+    k = te.reduce_axis((0, k), name="k")
+    c = te.compute(
+        (n, m),
+        lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]),
+        name="C",
+    )
+    return (a, b, c)
+
+
+def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, 
te.Tensor]:
+    a = te.placeholder((n, k), name="A", dtype="float16")
+    b = te.placeholder((k, m), name="B", dtype="float16")
+    k = te.reduce_axis((0, k), name="k")
+
+    def f_compute(i, j):
+        v_a = tir.Cast(dtype="float32", value=a[i, k])
+        v_b = tir.Cast(dtype="float32", value=b[k, j])
+        return te.sum(v_a * v_b, axis=[k])
+
+    c = te.compute((n, m), f_compute, name="C")
+    return (a, b, c)
+
+
+def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, 
te.Tensor]:
+    a = te.placeholder((n, k), name="A")
+    b = te.placeholder((m, k), name="B")
+    k = te.reduce_axis((0, k), name="k")
+    c = te.compute(
+        (n, m),
+        lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]),
+        name="C",
+    )
+    d = topi.nn.relu(c)  # pylint: disable=invalid-name
+    return (a, b, d)
+
+
+def matmul_relu_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, 
te.Tensor]:
+    a = te.placeholder((n, k), name="A", dtype="float16")
+    b = te.placeholder((k, m), name="B", dtype="float16")
+    k = te.reduce_axis((0, k), name="k")
+
+    def f_compute(i, j):
+        v_a = tir.Cast(dtype="float32", value=a[i, k])
+        v_b = tir.Cast(dtype="float32", value=b[k, j])
+        return te.sum(v_a * v_b, axis=[k])
+
+    c = te.compute((n, m), f_compute, name="C")
+    d = topi.nn.relu(c)  # pylint: disable=invalid-name
+    return (a, b, d)
+
+
+def conv2d_nchw(  # pylint: disable=invalid-name
+    n: int,
+    h: int,
+    w: int,
+    ci: int,
+    co: int,
+    kh: int,
+    kw: int,
+    stride: int,
+    padding: int,
+    dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
+    x = te.placeholder((n, ci, h, w), name="X")
+    w = te.placeholder((co, ci, kh, kw), name="W")
+    y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, 
dilation=dilation)
+    return (x, w, y)
+
+
+def conv2d_nchw_bias_bn_relu(  # pylint: disable=invalid-name
+    n: int,
+    h: int,
+    w: int,
+    ci: int,
+    co: int,
+    kh: int,
+    kw: int,
+    stride: int,
+    padding: int,
+    dilation: int = 1,
+) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]:
+    oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1  # pylint: 
disable=invalid-name
+    ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1  # pylint: 
disable=invalid-name
+    x = te.placeholder((n, ci, h, w), name="X")
+    w = te.placeholder((co, ci, kh, kw), name="W")
+    b = te.placeholder((co, 1, 1), name="B")
+    bn_scale = te.placeholder((co, 1, 1), name="bn_scale")
+    bn_offset = te.placeholder((co, 1, 1), name="bn_offset")
+    y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, 
dilation=dilation)
+    y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0, 
0], name="bias_add")
+    y = te.compute(
+        (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0], 
name="bn_mul"
+    )
+    y = te.compute(
+        (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0, 
0], name="bn_add"
+    )
+    y = topi.nn.relu(y)
+    return (x, w, b, bn_scale, bn_offset, y)
+
+
+def max_pool2d_nchw(  # pylint: disable=invalid-name
+    n: int,
+    h: int,
+    w: int,
+    ci: int,
+    padding: int,
+) -> Tuple[te.Tensor, te.Tensor]:  # pylint: disable=invalid-name
+    x = te.placeholder((n, ci, h, w), name="X")
+    y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding, 
padding], "max")
+    return (x, y)
+
+
+def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]:  # pylint: 
disable=invalid-name
+    a = te.placeholder((m, n), name="A")
+    b = topi.nn.softmax(a, axis=1)
+
+    return (a, b)
+
+
+def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
+    workload_func, params = CONFIGS[name]
+    return te.create_prim_func(workload_func(*params[idx]))  # type: ignore
+
+
+CONFIGS = {
+    "C1D": (
+        conv1d_nlc,
+        [
+            # derived from conv2d_shapes
+            (1, 256, 64, 128, 3, 2, 1),
+            #    (1, 256, 64, 128, 1, 2, 0),
+            #    (1, 256, 64, 64, 1, 1, 0),
+            #    (1, 128, 128, 256, 3, 2, 1),
+            (1, 128, 128, 256, 1, 2, 0),
+            #    (1, 128, 128, 128, 3, 1, 1),
+            #    (1, 64, 256, 512, 3, 2, 1),
+            #    (1, 64, 256, 512, 1, 2, 0),
+            (1, 64, 256, 256, 5, 1, 2),
+            (1, 32, 512, 512, 3, 1, 1),
+        ],
+    ),
+    "C2D": (
+        conv2d_nhwc,
+        [
+            # all conv2d layers in resnet-18
+            (1, 224, 224, 3, 64, 7, 2, 3),
+            #    (1, 56, 56, 64, 128, 3, 2, 1),
+            #    (1, 56, 56, 64, 128, 1, 2, 0),
+            #    (1, 56, 56, 64, 64, 3, 1, 1),
+            (1, 56, 56, 64, 64, 1, 1, 0),
+            #    (1, 28, 28, 128, 256, 3, 2, 1),
+            #    (1, 28, 28, 128, 256, 1, 2, 0),
+            #    (1, 28, 28, 128, 128, 3, 1, 1),
+            #    (1, 14, 14, 256, 512, 3, 2, 1),
+            #    (1, 14, 14, 256, 512, 1, 2, 0),
+            (1, 14, 14, 256, 256, 3, 1, 1),
+            (1, 7, 7, 512, 512, 3, 1, 1),
+        ],
+    ),
+    "C3D": (
+        conv3d_ndhwc,
+        [
+            # Derived from conv2d_shapes. Use depth=16 for all configurations
+            (1, 16, 224, 224, 3, 64, 7, 2, 3),
+            #    (1, 16, 56, 56, 64, 128, 3, 2, 1),
+            #    (1, 16, 56, 56, 64, 128, 1, 2, 0),
+            #    (1, 16, 56, 56, 64, 64, 3, 1, 1),
+            (1, 16, 56, 56, 64, 64, 1, 1, 0),
+            #    (1, 16, 28, 28, 128, 256, 3, 2, 1),
+            #    (1, 16, 28, 28, 128, 256, 1, 2, 0),
+            #    (1, 16, 28, 28, 128, 128, 3, 1, 1),
+            #    (1, 16, 14, 14, 256, 512, 3, 2, 1),
+            #    (1, 16, 14, 14, 256, 512, 1, 2, 0),
+            (1, 16, 14, 14, 256, 256, 3, 1, 1),
+            (1, 16, 7, 7, 512, 512, 3, 1, 1),
+        ],
+    ),
+    "GMM": (
+        batch_matmul_nkkm,
+        [
+            (1, 128, 128, 128),
+            (1, 512, 32, 512),
+            (1, 512, 512, 512),
+            (1, 1024, 1024, 1024),
+        ],
+    ),
+    "GRP": (
+        conv2d_nhwc,
+        [
+            # Derived from conv2d_shapes. Use group=4 for all configurations
+            (1, 56, 56, 64, 128, 3, 2, 1, 1, 4),
+            #    (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4),
+            #    (1, 56, 56, 64, 64, 3, 1, 1  , 1, 4),
+            (1, 56, 56, 64, 64, 1, 1, 0, 1, 4),
+            #    (1, 28, 28, 128, 256, 3, 2, 1, 1, 4),
+            #    (1, 28, 28, 128, 256, 1, 2, 0, 1, 4),
+            #    (1, 28, 28, 128, 128, 3, 1, 1, 1, 4),
+            #    (1, 14, 14, 256, 512, 3, 2, 1, 1, 4),
+            #    (1, 14, 14, 256, 512, 1, 2, 0, 1, 4),
+            (1, 14, 14, 256, 256, 3, 1, 1, 1, 4),
+            (1, 7, 7, 512, 512, 3, 1, 1, 1, 4),
+        ],
+    ),
+    "DIL": (
+        conv2d_nhwc,
+        [
+            # Derived from conv2d_shapes. Use dilation=2 for all configurations
+            (1, 224, 224, 3, 64, 7, 2, 3, 2),
+            #    (1, 56, 56, 64, 128, 3, 2, 1 , 2),
+            #    (1, 56, 56, 64, 128, 1, 2, 0 , 2),
+            #    (1, 56, 56, 64, 64, 3, 1, 1  , 2),
+            (1, 56, 56, 64, 64, 1, 1, 0, 2),
+            #    (1, 28, 28, 128, 256, 3, 2, 1, 2),
+            #    (1, 28, 28, 128, 256, 1, 2, 0, 2),
+            #    (1, 28, 28, 128, 128, 3, 1, 1, 2),
+            #    (1, 14, 14, 256, 512, 3, 2, 1, 2),
+            #    (1, 14, 14, 256, 512, 1, 2, 0, 2),
+            (1, 14, 14, 256, 256, 3, 1, 1, 2),
+            (1, 7, 7, 512, 512, 3, 1, 1, 2),
+        ],
+    ),
+    "DEP": (
+        depthwise_conv2d_nhwc,
+        [
+            # all depthwise conv2d layers in mobilenet
+            (1, 112, 112, 32, 3, 1, 1),
+            (1, 112, 112, 64, 3, 2, 1),
+            #    (1,  56,  56, 128, 3, 1, 1),
+            #    (1,  56,  56, 128, 3, 2, 1),
+            #    (1,  28,  28, 256, 3, 1, 1),
+            #    (1,  28,  28, 256, 3, 2, 1),
+            #    (1,  14,  14, 512, 3, 1, 1),
+            (1, 14, 14, 512, 3, 2, 1),
+            (1, 7, 7, 1024, 3, 1, 1),
+        ],
+    ),
+    "T2D": (
+        conv2d_transpose_nhwc,
+        [
+            # all conv2d tranpose layers in DCGAN
+            (1, 4, 4, 512, 256, 4, 2, 1),
+            (1, 8, 8, 256, 128, 4, 2, 1),
+            (1, 16, 16, 128, 64, 4, 2, 1),
+            (1, 32, 32, 64, 3, 4, 2, 1),
+        ],
+    ),
+    "CAP": (
+        conv2d_capsule_nhwijc,
+        [
+            # all conv2d capsule layers in matrix capsules withemrouting (ICLR 
2018)
+            (1, 16, 16, 32, 32, 3, 2, 1),
+            (1, 8, 8, 32, 32, 3, 1, 1),
+            (1, 16, 16, 8, 16, 3, 2, 1),
+            (1, 8, 8, 16, 16, 3, 1, 1),
+        ],
+    ),
+    "NRM": (
+        norm_bmn,
+        [
+            (1, 256, 256),
+            (1, 512, 512),
+            (1, 1024, 1024),
+            (1, 4096, 1024),
+        ],
+    ),
+    "SFM": (
+        softmax_mn,
+        [
+            (256, 256),
+            (512, 512),
+            (1024, 1024),
+            (2048, 2048),
+        ],
+    ),
+    "C2d-BN-RELU": (
+        conv2d_nhwc_bn_relu,
+        [
+            (1, 224, 224, 3, 64, 7, 2, 3),
+            (1, 56, 56, 64, 128, 3, 2, 1),
+            (1, 28, 28, 128, 256, 1, 2, 0),
+            (1, 7, 7, 512, 512, 3, 1, 1),
+        ],
+    ),
+    "TBG": (
+        transpose_batch_matmul,
+        [
+            (1, 128, 12, 64),
+            (1, 128, 16, 64),
+            (1, 64, 12, 128),
+            (1, 128, 12, 128),
+        ],
+    ),
+}
diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc 
b/src/meta_schedule/schedule_rule/add_rfactor.cc
new file mode 100644
index 0000000..5ef2ac3
--- /dev/null
+++ b/src/meta_schedule/schedule_rule/add_rfactor.cc
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class AddRFactorNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    ICHECK(context->target.defined());
+    Target target = context->target.value();
+    this->max_parallel_basic_ = GetTargetNumCores(target);
+    if (this->max_jobs_per_core != -1) {
+      this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core;
+    }
+  }
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& 
block_rv);
+
+ public:
+  /*!
+   * \brief The maximum number of jobs to be launched per core.
+   * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`.
+   * Use -1 to disable parallelism.
+   */
+  int max_jobs_per_core;
+  /*! \brief The maximum size of the innermost factor */
+  int max_innermost_factor;
+  /*! \brief The number of uplimit of parallelism. */
+  int max_parallel_extent_;
+  /*! \brief The number of cores. */
+  int max_parallel_basic_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("max_jobs_per_core", &max_jobs_per_core);
+    v->Visit("max_innermost_factor", &max_innermost_factor);
+    // `max_parallel_extent_` is not visited
+    // `max_parallel_basic_` is not visited
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.AddRFactor";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
+                                      Optional<Integer> max_innermost_factor) {
+  ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
+  n->max_jobs_per_core = max_jobs_per_core;
+  n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
+  n->max_parallel_extent_ = -1;
+  n->max_parallel_basic_ = -1;
+  return ScheduleRule(n);
+}
+
+Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const 
tir::BlockRV& block_rv) {
+  tir::StmtSRef block_sref = sch->GetSRef(block_rv);
+  if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, 
max_parallel_extent_,
+                                          max_parallel_basic_)) {
+    return {sch};
+  }
+
+  // Make a copy of the original schedule.
+  tir::Schedule ori_sch = sch->Copy();
+  ori_sch->Seed(sch->ForkSeed());
+
+  // Reorder the loop axes if reduction loops are not innermost.
+  // After the reordering, fuse all the reduction loops.
+  size_t num_spatial_loops;
+  tir::LoopRV fused_reduce_loop;
+  ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, 
&num_spatial_loops);
+
+  // Split the fused reduction loop.
+  Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, 
max_innermost_factor);
+  const Array<tir::LoopRV>& split_loops =
+      sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
+
+  Array<tir::Schedule> res;
+  for (const tir::LoopRV& split_loop : split_loops) {
+    tir::Schedule sch_tmp = sch->Copy();
+    sch_tmp->Seed(sch->ForkSeed());
+    try {
+      const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, 
num_spatial_loops);
+      Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
+      ICHECK_GT(axes.size(), num_spatial_loops);
+
+      // Annotate that the rfactor block, which is now the producer of the 
original block, needs to
+      // be considered by the rule Random-Compute-Location.
+      sch_tmp->Annotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer, Bool(true));
+      res.push_back(sch_tmp);
+    } catch (const tvm::runtime::Error& e) {
+    }
+  }
+
+  res.push_back(ori_sch);
+  return res;
+}
+
+TVM_REGISTER_NODE_TYPE(AddRFactorNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor")
+    .set_body_typed(ScheduleRule::AddRFactor);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index ef15f49..5b49769 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -318,6 +318,26 @@ struct ThreadedTraceApply {
   Item* items_;
 };
 
+/*!
+ * \brief Get the number of cores in CPU
+ * \param target The target
+ * \return The number of cores.
+ */
+inline int GetTargetNumCores(const Target& target) {
+  int num_cores = target->GetAttr<Integer>("num-cores").value_or(-1);
+  if (num_cores == -1) {
+    static const auto* f_cpu_count = 
runtime::Registry::Get("meta_schedule.cpu_count");
+    ICHECK(f_cpu_count)
+        << "ValueError: Cannot find the packed function 
\"meta_schedule._cpu_count\"";
+    num_cores = (*f_cpu_count)(false);
+    LOG(FATAL)
+        << "Target does not have attribute \"num-cores\", physical core number 
must be "
+           "defined! For example, on the local machine, the target must be 
\"llvm -num-cores "
+        << num_cores << "\"";
+  }
+  return num_cores;
+}
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index e4bf48b..c562c78 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -254,6 +254,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
     .add_attr_option<String>("mabi")
     .add_attr_option<Bool>("system-lib")
     .add_attr_option<String>("runtime")
+    .add_attr_option<Integer>("num-cores")
     .add_attr_option<Bool>("link-params", Bool(false))
     .add_attr_option<Bool>("unpacked-api")
     .add_attr_option<String>("interface-api")
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 9622e2d..636cc7d 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -520,6 +520,44 @@ std::tuple</*exists=*/bool,
            /*no_shift_read=*/bool>
 AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& 
write_region);
 
+/*!
+ * \brief Check if the block is a data parallel block, i.e. all the block vars 
are data parallel
+ * \param block_sref The block to be checked
+ * \return A boolean flag indicating if the block is a data parallel block
+ */
+bool IsSpatial(const StmtSRef& block_sref);
+
+/*!
+ * \brief Check whether a block has a trivial binding, i.e. each block var is 
bound to a outer loop,
+ * from outer to inner.
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean flag indicating if the block has a trivial binding
+ */
+bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
+
+/*!
+ * \brief Checks if the given block has data reuse opportunity and thus 
multi-level tiling is
+ * beneficial.
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean indicating whether the block has data reuse opportunity
+ */
+bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& 
block_sref);
+
+/*!
+ * \brief Checks if the rfactor or cross thread reduction is beneficial to the 
given block.
+ * \param self The schedule state.
+ * \param block_sref The block to be checked.
+ * \param max_parallel_extent The maximum parallel jobs on the target.
+ * \param max_parallel_basic The maximum cores on the target.
+ * \return A boolean indicating whether the operation is beneficial.
+ */
+bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
+                                        const tir::StmtSRef& block_sref,  //
+                                        int64_t max_parallel_extent,      //
+                                        int64_t max_parallel_basic);
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 0520973..2053f8d 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1661,5 +1661,191 @@ void CheckStorageScope(const ScheduleState& self, 
String storage_scope) {
   }
 }
 
+bool IsSpatial(const StmtSRef& block_sref) {
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  for (const IterVar& iter_var : block->iter_vars) {
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  Array<PrimExpr> binds = GetBlockRealize(self, block_sref)->iter_values;
+  if (loops.size() != binds.size()) {
+    return false;
+  }
+  for (int i = 0, n = loops.size(); i < n; ++i) {
+    const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]);
+    if (binds[i].get() != loop->loop_var.get()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& 
block_sref) {
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  if (block->writes.size() != 1 || block->reads.empty() || 
IsSpatial(block_sref) ||
+      !IsTrivialBinding(self, block_sref)) {
+    return false;
+  }
+  const BufferNode* write_buffer = block->writes[0]->buffer.get();
+  // Step 1. Sort out spatial block variables
+  std::vector<const VarNode*> spatial_block_vars;
+  spatial_block_vars.reserve(block->iter_vars.size());
+  for (const IterVar& block_var : block->iter_vars) {
+    if (block_var->iter_type == IterVarType::kDataPar) {
+      spatial_block_vars.push_back(block_var->var.get());
+    }
+  }
+  // Step 2. Enumerate each read region, check the number of block vars that 
are not used
+  // to index the read region
+  int total_unused_block_vars = 0;
+  std::unordered_set<const BufferNode*> read_buffers;
+  read_buffers.reserve(block->reads.size());
+  for (const BufferRegion& buffer_region : block->reads) {
+    const BufferNode* buffer = buffer_region->buffer.get();
+    const Array<Range>& regions = buffer_region->region;
+    // Step 2.1. Duplication of read buffers are not allowed
+    if (read_buffers.insert(buffer).second == false) {
+      return false;
+    }
+    // Step 2.2. Skip the reduction buffer
+    if (buffer == write_buffer) {
+      continue;
+    }
+    // Step 2.3. Collect the block vars that are used to index the read region
+    std::unordered_set<const VarNode*> vars;
+    for (const Range& range : regions) {
+      if (as_const_int(range->extent) == nullptr) {
+        return false;
+      }
+      for (const Var& var : UndefinedVars(range->min)) {
+        vars.insert(var.get());
+      }
+    }
+    // Step 2.4. Check if the block vars are not used to index the read region
+    int n_unused_block_vars = 0;
+    for (const VarNode* block_var : spatial_block_vars) {
+      if (vars.count(block_var) == 0) {
+        ++n_unused_block_vars;
+      }
+    }
+    total_unused_block_vars += n_unused_block_vars;
+  }
+  return total_unused_block_vars >= 1;
+}
+
+std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const 
tir::ScheduleState& self,
+                                                                 const 
tir::StmtSRef& block_sref) {
+  Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
+  int64_t cum_space_len = 1, cum_reduce_len = 1;
+  /*
+   * Return (-1, -1) if
+   *   1. there is some loop with type other than kDataPar and kCommReduce;
+   *   2. there is some loop which is dynamic.
+   */
+  for (const tir::StmtSRef& loop_sref : loops) {
+    tir::IterVarType type = GetLoopIterType(loop_sref);
+    if (type == tir::kDataPar) {
+      const int64_t* extent = GetLoopIntExtent(loop_sref);
+      if (*extent != -1) {
+        cum_space_len *= *extent;
+      } else {
+        return std::make_pair(-1, -1);
+      }
+    } else if (type == tir::kCommReduce) {
+      const int64_t* extent = GetLoopIntExtent(loop_sref);
+      if (*extent != -1) {
+        cum_reduce_len *= *extent;
+      } else {
+        return std::make_pair(-1, -1);
+      }
+    } else {
+      return std::make_pair(-1, -1);
+    }
+  }
+  return std::make_pair(cum_space_len, cum_reduce_len);
+}
+
+bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
+                                        const tir::StmtSRef& block_sref,  //
+                                        int64_t max_parallel_extent,      //
+                                        int64_t max_parallel_basic) {
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
+
+  // Cond 1. The block has only one write buffer
+  if (block->writes.size() != 1) {
+    return false;
+  }
+
+  // Cond 2. The block is a reduction block and has trivial binding.
+  const StmtSRef& scope_sref = GetScopeRoot(self, block_sref,                  
//
+                                            /*require_stage_pipeline=*/false,  
//
+                                            
/*require_subtree_compact_dataflow=*/false);
+  if (!(IsReductionBlock(self, block_sref, scope_sref) &&  //
+        IsTrivialBinding(self, block_sref))) {
+    return false;
+  }
+
+  // Cond 3. Every the loop axis must be either spatial axis or reduction axis.
+  for (const tir::StmtSRef& loop_sref : loops) {
+    const tir::IterVarType& type = GetLoopIterType(loop_sref);
+    if (type != tir::kDataPar && type != tir::kCommReduce) {
+      return false;
+    }
+  }
+
+  // Cond 4. Whether there is at least one reduction loop.
+  // Cond 5. The loops are continuous, and the body of the innermost loop is 
exactly the block.
+  bool has_reduction_loop = false;
+  for (size_t i = 0; i < loops.size(); ++i) {
+    // Cond 4.
+    if (GetLoopIterType(loops[i]) == tir::kCommReduce) {
+      has_reduction_loop = true;
+    }
+
+    // Cond 5.
+    const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
+    if (i < loops.size() - 1) {
+      const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
+      if (loop_i->body.get() != loop_i1) {
+        return false;
+      }
+    } else {
+      const auto* block_realize = loop_i->body.as<tir::BlockRealizeNode>();
+      if (!block_realize || block_realize->block.get() != block) {
+        return false;
+      }
+    }
+  }
+  if (!has_reduction_loop) {
+    return false;
+  }
+
+  // Cond 6. Can successfully calculating the cumulative loop length.
+  int64_t cum_space_len, cum_reduce_len;
+  std::tie(cum_space_len, cum_reduce_len) = 
GetCumulativeSpaceAndReductionLength(self, block_sref);
+  if (cum_space_len == -1 || cum_reduce_len == -1) {
+    return false;
+  }
+
+  // Cond 7.
+  if (NeedsMultiLevelTiling(self, block_sref)) {
+    // Do not use rfactor/cross-thread-reduction if we have enough parallelism 
on spatial loops.
+    return !(cum_space_len >= cum_reduce_len || cum_space_len > 
max_parallel_extent);
+  } else if (cum_reduce_len > 1) {
+    // Always try rfactor/cross-thread-reduction for other reduction blocks.
+    return cum_reduce_len > max_parallel_basic;
+  } else {
+    return false;
+  }
+}
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index 2acab38..be6d5a1 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -320,6 +320,60 @@ inline bool HasAnn(const StmtSRef& sref, const String& 
ann_key, bool ann_val) {
   return result.defined() && result.value()->value == ann_val;
 }
 
+/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction 
**********/
+
+/*!
+ * \brief Reorder the reduction loops to innermost positions if needed.
+ * \param sch The schedule
+ * \param block_rv The block where to apply the reorder
+ * \param fused_reduce_loop The fusion-generated loop to return.
+ * \param num_spatial_loops The number of spatial loops to return.
+ * \note Before invoking this helper function, make sure that the block has 
only spatial and
+ *       reduction loop axes.
+ */
+inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const 
tir::BlockRV& block_rv,
+                                         tir::LoopRV* fused_reduce_loop,
+                                         size_t* num_spatial_loops) {
+  Array<tir::LoopRV> loops = sch->GetLoops(block_rv);
+  Array<tir::StmtSRef> loop_srefs;
+  for (const tir::LoopRV& loop_rv : loops) {
+    loop_srefs.push_back(sch->GetSRef(loop_rv));
+  }
+
+  Array<tir::LoopRV> new_order;
+  // Step 1. Add spatial loops.
+  *num_spatial_loops = 0;
+  for (size_t i = 0; i < loops.size(); ++i) {
+    if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) {
+      new_order.push_back(loops[i]);
+      (*num_spatial_loops)++;
+    }
+  }
+  // Step 2. Add reduction loops.
+  Array<tir::LoopRV> reduction_loops;
+  for (size_t i = 0; i < loops.size(); ++i) {
+    if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) {
+      new_order.push_back(loops[i]);
+      reduction_loops.push_back(loops[i]);
+    }
+  }
+  // Step 3. Apply reordering if new_order differs from the original order.
+  ICHECK_EQ(new_order.size(), loops.size());
+  for (size_t i = 0; i < loops.size(); ++i) {
+    if (!new_order[i].same_as(loops[i])) {
+      sch->Reorder(new_order);
+      break;
+    }
+  }
+  // Step 4. Fuse all the reduction loops if there are multiple reduction 
loops.
+  CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one 
reduction loop";
+  if (reduction_loops.size() > 1) {
+    *fused_reduce_loop = sch->Fuse(reduction_loops);
+  } else {
+    *fused_reduce_loop = reduction_loops[0];
+  }
+}
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
new file mode 100644
index 0000000..5a80312
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+
+from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm.meta_schedule.testing import te_workload
+from tvm.meta_schedule.testing.schedule_rule import add_rfactor
+from tvm.meta_schedule.testing.space_generation import check_trace
+from tvm.meta_schedule.tune_context import TuneContext
+from tvm.target import Target
+from tvm.te.operation import create_prim_func
+
+
+def _create_context(mod, target, rule) -> TuneContext:
+    ctx = TuneContext(
+        mod=mod,
+        target=target,
+        space_generator=PostOrderApply(),
+        sch_rules=[rule],
+        task_name="test",
+    )
+    ctx.space_generator.initialize_with_tune_context(ctx)
+    for sch_rule in ctx.sch_rules:
+        sch_rule.initialize_with_tune_context(ctx)
+    return ctx
+
+
+def test_cpu_matmul():
+    expected = [
+        [],
+        [
+            'b0 = sch.get_block(name="C", func_name="main")',
+            "l1, l2, l3 = sch.get_loops(block=b0)",
+            "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, 
max_innermost_factor=64)",
+            "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+            "b8 = sch.rfactor(loop=l7, factor_axis=2)",
+            'sch.annotate(block_or_loop=b0, 
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
+        ],
+        [
+            'b0 = sch.get_block(name="C", func_name="main")',
+            "l1, l2, l3 = sch.get_loops(block=b0)",
+            "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, 
max_innermost_factor=64)",
+            "l6, l7 = sch.split(loop=l3, factors=[v4, v5])",
+            "b8 = sch.rfactor(loop=l6, factor_axis=2)",
+            'sch.annotate(block_or_loop=b0, 
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
+        ],
+    ]
+    target = Target("llvm --num-cores=32")
+    ctx = _create_context(
+        create_prim_func(
+            te_workload.matmul(
+                n=4,
+                m=4,
+                k=512,
+            )
+        ),
+        target=target,
+        rule=add_rfactor(target=target),
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 3
+    check_trace(spaces, expected)
+
+
+if __name__ == "__main__":
+    test_cpu_matmul()

Reply via email to