This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c279b94  [TE] Light refactoring of TE -> TIR paths. (#9263)
c279b94 is described below

commit c279b94a659837ca90b83ee1f5b8d500a4f6d5ac
Author: Lunderberg <[email protected]>
AuthorDate: Fri Oct 15 23:01:01 2021 -0500

    [TE] Light refactoring of TE -> TIR paths. (#9263)
    
    * [TE] Light refactoring of TE -> TIR paths.
    
    - Added ScheduleToPrimFunc, extracting out common behavior in
      ScheduleToModule and auto_scheduler's feature extraction.
    
    - Added `tvm.driver.build_module.schedule_to_module`, to avoid needing
      to 4-line boilerplate needed to do so.  Also makes deviations from
      the usual path (e.g. `debug_keep_trivial_loop`) much more explicit.
    
    * Removed schedule_to_primfunc, replaced usage with schedule_to_module.
    
    * Returned C++ function ScheduleToPrimfunc to be inside ScheduleToModule.
---
 python/tvm/autotvm/feature.py                      | 10 ++---
 python/tvm/driver/build_module.py                  |  5 +++
 .../relay/backend/contrib/ethosu/tir/compiler.py   | 15 +++----
 src/auto_scheduler/feature.cc                      | 38 +++++------------
 src/driver/driver_api.cc                           | 19 ++++-----
 tests/python/integration/test_reduce.py            |  8 ++--
 tests/python/unittest/test_te_schedule_ops.py      | 26 +++++-------
 .../test_tir_transform_inject_copy_intrin.py       | 18 ++-------
 .../unittest/test_tir_transform_make_packed_api.py |  9 ++---
 ...form_merge_dynamic_shared_memory_allocations.py | 13 +++---
 .../unittest/test_tir_transform_narrow_datatype.py |  9 ++---
 .../unittest/test_tir_transform_storage_flatten.py | 13 ++----
 .../unittest/test_tir_transform_storage_rewrite.py | 47 +++++-----------------
 13 files changed, 75 insertions(+), 155 deletions(-)

diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py
index 8d2591d..f73c65f 100644
--- a/python/tvm/autotvm/feature.py
+++ b/python/tvm/autotvm/feature.py
@@ -31,7 +31,6 @@ import numpy as np
 import tvm._ffi
 
 from tvm.target import Target
-from tvm.te import schedule
 from tvm.driver import build_module
 
 
@@ -39,13 +38,12 @@ def ana_lower(sch, args, binds=None, simple_mode=True):
     """Do lower while keeping all axes in IR
     i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or 
inject virtual threads
     """
-    binds, _ = build_module.get_binds(args, compact=False, binds=binds)
     sch = sch.normalize()
     # Phase 0
-    bounds = schedule.InferBound(sch)
-    stmt = schedule.ScheduleOps(sch, bounds, True)
-    func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
-    mod = tvm.IRModule.from_expr(func._move())
+    context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop": 
True})
+    with context:
+        mod = build_module.schedule_to_module(sch, args, binds=binds)
+
     mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
     mod = tvm.tir.transform.Simplify()(mod._move())
     assert simple_mode
