junrushao1994 commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785375497



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& 
block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at 
location (specified by
+    // the annotation), we colect the producer first, and transform the 
producer block later.

Review comment:
       ```suggestion
       // the annotation), we collect the producer first, and transform the 
producer block later.
   ```

##########
File path: src/tir/schedule/primitive/sampling.cc
##########
@@ -354,6 +354,40 @@ std::vector<int64_t> SamplePerfectTile(
   return result;
 }
 
+tir::StmtSRef SampleComputeLocation(tir::ScheduleState self,
+                                    
support::LinearCongruentialEngine::TRandState* rand_state,
+                                    const StmtSRef& block_sref, 
Optional<Integer>* decision) {
+  // Step 1. Collect all possible compute-at locations.
+  Array<tir::StmtSRef> location_srefs;
+  std::vector<int> location_indices;
+  std::tie(location_srefs, location_indices) = CollectComputeLocation(self, 
block_sref);
+  ICHECK_EQ(location_srefs.size(), location_indices.size());
+
+  // Step 2. If there was a previous decision, keep the decision unchanged if 
it exists in the
+  // location candidates. Otherwise, pick the location before the previous 
decision.
+  // Step 3. If there was not a previous decision, sample a decision from the 
collected locations.
+  if (decision->defined()) {
+    int64_t old_decision = Downcast<Integer>(*decision)->value;
+    auto it = std::lower_bound(location_indices.begin(), 
location_indices.end(), old_decision);
+    int idx = it - location_indices.begin();
+
+    if (it != location_indices.end() && *it == old_decision) {
+      *decision = Integer(old_decision);
+      return location_srefs[idx];
+    } else if (it != location_indices.begin()) {
+      *decision = Integer(*--it);

Review comment:
       nit: just to make it a bit clearer :-)
   
   ```suggestion
         *decision = Integer(location_indices[idx - 1]);
   ```

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& 
block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at 
location (specified by
+    // the annotation), we colect the producer first, and transform the 
producer block later.
+    // - The reason we collect the producer before transforming the input 
block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input 
block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), 
tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is 
needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) 
const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost 
loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 
1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the 
block's annotation.
+    if (tir::GetAnn<String>(block_sref, 
tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }
+    // Cond 6. The block has at lease one consumer.
+    if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Keep sampling a compute-at location for the input block until 
success.
+   * \param sch The TIR schedule
+   * \param block_rv The block whose compute-at location is to be sampled
+   * \return The TIR schedule after transformation
+   */
+  tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const 
tir::BlockRV& block_rv) {
+    for (;;) {
+      tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
+      try {
+        sch->ComputeAt(block_rv, compute_at_loc, true);
+      } catch (const dmlc::Error& e) {
+        // ComputeAt fails, cleanup the following before re-try:
+        // 1) trace: instruction & decisions
+        // 2) sym_tab
+        sch->trace().value()->Pop();
+        sch->RemoveRV(compute_at_loc);
+        continue;
+      }
+      break;
+    }
+    return sch;
+  }
+
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = 
"meta_schedule.RandomComputeLocation";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::RandomComputeLocation() {
+  ObjectPtr<RandomComputeLocationNode> n = 
make_object<RandomComputeLocationNode>();
+  return ScheduleRule(n);

Review comment:
       nit: can be merged into a single line:
   
   ```suggestion
     return ScheduleRule(make_object<RandomComputeLocationNode>());
   ```

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& 
block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at 
location (specified by
+    // the annotation), we colect the producer first, and transform the 
producer block later.
+    // - The reason we collect the producer before transforming the input 
block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input 
block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), 
tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is 
needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }

Review comment:
       To make sure I understand correctly, it means `Compute-At` could 
potentially happen twice (when it comes with annotation): first compute the 
producer onto this block, then move this block somewhere to one of its 
consumers.
   
   Is there any corner case we potentially want to check carefully and 
disallow? For example, do we allow inline the producer?

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -646,6 +646,152 @@ BlockRealize GetBlockRealize(const ScheduleState& self, 
const StmtSRef& block_sr
   }
 }
 
+IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const Var& loop_var = loop->loop_var;
+  int n_spatial = 0;
+  int n_reduce = 0;
+  int n_other = 0;
+  auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& 
obj) -> bool {
+    if (const auto* realize = obj.as<BlockRealizeNode>()) {
+      const BlockNode* block = realize->block.get();
+      // Number of block vars and their bindings
+      ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size());
+      size_t n = realize->iter_values.size();
+      for (size_t i = 0; i < n; ++i) {
+        const IterVar& iter_var = block->iter_vars[i];
+        const PrimExpr& binding = realize->iter_values[i];
+        // Categorize the current block var
+        int* ref = nullptr;
+        if (iter_var->iter_type == IterVarType::kDataPar) {
+          ref = &n_spatial;
+        } else if (iter_var->iter_type == IterVarType::kCommReduce) {
+          ref = &n_reduce;
+        } else {
+          ref = &n_other;
+        }
+        // Visit the binding to see if `loop_var` appears
+        PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> 
void {
+          if (obj.same_as(loop_var)) {
+            (*ref) += 1;
+          }
+        });
+      }
+      return false;
+    }
+    return true;
+  };
+  PreOrderVisit(loop->body, f_visit);
+  if (n_other) {
+    return IterVarType::kOpaque;
+  } else if (n_spatial && n_reduce) {
+    return IterVarType::kOpaque;
+  } else if (n_reduce) {
+    return IterVarType::kCommReduce;
+  } else {
+    return IterVarType::kDataPar;
+  }
+}
+
+StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs) {
+  CHECK(!srefs.empty()) << "ValueError: The input array is required to have at 
least one sref";
+
+  std::unordered_map<const StmtSRefNode*, size_t> sref_visited_cnt;
+  for (const StmtSRef& sref : srefs) {
+    const StmtSRefNode* p = sref.get();
+    while (p != nullptr) {
+      ++sref_visited_cnt[p];
+      p = p->parent;
+    }
+  }
+  size_t n_sref = srefs.size();
+  const StmtSRefNode* p = srefs[0].get();
+  while (p != nullptr && sref_visited_cnt[p] != n_sref) {
+    p = p->parent;
+  }
+  ICHECK(p != nullptr);
+  return GetRef<StmtSRef>(p);
+}
+
+std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const 
ScheduleState& self,
+                                                                    const 
StmtSRef& block_sref) {
+  Array<StmtSRef> location_srefs;
+  std::vector<int> location_indices;
+
+  // Step 1. Add the "compute-root" candidate. Add the "compute-inline" 
candidate if the block can
+  // be inlined.
+  if (CanComputeInline(self, block_sref)) {
+    location_srefs.push_back(StmtSRef::InlineMark());
+    location_indices.push_back(-2);
+  }
+  location_srefs.push_back(StmtSRef::RootMark());
+  location_indices.push_back(-1);
+
+  // Step 2. If the block has no consumer, there is no more candidate.
+  Array<StmtSRef> consumers = GetConsumers(self, block_sref);
+  if (consumers.empty()) {
+    return std::make_pair(location_srefs, location_indices);
+  }
+
+  // Step 3. Get the deepest loop that the input block can be computed at 
(namely "boundary"). If
+  // such a loop cannot be found, there is no more candidate and we just 
return.
+  StmtSRef loop_boundary = consumers.size() > 1 ? 
GetSRefLowestCommonAncestor(consumers)
+                                                : 
GetRef<StmtSRef>(consumers[0]->parent);
+  if (loop_boundary->StmtAs<ForNode>() == nullptr) {
+    return std::make_pair(location_srefs, location_indices);
+  }
+
+  // Step 4. Collect the loops outside the first consumer and locate the 
boundary loop. The position
+  // of the boundary loop reveals the number of possible additional candidates.
+  Array<StmtSRef> loop_srefs = GetLoops(consumers[0]);
+  size_t lca_pos =
+      std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - 
loop_srefs.begin();
+  ICHECK_LT(lca_pos, loop_srefs.size());
+  size_t n_candidate = lca_pos + 1;
+
+  // Step 5. Find the position of the deepest data-parallel loop among the 
candidate loops. This
+  // position is used for removing the unwanted candidates from the 
perspective of performance.
+  std::vector<IterVarType> loop_iter_types;
+  loop_iter_types.reserve(n_candidate);
+  int i_last_datapar = -1;
+  for (size_t i = 0; i < n_candidate; ++i) {
+    IterVarType iter_type = GetLoopIterType(loop_srefs[i]);

Review comment:
       We might want to improve the performance of this snippet in the future, 
but it doesn't look like it's the bottleneck now :-)

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& 
block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at 
location (specified by
+    // the annotation), we colect the producer first, and transform the 
producer block later.
+    // - The reason we collect the producer before transforming the input 
block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input 
block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), 
tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is 
needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) 
const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost 
loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 
1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the 
block's annotation.
+    if (tir::GetAnn<String>(block_sref, 
tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }
+    // Cond 6. The block has at lease one consumer.
+    if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Keep sampling a compute-at location for the input block until 
success.
+   * \param sch The TIR schedule
+   * \param block_rv The block whose compute-at location is to be sampled
+   * \return The TIR schedule after transformation
+   */
+  tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const 
tir::BlockRV& block_rv) {
+    for (;;) {
+      tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
+      try {
+        sch->ComputeAt(block_rv, compute_at_loc, true);
+      } catch (const dmlc::Error& e) {
+        // ComputeAt fails, cleanup the following before re-try:
+        // 1) trace: instruction & decisions
+        // 2) sym_tab
+        sch->trace().value()->Pop();
+        sch->RemoveRV(compute_at_loc);
+        continue;
+      }
+      break;
+    }

Review comment:
       The try-catch loop here is not desirable. Shall we use `CanComputeAt` in 
`Sample-Compute-Location` to make sure every outcome of it works?

##########
File path: python/tvm/meta_schedule/schedule_rule/__init__.py
##########
@@ -16,4 +16,6 @@
 Meta Schedule schedule rules are used for modification of
 blocks in a schedule. See also PostOrderApply.
 """
+

Review comment:
       no need for this blank line, i suppose

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& 
block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at 
location (specified by
+    // the annotation), we colect the producer first, and transform the 
producer block later.
+    // - The reason we collect the producer before transforming the input 
block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input 
block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), 
tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is 
needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) 
const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);

Review comment:
       No need to use reference
   
   ```suggestion
       tir::StmtSRef block_sref = sch->GetSRef(block_rv);
   ```

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {

Review comment:
       nit: add a blank line above




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to