junrushao1994 commented on a change in pull request #5962:
URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449821953



##########
File path: src/ansor/search_task.h
##########
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_task.h
+ * \brief Meta information and hardware parameters for a search task.
+ */
+
+#ifndef TVM_ANSOR_SEARCH_TASK_H_
+#define TVM_ANSOR_SEARCH_TASK_H_
+
+#include <tvm/target/target.h>
+
+#include "compute_dag.h"
+
+namespace tvm {
+namespace ansor {
+
+class HardwareParams;
+
+/*! \brief The parameters of target hardware used to guide the search process 
of SearchPolicy. */
+class HardwareParamsNode : public Object {
+ public:
+  /*! \brief The number of cores. */
+  int num_cores;
+  /*! \brief The width of vector units in bytes. */
+  int vector_unit_bytes;
+  /*! \brief The size of cache line in bytes. */
+  int cache_line_bytes;
+
+  // GPU related parameters got from device query API
+
+  /*! \brief The max shared memory per block. */
+  int max_shared_memory_per_block{INT32_MAX};
+  /*! \brief The max register memory per block. */
+  int max_registers_per_block{INT32_MAX};
+  /*! \brief The max threads per block. */
+  int max_threads_per_block{INT32_MAX};
+  /*! \brief The max vthread extent. */
+  int max_vthread_extent{INT32_MAX};
+  /*! \brief The thread numbers of a warp. */
+  int warp_size{INT32_MAX};
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_cores", &num_cores);
+    v->Visit("vector_unit_bytes", &vector_unit_bytes);
+    v->Visit("cache_line_bytes", &cache_line_bytes);
+    v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
+    v->Visit("max_registers_per_block", &max_registers_per_block);
+    v->Visit("max_threads_per_block", &max_threads_per_block);
+    v->Visit("max_vthread_extent", &max_vthread_extent);
+    v->Visit("warp_size", &warp_size);
+  }
+
+  /*!
+   * \brief Get the default hardware params.
+   * \param target A `tvm.target`.
+   * \param target_host A `tvm.target` for host device.
+   * \return A HardwareParams object.
+   */
+  static HardwareParams GetDefaultHardwareParams(const Target& target, const 
Target& target_host);

Review comment:
       Maybe move to `HardwareParams` instead, given right now we prefer 
constructor-like stuff in the container?

##########
File path: src/ansor/auto_schedule.h
##########
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/auto_schedule.h
+ * \brief The user interface of the Ansor auto-scheduler. This is the entry 
structure to get
+ * schedule search requirements from upper level (Python API), and returns a 
high performance
+ * schedule after search process.
+ */
+
+#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_
+#define TVM_ANSOR_AUTO_SCHEDULE_H_
+
+#include <utility>
+
+#include "measure.h"
+#include "search_policy/search_policy.h"
+
+namespace tvm {
+namespace ansor {
+
+/*! \brief Tuning and measurement options. */
+class TuningOptionsNode : public Object {
+ public:
+  /*! \brief Number of total measurement trials. */
+  int num_measure_trials;
+  /*! \brief Stops early the tuning if no improvement after n measurements. */
+  int early_stopping;
+  /*! \brief The number of programs to be measured at each search round. */
+  int num_measures_per_round;
+  /*! \brief Verbosity level. 0 for silent, 1 to output information during 
schedule searching. */
+  int verbose;
+  /*! \brief ProgramBuilder which builds the program */
+  ProgramBuilder builder;
+  /*! \brief ProgramRunner which runs the program and measure time costs */
+  ProgramRunner runner;
+  /*! \brief MeasureCallback functions to be called after each measure batch */
+  Array<MeasureCallback> measure_callbacks;
+  /*! \brief SearchCallback functions to be called before schedule search */
+  Array<SearchCallback> pre_search_callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_measure_trials", &num_measure_trials);
+    v->Visit("early_stopping", &early_stopping);
+    v->Visit("num_measures_per_round", &num_measures_per_round);
+    v->Visit("verbose", &verbose);
+    v->Visit("builder", &builder);
+    v->Visit("runner", &runner);
+    v->Visit("measure_callbacks", &measure_callbacks);
+    v->Visit("pre_search_callbacks", &pre_search_callbacks);
+  }
+
+  static constexpr const char* _type_key = "ansor.TuningOptions";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to TuningOptionsNode.
+ * \sa TuningOptionsNode
+ */
+class TuningOptions : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param num_measure_trials Number of total measurement trials.
+   * \param early_stopping Stops early the tuning if no improvement after n 
measurements.
+   * \param num_measures_per_round The number of programs to be measured at 
each search round.
+   * \param verbose Verbosity level. 0 for silent, 1 to output information 
during schedule
+   * search.
+   * \param builder ProgramBuilder which builds the program.
+   * \param runner ProgramRunner which runs the program and measure time costs.
+   * \param measure_callbacks MeasureCallback functions to be called after 
each measure batch.
+   * \param pre_search_callbacks SearchCallback functions to be called before 
schedule search.
+   */
+  TuningOptions(int num_measure_trials, int early_stopping, int 
num_measures_per_round, int verbose,
+                ProgramBuilder builder, ProgramRunner runner,
+                Array<MeasureCallback> measure_callbacks,
+                Array<SearchCallback> pre_search_callbacks);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode);
+};
+
+/*!
+ * \brief Auto schedule search for a given compute declaration.
+ * \param task The search task of the compute declaration.
+ * \param search_policy The search policy to be used for schedule search.
+ * \param tuning_options Tuning and measurement options.
+ * \return A `te::schedule` and the a Array of `te::Tensor` to be used in 
`tvm.lower` or
+ * `tvm.build`.
+ */
+std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task, 
SearchPolicy search_policy,

Review comment:
       Let's prepend `TVM_DLL` as it directly interacts with FFI.
   
   ```suggestion
   TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask 
task, SearchPolicy search_policy,
   ```

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+    "NoError",
+    "InstantiationError",
+    "CompileHostError",
+    "CompileDeviceError",
+    "RuntimeDeviceError",
+    "WrongAnswerError",
+    "BuildTimeoutError",
+    "RunTimeoutError",
+    "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+  auto node = make_object<MeasureInputNode>();
+  node->task = std::move(task);
+  node->state = std::move(state);
+  data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {

Review comment:
       I suppose it is a shallow copy, but the doc says it is deep copy?

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+    "NoError",
+    "InstantiationError",
+    "CompileHostError",
+    "CompileDeviceError",
+    "RuntimeDeviceError",
+    "WrongAnswerError",
+    "BuildTimeoutError",
+    "RunTimeoutError",
+    "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+  auto node = make_object<MeasureInputNode>();
+  node->task = std::move(task);
+  node->state = std::move(state);
+  data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {
+  auto node = make_object<MeasureInputNode>();
+  node->task = task;
+  node->state = state;
+  return MeasureInput(node);
+}
+
+BuildResult::BuildResult(String filename, Array<te::Tensor> args, int 
error_no, String error_msg,
+                         double time_cost) {
+  auto node = make_object<BuildResultNode>();
+  node->filename = std::move(filename);
+  node->args = std::move(args);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->time_cost = time_cost;
+  data_ = std::move(node);
+}
+
+MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String 
error_msg, double all_cost,
+                             double timestamp) {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = std::move(costs);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  data_ = std::move(node);
+}
+
+MeasureResult MeasureResultNode::copy() const {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = costs;
+  node->error_no = error_no;
+  node->error_msg = error_msg;
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  return MeasureResult(node);
+}
+
+/********** LocalBuilder **********/
+LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& 
build_func) {
+  auto node = make_object<LocalBuilderNode>();
+  node->timeout = timeout;
+  node->n_parallel = n_parallel;
+  node->build_func = build_func;
+  data_ = std::move(node);
+}
+
+Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, 
int verbose) {
+  if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) {
+    Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, 
verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "ansor.local_builder.build is not registered";
+  }
+  return Array<BuildResult>();
+}
+
+/********** LocalRunner **********/
+LocalRunner::LocalRunner(int timeout, int number, int repeat, int 
min_repeat_ms,
+                         double cooldown_interval) {
+  ObjectPtr<LocalRunnerNode> node = make_object<LocalRunnerNode>();
+  node->timeout = timeout;
+  node->number = number;
+  node->repeat = repeat;
+  node->min_repeat_ms = min_repeat_ms;
+  node->cooldown_interval = cooldown_interval;
+  data_ = std::move(node);
+}
+
+Array<MeasureResult> LocalRunnerNode::Run(const Array<MeasureInput>& inputs,
+                                          const Array<BuildResult>& 
build_results, int verbose) {
+  if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) {
+    Array<MeasureResult> results = (*f)(inputs, build_results, timeout, 
number, repeat,
+                                        min_repeat_ms, cooldown_interval, 
verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "ansor.local_runner.run is not registered";
+  }
+  return Array<MeasureResult>();

Review comment:
       Same here

##########
File path: src/ansor/utils.h
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/utils.h
+ * \brief Common utilities.
+ */
+
+#ifndef TVM_ANSOR_UTILS_H_
+#define TVM_ANSOR_UTILS_H_
+
+#include <dmlc/common.h>
+#include <tvm/tir/expr.h>
+
+#include <algorithm>
+#include <deque>
+#include <exception>
+#include <future>
+#include <string>
+#include <thread>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+namespace std {
+
+/*! \brief Hash function for std::pair */
+template <typename T1, typename T2>
+struct hash<std::pair<T1, T2>> {
+  std::size_t operator()(const std::pair<T1, T2>& k) const {
+    return ::dmlc::HashCombine(std::hash<T1>()(k.first), 
std::hash<T2>()(k.second));
+  }
+};
+
+/*! \brief Hash function for std::tuple */
+template <typename T1, typename T2, typename T3>
+struct hash<std::tuple<T1, T2, T3>> {
+  std::size_t operator()(const std::tuple<T1, T2, T3>& k) const {
+    return ::dmlc::HashCombine(
+        ::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), 
std::hash<T2>()(std::get<1>(k))),
+        std::hash<T3>()(std::get<2>(k)));
+  }
+};
+
+}  // namespace std
+
+namespace tvm {
+namespace ansor {
+
+/********** Utilities for Array, std::string **********/
+/*! \brief Get the first appearance index of elements in an Array */
+template <typename T>
+inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, 
Array<Integer>* indices) {
+  for (const auto& v : to_locate) {
+    auto it = std::find(array.begin(), array.end(), v);
+    if (it != array.end()) {
+      indices->push_back(it - array.begin());
+    } else {
+      LOG(FATAL) << "Cannot find the item";
+    }
+  }
+}
+
+/*! \brief Get the first appearance index of an element in an Array */
+template <typename T>
+inline int GetIndex(const Array<T>& array, const T& to_locate) {
+  for (size_t i = 0; i < array.size(); ++i) {
+    if (array[i] == to_locate) {
+      return i;
+    }
+  }
+  LOG(FATAL) << "Cannot find the item";
+  return -1;
+}
+
+/*! \brief Replace a sub-string to another sub-string in a string */
+inline void StrReplace(std::string* base, const std::string& from, const 
std::string& to) {
+  auto pos = base->find(from);
+  while (pos != std::string::npos) {
+    base->replace(pos, from.size(), to);
+    pos = base->find(from, pos + to.size());
+  }
+}
+
+/********** Utilities for TVM Containers / ByteArray **********/
+/*! \brief Compute mean of a FloatImm array */
+inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
+  double sum = 0;
+  if (float_array.empty()) {
+    return 0.0;
+  }
+
+  for (const auto& x : float_array) {
+    auto floatimm = x.as<tir::FloatImmNode>();
+    CHECK(floatimm != nullptr);
+    sum += floatimm->value;
+  }
+  return sum / float_array.size();
+}
+
+/********** Other Utilities **********/
+/*! \brief Get an int value from an Expr */
+inline int64_t GetIntImm(const PrimExpr& expr) {
+  auto pint = expr.as<IntImmNode>();
+  CHECK(pint != nullptr);
+  return pint->value;
+}
+
+/*! \brief Compute the product of the lengths of axes */
+inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
+  int64_t ret = 1.0;
+  for (const auto& x : axes) {
+    if (const IntImmNode* imm = x->dom->extent.as<IntImmNode>()) {
+      ret *= imm->value;
+    } else {
+      return -1.0;
+    }
+  }
+  return ret;
+}
+
+/*!
+ * \brief Clean the name of an iterator to make it valid in python code.
+ * \param str The original name.
+ * \return The cleaned name.
+ */
+inline std::string CleanName(const std::string& str) {
+  std::string ret = str;
+  StrReplace(&ret, ".", "_");
+  StrReplace(&ret, "@", "_");
+  StrReplace(&ret, "outer", "o");
+  StrReplace(&ret, "inner", "i");
+  return ret;
+}
+
+/*! \brief An empty output stream */
+class NullStream : public std::ostream {
+ public:
+  NullStream() : std::ostream(nullptr) {}
+  NullStream(const NullStream&) : std::ostream(nullptr) {}
+  static NullStream& Global();
+};
+
+template <class T>
+NullStream& operator<<(NullStream& os, const T& value) {
+  return os;
+}
+
+/*! \brief Get std cout with verbose control */
+inline std::ostream& StdCout(int verbose) {
+  return verbose == 1 ? std::cout : NullStream::Global();

Review comment:
       Do we really want to use `std::cout`? Maybe `LOG(INFO)` or 
`LOG(VERBOSE)`?

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+    "NoError",
+    "InstantiationError",
+    "CompileHostError",
+    "CompileDeviceError",
+    "RuntimeDeviceError",
+    "WrongAnswerError",
+    "BuildTimeoutError",
+    "RunTimeoutError",
+    "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+  auto node = make_object<MeasureInputNode>();
+  node->task = std::move(task);
+  node->state = std::move(state);
+  data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {
+  auto node = make_object<MeasureInputNode>();
+  node->task = task;
+  node->state = state;
+  return MeasureInput(node);
+}
+
+BuildResult::BuildResult(String filename, Array<te::Tensor> args, int 
error_no, String error_msg,
+                         double time_cost) {
+  auto node = make_object<BuildResultNode>();
+  node->filename = std::move(filename);
+  node->args = std::move(args);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->time_cost = time_cost;
+  data_ = std::move(node);
+}
+
+MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String 
error_msg, double all_cost,
+                             double timestamp) {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = std::move(costs);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  data_ = std::move(node);
+}
+
+MeasureResult MeasureResultNode::copy() const {

Review comment:
       It is shallow copy too.

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);

