This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c279b94 [TE] Light refactoring of TE -> TIR paths. (#9263)
c279b94 is described below
commit c279b94a659837ca90b83ee1f5b8d500a4f6d5ac
Author: Lunderberg <[email protected]>
AuthorDate: Fri Oct 15 23:01:01 2021 -0500
[TE] Light refactoring of TE -> TIR paths. (#9263)
* [TE] Light refactoring of TE -> TIR paths.
- Added ScheduleToPrimFunc, extracting out common behavior in
ScheduleToModule and auto_scheduler's feature extraction.
- Added `tvm.driver.build_module.schedule_to_module`, to avoid needing
to 4-line boilerplate needed to do so. Also makes deviations from
the usual path (e.g. `debug_keep_trivial_loop`) much more explicit.
* Removed schedule_to_primfunc, replaced usage with schedule_to_module.
* Returned C++ function ScheduleToPrimfunc to be inside ScheduleToModule.
---
python/tvm/autotvm/feature.py | 10 ++---
python/tvm/driver/build_module.py | 5 +++
.../relay/backend/contrib/ethosu/tir/compiler.py | 15 +++----
src/auto_scheduler/feature.cc | 38 +++++------------
src/driver/driver_api.cc | 19 ++++-----
tests/python/integration/test_reduce.py | 8 ++--
tests/python/unittest/test_te_schedule_ops.py | 26 +++++-------
.../test_tir_transform_inject_copy_intrin.py | 18 ++-------
.../unittest/test_tir_transform_make_packed_api.py | 9 ++---
...form_merge_dynamic_shared_memory_allocations.py | 13 +++---
.../unittest/test_tir_transform_narrow_datatype.py | 9 ++---
.../unittest/test_tir_transform_storage_flatten.py | 13 ++----
.../unittest/test_tir_transform_storage_rewrite.py | 47 +++++-----------------
13 files changed, 75 insertions(+), 155 deletions(-)
diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py
index 8d2591d..f73c65f 100644
--- a/python/tvm/autotvm/feature.py
+++ b/python/tvm/autotvm/feature.py
@@ -31,7 +31,6 @@ import numpy as np
import tvm._ffi
from tvm.target import Target
-from tvm.te import schedule
from tvm.driver import build_module
@@ -39,13 +38,12 @@ def ana_lower(sch, args, binds=None, simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or
inject virtual threads
"""
- binds, _ = build_module.get_binds(args, compact=False, binds=binds)
sch = sch.normalize()
# Phase 0
- bounds = schedule.InferBound(sch)
- stmt = schedule.ScheduleOps(sch, bounds, True)
- func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
- mod = tvm.IRModule.from_expr(func._move())
+ context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop":
True})
+ with context:
+ mod = build_module.schedule_to_module(sch, args, binds=binds)
+
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = tvm.tir.transform.Simplify()(mod._move())
assert simple_mode
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 429b3e1..29fff77 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -67,6 +67,11 @@ def schedule_to_module(
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
) -> IRModule:
"""According to the given schedule, form a function.
+
+ This is a low-level function intended for testing purposes, and
+ does not apply any optimization passes. In general, `tvm.lower`
+ and `tvm.build` should be used instead.
+
Parameters
----------
sch : tvm.te.schedule.Schedule
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index 3283e05..c792ade 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -19,7 +19,7 @@
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
-from tvm.driver.build_module import get_binds
+from tvm.driver.build_module import schedule_to_module
from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
from .scheduler import schedule
@@ -64,22 +64,17 @@ def lower_ethosu(sch, args, const_dict, name="main"):
"no_unroll_loop_with_extent_one": True,
},
"tir.UnrollLoop": {"auto_max_depth": -1},
+ "tir.noalias": True,
+ "tir.debug_keep_trivial_loop": True,
}
# Merge two configs
curr_cfg = {**curr_cfg, **tir_compiler_cfg}
sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True)
- compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
- binds, arg_list = get_binds(args, compact, None)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
-
- func = func.with_attr("global_symbol", name)
- func = func.with_attr("tir.noalias", True)
- mod = tvm.IRModule({name: func})
with tvm.transform.PassContext(config=curr_cfg):
+ mod = schedule_to_module(sch, args, name)
+
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index be78bc4..aaf7d48 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -26,6 +26,7 @@
#include <tvm/auto_scheduler/feature.h>
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/measure_record.h>
+#include <tvm/driver/driver_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
@@ -45,13 +46,6 @@
#include "utils.h"
namespace tvm {
-// import the function from driver_api.cc
-void GetBinds(const Array<te::Tensor>& args, bool compact,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>*
out_arg_list);
-} // namespace tvm
-
-namespace tvm {
namespace auto_scheduler {
using namespace tvm::tir;
@@ -1268,35 +1262,25 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask&
task, const State& state, i
Array<te::Tensor> tensors;
std::tie(sch, tensors) =
task->compute_dag.ApplySteps(state->transform_steps);
+
+ // When inlining, replace const matrices with const values.
+ // Produces wrong IR, but good enough for feature extraction, and
+ // can improve the speed of feature extraction/search. Must be
+ // called before ScheduleToModule to have an effect.
sch = sch.normalize_for_feature_extraction();
- auto bounds = te::InferBound(sch);
try {
- auto stmt = te::ScheduleOps(sch, bounds, false);
- Map<te::Tensor, te::Buffer> out_binds;
- Array<ObjectRef> out_arg_list;
- bool compact = te::VerifyCompactBuffer(stmt);
const std::string& name = "main";
- GlobalVar global_var(name);
-
- // Copied from driver_api.cc::lower
auto pass_ctx = tvm::transform::PassContext::Current();
- GetBinds(tensors, compact, std::unordered_map<te::Tensor, te::Buffer>(),
&out_binds,
- &out_arg_list);
- tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
- f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
- bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias",
Bool(true)).value();
+ auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(),
tensors.end()}, name,
+ std::unordered_map<te::Tensor, te::Buffer>());
+
bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers",
Bool(false)).value();
- if (noalias) {
- f = WithAttr(std::move(f), "tir.noalias", Bool(true));
- }
- auto mod = IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
-
if (IsGPUTask(task)) {
auto pass_list = Array<tvm::transform::Pass>();
// Phase 0
@@ -1323,9 +1307,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask&
task, const State& state, i
const auto& optimize =
tir::transform::Sequential(Array<tvm::transform::Pass>{tir::transform::Simplify()});
mod = optimize(std::move(mod));
- const auto& it = mod->functions.find(global_var);
- ICHECK(it != mod->functions.end());
- const auto& prim_func = (*it).second.as<PrimFuncNode>();
+ PrimFunc prim_func = Downcast<PrimFunc>(mod->Lookup(name));
GetPerStoreFeature(prim_func->body,
task->hardware_params->cache_line_bytes, max_n_bufs,
feature);
} catch (Error& e) {
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e659421..2d57d6e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
using runtime::PackedFunc;
using runtime::TVMArgs;
@@ -287,24 +288,24 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential
seq) {
return mod;
}
+// Convert te schedule to IRModule
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>&
binds) {
- // Convert te schedule to IRModule
- Array<ObjectRef> out_arg_list;
- transform::PassContext pass_ctx = transform::PassContext::Current();
-
sch = sch.normalize();
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool debug_keep_trivial_loop =
+ pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop",
Bool(false)).value();
+
// Before TIR transformation.
- Map<tir::IterVar, Range> bounds = te::InferBound(sch);
- tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false);
+ tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch),
debug_keep_trivial_loop);
bool compact = te::VerifyCompactBuffer(stmt);
Map<te::Tensor, tir::Buffer> out_binds;
+ Array<ObjectRef> out_arg_list;
GetBinds(args, compact, binds, &out_binds, &out_arg_list);
- // Build the function
- // At this point binds is only te::Tensors
+ // Build the function, converting from te::Tensor to tir::Buffer
tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
@@ -325,7 +326,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
const Map<te::Tensor, tir::Buffer>& binds) {
std::unordered_map<te::Tensor, tir::Buffer> c_binds;
// Check to make sure binds is not null before doing the conversion;
- if (binds.get() != nullptr) {
+ if (binds.defined()) {
for (auto kv : binds) {
c_binds.insert({kv.first, kv.second});
}
diff --git a/tests/python/integration/test_reduce.py
b/tests/python/integration/test_reduce.py
index ca09773..a40164d 100644
--- a/tests/python/integration/test_reduce.py
+++ b/tests/python/integration/test_reduce.py
@@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+import numpy as np
import tvm
from tvm import te, topi
-import numpy as np
+from tvm.driver.build_module import schedule_to_module
import tvm.testing
import tvm.topi.testing
@@ -532,10 +533,7 @@ def test_reduce_storage_reuse():
target = tvm.target.Target("cuda")
def run_passes(sch, args):
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(sch, args)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target))(mod)
return tvm.transform.Sequential(
[
diff --git a/tests/python/unittest/test_te_schedule_ops.py
b/tests/python/unittest/test_te_schedule_ops.py
index bc4bc4f..ca3ab3a 100644
--- a/tests/python/unittest/test_te_schedule_ops.py
+++ b/tests/python/unittest/test_te_schedule_ops.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy as np
+
import tvm
from tvm import te
-import numpy as np
+from tvm.driver.build_module import schedule_to_module
def test_schedule0():
@@ -26,11 +28,8 @@ def test_schedule0():
A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
s = te.create_schedule(A1.op)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None)
- assert isinstance(func, tvm.tir.PrimFunc)
+ mod = schedule_to_module(s, [A, A1])
+ assert isinstance(mod["main"], tvm.tir.PrimFunc)
def test_schedule1():
@@ -42,12 +41,9 @@ def test_schedule1():
s = te.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
s[A1].pragma(xo, "auto_unroll_max_step", 10)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None)
- assert isinstance(func, tvm.tir.PrimFunc)
+ mod = schedule_to_module(s, [A, A1])
+ assert isinstance(mod["main"], tvm.tir.PrimFunc)
def test_schedule2():
@@ -60,11 +56,9 @@ def test_schedule2():
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
- assert isinstance(func, tvm.tir.PrimFunc)
+
+ mod = schedule_to_module(s, [A, A2])
+ assert isinstance(mod["main"], tvm.tir.PrimFunc)
def test_schedule_scan():
diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
index 86bf87d..aa0448c 100644
--- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
+++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
@@ -17,6 +17,7 @@
import tvm
import tvm.testing
from tvm import te
+from tvm.driver.build_module import schedule_to_module
def test_copy2d():
@@ -53,11 +54,7 @@ def test_copy_pad():
)
s = te.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
- bounds = tvm.te.schedule.InferBound(s)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
@@ -77,11 +74,7 @@ def test_single_point_test():
B = te.compute((1,), lambda i: A[i], name="B")
s = te.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
- bounds = tvm.te.schedule.InferBound(s)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
@@ -105,11 +98,8 @@ def test_copy_pad_split():
xo, xi = s[B].split(B.op.axis[0], factor=4)
s[Apad].compute_at(s[B], xo)
s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
- bounds = tvm.te.schedule.InferBound(s)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = tvm.tir.transform.Simplify()(mod._move())
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index 15f9940..1ab6bda 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy
+
import tvm
from tvm import te
-import numpy
+from tvm.driver.build_module import schedule_to_module
def test_makeapi():
@@ -27,10 +29,7 @@ def test_makeapi():
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
s = te.create_schedule(C.op)
- bounds = tvm.te.schedule.InferBound(s)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [n, A, B, C])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
diff --git
a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
index 9c511f1..cc78b84 100644
---
a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++
b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -14,20 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import tvm
-from tvm import te
import numpy as np
+
+import tvm
import tvm.testing
+from tvm import te
+from tvm.driver.build_module import schedule_to_module
from tvm.topi.math import cast
def run_passes(sch, args):
- bounds = tvm.te.schedule.InferBound(sch)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(sch, args)
return tvm.transform.Sequential(
[
tvm.tir.transform.StorageFlatten(64),
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py
b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index cb8968c..b5620d7 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
-from tvm import relay
+from tvm import te, relay
+from tvm.driver.build_module import schedule_to_module
from tvm.tir import const
@@ -39,11 +39,8 @@ def lower_sch(sch, args, target_bits):
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
- bounds = te.schedule.InferBound(sch)
- stmt = te.schedule.ScheduleOps(sch, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(sch, args)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py
b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 3722349..a51e926 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm import te
+from tvm.driver.build_module import schedule_to_module
from tvm.script import tir as T
from tvm.relay import GlobalVar
@@ -30,14 +31,10 @@ def test_flatten2():
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A")
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2")
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab,
A2: A2b})
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b})
mod = tvm.tir.transform.StorageFlatten(64)(mod)
@@ -70,12 +67,8 @@ def test_flatten_storage_align():
s = te.create_schedule(A2.op)
s[A1].storage_align(A1.op.axis[0], 2, 1)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, A2])
mod = tvm.transform.Sequential(
[tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()]
)(mod)
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index 9e738b1..5a91788 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm import te
+from tvm.driver.build_module import schedule_to_module
def test_storage_share():
@@ -28,12 +29,7 @@ def test_storage_share():
B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t)
s = te.create_schedule(B.op)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
@@ -169,12 +165,7 @@ def test_inplace_rule():
AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA")
B = te.compute((m,), lambda i: AA[i] + 1, name="B")
s = te.create_schedule(B.op)
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
@@ -206,11 +197,8 @@ def test_storage_combine():
s = te.create_schedule(B.op)
for S in stages[:-1]:
s[S].set_scope("global:tag")
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+
+ mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
@@ -238,10 +226,7 @@ def test_storage_combine_with_vectorization():
BB = s.cache_read(B, "global:tag", readers=[C])
CC = s.cache_write(C, "global:tag")
s[CC].vectorize(s[CC].op.axis[0])
- bounds = tvm.te.schedule.InferBound(s)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B, C])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.VectorizeLoop()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
@@ -285,11 +270,7 @@ def test_storage_share_gpu():
s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx)
s[A[2 * t + 1]].set_scope("shared")
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt,
None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A[0], A[-1]])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
@@ -418,12 +399,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024
* 1024 * 1024):
A0L = s.cache_read(A0, scope_tb, [A2])
A1L = s.cache_read(A1, scope_tb, [A2])
A2L = s.cache_read(A2, scope_tb, [B])
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [A, B, C, D])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
@@ -511,12 +487,7 @@ def test_inplace_rule3():
s[B10].compute_inline()
s = s.normalize()
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5,
B], stmt, None)
- mod = tvm.IRModule.from_expr(func)
+ mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)