This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 648a29a53a [MetaSchedule] Introduce `ScheduleFnDatabase` (#12626)
648a29a53a is described below
commit 648a29a53a641f1e923220600dce9c9215104879
Author: Junru Shao <[email protected]>
AuthorDate: Mon Aug 29 00:34:11 2022 -0700
[MetaSchedule] Introduce `ScheduleFnDatabase` (#12626)
Following #12520, this PR introduces `ScheduleFnDatabase`, a mocked
database to allow injecting handcrafted schedules provided by a schedule
function.
The schedule function comes with the following signature:
```python
def schedule_fn(
sch: tir.Schedule,
) -> bool:
task_name = sch.mod.attrs["task_name"]
# ^^^ provides an optional name of the task queried
...
```
This mocked database helps incorporate the existing testing utility
`apply_fixed_schedule` more formally into the MetaSchedule-Relay build
pipeline, and allows further extension to Relax with the same interface.
Next as another follow-up, we will introduce ConcatDatabase that allows
mixing multiple databases, including the mocked and ones from JSON
files.
---
include/tvm/meta_schedule/database.h | 19 +++-
python/tvm/meta_schedule/database/__init__.py | 1 +
python/tvm/meta_schedule/database/database.py | 41 ++++++--
.../{__init__.py => schedule_fn_database.py} | 29 ++++--
python/tvm/meta_schedule/testing/utils.py | 83 -----------------
src/meta_schedule/database/database.cc | 13 ++-
src/meta_schedule/database/memory_database.cc | 10 +-
src/meta_schedule/database/schedule_fn_database.cc | 103 +++++++++++++++++++++
src/relay/backend/te_compiler_cache.cc | 5 +-
tests/python/unittest/test_link_params.py | 15 ++-
.../unittest/test_meta_schedule_multi_anchor.py | 8 +-
.../test_meta_schedule_relay_tir_compute.py | 18 ++--
.../unittest/test_meta_schedule_tune_relay.py | 7 +-
13 files changed, 210 insertions(+), 142 deletions(-)
diff --git a/include/tvm/meta_schedule/database.h
b/include/tvm/meta_schedule/database.h
index 0e7f45d393..88db2e2277 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -207,23 +207,29 @@ class DatabaseNode : public runtime::Object {
* \brief Query the best record of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
+ * \param workload_name The name of the workload to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
- virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target
target);
+ virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const
Target& target,
+ const String&
workload_name);
/*!
* \brief Query the best schedule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
+ * \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt
if not found.
*/
- virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
+ virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const
Target& target,
+ const String& workload_name);
/*!
* \brief Query the best IRModule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
+ * \param workload_name The name of the workload to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt
if not found.
*/
- virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);
+ virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target&
target,
+ const String& workload_name);
static constexpr const char* _type_key = "meta_schedule.Database";
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
@@ -336,6 +342,13 @@ class Database : public runtime::ObjectRef {
public:
/*! An in-memory database. */
TVM_DLL static Database MemoryDatabase();
+ /*!
+ * \brief A database for injecting handcrafted schedule functions.
+ * \param schedule_fn The function to do scheduling, which takes a TIR
schedule,
+ * and returns a boolean indicating if the schedule is successful.
+ */
+ TVM_DLL static Database ScheduleFnDatabase(
+ runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn);
/*!
* \brief Create a default database that uses JSON file for tuning records.
* \param path_workload The path to the workload table.
diff --git a/python/tvm/meta_schedule/database/__init__.py
b/python/tvm/meta_schedule/database/__init__.py
index 2a87eea147..7726daf6eb 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/database/__init__.py
@@ -21,3 +21,4 @@ The database that stores serialized tuning records and
workloads
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
+from .schedule_fn_database import ScheduleFnDatabase
diff --git a/python/tvm/meta_schedule/database/database.py
b/python/tvm/meta_schedule/database/database.py
index 68283b4554..aa509b7151 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -235,7 +235,12 @@ class Database(Object):
"""
return _ffi_api.DatabaseSize(self) # type: ignore # pylint:
disable=no-member
- def query_tuning_record(self, mod: IRModule, target: Target) ->
Optional[TuningRecord]:
+ def query_tuning_record(
+ self,
+ mod: IRModule,
+ target: Target,
+ workload_name: str,
+ ) -> Optional[TuningRecord]:
"""Query the best record of the given workload from the database.
Parameters
@@ -244,15 +249,22 @@ class Database(Object):
The IRModule to be searched for.
target : Target
The target to be searched for.
+ workload_name : str
+ The name of the workload to be searched for.
Returns
-------
tuning_record : Optional[TuningRecord]
The best record of the given workload; None if not found.
"""
- return _ffi_api.DatabaseQueryTuningRecord(self, mod, target) # type:
ignore # pylint: disable=no-member
+ return _ffi_api.DatabaseQueryTuningRecord(self, mod, target,
workload_name) # type: ignore # pylint: disable=no-member
- def query_schedule(self, mod: IRModule, target: Target) ->
Optional[Schedule]:
+ def query_schedule(
+ self,
+ mod: IRModule,
+ target: Target,
+ workload_name: str,
+ ) -> Optional[Schedule]:
"""Query the best schedule of the given workload from the database.
Parameters
@@ -261,15 +273,22 @@ class Database(Object):
The IRModule to be searched for.
target : Target
The target to be searched for.
+ workload_name : str
+ The name of the workload to be searched for.
Returns
-------
schedule : Optional[Schedule]
The best schedule of the given workload; None if not found.
"""
- return _ffi_api.DatabaseQuerySchedule(self, mod, target) # type:
ignore # pylint: disable=no-member
+ return _ffi_api.DatabaseQuerySchedule(self, mod, target,
workload_name) # type: ignore # pylint: disable=no-member
- def query_ir_module(self, mod: IRModule, target: Target) ->
Optional[IRModule]:
+ def query_ir_module(
+ self,
+ mod: IRModule,
+ target: Target,
+ workload_name: str,
+ ) -> Optional[IRModule]:
"""Query the best IRModule of the given workload from the database.
Parameters
@@ -278,18 +297,22 @@ class Database(Object):
The IRModule to be searched for.
target : Target
The target to be searched for.
+ workload_name : str
+ The name of the workload to be searched for.
Returns
-------
ir_module : Optional[IRModule]
The best IRModule of the given workload; None if not found.
"""
- return _ffi_api.DatabaseQueryIRModule(self, mod, target) # type:
ignore # pylint: disable=no-member
+ return _ffi_api.DatabaseQueryIRModule(self, mod, target,
workload_name) # type: ignore # pylint: disable=no-member
def query(
self,
mod: IRModule,
target: Target,
+ *,
+ workload_name: str = "main",
kind: Union[
Literal["schedule"],
Literal["record"],
@@ -313,11 +336,11 @@ class Database(Object):
The best optimization outcome of the given workload.
"""
if kind == "schedule":
- return self.query_schedule(mod, target)
+ return self.query_schedule(mod, target, workload_name)
if kind == "record":
- return self.query_tuning_record(mod, target)
+ return self.query_tuning_record(mod, target, workload_name)
if kind == "ir_module":
- return self.query_ir_module(mod, target)
+ return self.query_ir_module(mod, target, workload_name)
raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule",
"record", "ir_module"')
def __enter__(self) -> "Database":
diff --git a/python/tvm/meta_schedule/database/__init__.py
b/python/tvm/meta_schedule/database/schedule_fn_database.py
similarity index 55%
copy from python/tvm/meta_schedule/database/__init__.py
copy to python/tvm/meta_schedule/database/schedule_fn_database.py
index 2a87eea147..2918f05799 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/database/schedule_fn_database.py
@@ -14,10 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-The tvm.meta_schedule.database package.
-The database that stores serialized tuning records and workloads
-"""
-from .database import Database, PyDatabase, TuningRecord, Workload
-from .json_database import JSONDatabase
-from .memory_database import MemoryDatabase
+"""A database for injecting handcrafted schedule functions."""
+from typing import Callable
+
+from tvm._ffi import register_object
+from tvm.tir import Schedule
+
+from .. import _ffi_api
+from .database import Database
+
+
+@register_object("meta_schedule.ScheduleFnDatabase")
+class ScheduleFnDatabase(Database):
+ """A database for injecting handcrafted schedule functions."""
+
+ def __init__(
+ self,
+ schedule_fn: Callable[[Schedule], bool],
+ ) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.DatabaseScheduleFnDatabase, # type: ignore # pylint:
disable=no-member
+ schedule_fn,
+ )
diff --git a/python/tvm/meta_schedule/testing/utils.py
b/python/tvm/meta_schedule/testing/utils.py
deleted file mode 100644
index 5919fb47c8..0000000000
--- a/python/tvm/meta_schedule/testing/utils.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Testing utility functions in meta schedule"""
-from typing import Callable, Dict, Optional, Union
-
-from tvm import meta_schedule as ms
-from tvm.ir import IRModule, transform
-from tvm.relay import Function as RelayFunc
-from tvm.runtime import NDArray
-from tvm.target import Target
-from tvm.tir import Schedule
-
-
-def apply_fixed_schedules(
- relay_mod: Union[RelayFunc, IRModule],
- target: Union[str, Target],
- params: Optional[Dict[str, NDArray]],
- schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool],
- tir_converter: str = "default",
-):
- """Apply fixed schedules (manually written, without any tunable knobs) as
specified by
- schedule_fn to extracted tasks, and return a database that can be passed
to compilation.
-
- Parameters
- ----------
- mod : Union[RelayFunc, IRModule]
- The Relay module to apply fixed schedules.
- target : Union[str, Target]
- The target used to extract tasks.
- params : Optional[Dict[str, tvm.runtime.NDArray]]
- The associated parameters of the module.
- schedule_fn : Callable[[ExtractedTask, Schedule], bool]
- A callable that is applied for each extracted task and the
corresponding default schedule.
- Returns True if the given schedule should be committed to the
database, False otherwise.
- tir_converter : str
- The filter function to filter out the extracted tasks. Builtin filters:
- - "default"
- - "allow_extern"
- The converter is a PackedFunc registered as
f"relay.backend.tir_converter.{tir_converter}",
- with the signature below:
- (args: List[te.Tensor], constants: List[NDArray]) ->
Optional[tir.PrimFunc]
-
- Returns
- -------
- database : Database
- The database containing dummy tuning records for manually scheduled
traces.
- """
- target = Target(target) if isinstance(target, str) else target
- config = {"relay.backend.use_meta_schedule": True}
- for k, v in transform.PassContext.current().config.items():
- config[k] = v
-
- extracted_tasks = ms.extract_task_from_relay(
- relay_mod,
- target,
- params,
- tir_converter=tir_converter,
- )
- database = ms.database.MemoryDatabase()
- for task in extracted_tasks:
- mod = ms.default_config.mod(task.dispatched[0])
- sch = Schedule(mod)
-
- if schedule_fn(task, sch):
- workload = database.commit_workload(mod)
- tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0],
target, [])
- database.commit_tuning_record(tune_rec)
-
- return database
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index fedd2aa352..d082ff7a39 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -156,7 +156,8 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef&
json_obj, const Workload& w
/******** Database ********/
-Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target
target) {
+Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod,
const Target& target,
+ const String&
workload_name) {
if (!this->HasWorkload(mod)) {
return NullOpt;
}
@@ -168,8 +169,9 @@ Optional<TuningRecord>
DatabaseNode::QueryTuningRecord(IRModule mod, Target targ
return records[0];
}
-Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target
target) {
- if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod,
target)) {
+Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const
Target& target,
+ const String&
workload_name) {
+ if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target,
workload_name)) {
TuningRecord record = opt_record.value();
tir::Schedule sch =
tir::Schedule::Traced(record->workload->mod, /*seed=*/-1,
/*debug_mask=*/0,
@@ -181,8 +183,9 @@ Optional<tir::Schedule>
DatabaseNode::QuerySchedule(IRModule mod, Target target)
}
}
-Optional<IRModule> DatabaseNode::QueryIRModule(IRModule mod, Target target) {
- if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target)) {
+Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const
Target& target,
+ const String& workload_name) {
+ if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target,
workload_name)) {
return opt_sch.value()->mod();
} else {
return NullOpt;
diff --git a/src/meta_schedule/database/memory_database.cc
b/src/meta_schedule/database/memory_database.cc
index a00d5501ad..b6c6355551 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -44,7 +44,7 @@ class MemoryDatabaseNode : public DatabaseNode {
return false;
}
- Workload CommitWorkload(const IRModule& mod) {
+ Workload CommitWorkload(const IRModule& mod) final {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
return workload;
@@ -55,9 +55,9 @@ class MemoryDatabaseNode : public DatabaseNode {
return workload;
}
- void CommitTuningRecord(const TuningRecord& record) {
records.push_back(record); }
+ void CommitTuningRecord(const TuningRecord& record) final {
records.push_back(record); }
- Array<TuningRecord> GetTopK(const Workload& workload, int top_k) {
+ Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
std::vector<std::pair<double, TuningRecord>> results;
results.reserve(this->records.size());
for (const TuningRecord& record : records) {
@@ -91,9 +91,9 @@ class MemoryDatabaseNode : public DatabaseNode {
return ret;
}
- Array<TuningRecord> GetAllTuningRecords() { return records; }
+ Array<TuningRecord> GetAllTuningRecords() final { return records; }
- int64_t Size() { return records.size(); }
+ int64_t Size() final { return records.size(); }
};
Database Database::MemoryDatabase() {
diff --git a/src/meta_schedule/database/schedule_fn_database.cc
b/src/meta_schedule/database/schedule_fn_database.cc
new file mode 100644
index 0000000000..751721fe52
--- /dev/null
+++ b/src/meta_schedule/database/schedule_fn_database.cc
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class ScheduleFnDatabaseNode : public DatabaseNode {
+ public:
+ runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn;
+
+ void VisitAttrs(AttrVisitor* v) {
+ // `schedule_fn` is not visited.
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode);
+
+ public:
+ Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target&
target,
+ const String& workload_name) final {
+ if (Optional<tir::Schedule> sch = this->QuerySchedule(mod, target,
workload_name)) {
+ return TuningRecord(sch.value()->trace().value(),
+ /*workload=*/Workload(mod, 0), //
+ /*run_secs=*/NullOpt, //
+ /*target=*/target, //
+ /*arg_info=*/NullOpt);
+ }
+ return NullOpt;
+ }
+
+ Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target&
target,
+ const String& workload_name) final {
+ tir::Schedule sch =
+ tir::Schedule::Traced(WithAttr<IRModule>(mod, "task_name",
workload_name),
+ /*rand_state=*/-1,
+ /*debug_mode=*/0,
+
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
+ if (!schedule_fn(sch)) {
+ return NullOpt;
+ }
+ return sch;
+ }
+
+ bool HasWorkload(const IRModule& mod) final {
+ LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.HasWorkload";
+ throw;
+ }
+
+ Workload CommitWorkload(const IRModule& mod) final {
+ LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitWorkload";
+ throw;
+ }
+
+ void CommitTuningRecord(const TuningRecord& record) final {
+ LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitTuningRecord";
+ throw;
+ }
+
+ Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
+ LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK";
+ throw;
+ }
+
+ Array<TuningRecord> GetAllTuningRecords() final {
+ LOG(FATAL) << "NotImplementedError:
ScheduleFnDatabase.GetAllTuningRecords";
+ throw;
+ }
+
+ int64_t Size() final {
+ LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.size";
+ throw;
+ }
+};
+
+Database
Database::ScheduleFnDatabase(runtime::TypedPackedFunc<bool(tir::Schedule)>
schedule_fn) {
+ ObjectPtr<ScheduleFnDatabaseNode> n = make_object<ScheduleFnDatabaseNode>();
+ n->schedule_fn = std::move(schedule_fn);
+ return Database(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase")
+ .set_body_typed(Database::ScheduleFnDatabase);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/relay/backend/te_compiler_cache.cc
b/src/relay/backend/te_compiler_cache.cc
index 0e2a3e2702..1d7566ebe2 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -367,7 +367,8 @@ class ScheduleBuilder : public ExprVisitor {
if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
if (Optional<TuningRecord> opt_record =
database_.value()->QueryTuningRecord(
/*mod=*/backend::PrimFuncToIRModule(f.value()),
- /*target=*/target_)) {
+ /*target=*/target_,
+ /*workload_name=*/prim_fn_var->name_hint)) {
static InstructionKind kind_transform_layout =
InstructionKind::Get("TransformLayout");
TuningRecord record = opt_record.value();
for (const Instruction& inst : record->trace->insts) {
@@ -383,6 +384,8 @@ class ScheduleBuilder : public ExprVisitor {
ICHECK_EQ(mod->functions.size(), 1);
mod =
tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod));
prim_func = Downcast<PrimFunc>(mod->Lookup("main"));
+ } else {
+ LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint;
}
}
}
diff --git a/tests/python/unittest/test_link_params.py
b/tests/python/unittest/test_link_params.py
index c741ecb59a..b14c18e55f 100644
--- a/tests/python/unittest/test_link_params.py
+++ b/tests/python/unittest/test_link_params.py
@@ -29,7 +29,6 @@ import tvm.testing
from tvm import meta_schedule as ms
from tvm import relay
from tvm.contrib import utils
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
from tvm.relay.backend import Executor, Runtime
INPUT_SHAPE = (1, 3, 16, 16)
@@ -407,21 +406,21 @@ def test_tir_link_params():
target = "llvm"
params = {"weight": weight_np}
- def schedule_fn(task, sch):
- if "nn_dense" in task.task_name:
+ def schedule_fn(sch):
+ if "nn_dense" in sch.mod.attrs["task_name"]:
schedule_dense(sch)
return True
return False
link_params = True
- with tvm.transform.PassContext(config={"relay.FuseOps.link_params":
link_params}):
- database = apply_fixed_schedules(relay_mod, target, params,
schedule_fn)
-
with StringIO() as stderr_buf, redirect_stderr(stderr_buf):
- with database, tvm.transform.PassContext(
+ with ms.database.ScheduleFnDatabase(schedule_fn),
tvm.transform.PassContext(
opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
+ config={
+ "relay.backend.use_meta_schedule": True,
+ "relay.FuseOps.link_params": link_params,
+ },
):
executor = Executor("graph", {"link-params": link_params})
lib = relay.build(relay_mod, target=target, executor=executor)
diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py
b/tests/python/unittest/test_meta_schedule_multi_anchor.py
index 1770017811..cb6f59c6e5 100644
--- a/tests/python/unittest/test_meta_schedule_multi_anchor.py
+++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py
@@ -19,7 +19,6 @@ import tvm
import tvm.testing
from tvm import meta_schedule as ms
from tvm import relay
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
def get_dense_dense(data_shape, weight_shape):
@@ -63,14 +62,13 @@ def test_dense_dense():
target = "llvm"
params = {"weight1": weight1_np, "weight2": weight2_np}
- def schedule_fn(task, sch):
- if "nn_dense_nn_dense" in task.task_name:
+ def schedule_fn(sch):
+ if "nn_dense_nn_dense" in sch.mod.attrs["task_name"]:
schedule_dense_dense(sch)
return True
return False
- database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)
- with database:
+ with ms.database.ScheduleFnDatabase(schedule_fn):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
index 939851a657..b373338036 100644
--- a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
+++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
@@ -18,8 +18,9 @@ import numpy as np
import tvm
import tvm.testing
import tvm.topi.testing
-from tvm import autotvm, relay, te
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
+from tvm import autotvm
+from tvm import meta_schedule as ms
+from tvm import relay, te
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.script import tir as T
@@ -139,21 +140,14 @@ def test_conv2d():
target = "llvm"
params = {"weight": weight_np}
- def schedule_fn(task, sch):
- if "nn_conv2d" in task.task_name:
+ def schedule_fn(sch):
+ if "nn_conv2d" in sch.mod.attrs["task_name"]:
schedule_tir_conv2d_nchw_oihw(sch)
return True
return False
with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy):
- database = apply_fixed_schedules(
- relay_mod,
- target,
- params,
- schedule_fn,
- tir_converter="allow_extern",
- )
- with database, tvm.transform.PassContext(
+ with ms.database.ScheduleFnDatabase(schedule_fn),
tvm.transform.PassContext(
opt_level=3,
config={
"relay.backend.use_meta_schedule": True,
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py
b/tests/python/unittest/test_meta_schedule_tune_relay.py
index bc37fed7d6..b05b57feaf 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -29,7 +29,6 @@ from tvm._ffi import register_func
from tvm.contrib import graph_executor
from tvm.ir import IRModule
from tvm.meta_schedule.testing.relay_workload import get_network
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
from tvm.script import tir as T
from tvm.target.target import Target
from tvm.tir.schedule import BlockRV, Schedule
@@ -452,8 +451,8 @@ def manual_tir_common(do_tune=False):
)
else:
- def schedule_fn(task, sch):
- if "dense" not in task.task_name:
+ def schedule_fn(sch) -> bool:
+ if "dense" not in sch.mod.attrs["task_name"]:
return False
block = sch.get_block("compute")
@@ -468,7 +467,7 @@ def manual_tir_common(do_tune=False):
return True
- database = apply_fixed_schedules(relay_mod, target, params,
schedule_fn)
+ database = ms.database.ScheduleFnDatabase(schedule_fn)
with database, tvm.transform.PassContext(
opt_level=3,