Review comment:
       Is there specific rules why some of them are registered using 
`OBJECT_TYPE` but the others are `NODE_TYPE`? IIRC `NODE_TYPE` helps register 
creator into the vtable, so my understand is that if we don't customize our own 
creator, just use `NODE_TYPE` instead.

##########
File path: src/ansor/measure.h
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.h
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ * The flow of data structures is MeasureInput -> BuildeResult -> 
MeasureResult.
+ */
+
+#ifndef TVM_ANSOR_MEASURE_H_
+#define TVM_ANSOR_MEASURE_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "loop_state.h"
+#include "search_task.h"
+
+namespace tvm {
+namespace ansor {
+
+class SearchPolicy;
+class MeasureInput;
+class MeasureResult;
+
+/*! \brief The error code of one measurement */
+enum MeasureErrorNO {
+  /*! \brief No error. */
+  kNoError = 0,
+  /*! \brief Errors happen when apply transform steps from init state. */
+  kInstantiationError = 1,
+  /*! \brief Errors happen when compiling code on host. (when build module) */
+  kCompileHostError = 2,
+  /*! \brief Errors happen when compiling code on device. (when load module) */
+  kCompileDeviceError = 3,
+  /*! \brief Errors happen when run program on device. */
+  kRuntimeDeviceError = 4,
+  /*! \brief Answer is wrong when compared to a reference output. */
+  kWrongAnswerError = 5,
+  /*! \brief Timeout during compilation. */
+  kBuildTimeoutError = 6,
+  /*! \brief Timeout during run. */
+  kRunTimeoutError = 7,
+  /*! \brief Unknown error. */
+  kUnknonwError = 8,
+};
+
+// Inputs and results of one measurement
+
+/*! \brief Store the input of a measurement */
+class MeasureInputNode : public Object {
+ public:
+  /*! \brief The search task. */
+  SearchTask task;
+  /*! \brief The program state to be measured. */
+  State state;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("task", &task);
+    v->Visit("state", &state);
+  }
+
+  /*! \brief Do deep copy. */
+  MeasureInput copy() const;

Review comment:
       Use `ShallowCopy`/`DeepCopy` to deliver clearer information.

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+    "NoError",
+    "InstantiationError",
+    "CompileHostError",
+    "CompileDeviceError",
+    "RuntimeDeviceError",
+    "WrongAnswerError",
+    "BuildTimeoutError",
+    "RunTimeoutError",
+    "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+  auto node = make_object<MeasureInputNode>();
+  node->task = std::move(task);
+  node->state = std::move(state);
+  data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {
+  auto node = make_object<MeasureInputNode>();
+  node->task = task;
+  node->state = state;
+  return MeasureInput(node);
+}
+
+BuildResult::BuildResult(String filename, Array<te::Tensor> args, int 
error_no, String error_msg,
+                         double time_cost) {
+  auto node = make_object<BuildResultNode>();
+  node->filename = std::move(filename);
+  node->args = std::move(args);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->time_cost = time_cost;
+  data_ = std::move(node);
+}
+
+MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String 
error_msg, double all_cost,
+                             double timestamp) {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = std::move(costs);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  data_ = std::move(node);
+}
+
+MeasureResult MeasureResultNode::copy() const {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = costs;
+  node->error_no = error_no;
+  node->error_msg = error_msg;
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  return MeasureResult(node);
+}
+
+/********** LocalBuilder **********/
+LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& 
build_func) {
+  auto node = make_object<LocalBuilderNode>();
+  node->timeout = timeout;
+  node->n_parallel = n_parallel;
+  node->build_func = build_func;
+  data_ = std::move(node);
+}
+
+Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, 
int verbose) {
+  if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) {
+    Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, 
verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "ansor.local_builder.build is not registered";
+  }
+  return Array<BuildResult>();