diff --git a/python/tvm/driver/build_module.py 
b/python/tvm/driver/build_module.py
index 429b3e1..29fff77 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -67,6 +67,11 @@ def schedule_to_module(
     binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
 ) -> IRModule:
     """According to the given schedule, form a function.
+
+    This is a low-level function intended for testing purposes, and
+    does not apply any optimization passes.  In general, `tvm.lower`
+    and `tvm.build` should be used instead.
+
     Parameters
     ----------
     sch : tvm.te.schedule.Schedule
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py 
b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index 3283e05..c792ade 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -19,7 +19,7 @@
 import tvm
 from tvm import relay
 from tvm.relay.expr_functor import ExprMutator
-from tvm.driver.build_module import get_binds
+from tvm.driver.build_module import schedule_to_module
 
 from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
 from .scheduler import schedule
@@ -64,22 +64,17 @@ def lower_ethosu(sch, args, const_dict, name="main"):
             "no_unroll_loop_with_extent_one": True,
         },
         "tir.UnrollLoop": {"auto_max_depth": -1},
+        "tir.noalias": True,
+        "tir.debug_keep_trivial_loop": True,
     }
     # Merge two configs
     curr_cfg = {**curr_cfg, **tir_compiler_cfg}
 
     sch = sch.normalize()
-    bounds = tvm.te.schedule.InferBound(sch)
-    stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True)
 
-    compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
-    binds, arg_list = get_binds(args, compact, None)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
-
-    func = func.with_attr("global_symbol", name)
-    func = func.with_attr("tir.noalias", True)
-    mod = tvm.IRModule({name: func})
     with tvm.transform.PassContext(config=curr_cfg):
+        mod = schedule_to_module(sch, args, name)
+
         mod = tvm.tir.transform.Simplify()(mod)
         mod = tvm.tir.transform.StorageFlatten(64)(mod)
         mod = tvm.tir.transform.UnrollLoop()(mod)
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index be78bc4..aaf7d48 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -26,6 +26,7 @@
 #include <tvm/auto_scheduler/feature.h>
 #include <tvm/auto_scheduler/measure.h>
 #include <tvm/auto_scheduler/measure_record.h>
+#include <tvm/driver/driver_api.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/support/parallel_for.h>
 #include <tvm/te/operation.h>
@@ -45,13 +46,6 @@
 #include "utils.h"
 
 namespace tvm {
-// import the function from driver_api.cc
-void GetBinds(const Array<te::Tensor>& args, bool compact,
-              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* 
out_arg_list);
-}  // namespace tvm
-
-namespace tvm {
 namespace auto_scheduler {
 
 using namespace tvm::tir;
@@ -1268,35 +1262,25 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& 
task, const State& state, i
   Array<te::Tensor> tensors;
 
   std::tie(sch, tensors) = 
task->compute_dag.ApplySteps(state->transform_steps);
+
+  // When inlining, replace const matrices with const values.
+  // Produces wrong IR, but good enough for feature extraction, and
+  // can improve the speed of feature extraction/search.  Must be
+  // called before ScheduleToModule to have an effect.
   sch = sch.normalize_for_feature_extraction();
-  auto bounds = te::InferBound(sch);
 
   try {
-    auto stmt = te::ScheduleOps(sch, bounds, false);
-    Map<te::Tensor, te::Buffer> out_binds;
-    Array<ObjectRef> out_arg_list;
-    bool compact = te::VerifyCompactBuffer(stmt);
     const std::string& name = "main";
-    GlobalVar global_var(name);
-
-    // Copied from driver_api.cc::lower
     auto pass_ctx = tvm::transform::PassContext::Current();
-    GetBinds(tensors, compact, std::unordered_map<te::Tensor, te::Buffer>(), 
&out_binds,
-             &out_arg_list);
-    tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-    f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
 
-    bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", 
Bool(true)).value();
+    auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), 
tensors.end()}, name,
+                                std::unordered_map<te::Tensor, te::Buffer>());
+
     bool disable_vectorize =
         pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
     bool instrument_bound_checkers =
         pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-    if (noalias) {
-      f = WithAttr(std::move(f), "tir.noalias", Bool(true));
-    }
-    auto mod = IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
-
     if (IsGPUTask(task)) {
       auto pass_list = Array<tvm::transform::Pass>();
       // Phase 0
@@ -1323,9 +1307,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& 
task, const State& state, i
     const auto& optimize =
         
tir::transform::Sequential(Array<tvm::transform::Pass>{tir::transform::Simplify()});
     mod = optimize(std::move(mod));
-    const auto& it = mod->functions.find(global_var);
-    ICHECK(it != mod->functions.end());
-    const auto& prim_func = (*it).second.as<PrimFuncNode>();
+    PrimFunc prim_func = Downcast<PrimFunc>(mod->Lookup(name));
     GetPerStoreFeature(prim_func->body, 
task->hardware_params->cache_line_bytes, max_n_bufs,
                        feature);
   } catch (Error& e) {
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e659421..2d57d6e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
 
 using runtime::PackedFunc;
 using runtime::TVMArgs;
@@ -287,24 +288,24 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential 
seq) {
   return mod;
 }
 
+// Convert te schedule to IRModule
 IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
                           const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
-  // Convert te schedule to IRModule
-  Array<ObjectRef> out_arg_list;
-  transform::PassContext pass_ctx = transform::PassContext::Current();
-
   sch = sch.normalize();
 
+  transform::PassContext pass_ctx = transform::PassContext::Current();
+  bool debug_keep_trivial_loop =
+      pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop", 
Bool(false)).value();
+
   // Before TIR transformation.
-  Map<tir::IterVar, Range> bounds = te::InferBound(sch);
-  tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false);
+  tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), 
debug_keep_trivial_loop);
   bool compact = te::VerifyCompactBuffer(stmt);
 
   Map<te::Tensor, tir::Buffer> out_binds;
+  Array<ObjectRef> out_arg_list;
   GetBinds(args, compact, binds, &out_binds, &out_arg_list);
 
-  // Build the function
-  // At this point binds is only te::Tensors
+  // Build the function, converting from te::Tensor to tir::Buffer
   tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
   f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
 
@@ -325,7 +326,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
                        const Map<te::Tensor, tir::Buffer>& binds) {
       std::unordered_map<te::Tensor, tir::Buffer> c_binds;
       // Check to make sure binds is not null before doing the conversion;
-      if (binds.get() != nullptr) {
+      if (binds.defined()) {
         for (auto kv : binds) {
           c_binds.insert({kv.first, kv.second});
         }
diff --git a/tests/python/integration/test_reduce.py 
b/tests/python/integration/test_reduce.py
index ca09773..a40164d 100644
--- a/tests/python/integration/test_reduce.py
+++ b/tests/python/integration/test_reduce.py
@@ -15,10 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+import numpy as np
 
 import tvm
 from tvm import te, topi
-import numpy as np
+from tvm.driver.build_module import schedule_to_module
 import tvm.testing
 import tvm.topi.testing
 
@@ -532,10 +533,7 @@ def test_reduce_storage_reuse():
     target = tvm.target.Target("cuda")
 
     def run_passes(sch, args):
-        bounds = tvm.te.schedule.InferBound(sch)
-        stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
-        func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
-        mod = tvm.IRModule.from_expr(func)
+        mod = schedule_to_module(sch, args)
         mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
target))(mod)
         return tvm.transform.Sequential(
             [
diff --git a/tests/python/unittest/test_te_schedule_ops.py 
b/tests/python/unittest/test_te_schedule_ops.py
index bc4bc4f..ca3ab3a 100644
--- a/tests/python/unittest/test_te_schedule_ops.py
+++ b/tests/python/unittest/test_te_schedule_ops.py
@@ -14,9 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
+
 import tvm
 from tvm import te
-import numpy as np
+from tvm.driver.build_module import schedule_to_module
 
 
 def test_schedule0():
@@ -26,11 +28,8 @@ def test_schedule0():
     A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
     s = te.create_schedule(A1.op)
 
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None)
-    assert isinstance(func, tvm.tir.PrimFunc)
+    mod = schedule_to_module(s, [A, A1])
+    assert isinstance(mod["main"], tvm.tir.PrimFunc)
 
 
 def test_schedule1():
@@ -42,12 +41,9 @@ def test_schedule1():
     s = te.create_schedule(A1.op)
     xo, xi = s[A1].split(A1.op.axis[0], 8)
     s[A1].pragma(xo, "auto_unroll_max_step", 10)
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None)
-    assert isinstance(func, tvm.tir.PrimFunc)
+    mod = schedule_to_module(s, [A, A1])
+    assert isinstance(mod["main"], tvm.tir.PrimFunc)
 
 
 def test_schedule2():
@@ -60,11 +56,9 @@ def test_schedule2():
     s = te.create_schedule(A2.op)
     xo, xi = s[A2].split(A2.op.axis[0], 8)
     s[A1].compute_at(s[A2], xo)
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
-    assert isinstance(func, tvm.tir.PrimFunc)
+
+    mod = schedule_to_module(s, [A, A2])
+    assert isinstance(mod["main"], tvm.tir.PrimFunc)
 
 
 def test_schedule_scan():
diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py 
b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
index 86bf87d..aa0448c 100644
--- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
+++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
@@ -17,6 +17,7 @@
 import tvm
 import tvm.testing
 from tvm import te
+from tvm.driver.build_module import schedule_to_module
 
 
 def test_copy2d():
@@ -53,11 +54,7 @@ def test_copy_pad():
     )
     s = te.create_schedule(B.op)
     s[B].pragma(B.op.axis[0], "memcpy")
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     def cb(src, dst, pad_before, pad_after, pad_value):
@@ -77,11 +74,7 @@ def test_single_point_test():
     B = te.compute((1,), lambda i: A[i], name="B")
     s = te.create_schedule(B.op)
     s[B].pragma(B.op.axis[0], "memcpy")
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     def cb(src, dst, pad_before, pad_after, pad_value):
@@ -105,11 +98,8 @@ def test_copy_pad_split():
     xo, xi = s[B].split(B.op.axis[0], factor=4)
     s[Apad].compute_at(s[B], xo)
     s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
     mod = tvm.tir.transform.Simplify()(mod._move())
 
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py 
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index 15f9940..1ab6bda 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -14,9 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy
+
 import tvm
 from tvm import te
-import numpy
+from tvm.driver.build_module import schedule_to_module
 
 
 def test_makeapi():
@@ -27,10 +29,7 @@ def test_makeapi():
     C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
     s = te.create_schedule(C.op)
 
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [n, A, B, C])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
     mod = tvm.tir.transform.Apply(
         lambda f: f.with_attr(
diff --git 
a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
 
b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
index 9c511f1..cc78b84 100644
--- 
a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++ 
b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -14,20 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import tvm
-from tvm import te
 import numpy as np
+
+import tvm
 import tvm.testing
+from tvm import te
+from tvm.driver.build_module import schedule_to_module
 from tvm.topi.math import cast
 
 
 def run_passes(sch, args):
-    bounds = tvm.te.schedule.InferBound(sch)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(sch, args)
     return tvm.transform.Sequential(
         [
             tvm.tir.transform.StorageFlatten(64),
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py 
b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index cb8968c..b5620d7 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
-from tvm import relay
+from tvm import te, relay
+from tvm.driver.build_module import schedule_to_module
 from tvm.tir import const
 
 
@@ -39,11 +39,8 @@ def lower_sch(sch, args, target_bits):
         else:
             raise ValueError("args must be Tensor, Buffer or Var")
     sch = sch.normalize()
-    bounds = te.schedule.InferBound(sch)
-    stmt = te.schedule.ScheduleOps(sch, bounds)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(sch, args)
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
     return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
 
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py 
b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 3722349..a51e926 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -16,6 +16,7 @@
 # under the License.
 import tvm
 from tvm import te
+from tvm.driver.build_module import schedule_to_module
 from tvm.script import tir as T
 from tvm.relay import GlobalVar
 
@@ -30,14 +31,10 @@ def test_flatten2():
     s = te.create_schedule(A2.op)
     xo, xi = s[A2].split(A2.op.axis[0], 8)
     s[A1].compute_at(s[A2], xo)
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A")
     A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2")
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab, 
A2: A2b})
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b})
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
 
@@ -70,12 +67,8 @@ def test_flatten_storage_align():
 
     s = te.create_schedule(A2.op)
     s[A1].storage_align(A1.op.axis[0], 2, 1)
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, A2])
     mod = tvm.transform.Sequential(
         [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()]
     )(mod)
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py 
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index 9e738b1..5a91788 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -16,6 +16,7 @@
 # under the License.
 import tvm
 from tvm import te
+from tvm.driver.build_module import schedule_to_module
 
 
 def test_storage_share():
@@ -28,12 +29,7 @@ def test_storage_share():
         B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t)
 
     s = te.create_schedule(B.op)
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     mod = tvm.tir.transform.Simplify()(mod)
@@ -169,12 +165,7 @@ def test_inplace_rule():
     AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA")
     B = te.compute((m,), lambda i: AA[i] + 1, name="B")
     s = te.create_schedule(B.op)
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     mod = tvm.tir.transform.Simplify()(mod)
@@ -206,11 +197,8 @@ def test_storage_combine():
     s = te.create_schedule(B.op)
     for S in stages[:-1]:
         s[S].set_scope("global:tag")
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+
+    mod = schedule_to_module(s, [A, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     mod = tvm.tir.transform.Simplify()(mod)
@@ -238,10 +226,7 @@ def test_storage_combine_with_vectorization():
     BB = s.cache_read(B, "global:tag", readers=[C])
     CC = s.cache_write(C, "global:tag")
     s[CC].vectorize(s[CC].op.axis[0])
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B, C])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
     mod = tvm.tir.transform.VectorizeLoop()(mod)
     mod = tvm.tir.transform.StorageRewrite()(mod)
@@ -285,11 +270,7 @@ def test_storage_share_gpu():
         s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx)
         s[A[2 * t + 1]].set_scope("shared")
 
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, 
None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A[0], A[-1]])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
     mod = tvm.tir.transform.Simplify()(mod)
     mod = tvm.tir.transform.StorageRewrite()(mod)
@@ -418,12 +399,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 
* 1024 * 1024):
     A0L = s.cache_read(A0, scope_tb, [A2])
     A1L = s.cache_read(A1, scope_tb, [A2])
     A2L = s.cache_read(A2, scope_tb, [B])
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [A, B, C, D])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     mod = tvm.tir.transform.Simplify()(mod)
@@ -511,12 +487,7 @@ def test_inplace_rule3():
     s[B10].compute_inline()
 
     s = s.normalize()
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
-    func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5, 
B], stmt, None)
-    mod = tvm.IRModule.from_expr(func)
+    mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B])
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
     mod = tvm.tir.transform.Simplify()(mod)

Reply via email to