This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/metaschedule-get-shash-directly in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 6cfb888f73e35c8aebb8ba798a78e970dbe289cb Author: Andrew Zhao Luo <[email protected]> AuthorDate: Tue May 24 11:32:23 2022 -0700 initial commit --- include/tvm/meta_schedule/database.h | 6 ++++++ src/meta_schedule/database/database.cc | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index f07d8e1366..8c78c3a611 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -49,6 +49,12 @@ class WorkloadNode : public runtime::Object { * \return An array containing the structural hash and the base64 json string. */ ObjectRef AsJSON() const; + + /*! + * \brief Export the structural hash of the workload. + * \return A string representing the structural hash. + */ + tvm::runtime::String GetHash() const; }; /*! diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index fc7cc74de5..1eeb961226 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -37,6 +37,8 @@ Workload::Workload(IRModule mod, Workload::THashCode shash) { data_ = std::move(n); } +tvm::runtime::String WorkloadNode::GetHash() const { return SHash2Str(this->shash); } + ObjectRef WorkloadNode::AsJSON() const { // Convert `this->mod` to JSON std::string json_mod = tvm::SaveJSON(this->mod); @@ -160,6 +162,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON") .set_body_method<Workload>(&WorkloadNode::AsJSON); TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.GetHash") + .set_body_typed<tvm::runtime::String(Workload)>([](Workload w) { return w->GetHash(); }); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") .set_body_typed([](tir::Trace trace, Array<FloatImm> run_secs, Workload workload, Target target, Array<ArgInfo> args_info) {