Review comment:
       Don't have to return anything because it is fatal anyways.
   
   ```suggestion
     }
     LOG(FATAL) << "ansor.local_builder.build is not registered";
     throw;
   ```

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+    "NoError",
+    "InstantiationError",
+    "CompileHostError",
+    "CompileDeviceError",
+    "RuntimeDeviceError",
+    "WrongAnswerError",
+    "BuildTimeoutError",
+    "RunTimeoutError",
+    "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+  auto node = make_object<MeasureInputNode>();
+  node->task = std::move(task);
+  node->state = std::move(state);
+  data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {
+  auto node = make_object<MeasureInputNode>();
+  node->task = task;
+  node->state = state;
+  return MeasureInput(node);
+}
+
+BuildResult::BuildResult(String filename, Array<te::Tensor> args, int 
error_no, String error_msg,
+                         double time_cost) {
+  auto node = make_object<BuildResultNode>();
+  node->filename = std::move(filename);
+  node->args = std::move(args);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->time_cost = time_cost;
+  data_ = std::move(node);
+}
+
+MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String 
error_msg, double all_cost,
+                             double timestamp) {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = std::move(costs);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  data_ = std::move(node);
+}
+
+MeasureResult MeasureResultNode::copy() const {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = costs;
+  node->error_no = error_no;
+  node->error_msg = error_msg;
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  return MeasureResult(node);
+}
+
+/********** LocalBuilder **********/
+LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& 
build_func) {
+  auto node = make_object<LocalBuilderNode>();
+  node->timeout = timeout;
+  node->n_parallel = n_parallel;
+  node->build_func = build_func;
+  data_ = std::move(node);
+}
+
+Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, 
int verbose) {
+  if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) {
+    Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, 
verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "ansor.local_builder.build is not registered";
+  }
+  return Array<BuildResult>();
+}
+
+/********** LocalRunner **********/
+LocalRunner::LocalRunner(int timeout, int number, int repeat, int 
min_repeat_ms,
+                         double cooldown_interval) {
+  ObjectPtr<LocalRunnerNode> node = make_object<LocalRunnerNode>();
+  node->timeout = timeout;
+  node->number = number;
+  node->repeat = repeat;
+  node->min_repeat_ms = min_repeat_ms;
+  node->cooldown_interval = cooldown_interval;
+  data_ = std::move(node);
+}
+
+Array<MeasureResult> LocalRunnerNode::Run(const Array<MeasureInput>& inputs,
+                                          const Array<BuildResult>& 
build_results, int verbose) {
+  if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) {
+    Array<MeasureResult> results = (*f)(inputs, build_results, timeout, 
number, repeat,
+                                        min_repeat_ms, cooldown_interval, 
verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "ansor.local_runner.run is not registered";
+  }
+  return Array<MeasureResult>();
+}
+
+/********** ProgramMeasurer **********/
+ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
+                                 Array<MeasureCallback> callbacks, int verbose,
+                                 int max_continous_error) {
+  auto node = make_object<ProgramMeasurerNode>();
+  node->builder = std::move(builder);
+  node->runner = std::move(runner);
+  node->callbacks = std::move(callbacks);
+  node->verbose = verbose;
+  node->max_continous_error = max_continous_error < 0
+                                  ? 
ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR
+                                  : max_continous_error;
+  data_ = std::move(node);
+}
+
+void ProgramMeasurerNode::Reset() {
+  ct = error_ct = 0;
+  best_flops.clear();
+  best_ct.clear();
+  best_state.clear();
+}
+
+void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& 
policy,
+                                  const Array<MeasureInput>& inputs, 
Array<MeasureResult>* results,
+                                  int batch_size) {
+  results->clear();
+  results->reserve(inputs.size());
+
+  if (batch_size == -1) {
+    // set default batch size
+    batch_size = builder->n_parallel * 2;
+  }
+
+  StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This 
may take a while)"
+                   << std::endl;
+
+  for (size_t i = 0; i < inputs.size(); i += batch_size) {
+    Array<MeasureInput> input_batch(inputs.begin() + i,
+                                    inputs.begin() + std::min(i + batch_size, 
inputs.size()));
+    Array<MeasureResult> result_batch;
+
+    // build and run
+    SilentMeasure(task, input_batch, &result_batch);
+
+    // update current best state according to the new measure result
+    for (size_t j = 0; j < input_batch.size(); ++j) {
+      double flops;
+      if (result_batch[j]->error_no == 0) {
+        flops = task->compute_dag->flop_ct / 
FloatArrayMean(result_batch[j]->costs);
+        error_ct = 0;
+      } else {
+        flops = 0.0;
+        error_ct++;
+      }
+
+      const String& workload_key = input_batch[j]->task->workload_key;
+      if (flops > best_flops[workload_key]) {
+        best_flops[workload_key] = flops;
+        best_state[workload_key] = input_batch[j]->state;
+        best_ct[workload_key] = ct;
+      }
+
+      ct++;
+      StdCout(verbose) << std::fixed << std::setprecision(2) << Chars('=', 50) 
<< "\n"
+                       << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / "
+                       << best_flops[workload_key] / 1e9 << "\tresults: " << 
result_batch[j] << "\n"
+                       << Chars('=', 50) << "\n"
+                       << input_batch[j]->state << "\n";
+    }
+
+    // Call callback functions
+    for (const auto& callback : callbacks) {
+      callback->Callback(policy, input_batch, result_batch);
+    }
+
+    // Store result batch
+    for (auto& res : result_batch) {
+      results->push_back(res);
+    }
+
+    if (error_ct > max_continous_error) {
+      LOG(FATAL) << "Too many errors happened during tuning";
+    }
+  }
+}
+
+void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const 
Array<MeasureInput>& inputs,
+                                        Array<MeasureResult>* results) {
+  // Close the thread pool to avoid the conflits with python environment
+  ThreadPool::Global().Abort();
+
+  results->clear();
+  results->reserve(inputs.size());
+  Array<MeasureInput> input_batch(inputs.begin(), inputs.end());

Review comment:
       Why we need a copy here?

##########
File path: src/ansor/measure.cc
##########
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.cc
+ * \brief Distributed measurement infrastructure to measure the runtime costs 
of tensor programs.
+ */
+
+#include "measure.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(MeasureInputNode);
+TVM_REGISTER_NODE_TYPE(BuildResultNode);
+TVM_REGISTER_NODE_TYPE(MeasureResultNode);
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
+
+static const char* ErrorNoToStr[] = {
+    "NoError",
+    "InstantiationError",
+    "CompileHostError",
+    "CompileDeviceError",
+    "RuntimeDeviceError",
+    "WrongAnswerError",
+    "BuildTimeoutError",
+    "RunTimeoutError",
+    "UnknownError",
+};
+
+/********** Measure input and result **********/
+MeasureInput::MeasureInput(SearchTask task, State state) {
+  auto node = make_object<MeasureInputNode>();
+  node->task = std::move(task);
+  node->state = std::move(state);
+  data_ = std::move(node);
+}
+
+MeasureInput MeasureInputNode::copy() const {
+  auto node = make_object<MeasureInputNode>();
+  node->task = task;
+  node->state = state;
+  return MeasureInput(node);
+}
+
+BuildResult::BuildResult(String filename, Array<te::Tensor> args, int 
error_no, String error_msg,
+                         double time_cost) {
+  auto node = make_object<BuildResultNode>();
+  node->filename = std::move(filename);
+  node->args = std::move(args);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->time_cost = time_cost;
+  data_ = std::move(node);
+}
+
+MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String 
error_msg, double all_cost,
+                             double timestamp) {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = std::move(costs);
+  node->error_no = error_no;
+  node->error_msg = std::move(error_msg);
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  data_ = std::move(node);
+}
+
+MeasureResult MeasureResultNode::copy() const {
+  auto node = make_object<MeasureResultNode>();
+  node->costs = costs;
+  node->error_no = error_no;
+  node->error_msg = error_msg;
+  node->all_cost = all_cost;
+  node->timestamp = timestamp;
+  return MeasureResult(node);
+}
+
+/********** LocalBuilder **********/
+LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& 
build_func) {
+  auto node = make_object<LocalBuilderNode>();
+  node->timeout = timeout;
+  node->n_parallel = n_parallel;
+  node->build_func = build_func;
+  data_ = std::move(node);
+}
+
+Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, 
int verbose) {
+  if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) {
+    Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, 
verbose);
+    return results;
+  } else {
+    LOG(FATAL) << "ansor.local_builder.build is not registered";

Review comment:
       I think we should document where `ansor.local_builder.build` is defined, 
and hint some potential reasons why it is not found (e.g. python package not 
fully loaded)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to