This is an automated email from the ASF dual-hosted git repository.
zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a867bcb [Auto Scheduler] Add target host to measure record (#7046)
a867bcb is described below
commit a867bcbf1ecf537cfb061a2ca4790b16a9cc9748
Author: Zhao Wu <[email protected]>
AuthorDate: Tue Dec 8 14:46:29 2020 +0800
[Auto Scheduler] Add target host to measure record (#7046)
* [Auto Scheduler] Add target host to measure record
* Fix PyLint
* Fix lint
* Solve the serialization logic when we don't have hardware params
* update auto scheduler log
---
src/auto_scheduler/measure_record.cc | 12 ++++++++--
.../python/unittest/test_auto_scheduler_measure.py | 26 ++++++++++++++++++++++
2 files changed, 36 insertions(+), 2 deletions(-)
diff --git a/src/auto_scheduler/measure_record.cc
b/src/auto_scheduler/measure_record.cc
index d57e2f2..aad0abe 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -163,6 +163,9 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
writer->WriteArrayItem(std::string(data.workload_key));
writer->WriteArrayItem(data.target->str());
writer->WriteArrayItem(*data.hardware_params.get());
+ if (data.target_host.defined()) {
+ writer->WriteArrayItem(data.target_host->str());
+ }
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::auto_scheduler::SearchTaskNode* data) {
@@ -183,7 +186,12 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
reader->Read(hardware_params_node.get());
s = reader->NextArrayItem();
data->hardware_params =
::tvm::auto_scheduler::HardwareParams(hardware_params_node);
- ICHECK(!s);
+ if (s) {
+ reader->Read(&str_value);
+ data->target_host = ::tvm::Target(str_value);
+ s = reader->NextArrayItem();
+ ICHECK(!s);
+ }
}
}
};
@@ -271,7 +279,7 @@ namespace auto_scheduler {
TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
-const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.3"; // NOLINT(*)
+const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
RecordToFile::RecordToFile(String filename) {
auto node = make_object<RecordToFileNode>();
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py
b/tests/python/unittest/test_auto_scheduler_measure.py
index b214d9c..10bb0b4 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -250,6 +250,31 @@ def test_measure_local_builder_rpc_runner_spawn():
p.join()
[email protected]_llvm
+def test_measure_target_host():
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test,
+ args=(512, 512, 512),
+ target="llvm",
+ target_host="llvm -mtriple=aarch64-linux-gnu",
+ )
+
+ inp = auto_scheduler.measure.MeasureInput(task,
task.compute_dag.init_state)
+ res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ auto_scheduler.save_records(fp.name, [inp], [res])
+
+ log_reader = auto_scheduler.RecordReader(fp.name)
+ inputs, results = log_reader.read_lines()
+ assert len(inputs) == 1
+
+ raw_inp = inputs[0]
+
+ recovered_inp = auto_scheduler.measure.recover_measure_input(raw_inp)
+ assert str(recovered_inp.task.target_host) == str(inp.task.target_host)
+
+
if __name__ == "__main__":
test_record_split_reorder_fuse_annotation()
test_record_compute_at_root_inline_cache_read_write()
@@ -258,3 +283,4 @@ if __name__ == "__main__":
test_recover_measure_input()
test_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()
+ test_measure_target_host()