This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 093720230b [Unity] Introduce Default GPU Schedule Pass (#14182)
093720230b is described below

commit 093720230bc2403d278e820bd164b93158ab8823
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Mar 6 10:19:19 2023 -0800

    [Unity] Introduce Default GPU Schedule Pass (#14182)
    
    * Implement default schedule.
    
    * Add test.
    
    * Add tests.
    
    * Fix linting.
    
    * Skip scheduled blocks.
    
    * Address issues.
    
    * Use target current.
    
    * Minor fixes.
    
    * Remove Mutator.
    
    * Move pass to tir given it's all primfunc scheduling.
    
    * Move unit tests.
    
    * Remove redundant headers and adjust comments.
    
    * Add more explanation to the pass documents.
    
    * Modify pass name.
    
    * Fix comment.
---
 include/tvm/tir/transform.h                        |  12 +
 python/tvm/tir/transform/transform.py              |  18 +
 .../space_generator/post_order_apply.cc            |  62 ---
 src/meta_schedule/utils.h                          |  62 +++
 src/tir/transforms/default_gpu_schedule.cc         | 116 ++++++
 .../test_transform_default_gpu_schedule.py         | 417 +++++++++++++++++++++
 6 files changed, 625 insertions(+), 62 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index d212578b8d..4653fa3640 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -717,6 +717,18 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief The pass sets default thread bindings for PrimFuncs, including 
symbolic shape functions,
+ *  allowing their build and execution on GPU devices. It examines all the 
blocks within the
+ *  PrimFunc and conducts loop fusion, splitting, and reordering operations 
based on the loop extent
+ *  and target information, such as the maximum thread block number and 
maximum thread per block.
+ * \note The primary objective of this pass is not to optimize performance, 
but rather to
+ *  generate a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. 
The pass is
+ *  currently only working for CUDA targets.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultGPUSchedule();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index a6e0cf06cb..69d3119531 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1055,3 +1055,21 @@ def InstallDebugSpans():
         The result pass
     """
     return _ffi_api.InstallDebugSpans()  # type: ignore
+
+
+def DefaultGPUSchedule():
+    """The pass sets default thread bindings for PrimFuncs, including symbolic 
shape functions,
+    allowing their build and execution on GPU devices. It examines all the 
blocks within the
+    PrimFunc and conducts loop fusion, splitting, and reordering operation 
based on the loop
+    extent and target information, such as the maximum thread block number and 
maximum thread
+    per block.
+
+    The primary objective of this pass is not to optimize performance, but 
rather to generate
+    a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass 
is currently only
+    working for CUDA targets.
+
+    Returns
+    -------
+    ret: tvm.transform.Pass
+    """
+    return _ffi_api.DefaultGPUSchedule()  # type: ignore
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc 
b/src/meta_schedule/space_generator/post_order_apply.cc
index 491af6e28f..1bde99869e 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -21,68 +21,6 @@
 namespace tvm {
 namespace meta_schedule {
 
-/*! \brief Collecting all the blocks */
-class BlockCollector : public tir::StmtVisitor {
- public:
-  static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
-                                     const runtime::PackedFunc f_block_filter 
= nullptr) {  //
-    return BlockCollector(sch, f_block_filter).Run();
-  }
-
- private:
-  /*! \brief Entry point */
-  Array<tir::BlockRV> Run() {
-    std::vector<tir::BlockRV> results;
-    for (const auto& kv : sch_->mod()->functions) {
-      const GlobalVar& gv = kv.first;         // `gv->name_hint` is the name 
of the function
-      const BaseFunc& base_func = kv.second;  // this can be PrimFunc or 
relay::Function
-      if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
-        func_name_ = gv->name_hint;
-        block_names_.clear();
-        blocks_to_collect_.clear();
-        VisitStmt(func->body);
-        for (const String& name : blocks_to_collect_) {
-          results.push_back(sch_->GetBlock(name, func_name_));
-        }
-      }
-    }
-    return results;
-  }
-  /*! \brief Constructor */
-  explicit BlockCollector(const tir::Schedule& sch,
-                          const runtime::PackedFunc f_block_filter = nullptr)
-      : sch_(sch), f_block_filter_(f_block_filter) {}
-  /*! \brief Override the Stmt visiting behaviour */
-  void VisitStmt_(const tir::BlockNode* block) override {
-    tir::StmtVisitor::VisitStmt_(block);
-    CHECK(block_names_.count(block->name_hint) == 0)
-        << "Duplicated block name " << block->name_hint << " in function " << 
func_name_
-        << " not supported!";
-    block_names_.insert(block->name_hint);
-
-    // If filter function is provided, use it to selectively collect blocks.
-    // Otherwise collect all blocks.
-    Bool collect_block = Bool(true);
-    if (f_block_filter_ != nullptr) {
-      collect_block = f_block_filter_(GetRef<tir::Block>(block));
-    }
-    if (collect_block) {
-      blocks_to_collect_.push_back(block->name_hint);
-    }
-  }
-
-  /*! \brief The schedule to be collected */
-  const tir::Schedule& sch_;
-  /*! \brief An optional packed func that allows only certain blocks to be 
collected. */
-  const runtime::PackedFunc f_block_filter_;
-  /*! \brief The set of func name and block name pair */
-  std::unordered_set<String> block_names_;
-  /* \brief The list of blocks to collect in order */
-  Array<String> blocks_to_collect_;
-  /*! \brief Name of the current PrimFunc */
-  String func_name_;
-};
-
 /*!
  * \brief Design Space Generator that generates design spaces by applying 
schedule rules to blocks
  *  in post-DFS order.
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 9a372dde8f..955381b740 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -554,6 +554,68 @@ inline double Sum(const Array<FloatImm>& arr) {
   return sum;
 }
 
+/*! \brief Collecting all the blocks */
+class BlockCollector : public tir::StmtVisitor {
+ public:
+  static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
+                                     const runtime::PackedFunc f_block_filter 
= nullptr) {  //
+    return BlockCollector(sch, f_block_filter).Run();
+  }
+
+ private:
+  /*! \brief Entry point */
+  Array<tir::BlockRV> Run() {
+    std::vector<tir::BlockRV> results;
+    for (const auto& [gv, base_func] : sch_->mod()->functions) {
+      // `gv->name_hint` is the name of the function
+      // `base_func` can be PrimFunc or relay::Function
+      if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+        func_name_ = gv->name_hint;
+        block_names_.clear();
+        blocks_to_collect_.clear();
+        VisitStmt(func->body);
+        for (const String& name : blocks_to_collect_) {
+          results.push_back(sch_->GetBlock(name, func_name_));
+        }
+      }
+    }
+    return results;
+  }
+  /*! \brief Constructor */
+  explicit BlockCollector(const tir::Schedule& sch,
+                          const runtime::PackedFunc f_block_filter = nullptr)
+      : sch_(sch), f_block_filter_(f_block_filter) {}
+  /*! \brief Override the Stmt visiting behaviour */
+  void VisitStmt_(const tir::BlockNode* block) override {
+    tir::StmtVisitor::VisitStmt_(block);
+    CHECK(block_names_.count(block->name_hint) == 0)
+        << "Duplicated block name " << block->name_hint << " in function " << 
func_name_
+        << " not supported!";
+    block_names_.insert(block->name_hint);
+
+    // If filter function is provided, use it to selectively collect blocks.
+    // Otherwise collect all blocks.
+    Bool collect_block = Bool(true);
+    if (f_block_filter_ != nullptr) {
+      collect_block = f_block_filter_(GetRef<tir::Block>(block));
+    }
+    if (collect_block) {
+      blocks_to_collect_.push_back(block->name_hint);
+    }
+  }
+
+  /*! \brief The schedule to be collected */
+  const tir::Schedule& sch_;
+  /*! \brief An optional packed func that allows only certain blocks to be 
collected. */
+  const runtime::PackedFunc f_block_filter_;
+  /*! \brief The set of func name and block name pair */
+  std::unordered_set<String> block_names_;
+  /* \brief The list of blocks to collect in order */
+  Array<String> blocks_to_collect_;
+  /*! \brief Name of the current PrimFunc */
+  String func_name_;
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/src/tir/transforms/default_gpu_schedule.cc 
b/src/tir/transforms/default_gpu_schedule.cc
new file mode 100644
index 0000000000..7877c86d0a
--- /dev/null
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../../meta_schedule/utils.h"
+
+namespace tvm {
+namespace tir {
+namespace transform {
+/*!
+ * \brief A helper function to do default thread binding for a block.
+ * \param sch The schedule to work on.
+ * \param block The block to be scheduled.
+ * \param max_thread_per_block The maximum number of threads per block.
+ * \param max_threadblocks The maximum number of threadblocks.
+ */
+void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t 
max_thread_per_block,
+                int64_t max_threadblocks = 256) {
+  // fetch the loops
+  Array<tir::LoopRV> loops = sch->GetLoops(block);
+  for (const tir::LoopRV& loop : loops) {
+    // skip block if already scheduled
+    if (sch->Get(loop)->thread_binding.defined()) {
+      return;
+    }
+  }
+  Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+  ICHECK_EQ(loops.size(), iters.size());
+  Array<tir::LoopRV> data_parallel_loops;
+  // only fuse data parallel loops
+  for (size_t i = 0; i < loops.size(); ++i) {
+    if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+      data_parallel_loops.push_back(loops[i]);
+    }
+  }
+  // skip if no data parallel loops
+  if (data_parallel_loops.size() == 0) {
+    return;
+  }
+  // fuse all data parallel loops
+  tir::LoopRV fused = sch->Fuse(data_parallel_loops, 
/*preserve_unit_iters=*/false);
+  int64_t product = std::numeric_limits<int64_t>::max();
+  if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+    product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+  }
+  // schedule the fused loop
+  if (product > max_thread_per_block * max_threadblocks) {
+    Array<tir::LoopRV> splits =
+        sch->Split(fused,
+                   /*factors=*/{NullOpt, Integer(max_threadblocks), 
Integer(max_thread_per_block)});
+    sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]});
+    sch->Bind(splits[1], "blockIdx.x");
+    sch->Bind(splits[2], "threadIdx.x");
+  } else {
+    Array<tir::LoopRV> splits =
+        sch->Split(fused, /*factors=*/{NullOpt, Integer(std::min(product, 
max_thread_per_block))});
+    sch->Bind(splits[0], "blockIdx.x");
+    sch->Bind(splits[1], "threadIdx.x");
+  }
+}
+
+Pass DefaultGPUSchedule() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) {
+        // get the target from context.
+        tvm::Target target = tvm::Target::Current();
+        ICHECK(target.defined()) << "Target is not set in current context";
+        // skip non-cuda targets.
+        if (target->kind->name != "cuda") {
+          return m;
+        }
+        // get the max thread per block from target.
+        Optional<Integer> opt_max_thread_per_block = 
target->GetAttr<Integer>("max_num_threads");
+        ICHECK(opt_max_thread_per_block.defined())
+            << "max_num_threads is not set for target " << target;
+        int64_t max_thread_per_block = 
opt_max_thread_per_block.value().IntValue();
+        tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1, 
/*debug_mask=*/0,
+                                                  
tir::ScheduleErrorRenderLevel::kDetail);
+        for (const auto& [gv, func] : m->functions) {
+          if (func->IsInstance<tir::PrimFuncNode>()) {
+            sch->WorkOn(gv->name_hint);
+            Array<tir::BlockRV> blocks = 
meta_schedule::BlockCollector::Collect(sch);
+            for (const tir::BlockRV& block : blocks) {
+              ThreadBind(sch, block, max_thread_per_block);
+            }
+          }
+        }
+        return sch->mod();
+      };
+  return CreateModulePass(/*pass_function=*/pass_func,         //
+                          /*opt_level=*/0,                     //
+                          /*pass_name=*/"DefaultGPUSchedule",  //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py 
b/tests/python/unittest/test_transform_default_gpu_schedule.py
new file mode 100644
index 0000000000..644a9aede0
--- /dev/null
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -0,0 +1,417 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,,missing-function-docstring
+import tvm
+from tvm.tir.transform import DefaultGPUSchedule
+from tvm.script import tir as T
+import tvm.testing
+
+
+def test_broadcast_to_symbolic():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def broadcast_to(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
+            var_T_broadcast_to: T.handle,
+        ):
+            T.func_attr({"tir.noalias": True})
+            x_0 = T.int64()
+            x_1 = T.int64()
+            T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(x_0, x_1):
+                with T.block("T_broadcast_to"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, T.int64(0)])
+                    T.writes(T_broadcast_to[v_ax0, v_ax1])
+                    T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, 
T.int64(0)]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def broadcast_to(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
+            var_T_broadcast_to: T.handle,
+        ):
+            T.func_attr({"tir.noalias": True})
+            x_0 = T.int64()
+            x_1 = T.int64()
+            T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
+            # with T.block("root"):
+            for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), 
thread="blockIdx.x"):
+                for ax0_ax1_fused_2 in T.thread_binding(
+                    T.int64(1024), thread="threadIdx.x"
+                ):
+                    for ax0_ax1_fused_0 in range(
+                        (x_0 * x_1 + T.int64(262143)) // T.int64(262144)
+                    ):
+                        with T.block("T_broadcast_to"):
+                            v_ax0 = T.axis.spatial(
+                                x_0,
+                                (
+                                    (ax0_ax1_fused_0 * T.int64(256) + 
ax0_ax1_fused_1)
+                                    * T.int64(1024)
+                                    + ax0_ax1_fused_2
+                                )
+                                // x_1,
+                            )
+                            v_ax1 = T.axis.spatial(
+                                x_1,
+                                (
+                                    (ax0_ax1_fused_0 * T.int64(256) + 
ax0_ax1_fused_1)
+                                    * T.int64(1024)
+                                    + ax0_ax1_fused_2
+                                )
+                                % x_1,
+                            )
+                            T.where(
+                                (ax0_ax1_fused_0 * T.int64(256) + 
ax0_ax1_fused_1)
+                                * T.int64(1024)
+                                + ax0_ax1_fused_2
+                                < x_0 * x_1
+                            )
+                            T.reads(rxplaceholder[v_ax0, T.int64(0)])
+                            T.writes(T_broadcast_to[v_ax0, v_ax1])
+                            T_broadcast_to[v_ax0, v_ax1] = 
rxplaceholder[v_ax0, T.int64(0)]
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=3):
+        After = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_matmul():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def matmul(
+            A: T.Buffer((32, 32), "float16"),
+            B: T.Buffer((32, 32), "float16"),
+            C: T.Buffer((32, 32), "float16"),
+        ):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            # with T.block("root"):
+            for i, j, k in T.grid(32, 32, 32):
+                with T.block("C"):
+                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                    T.reads(A[v_i, v_k], B[v_k, v_j])
+                    T.writes(C[v_i, v_j])
+                    with T.init():
+                        C[v_i, v_j] = T.float16(0)
+                    C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def matmul(
+            A: T.Buffer((32, 32), "float16"),
+            B: T.Buffer((32, 32), "float16"),
+            C: T.Buffer((32, 32), "float16"),
+        ):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            # with T.block("root"):
+            for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+                for i_j_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
+                    for k in range(32):
+                        with T.block("C"):
+                            v_i = T.axis.spatial(
+                                32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32
+                            )
+                            v_j = T.axis.spatial(
+                                32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32
+                            )
+                            v_k = T.axis.reduce(32, k)
+                            T.reads(A[v_i, v_k], B[v_k, v_j])
+                            T.writes(C[v_i, v_j])
+                            with T.init():
+                                C[v_i, v_j] = T.float16(0)
+                            C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, 
v_j]
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=3):
+        After = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_add():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), 
T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), 
T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), 
T.int64(3)):
+                with T.block("T_add"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], 
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_add[ax0, ax1, ax2, ax3])
+                    T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, 
ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def add(
+            rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), 
"float32"),
+            rxplaceholder_1: T.Buffer(
+                (T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"
+            ),
+            T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_i2_i3_fused_1 in T.thread_binding(
+                    T.int64(72), thread="threadIdx.x"
+                ):
+                    with T.block("T_add"):
+                        ax0 = T.axis.spatial(
+                            T.int64(4),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            // T.int64(18),
+                        )
+                        ax1 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            % T.int64(18)
+                            // T.int64(6),
+                        )
+                        ax2 = T.axis.spatial(
+                            T.int64(2),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            % T.int64(6)
+                            // T.int64(3),
+                        )
+                        ax3 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            % T.int64(3),
+                        )
+                        T.reads(
+                            rxplaceholder[T.int64(0), ax2, ax3],
+                            rxplaceholder_1[ax0, ax1, ax2, T.int64(0)],
+                        )
+                        T.writes(T_add[ax0, ax1, ax2, ax3])
+                        T_add[ax0, ax1, ax2, ax3] = (
+                            rxplaceholder[T.int64(0), ax2, ax3]
+                            + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+                        )
+
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=3):
+        After = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_full():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def full(rxplaceholder: T.Buffer((), "int32"), T_full: 
T.Buffer((T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_full"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[()])
+                    T.writes(T_full[ax0, ax1])
+                    T_full[ax0, ax1] = rxplaceholder[()]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def full(
+            rxplaceholder: T.Buffer((), "int32"),
+            T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(6), 
thread="threadIdx.x"):
+                    with T.block("T_full"):
+                        ax0 = T.axis.spatial(
+                            T.int64(2),
+                            (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // 
T.int64(3),
+                        )
+                        ax1 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % 
T.int64(3),
+                        )
+                        T.reads(rxplaceholder[()])
+                        T.writes(T_full[ax0, ax1])
+                        T_full[ax0, ax1] = rxplaceholder[()]
+
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=3):
+        After = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_scheduled():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+
+    @tvm.script.ir_module
+    class Scheduled:
+        @T.prim_func
+        def full(
+            rxplaceholder: T.Buffer((), "int32"),
+            T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(6), 
thread="threadIdx.x"):
+                    with T.block("T_full"):
+                        ax0 = T.axis.spatial(
+                            T.int64(2),
+                            (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // 
T.int64(3),
+                        )
+                        ax1 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % 
T.int64(3),
+                        )
+                        T.reads(rxplaceholder[()])
+                        T.writes(T_full[ax0, ax1])
+                        T_full[ax0, ax1] = rxplaceholder[()]
+
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=3):
+        # should do nothing
+        After = DefaultGPUSchedule()(Scheduled)
+    tvm.ir.assert_structural_equal(After, Scheduled)
+
+
+def test_multiple():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), 
T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), 
T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), 
T.int64(3)):
+                with T.block("T_add"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], 
rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_add[ax0, ax1, ax2, ax3])
+                    T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, 
ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+
+        @T.prim_func
+        def full(rxplaceholder: T.Buffer((), "int32"), T_full: 
T.Buffer((T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_full"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[()])
+                    T.writes(T_full[ax0, ax1])
+                    T_full[ax0, ax1] = rxplaceholder[()]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def add(
+            rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), 
"float32"),
+            rxplaceholder_1: T.Buffer(
+                (T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"
+            ),
+            T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_i2_i3_fused_1 in T.thread_binding(
+                    T.int64(72), thread="threadIdx.x"
+                ):
+                    with T.block("T_add"):
+                        ax0 = T.axis.spatial(
+                            T.int64(4),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            // T.int64(18),
+                        )
+                        ax1 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            % T.int64(18)
+                            // T.int64(6),
+                        )
+                        ax2 = T.axis.spatial(
+                            T.int64(2),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            % T.int64(6)
+                            // T.int64(3),
+                        )
+                        ax3 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_i2_i3_fused_0 * T.int64(72) + 
i0_i1_i2_i3_fused_1)
+                            % T.int64(3),
+                        )
+                        T.reads(
+                            rxplaceholder[T.int64(0), ax2, ax3],
+                            rxplaceholder_1[ax0, ax1, ax2, T.int64(0)],
+                        )
+                        T.writes(T_add[ax0, ax1, ax2, ax3])
+                        T_add[ax0, ax1, ax2, ax3] = (
+                            rxplaceholder[T.int64(0), ax2, ax3]
+                            + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+                        )
+
+        @T.prim_func
+        def full(
+            rxplaceholder: T.Buffer((), "int32"),
+            T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(6), 
thread="threadIdx.x"):
+                    with T.block("T_full"):
+                        ax0 = T.axis.spatial(
+                            T.int64(2),
+                            (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // 
T.int64(3),
+                        )
+                        ax1 = T.axis.spatial(
+                            T.int64(3),
+                            (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) % 
T.int64(3),
+                        )
+                        T.reads(rxplaceholder[()])
+                        T.writes(T_full[ax0, ax1])
+                        T_full[ax0, ax1] = rxplaceholder[()]
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=3):
+        After = DefaultGPUSchedule()(Before)
+    assert tvm.ir.structural_equal(After, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to