This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 52d90da1d3 [MetaSchedule] TuningRecord Optional Arguments (#11598)
52d90da1d3 is described below

commit 52d90da1d3bc6b12611b1d30a38c02837fbf8d76
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Tue Jun 7 18:05:14 2022 -0700

    [MetaSchedule] TuningRecord Optional Arguments (#11598)
    
    In some situations, such as before measuring the candidates, the arguments 
`run_secs`, `target`, and `args_info` in `TuningRecord` are not required. Per 
this request, the new `TuningRecord` API now accepts arguments in the order of 
`trace, workload, run_secs, target, args_info` with the last three being 
optional. Note that some tests might fail due to the change of argument order, 
so they might need to be adjusted accordingly.
---
 include/tvm/meta_schedule/database.h               | 17 +++----
 python/tvm/meta_schedule/database/database.py      | 26 +++++------
 python/tvm/meta_schedule/testing/utils.py          |  2 +-
 src/meta_schedule/database/database.cc             | 54 ++++++++++++++--------
 src/meta_schedule/database/json_database.cc        |  4 +-
 .../measure_callback/add_to_database.cc            |  2 +-
 .../python/unittest/test_meta_schedule_database.py | 26 +++++------
 .../unittest/test_meta_schedule_integration.py     |  2 +-
 .../unittest/test_meta_schedule_tune_relay.py      |  2 +-
 9 files changed, 75 insertions(+), 60 deletions(-)

diff --git a/include/tvm/meta_schedule/database.h 
b/include/tvm/meta_schedule/database.h
index 1353dec3ed..37a315bf74 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -103,19 +103,19 @@ class TuningRecordNode : public runtime::Object {
  public:
   /*! \brief The trace tuned. */
   tir::Trace trace;
-  /*! \brief The profiling result in seconds. */
-  Array<FloatImm> run_secs;
   /*! \brief The workload. */
   Workload workload{nullptr};
+  /*! \brief The profiling result in seconds. */
+  Optional<Array<FloatImm>> run_secs;
   /*! \brief The target for tuning. */
-  Target target;
+  Optional<Target> target;
   /*! \brief The argument information. */
-  Array<ArgInfo> args_info;
+  Optional<Array<ArgInfo>> args_info;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("trace", &trace);
-    v->Visit("run_secs", &run_secs);
     v->Visit("workload", &workload);
+    v->Visit("run_secs", &run_secs);
     v->Visit("target", &target);
     v->Visit("args_info", &args_info);
   }
@@ -140,13 +140,14 @@ class TuningRecord : public runtime::ObjectRef {
   /*!
    \brief Constructor of a tuning record.
    \param trace The trace of the tuning record.
-   \param run_secs The running time of the tuning record.
    \param workload The workload of the tuning record.
+   \param run_secs The running time of the tuning record.
    \param target The target of the tuning record.
    \param args_info The argument information of the tuning record.
   */
-  TVM_DLL explicit TuningRecord(tir::Trace trace, Array<FloatImm> run_secs, 
Workload workload,
-                                Target target, Array<ArgInfo> args_info);
+  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
+                                Optional<Array<FloatImm>> run_secs, 
Optional<Target> target,
+                                Optional<Array<ArgInfo>> args_info);
   /*!
    * \brief Create a tuning record from a json object.
    * \param json_obj The json object.
diff --git a/python/tvm/meta_schedule/database/database.py 
b/python/tvm/meta_schedule/database/database.py
index 314bf434c4..8e0c805410 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Tuning record database"""
-from typing import Any, Callable, List
+from typing import Any, Callable, List, Optional
 
 from tvm._ffi import register_object
 from tvm.ir.module import IRModule
@@ -82,35 +82,35 @@ class TuningRecord(Object):
     ----------
     trace : tvm.ir.Trace
         The trace of the tuning record.
-    run_secs : List[float]
-        The run time of the tuning record.
     workload : Workload
         The workload of the tuning record.
-    target : Target
+    run_secs : Optional[List[float]]
+        The run time of the tuning record.
+    target : Optional[Target]
         The target of the tuning record.
-    args_info : List[ArgInfo]
+    args_info : Optional[List[ArgInfo]]
         The argument information of the tuning record.
     """
 
     trace: Trace
-    run_secs: List[float]
     workload: Workload
-    target: Target
-    args_info: List[ArgInfo]
+    run_secs: Optional[List[float]]
+    target: Optional[Target]
+    args_info: Optional[List[ArgInfo]]
 
-    def __init__(
+    def __init__(  # type: ignore # pylint: disable=too-many-arguments
         self,
         trace: Trace,
-        run_secs: List[float],
         workload: Workload,
-        target: Target,
-        args_info: List[ArgInfo],
+        run_secs: Optional[List[float]] = None,
+        target: Optional[Target] = None,
+        args_info: Optional[List[ArgInfo]] = None,
     ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.TuningRecord,  # type: ignore # pylint: disable=no-member
             trace,
-            run_secs,
             workload,
+            run_secs,
             target,
             args_info,
         )
diff --git a/python/tvm/meta_schedule/testing/utils.py 
b/python/tvm/meta_schedule/testing/utils.py
index a832dfc6bc..62950fdd0b 100644
--- a/python/tvm/meta_schedule/testing/utils.py
+++ b/python/tvm/meta_schedule/testing/utils.py
@@ -155,7 +155,7 @@ def apply_fixed_schedules(
 
         if schedule_fn(task, sch):
             workload = database.commit_workload(mod)
-            tune_rec = TuningRecord(sch.trace, [0.0], workload, target, [])
+            tune_rec = TuningRecord(sch.trace, workload, [0.0], target, [])
             database.commit_tuning_record(tune_rec)
 
     return database
diff --git a/src/meta_schedule/database/database.cc 
b/src/meta_schedule/database/database.cc
index fc7cc74de5..86d999e4fd 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -74,48 +74,62 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) {
 
 /******** TuningRecord ********/
 
-TuningRecord::TuningRecord(tir::Trace trace, Array<FloatImm> run_secs, 
Workload workload,
-                           Target target, Array<ArgInfo> args_info) {
+TuningRecord::TuningRecord(tir::Trace trace, Workload workload, 
Optional<Array<FloatImm>> run_secs,
+                           Optional<Target> target, Optional<Array<ArgInfo>> 
args_info) {
   ObjectPtr<TuningRecordNode> n = make_object<TuningRecordNode>();
   n->trace = trace;
-  n->run_secs = run_secs;
   n->workload = workload;
+  n->run_secs = run_secs;
   n->target = target;
   n->args_info = args_info;
   this->data_ = n;
 }
 
 ObjectRef TuningRecordNode::AsJSON() const {
-  Array<ObjectRef> json_args_info;
-  json_args_info.reserve(args_info.size());
-  for (const ArgInfo& arg_info : args_info) {
-    json_args_info.push_back(arg_info->AsJSON());
+  Optional<Array<ObjectRef>> json_args_info{nullptr};
+  Optional<ObjectRef> json_target{nullptr};
+  if (args_info.defined()) {
+    Array<ObjectRef> info;
+    info.reserve(args_info.value().size());
+    for (const ArgInfo& arg_info : args_info.value()) {
+      info.push_back(arg_info->AsJSON());
+    }
+    json_args_info = info;
+  }
+  if (target.defined()) {
+    json_target = target.value()->Export();
   }
   return Array<ObjectRef>{trace->AsJSON(false),  //
                           run_secs,              //
-                          target->Export(),      //
+                          json_target,           //
                           json_args_info};
 }
 
 TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& 
workload) {
   tir::Trace trace{nullptr};
-  Array<FloatImm> run_secs{nullptr};
-  Target target{nullptr};
-  Array<ArgInfo> args_info;
+  Optional<Array<FloatImm>> run_secs{nullptr};
+  Optional<Target> target{nullptr};
+  Optional<Array<ArgInfo>> args_info{nullptr};
   try {
     const ArrayNode* json_array = json_obj.as<ArrayNode>();
     CHECK(json_array && json_array->size() == 4);
     // Load json[1] => run_secs
-    run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
+    if (json_array->at(1).defined()) {
+      run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
+    }
     // Load json[2] => target
-    target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
+    if (json_array->at(2).defined()) {
+      target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
+    }
     // Load json[3] => args_info
-    {
+    if (json_array->at(3).defined()) {
       const ArrayNode* json_args_info = json_array->at(3).as<ArrayNode>();
-      args_info.reserve(json_args_info->size());
+      Array<ArgInfo> info;
+      info.reserve(json_args_info->size());
       for (const ObjectRef& json_arg_info : *json_args_info) {
-        args_info.push_back(ArgInfo::FromJSON(json_arg_info));
+        info.push_back(ArgInfo::FromJSON(json_arg_info));
       }
+      args_info = info;
     }
     // Load json[0] => trace
     {
@@ -130,7 +144,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& 
json_obj, const Workload& w
     LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
                << "\nThe error is: " << e.what();
   }
-  return TuningRecord(trace, run_secs, workload, target, args_info);
+  return TuningRecord(trace, workload, run_secs, target, args_info);
 }
 
 /******** PyDatabase ********/
@@ -161,9 +175,9 @@ TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON")
     .set_body_method<Workload>(&WorkloadNode::AsJSON);
 
TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON);
 TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
-    .set_body_typed([](tir::Trace trace, Array<FloatImm> run_secs, Workload 
workload, Target target,
-                       Array<ArgInfo> args_info) {
-      return TuningRecord(trace, run_secs, workload, target, args_info);
+    .set_body_typed([](tir::Trace trace, Workload workload, 
Optional<Array<FloatImm>> run_secs,
+                       Optional<Target> target, Optional<Array<ArgInfo>> 
args_info) {
+      return TuningRecord(trace, workload, run_secs, target, args_info);
     });
 TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
     .set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
diff --git a/src/meta_schedule/database/json_database.cc 
b/src/meta_schedule/database/json_database.cc
index 2e76940fee..155d223217 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -40,8 +40,8 @@ struct SortTuningRecordByMeanRunSecs {
   }
 
   bool operator()(const TuningRecord& a, const TuningRecord& b) const {
-    double a_time = Mean(a->run_secs);
-    double b_time = Mean(b->run_secs);
+    double a_time = Mean(a->run_secs.value_or({}));
+    double b_time = Mean(b->run_secs.value_or({}));
     return a_time < b_time;
   }
 };
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc 
b/src/meta_schedule/measure_callback/add_to_database.cc
index 0988da0414..27b4e55a7d 100644
--- a/src/meta_schedule/measure_callback/add_to_database.cc
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -47,8 +47,8 @@ class AddToDatabaseNode : public MeasureCallbackNode {
       }
       database->CommitTuningRecord(TuningRecord(
           /*trace=*/candidate->sch->trace().value(),
-          /*run_secs=*/run_secs,
           /*workload=*/workload,
+          /*run_secs=*/run_secs,
           /*target=*/target,
           /*args_info=*/candidate->args_info));
     }
diff --git a/tests/python/unittest/test_meta_schedule_database.py 
b/tests/python/unittest/test_meta_schedule_database.py
index d494f997c1..1edfbe6c7a 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -115,8 +115,8 @@ def test_meta_schedule_tuning_record_round_trip():
         workload = database.commit_workload(mod)
         record = TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
-            [1.5, 2.5, 1.8],
             workload,
+            [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
             ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
         )
@@ -140,8 +140,8 @@ def test_meta_schedule_database_has_workload():
         workload = database.commit_workload(mod)
         record = TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
-            [1.5, 2.5, 1.8],
             workload,
+            [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
             ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
         )
@@ -158,8 +158,8 @@ def test_meta_schedule_database_add_entry():
         workload = database.commit_workload(mod)
         record = TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
-            [1.5, 2.5, 1.8],
             workload,
+            [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
             ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
         )
@@ -178,8 +178,8 @@ def test_meta_schedule_database_missing():
         workload_2 = database.commit_workload(mod_2)
         record = TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
-            [1.5, 2.5, 1.8],
             workload,
+            [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
             ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
         )
@@ -197,43 +197,43 @@ def test_meta_schedule_database_sorting():
         records = [
             TuningRecord(
                 trace,
-                [7.0, 8.0, 9.0],
                 token,
+                [7.0, 8.0, 9.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [1.0, 2.0, 3.0],
                 token,
+                [1.0, 2.0, 3.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [4.0, 5.0, 6.0],
                 token,
+                [4.0, 5.0, 6.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [1.1, 1.2, 600.0],
                 token,
+                [1.1, 1.2, 600.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [1.0, 100.0, 6.0],
                 token,
+                [1.0, 100.0, 6.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [4.0, 9.0, 8.0],
                 token,
+                [4.0, 9.0, 8.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
@@ -259,22 +259,22 @@ def test_meta_schedule_database_reload():
         records = [
             TuningRecord(
                 trace,
-                [7.0, 8.0, 9.0],
                 token,
+                [7.0, 8.0, 9.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [1.0, 2.0, 3.0],
                 token,
+                [1.0, 2.0, 3.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
             TuningRecord(
                 trace,
-                [4.0, 5.0, 6.0],
                 token,
+                [4.0, 5.0, 6.0],
                 tvm.target.Target("llvm"),
                 ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
             ),
diff --git a/tests/python/unittest/test_meta_schedule_integration.py 
b/tests/python/unittest/test_meta_schedule_integration.py
index a423bdb48a..3b33039bd2 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -267,7 +267,7 @@ def test_meta_schedule_integration_apply_history_best():
     target = Target("llvm")
     workload = database.commit_workload(MockModule)
     database.commit_tuning_record(
-        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
+        TuningRecord(Schedule(MockModule).trace, workload, [1.0], target, [])
     )
     mod = env.query(task_name="mock-task", mod=mod, target=target, 
dispatched=[MockModule])
     assert tvm.ir.structural_equal(mod, workload.mod)
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py 
b/tests/python/unittest/test_meta_schedule_tune_relay.py
index e5076af520..e0883dbd22 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -307,8 +307,8 @@ def test_meta_schedule_relay_lowering():
         database.commit_tuning_record(
             TuningRecord(
                 Trace([], {}),
-                [0.0],
                 
database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc),
+                [0.0],
                 target=target,
                 args_info=[],
             )

Reply via email to