This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 52d90da1d3 [MetaSchedule] TuningRecord Optional Arguments (#11598)
52d90da1d3 is described below
commit 52d90da1d3bc6b12611b1d30a38c02837fbf8d76
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Tue Jun 7 18:05:14 2022 -0700
[MetaSchedule] TuningRecord Optional Arguments (#11598)
In some situations, such as before measuring the candidates, the arguments
`run_secs`, `target`, and `args_info` in `TuningRecord` are not required. Per
this request, the new `TuningRecord` API now accepts arguments in the order of
`trace, workload, run_secs, target, args_info` with the last three being
optional. Note that some tests might fail due to the change of argument order,
so they might need to be adjusted accordingly.
---
include/tvm/meta_schedule/database.h | 17 +++----
python/tvm/meta_schedule/database/database.py | 26 +++++------
python/tvm/meta_schedule/testing/utils.py | 2 +-
src/meta_schedule/database/database.cc | 54 ++++++++++++++--------
src/meta_schedule/database/json_database.cc | 4 +-
.../measure_callback/add_to_database.cc | 2 +-
.../python/unittest/test_meta_schedule_database.py | 26 +++++------
.../unittest/test_meta_schedule_integration.py | 2 +-
.../unittest/test_meta_schedule_tune_relay.py | 2 +-
9 files changed, 75 insertions(+), 60 deletions(-)
diff --git a/include/tvm/meta_schedule/database.h
b/include/tvm/meta_schedule/database.h
index 1353dec3ed..37a315bf74 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -103,19 +103,19 @@ class TuningRecordNode : public runtime::Object {
public:
/*! \brief The trace tuned. */
tir::Trace trace;
- /*! \brief The profiling result in seconds. */
- Array<FloatImm> run_secs;
/*! \brief The workload. */
Workload workload{nullptr};
+ /*! \brief The profiling result in seconds. */
+ Optional<Array<FloatImm>> run_secs;
/*! \brief The target for tuning. */
- Target target;
+ Optional<Target> target;
/*! \brief The argument information. */
- Array<ArgInfo> args_info;
+ Optional<Array<ArgInfo>> args_info;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("trace", &trace);
- v->Visit("run_secs", &run_secs);
v->Visit("workload", &workload);
+ v->Visit("run_secs", &run_secs);
v->Visit("target", &target);
v->Visit("args_info", &args_info);
}
@@ -140,13 +140,14 @@ class TuningRecord : public runtime::ObjectRef {
/*!
\brief Constructor of a tuning record.
\param trace The trace of the tuning record.
- \param run_secs The running time of the tuning record.
\param workload The workload of the tuning record.
+ \param run_secs The running time of the tuning record.
\param target The target of the tuning record.
\param args_info The argument information of the tuning record.
*/
- TVM_DLL explicit TuningRecord(tir::Trace trace, Array<FloatImm> run_secs,
Workload workload,
- Target target, Array<ArgInfo> args_info);
+ TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
+ Optional<Array<FloatImm>> run_secs,
Optional<Target> target,
+ Optional<Array<ArgInfo>> args_info);
/*!
* \brief Create a tuning record from a json object.
* \param json_obj The json object.
diff --git a/python/tvm/meta_schedule/database/database.py
b/python/tvm/meta_schedule/database/database.py
index 314bf434c4..8e0c805410 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Tuning record database"""
-from typing import Any, Callable, List
+from typing import Any, Callable, List, Optional
from tvm._ffi import register_object
from tvm.ir.module import IRModule
@@ -82,35 +82,35 @@ class TuningRecord(Object):
----------
trace : tvm.ir.Trace
The trace of the tuning record.
- run_secs : List[float]
- The run time of the tuning record.
workload : Workload
The workload of the tuning record.
- target : Target
+ run_secs : Optional[List[float]]
+ The run time of the tuning record.
+ target : Optional[Target]
The target of the tuning record.
- args_info : List[ArgInfo]
+ args_info : Optional[List[ArgInfo]]
The argument information of the tuning record.
"""
trace: Trace
- run_secs: List[float]
workload: Workload
- target: Target
- args_info: List[ArgInfo]
+ run_secs: Optional[List[float]]
+ target: Optional[Target]
+ args_info: Optional[List[ArgInfo]]
- def __init__(
+ def __init__( # type: ignore # pylint: disable=too-many-arguments
self,
trace: Trace,
- run_secs: List[float],
workload: Workload,
- target: Target,
- args_info: List[ArgInfo],
+ run_secs: Optional[List[float]] = None,
+ target: Optional[Target] = None,
+ args_info: Optional[List[ArgInfo]] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member
trace,
- run_secs,
workload,
+ run_secs,
target,
args_info,
)
diff --git a/python/tvm/meta_schedule/testing/utils.py
b/python/tvm/meta_schedule/testing/utils.py
index a832dfc6bc..62950fdd0b 100644
--- a/python/tvm/meta_schedule/testing/utils.py
+++ b/python/tvm/meta_schedule/testing/utils.py
@@ -155,7 +155,7 @@ def apply_fixed_schedules(
if schedule_fn(task, sch):
workload = database.commit_workload(mod)
- tune_rec = TuningRecord(sch.trace, [0.0], workload, target, [])
+ tune_rec = TuningRecord(sch.trace, workload, [0.0], target, [])
database.commit_tuning_record(tune_rec)
return database
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index fc7cc74de5..86d999e4fd 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -74,48 +74,62 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) {
/******** TuningRecord ********/
-TuningRecord::TuningRecord(tir::Trace trace, Array<FloatImm> run_secs,
Workload workload,
- Target target, Array<ArgInfo> args_info) {
+TuningRecord::TuningRecord(tir::Trace trace, Workload workload,
Optional<Array<FloatImm>> run_secs,
+ Optional<Target> target, Optional<Array<ArgInfo>>
args_info) {
ObjectPtr<TuningRecordNode> n = make_object<TuningRecordNode>();
n->trace = trace;
- n->run_secs = run_secs;
n->workload = workload;
+ n->run_secs = run_secs;
n->target = target;
n->args_info = args_info;
this->data_ = n;
}
ObjectRef TuningRecordNode::AsJSON() const {
- Array<ObjectRef> json_args_info;
- json_args_info.reserve(args_info.size());
- for (const ArgInfo& arg_info : args_info) {
- json_args_info.push_back(arg_info->AsJSON());
+ Optional<Array<ObjectRef>> json_args_info{nullptr};
+ Optional<ObjectRef> json_target{nullptr};
+ if (args_info.defined()) {
+ Array<ObjectRef> info;
+ info.reserve(args_info.value().size());
+ for (const ArgInfo& arg_info : args_info.value()) {
+ info.push_back(arg_info->AsJSON());
+ }
+ json_args_info = info;
+ }
+ if (target.defined()) {
+ json_target = target.value()->Export();
}
return Array<ObjectRef>{trace->AsJSON(false), //
run_secs, //
- target->Export(), //
+ json_target, //
json_args_info};
}
TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload&
workload) {
tir::Trace trace{nullptr};
- Array<FloatImm> run_secs{nullptr};
- Target target{nullptr};
- Array<ArgInfo> args_info;
+ Optional<Array<FloatImm>> run_secs{nullptr};
+ Optional<Target> target{nullptr};
+ Optional<Array<ArgInfo>> args_info{nullptr};
try {
const ArrayNode* json_array = json_obj.as<ArrayNode>();
CHECK(json_array && json_array->size() == 4);
// Load json[1] => run_secs
- run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
+ if (json_array->at(1).defined()) {
+ run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
+ }
// Load json[2] => target
- target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
+ if (json_array->at(2).defined()) {
+ target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
+ }
// Load json[3] => args_info
- {
+ if (json_array->at(3).defined()) {
const ArrayNode* json_args_info = json_array->at(3).as<ArrayNode>();
- args_info.reserve(json_args_info->size());
+ Array<ArgInfo> info;
+ info.reserve(json_args_info->size());
for (const ObjectRef& json_arg_info : *json_args_info) {
- args_info.push_back(ArgInfo::FromJSON(json_arg_info));
+ info.push_back(ArgInfo::FromJSON(json_arg_info));
}
+ args_info = info;
}
// Load json[0] => trace
{
@@ -130,7 +144,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef&
json_obj, const Workload& w
LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
<< "\nThe error is: " << e.what();
}
- return TuningRecord(trace, run_secs, workload, target, args_info);
+ return TuningRecord(trace, workload, run_secs, target, args_info);
}
/******** PyDatabase ********/
@@ -161,9 +175,9 @@ TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON")
.set_body_method<Workload>(&WorkloadNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
- .set_body_typed([](tir::Trace trace, Array<FloatImm> run_secs, Workload
workload, Target target,
- Array<ArgInfo> args_info) {
- return TuningRecord(trace, run_secs, workload, target, args_info);
+ .set_body_typed([](tir::Trace trace, Workload workload,
Optional<Array<FloatImm>> run_secs,
+ Optional<Target> target, Optional<Array<ArgInfo>>
args_info) {
+ return TuningRecord(trace, workload, run_secs, target, args_info);
});
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
.set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
diff --git a/src/meta_schedule/database/json_database.cc
b/src/meta_schedule/database/json_database.cc
index 2e76940fee..155d223217 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -40,8 +40,8 @@ struct SortTuningRecordByMeanRunSecs {
}
bool operator()(const TuningRecord& a, const TuningRecord& b) const {
- double a_time = Mean(a->run_secs);
- double b_time = Mean(b->run_secs);
+ double a_time = Mean(a->run_secs.value_or({}));
+ double b_time = Mean(b->run_secs.value_or({}));
return a_time < b_time;
}
};
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc
b/src/meta_schedule/measure_callback/add_to_database.cc
index 0988da0414..27b4e55a7d 100644
--- a/src/meta_schedule/measure_callback/add_to_database.cc
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -47,8 +47,8 @@ class AddToDatabaseNode : public MeasureCallbackNode {
}
database->CommitTuningRecord(TuningRecord(
/*trace=*/candidate->sch->trace().value(),
- /*run_secs=*/run_secs,
/*workload=*/workload,
+ /*run_secs=*/run_secs,
/*target=*/target,
/*args_info=*/candidate->args_info));
}
diff --git a/tests/python/unittest/test_meta_schedule_database.py
b/tests/python/unittest/test_meta_schedule_database.py
index d494f997c1..1edfbe6c7a 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -115,8 +115,8 @@ def test_meta_schedule_tuning_record_round_trip():
workload = database.commit_workload(mod)
record = TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
- [1.5, 2.5, 1.8],
workload,
+ [1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
)
@@ -140,8 +140,8 @@ def test_meta_schedule_database_has_workload():
workload = database.commit_workload(mod)
record = TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
- [1.5, 2.5, 1.8],
workload,
+ [1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
)
@@ -158,8 +158,8 @@ def test_meta_schedule_database_add_entry():
workload = database.commit_workload(mod)
record = TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
- [1.5, 2.5, 1.8],
workload,
+ [1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
)
@@ -178,8 +178,8 @@ def test_meta_schedule_database_missing():
workload_2 = database.commit_workload(mod_2)
record = TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
- [1.5, 2.5, 1.8],
workload,
+ [1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
)
@@ -197,43 +197,43 @@ def test_meta_schedule_database_sorting():
records = [
TuningRecord(
trace,
- [7.0, 8.0, 9.0],
token,
+ [7.0, 8.0, 9.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [1.0, 2.0, 3.0],
token,
+ [1.0, 2.0, 3.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [4.0, 5.0, 6.0],
token,
+ [4.0, 5.0, 6.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [1.1, 1.2, 600.0],
token,
+ [1.1, 1.2, 600.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [1.0, 100.0, 6.0],
token,
+ [1.0, 100.0, 6.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [4.0, 9.0, 8.0],
token,
+ [4.0, 9.0, 8.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
@@ -259,22 +259,22 @@ def test_meta_schedule_database_reload():
records = [
TuningRecord(
trace,
- [7.0, 8.0, 9.0],
token,
+ [7.0, 8.0, 9.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [1.0, 2.0, 3.0],
token,
+ [1.0, 2.0, 3.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
TuningRecord(
trace,
- [4.0, 5.0, 6.0],
token,
+ [4.0, 5.0, 6.0],
tvm.target.Target("llvm"),
ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
),
diff --git a/tests/python/unittest/test_meta_schedule_integration.py
b/tests/python/unittest/test_meta_schedule_integration.py
index a423bdb48a..3b33039bd2 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -267,7 +267,7 @@ def test_meta_schedule_integration_apply_history_best():
target = Target("llvm")
workload = database.commit_workload(MockModule)
database.commit_tuning_record(
- TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
+ TuningRecord(Schedule(MockModule).trace, workload, [1.0], target, [])
)
mod = env.query(task_name="mock-task", mod=mod, target=target,
dispatched=[MockModule])
assert tvm.ir.structural_equal(mod, workload.mod)
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py
b/tests/python/unittest/test_meta_schedule_tune_relay.py
index e5076af520..e0883dbd22 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -307,8 +307,8 @@ def test_meta_schedule_relay_lowering():
database.commit_tuning_record(
TuningRecord(
Trace([], {}),
- [0.0],
database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc),
+ [0.0],
target=target,
args_info=[],
)