This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab8a106  [MetaSchedule] Add Per-Store-Feature (#9860)
ab8a106 is described below

commit ab8a106f6a62c582c8910fffa8b06245c16c9b70
Author: Junru Shao <[email protected]>
AuthorDate: Fri Jan 7 23:09:46 2022 -0800

    [MetaSchedule] Add Per-Store-Feature (#9860)
    
    * [MetaSchedule] Add Per-Store-Feature
    
    Co-authored-by: Xiyou Zhou <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
    
    * fix lint
    
    * fix lint
    
    * Update per_store_feature.py
    
    * address comments
    
    * fix lint
    
    Co-authored-by: Xiyou Zhou <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
---
 include/tvm/tir/stmt.h                             |    6 +-
 python/tvm/meta_schedule/__init__.py               |    2 +
 .../meta_schedule/feature_extractor/__init__.py    |    1 +
 .../feature_extractor/per_store_feature.py         |   60 +
 .../feature_extractor/per_store_feature.cc         | 1337 +++++++++++++++++
 src/meta_schedule/utils.h                          |    4 +
 src/tir/schedule/primitive/sampling.cc             |    8 +-
 src/tir/schedule/utils.h                           |   64 +-
 src/tir/transforms/unify_thread_binding.cc         |   13 +-
 ...schedule_feature_extractor_per_store_feature.py | 1555 ++++++++++++++++++++
 10 files changed, 3034 insertions(+), 16 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 0664967..d3fbf0f 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1251,7 +1251,7 @@ constexpr const char* extern_scope = "extern_scope";
  *  This can hint some code generator to create a new function for compute.
  */
 constexpr const char* compute_scope = "compute_scope";
-/*! \brief Mark storage alignement requirement of buffers */
+/*! \brief Mark storage alignment requirement of buffers */
 constexpr const char* storage_alignment = "storage_alignment";
 /*! \brief Mark storage scope of realization */
 constexpr const char* realize_scope = "realize_scope";
@@ -1263,6 +1263,10 @@ constexpr const char* device_type = "device_type";
 constexpr const char* loop_scope = "loop_scope";
 /*! \brief Mark of reduce scope */
 constexpr const char* reduce_scope = "reduce_scope";
+/*! \brief Pragma: auto-unroll, max_step */
+constexpr const char* pragma_auto_unroll_max_step = 
"pragma_auto_unroll_max_step";
+/*! \brief Pragma: unroll explicit */
+constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
 /*! \brief Mark region is guarded by the pragma extension */
 constexpr const char* pragma_scope_prefix = "pragma_";
 /*! \brief Import C source or file into the final code gen module */
diff --git a/python/tvm/meta_schedule/__init__.py 
b/python/tvm/meta_schedule/__init__.py
index 8b6672c..e41e5b3 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -23,4 +23,6 @@ from . import space_generator
 from . import search_strategy
 from . import schedule_rule
 from . import integration
+from . import feature_extractor
 from .tune_context import TuneContext
+from .search_strategy import MeasureCandidate
diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py 
b/python/tvm/meta_schedule/feature_extractor/__init__.py
index f29c44b..83ac742 100644
--- a/python/tvm/meta_schedule/feature_extractor/__init__.py
+++ b/python/tvm/meta_schedule/feature_extractor/__init__.py
@@ -20,4 +20,5 @@ Meta Schedule feature extractors that extracts features from
 measure candidates for use in cost model.
 """
 from .feature_extractor import FeatureExtractor, PyFeatureExtractor
+from .per_store_feature import PerStoreFeature
 from .random_feature_extractor import RandomFeatureExtractor
diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py 
b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
new file mode 100644
index 0000000..306934d
--- /dev/null
+++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""We extract one feature vector per BufferStoreNode statement in a TIR Stmt,
+so we call this feature as "per-store" feature.
+"""
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .feature_extractor import FeatureExtractor
+
+
+@register_object("meta_schedule.PerStoreFeature")
+class PerStoreFeature(FeatureExtractor):
+    """PerStoreFeature extracts one feature vector per BufferStoreNode
+
+    Parameters
+    ----------
+    buffers_per_store : int
+        The number of buffers in each BufferStore; Pad or truncate if 
necessary.
+    arith_intensity_curve_num_samples : int
+        The number of samples used in the arithmetic intensity curve.
+    cache_line_bytes : int
+        The number of bytes in a cache line.
+    """
+
+    buffers_per_store: int
+    """The number of buffers in each BufferStore; Pad or truncate if 
necessary."""
+    arith_intensity_curve_num_samples: int  # pylint: disable=invalid-name
+    """The number of samples used in the arithmetic intensity curve."""
+    cache_line_bytes: int
+    """The number of bytes in a cache line."""
+    feature_vector_length: int
+    """Length of the feature vector."""
+
+    def __init__(
+        self,
+        buffers_per_store: int = 5,
+        arith_intensity_curve_num_samples: int = 10,
+        cache_line_bytes: int = 64,
+    ):
+        self.__init_handle_by_constructor__(
+            _ffi_api.FeatureExtractorPerStoreFeature,  # type: ignore # 
pylint: disable=no-member
+            buffers_per_store,
+            arith_intensity_curve_num_samples,
+            cache_line_bytes,
+        )
diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc 
b/src/meta_schedule/feature_extractor/per_store_feature.cc
new file mode 100644
index 0000000..42cd5d5
--- /dev/null
+++ b/src/meta_schedule/feature_extractor/per_store_feature.cc
@@ -0,0 +1,1337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/transform.h>
+
+#include <cmath>
+#include <memory>
+#include <numeric>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+using support::NDIntSet;
+
+/*! \brief Type for multi-dimensional index */
+using MultiIndex = std::vector<PrimExpr>;
+/*! \brief Vector of int64_t */
+using IntVec = std::vector<int64_t>;
+/*! \brief Vector of for loops */
+using ForVec = std::vector<const ForNode*>;
+
+/*!
+ * \brief An unordered_map for (for, buffer) => V
+ * \tparam V The value type
+ */
+template <class V>
+using ForBufferMap = std::unordered_map<const ForNode*, 
std::unordered_map<const BufferNode*, V>>;
+
+/*! \brief Given x, compute log2(|x| + 1) */
+inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x 
+ 1); }
+
+namespace utils {
+
+/*!
+ * \brief Get the shape of the buffer
+ * \param buffer The buffer
+ * \param analyzer The analyzer
+ * \return The shape of the buffer
+ */
+std::vector<int64_t> GetBufferShape(const Buffer& buffer, arith::Analyzer* 
analyzer) {
+  int ndim = buffer->shape.size();
+  std::vector<int64_t> result;
+  result.reserve(ndim);
+  for (const PrimExpr& i : buffer->shape) {
+    if (const IntImmNode* int_imm = i.as<IntImmNode>()) {
+      result.push_back(int_imm->value);
+      continue;
+    }
+    arith::ConstIntBound bound = analyzer->const_int_bound(i);
+    if (0 <= bound->max_value && bound->max_value < 
arith::ConstIntBound::kPosInf) {
+      result.push_back(bound->max_value);
+    } else {
+      result.push_back(1);
+    }
+  }
+  return result;
+}
+
+/*!
+ * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if 
it exists
+ * \param loop The loop to be checked
+ * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if 
it does not exist
+ */
+int64_t GetPragmaAutoUnroll(const ForNode* loop) {
+  if (Optional<IntImm> auto_unroll = GetAnn<IntImm>(loop, 
tir::attr::pragma_auto_unroll_max_step)) {
+    return auto_unroll.value()->value;
+  }
+  return -1;
+}
+
+/*!
+ * \brief Given a list of loops, return the extent of the first loop if the 
list is not empty,
+ * and the first loop has constant extent. Otherwise returns the default value 
given
+ * \param loops The list of loops to be checked
+ * \param default_value The default value to be returned if the list is empty 
or the first loop
+ * does not have constant extent
+ * \return The extent of the first loop if the list is not empty, or the first 
loop has constant
+ * extent. Otherwise returns the default value
+ */
+int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) {
+  if (!loops.empty()) {
+    if (const int64_t* extent = GetLoopIntExtent(loops[0])) {
+      return *extent;
+    }
+  }
+  return default_value;
+}
+
+/*!
+ * \brief Relax each of the multi-indexing pattern according to the domains 
bound in the analyzer,
+ * and then union them into a single region
+ * \param multi_index_pattern A list of multi-index pattern to be relaxed
+ * \param numel The size of the single region after union
+ * \param analyzer The analyzer that contains the domain information
+ * \return The relaxed and unioned region
+ */
+IntVec RelaxAndUnion(const std::vector<MultiIndex>& multi_indices, int64_t* 
numel,
+                     arith::Analyzer* analyzer) {
+  *numel = 1;
+  if (multi_indices.empty()) {
+    return {};
+  }
+  int n_indices = multi_indices.size();
+  int ndim = multi_indices[0].size();
+  IntVec access_shape(ndim, 0);
+  for (int i = 0; i < ndim; ++i) {
+    int64_t minimum = arith::ConstIntBound::kPosInf;
+    int64_t maximum = arith::ConstIntBound::kNegInf;
+    for (int j = 0; j < n_indices; ++j) {
+      arith::ConstIntBound bound = 
analyzer->const_int_bound(multi_indices[j][i]);
+      minimum = std::min(minimum, bound->min_value);
+      maximum = std::max(maximum, bound->max_value);
+    }
+    *numel *= maximum - minimum + 1;
+    access_shape[i] = maximum - minimum + 1;
+  }
+  return access_shape;
+}
+
+/*!
+ * \brief Given a list of multi-index pattern, return the minimal stride of a 
variable on it
+ * \param multi_indices The list of multi-index pattern
+ * \param buffer_stride The stride of the buffer
+ * \param var The variable to be checked
+ * \return The minimal stride of the variable on the multi-index pattern
+ */
+int64_t GetVarStride(const std::vector<MultiIndex>& multi_indices, const 
IntVec& buffer_stride,
+                     const Var& var) {
+  class CoefficientExtractor : private ExprVisitor {
+   public:
+    static int64_t Extract(const PrimExpr& expr, const Var& var) {
+      CoefficientExtractor extractor(var);
+      extractor.VisitExpr(expr);
+      return (extractor.visited_var && !extractor.visited_mul && 
!extractor.visited_add)
+                 ? 1
+                 : (extractor.visited_var ? extractor.stride : 0);
+    }
+
+   private:
+    explicit CoefficientExtractor(const Var& var)
+        : var(var), stride(0), visited_var(false), visited_add(false), 
visited_mul(false) {}
+
+    void VisitExpr_(const MulNode* node) override {
+      ExprVisitor::VisitExpr_(node);
+      if (visited_var && !visited_add) {
+        if (const auto* a = node->a.as<IntImmNode>()) {
+          visited_mul = true;
+          stride = a->value;
+        } else if (const auto* b = node->b.as<IntImmNode>()) {
+          visited_mul = true;
+          stride = b->value;
+        }
+      }
+    }
+
+    void VisitExpr_(const AddNode* node) override {
+      ExprVisitor::VisitExpr_(node);
+      if (visited_var && !visited_mul) {
+        visited_add = true;
+        stride = 1;
+      }
+    }
+
+    void VisitExpr_(const VarNode* node) override {
+      if (node == var.get()) {
+        visited_var = true;
+        stride = 2;
+      }
+    }
+
+    const Var& var;
+    int64_t stride;
+    bool visited_var;
+    bool visited_add;
+    bool visited_mul;
+  };
+
+  constexpr int64_t kNotFound = std::numeric_limits<int64_t>::max();
+  int ndim = buffer_stride.size();
+  // Calculate the min stride possible
+  int64_t result = kNotFound;
+  for (const MultiIndex& multi_index : multi_indices) {
+    ICHECK_EQ(multi_index.size(), buffer_stride.size());
+    // Find the rightest dimension that contains the given variable
+    for (int i = ndim - 1; i >= 0; --i) {
+      int64_t coef = CoefficientExtractor::Extract(multi_index[i], var);
+      if (coef != 0) {
+        result = std::min(result, std::abs(coef) * buffer_stride[i]);
+        break;
+      }
+    }
+  }
+  return (result == kNotFound) ? 0 : result;
+}
+
+/*!
+ * \brief Converts a 2-dimensional STL vector to a TVM NDArray
+ * \param src The source 2-dimensional STL vector
+ * \return The converted TVM NDArray
+ */
+runtime::NDArray AsNDArray(const std::vector<std::vector<double>>& src) {
+  ICHECK(!src.empty());
+  int n = src.size();
+  int m = src[0].size();
+  runtime::NDArray tgt = runtime::NDArray::Empty(
+      /*shape=*/{n, m},
+      /*dtype=*/DLDataType{kDLFloat, 64, 1},
+      /*ctx=*/DLDevice{kDLCPU, 0});
+  double* data = static_cast<double*>(tgt->data);
+  for (const std::vector<double>& row : src) {
+    for (double v : row) {
+      *data++ = v;
+    }
+  }
+  return tgt;
+}
+
+}  // namespace utils
+
+namespace transform {
+
+/*!
+ * \brief Create a pass that simplifies the IR for feature extraction
+ * \return The pass created
+ */
+Pass SimplifyForFeatureExtraction() {
+  class Simplifier : private StmtExprMutator {
+   public:
+    static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); }
+
+   private:
+    PrimExpr VisitExpr_(const SelectNode* node) final { return 
make_const(node->dtype, 1.0); }
+
+    PrimExpr VisitExpr_(const VarNode* var) final {
+      if (unit_vars_.count(GetRef<Var>(var))) {
+        return make_const(var->dtype, 0.0);
+      }
+      return GetRef<Var>(var);
+    }
+
+    Stmt VisitStmt_(const ForNode* loop) final {
+      if (is_zero(loop->min) && is_one(loop->extent) && loop->kind == 
ForKind::kSerial &&
+          loop->annotations.empty()) {
+        unit_vars_.insert(loop->loop_var);
+        return VisitStmt(loop->body);
+      } else {
+        return StmtExprMutator::VisitStmt_(loop);
+      }
+    }
+
+    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> unit_vars_;
+  };
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    PrimFuncNode* n = f.CopyOnWrite();
+    n->body = Simplifier::Run(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyForFeatureExtraction", 
{});
+}
+
+/*!
+ * \brief Create a list of passes that preprocesses the IR for feature 
extraction
+ * \return The list of passes created
+ */
+Sequential PassListForPerStoreFeature() {
+  return Sequential({
+      tir::transform::SimplifyForFeatureExtraction(),
+      tir::transform::LowerCrossThreadReduction(),
+      tir::transform::LowerInitBlock(),
+      tir::transform::PlanAndUpdateBufferAllocationLocation(),
+      tir::transform::ConvertBlocksToOpaque(),
+      tir::transform::UnifyThreadBinding(),
+      tir::transform::CompactBufferAllocation(),
+      tir::transform::LowerMatchBuffer(),
+      tir::transform::Simplify(),
+  });
+}
+
+}  // namespace transform
+
+/*! \brief A data structure managing loop nests */
+struct LoopNest {
+  int64_t prod = 1;    // The product of the extents of all the loops
+  ForVec loops;        // All the loops
+  IntVec auto_unroll;  // The loops with auto unroll pragma
+  ForVec parallel;     // The loops whose ForKind are kParallel
+  ForVec vectorize;    // The loops whose ForKind are kVectorized
+  ForVec unroll;       // The loops whose ForKind are kUnrolled
+  ForVec blockIdx_x;   // The loops whose ForKind are kThreadBinding to 
blockIdx.x
+  ForVec blockIdx_y;   // The loops whose ForKind are kThreadBinding to 
blockIdx.y
+  ForVec blockIdx_z;   // The loops whose ForKind are kThreadBinding to 
blockIdx.z
+  ForVec threadIdx_x;  // The loops whose ForKind are kThreadBinding to 
threadIdx.x
+  ForVec threadIdx_y;  // The loops whose ForKind are kThreadBinding to 
threadIdx.y
+  ForVec threadIdx_z;  // The loops whose ForKind are kThreadBinding to 
threadIdx.z
+  ForVec vthread;      // The loops whose ForKind are kThreadBinding to 
vthread.*
+
+  /*!
+   * \brief Push a new loop into the loop nest
+   * \param loop The loop to be pushed
+   * \param auto_unroll_attr The auto unroll attribute of the loop
+   * \return A list of for loops that the loop is bound to
+   */
+  ForVec* Push(const ForNode* loop, int64_t* auto_unroll_attr) {
+    if (const int64_t* extent = GetLoopIntExtent(loop)) {
+      this->prod *= *extent;
+    }
+    this->loops.push_back(loop);
+    if ((*auto_unroll_attr = utils::GetPragmaAutoUnroll(loop)) > 0) {
+      this->auto_unroll.push_back(*auto_unroll_attr);
+    }
+    ForVec* ref_loops = nullptr;
+    if (loop->kind == ForKind::kParallel) {
+      ref_loops = &parallel;
+    } else if (loop->kind == ForKind::kVectorized) {
+      ref_loops = &vectorize;
+    } else if (loop->kind == ForKind::kUnrolled) {
+      ref_loops = &unroll;
+    } else if (loop->kind == ForKind::kThreadBinding) {
+      std::string thread_tag = loop->thread_binding.value()->thread_tag;
+      if (thread_tag == "blockIdx.x") {
+        ref_loops = &blockIdx_x;
+      } else if (thread_tag == "blockIdx.y") {
+        ref_loops = &blockIdx_y;
+      } else if (thread_tag == "blockIdx.z") {
+        ref_loops = &blockIdx_z;
+      } else if (thread_tag == "threadIdx.x") {
+        ref_loops = &threadIdx_x;
+      } else if (thread_tag == "threadIdx.y") {
+        ref_loops = &threadIdx_y;
+      } else if (thread_tag == "threadIdx.z") {
+        ref_loops = &threadIdx_z;
+      } else if (support::StartsWith(thread_tag, "vthread")) {
+        ref_loops = &vthread;
+      } else {
+        LOG(FATAL) << "ValueError: Unable to recognize thread tag: " << 
thread_tag;
+      }
+    }
+    if (ref_loops != nullptr) {
+      ref_loops->push_back(loop);
+    }
+    return ref_loops;
+  }
+
+  /*!
+   * \brief Pop the last loop from the loop nest
+   * \param loop The loop to be popped
+   * \param ref_loops The list of for loops that the loop is bound to
+   * \param auto_unroll_attr The auto unroll attribute of the loop
+   */
+  void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) {
+    if (ref_loops) {
+      ref_loops->pop_back();
+    }
+    if (auto_unroll_attr > 0) {
+      this->auto_unroll.pop_back();
+    }
+    if (const int64_t* extent = GetLoopIntExtent(loop)) {
+      this->prod /= *extent;
+    }
+    this->loops.pop_back();
+  }
+};
+
+/****** Group 1: Computation related features ******/
+
+namespace group1 {
+
+/*! \brief Group 1 features */
+struct Feature {
+  /*! \brief Arithmetic features */
+  struct ArithOps {
+    // Float-point arithmetic features
+    int64_t float_mad = 0;         // The number of float MAD (Multiply–add) 
ops
+    int64_t float_add_sub = 0;     // The number of float add and sub ops
+    int64_t float_mul = 0;         // The number of float multiply ops
+    int64_t float_div_mod = 0;     // The number of float div and mod ops
+    int64_t float_cmp = 0;         // The number of float comparison ops
+    int64_t float_math_func = 0;   // The number of float math func calls
+    int64_t float_other_func = 0;  // The number of other float func calls
+    // Integer arithmetic features
+    int64_t int_mad = 0;         // The number of integer MAD (Multiply–add) 
ops
+    int64_t int_add_sub = 0;     // The number of integer add and sub ops
+    int64_t int_mul = 0;         // The number of integer multiply ops
+    int64_t int_div_mod = 0;     // The number of integer div and mod ops
+    int64_t int_cmp = 0;         // The number of integer comparison ops
+    int64_t int_math_func = 0;   // The number of integer math func calls
+    int64_t int_other_func = 0;  // The number of other integer func calls
+    // Other arithmetic features
+    int64_t bool_op = 0;    // The number of bool ops
+    int64_t select_op = 0;  // The number of select ops
+
+    static constexpr int64_t kCount = 16;
+
+    ArithOps() = default;
+    ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent);
+
+    void Export(std::vector<double>* v) const {
+      double vs[] = {
+          slog(float_mad), slog(float_add_sub),   slog(float_mul),        
slog(float_div_mod),
+          slog(float_cmp), slog(float_math_func), slog(float_other_func),  //
+          slog(int_mad),   slog(int_add_sub),     slog(int_mul),          
slog(int_div_mod),
+          slog(int_cmp),   slog(int_math_func),   slog(int_other_func),  //
+          slog(bool_op),   slog(select_op),
+      };
+      v->insert(v->end(), std::begin(vs), std::end(vs));
+    }
+  };
+
+  /*! \brief Loop binding features */
+  struct ForKindFeature {
+    enum class Pos : int {
+      kPosNone = 0,           // Does not have this kind of annotation
+      kPosInnerSpatial = 1,   // The annotated iterator is the innermost 
spatial iterator
+      kPosMiddleSpatial = 2,  // The annotated iterator is a middle spatial 
iterator
+      kPosOuterSpatial = 3,   // The annotated iterator is the outermost 
spatial iterator
+      kPosInnerReduce = 4,    // The annotated iterator is the innermost 
reduce iterator
+      kPosMiddleReduce = 5,   // The annotated iterator is a middle reduce 
iterator
+      kPosOuterReduce = 6,    // The annotated iterator is the outermost 
reduce iterator
+      kPosMixed = 7,          // The annotated iterator is a mixed space and 
reduce iterator
+    };
+    int64_t num = 0;           // The number of iterators with the annotation
+    int64_t prod = 0;          // The product of the lengths of iterators with 
the annotation
+    int64_t len = 0;           // The length of the innermost iterator with 
the annotation
+    Pos pos = Pos::kPosMixed;  // The position of the iterators with the 
annotation
+
+    static constexpr int64_t kCount = 11;
+
+    explicit ForKindFeature(const ForVec& loops);
+
+    void Export(std::vector<double>* v) const {
+      double vs[] = {
+          slog(num),
+          slog(prod),
+          slog(len),
+          static_cast<double>(static_cast<int>(pos) == 0),
+          static_cast<double>(static_cast<int>(pos) == 1),
+          static_cast<double>(static_cast<int>(pos) == 2),
+          static_cast<double>(static_cast<int>(pos) == 3),
+          static_cast<double>(static_cast<int>(pos) == 4),
+          static_cast<double>(static_cast<int>(pos) == 5),
+          static_cast<double>(static_cast<int>(pos) == 6),
+          static_cast<double>(static_cast<int>(pos) == 7),
+      };
+      v->insert(v->end(), std::begin(vs), std::end(vs));
+    }
+  };
+
+  ArithOps arith_ops;           // Arithmetic features
+  ForKindFeature vectorize;     // Loop binding features: kVectorize
+  ForKindFeature unroll;        // Loop binding features: kUnroll
+  ForKindFeature parallel;      // Loop binding features: kParallel
+  bool is_gpu = false;          // If the program is running on GPU
+  int64_t blockIdx_x_len = 1;   // The length of blockIdx.x
+  int64_t blockIdx_y_len = 1;   // The length of blockIdx.y
+  int64_t blockIdx_z_len = 1;   // The length of blockIdx.z
+  int64_t threadIdx_x_len = 1;  // The length of threadIdx.x
+  int64_t threadIdx_y_len = 1;  // The length of threadIdx.y
+  int64_t threadIdx_z_len = 1;  // The length of threadIdx.z
+  int64_t vthread_len = 1;      // The length of virtual thread
+
+  static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount 
* 3 + 8;
+
+  explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, 
bool is_gpu)
+      : arith_ops(store, loop_nest.prod),
+        vectorize(loop_nest.vectorize),
+        unroll(loop_nest.unroll),
+        parallel(loop_nest.parallel) {
+    if (is_gpu) {
+      this->is_gpu = true;
+      this->blockIdx_x_len = utils::FirstLoopExtent(loop_nest.blockIdx_x, 1);
+      this->blockIdx_y_len = utils::FirstLoopExtent(loop_nest.blockIdx_y, 1);
+      this->blockIdx_z_len = utils::FirstLoopExtent(loop_nest.blockIdx_z, 1);
+      this->threadIdx_x_len = utils::FirstLoopExtent(loop_nest.threadIdx_x, 1);
+      this->threadIdx_y_len = utils::FirstLoopExtent(loop_nest.threadIdx_y, 1);
+      this->threadIdx_z_len = utils::FirstLoopExtent(loop_nest.threadIdx_z, 1);
+      this->vthread_len = utils::FirstLoopExtent(loop_nest.vthread, 1);
+    }
+  }
+
+  void Export(std::vector<double>* v) const {
+    this->arith_ops.Export(v);
+    this->vectorize.Export(v);
+    this->unroll.Export(v);
+    this->parallel.Export(v);
+    double vs[] = {
+        static_cast<double>(is_gpu),  //
+        slog(blockIdx_x_len),        slog(blockIdx_y_len),  
slog(blockIdx_z_len),
+        slog(threadIdx_x_len),       slog(threadIdx_y_len), 
slog(threadIdx_z_len),
+        slog(vthread_len),
+    };
+    v->insert(v->end(), std::begin(vs), std::end(vs));
+  }
+};
+
+Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t 
prod_loop_extent) {
+  class ArithOpCounter : public ExprVisitor {
+   public:
+#define TVM_FEATURE_SIMPLE(Type, Counter)       \
+  void VisitExpr_(const Type* op) final {       \
+    result_.Counter += this->prod_loop_extent_; \
+    ExprVisitor::VisitExpr_(op);                \
+  }
+#define TVM_FEATURE_BINARY(Type, FloatCounter, IntCounter) \
+  void VisitExpr_(const Type* op) final {                  \
+    if (op->dtype.is_float()) {                            \
+      result_.FloatCounter += this->prod_loop_extent_;     \
+    } else {                                               \
+      result_.IntCounter += this->prod_loop_extent_;       \
+    }                                                      \
+    ExprVisitor::VisitExpr_(op);                           \
+  }
+    TVM_FEATURE_SIMPLE(AndNode, bool_op);
+    TVM_FEATURE_SIMPLE(OrNode, bool_op);
+    TVM_FEATURE_SIMPLE(NotNode, bool_op);
+    TVM_FEATURE_SIMPLE(SelectNode, select_op);
+    TVM_FEATURE_BINARY(AddNode, float_add_sub, int_add_sub);
+    TVM_FEATURE_BINARY(SubNode, float_add_sub, int_add_sub);
+    TVM_FEATURE_BINARY(MulNode, float_mul, int_mul);
+    TVM_FEATURE_BINARY(DivNode, float_div_mod, int_div_mod);
+    TVM_FEATURE_BINARY(ModNode, float_div_mod, int_div_mod);
+    TVM_FEATURE_BINARY(FloorDivNode, float_div_mod, int_div_mod);
+    TVM_FEATURE_BINARY(FloorModNode, float_div_mod, int_div_mod);
+    TVM_FEATURE_BINARY(MaxNode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(MinNode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(EQNode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(NENode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(LTNode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(LENode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(GTNode, float_cmp, int_cmp);
+    TVM_FEATURE_BINARY(GENode, float_cmp, int_cmp);
+#undef TVM_FEATURE_BINARY
+#undef TVM_FEATURE_SIMPLE
+
+    void VisitExpr_(const CallNode* op) final {
+      static auto op_call_effect_ = 
Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+      TCallEffectKind effect_kind = op_call_effect_[Downcast<Op>(op->op)];
+      bool is_pure =
+          effect_kind == CallEffectKind::kPure || effect_kind == 
CallEffectKind::kExprAnnotation;
+      if (is_pure) {
+        if (op->dtype.is_float()) {
+          result_.float_math_func += prod_loop_extent_;
+        } else {
+          result_.int_math_func += prod_loop_extent_;
+        }
+      } else {
+        if (op->dtype.is_float()) {
+          result_.float_other_func += prod_loop_extent_;
+        } else {
+          result_.int_other_func += prod_loop_extent_;
+        }
+      }
+      ExprVisitor::VisitExpr_(op);
+    }
+
+    int64_t prod_loop_extent_;
+    ArithOps result_;
+  };
+  ArithOpCounter counter;
+  counter.prod_loop_extent_ = prod_loop_extent;
+  counter(store->value);
+  *this = counter.result_;
+}
+
+Feature::ForKindFeature::ForKindFeature(const ForVec& loops) {
+  if (loops.empty()) {
+    this->num = 0;
+    this->prod = 0;
+    this->len = 0;
+    this->pos = ForKindFeature::Pos::kPosNone;
+  } else {
+    const int64_t* last_loop_extent = GetLoopIntExtent(loops.back());
+    this->num = loops.size();
+    this->len = last_loop_extent ? *last_loop_extent : 1;
+    this->pos = ForKindFeature::Pos::kPosMixed;
+    int64_t& prod = this->prod = 1;
+    for (const ForNode* loop : loops) {
+      if (const int64_t* extent = GetLoopIntExtent(loop)) {
+        prod *= *extent;
+      }
+    }
+  }
+}
+
+}  // namespace group1
+
+namespace group2 {
+
+/*! \brief Group 2 features */
+struct Feature {
+  enum class AccessType : int {
+    /*! The buffer is read but not written */
+    kRead = 0,
+    /*! The buffer is written but not read */
+    kWrite = 1,
+    /*! The buffer is both read and written */
+    kReadWrite = 2,
+    /*! Unknown type */
+    kUnknownRW = 3,
+  };
+  enum class ReuseType : int {
+    /*! Buffer reuse because accessed on each iteration of a loop */
+    kLoopMultipleRead = 0,
+    /*! Buffer reuse because it is serially accessed */
+    kSerialMultipleReadWrite = 1,
+    /*! No buffer reuse */
+    kNoReuse = 2,
+  };
+
+  struct SubFeature {
+    /*! \brief The buffer this feature is for */
+    const BufferNode* buffer = nullptr;
+    /*! \brief The access type of the buffer */
+    AccessType access_type = AccessType::kUnknownRW;
+    /*! \brief A list of multi-dimensonal indices used to access the buffer */
+    std::vector<MultiIndex> multi_indices = {};
+    // Access information
+    /*! \brief loop_accessed_numel[i][...] means the number of elements 
accessed by loops[i] */
+    std::vector<std::unordered_map<const BufferNode*, int64_t>> 
loop_accessed_numel = {};
+    /*! \brief The shape of the data access */
+    IntVec access_shape;
+    /*! \brief The bytes that are continuously accessed */
+    int64_t num_continuous_bytes = 1;
+    // Stride information
+    /*! \brief The min stride of the access */
+    int64_t min_stride = 0;
+    /*! \brief The innermost stride */
+    int64_t innermost_stride = 0;
+    /*! \brief The product of the non-strided loops */
+    int64_t prod_non_strided_loop_extent = 0;
+    // Reuse information
+    /*! The type of data reuse */
+    ReuseType reuse_type = ReuseType::kNoReuse;
+    /*! The reuse distance in terms of number of iterations */
+    double reuse_dis_iter = 0.0;
+    /*! The reuse distance in terms of bytes */
+    double reuse_dis_bytes = 0.0;
+    /*! The reuse count */
+    int64_t reuse_ct = 0;
+    // Features
+    /*! The touched memory in bytes */
+    double bytes;
+    /*! The touched unique memory in bytes */
+    double unique_bytes;
+    /*! The number of touched cache lines */
+    double lines;
+    /*! The number touched unique cache lines */
+    double unique_lines;
+    /*! bytes / reuse_ct */
+    double bytes_d_reuse_ct;
+    /*! unique_bytes / reuse_ct */
+    double unique_bytes_d_reuse_ct;
+    /*! lines / reuse_ct */
+    double lines_d_reuse_ct;
+    /*! unique_lines / reuse_ct */
+    double unique_lines_d_reuse_ct;
+    /*! The stride in access */
+    double stride;
+
+    static constexpr int64_t kCount = 18;
+
+    void Export(std::vector<double>* v) const {
+      double vs[] = {
+          static_cast<double>(static_cast<int>(access_type) == 0),
+          static_cast<double>(static_cast<int>(access_type) == 1),
+          static_cast<double>(static_cast<int>(access_type) == 2),
+          // FeatureSet::BufferAccess::AccessType::kUnknownRW is ignored
+          slog(bytes),
+          slog(unique_bytes),
+          slog(lines),
+          slog(unique_lines),
+          static_cast<double>(static_cast<int>(reuse_type) == 0),
+          static_cast<double>(static_cast<int>(reuse_type) == 1),
+          static_cast<double>(static_cast<int>(reuse_type) == 2),
+          slog(reuse_dis_iter),
+          slog(reuse_dis_bytes),
+          slog(reuse_ct),
+          slog(bytes_d_reuse_ct),
+          slog(unique_bytes_d_reuse_ct),
+          slog(lines_d_reuse_ct),
+          slog(unique_lines_d_reuse_ct),
+          slog(stride),
+      };
+      v->insert(v->end(), std::begin(vs), std::end(vs));
+    }
+
+    static void Pad(std::vector<double>* v) { v->insert(v->end(), 18, 0.0); }
+
+    void SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer);
+
+    void SetReuse(const LoopNest& loop_nest,     //
+                  int64_t top_loop_touch_bytes,  //
+                  const ForBufferMap<IntVec>& buffer_touched_under_loop);
+
+    void SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes);
+
+    explicit SubFeature(const BufferNode* buffer, AccessType access_type,
+                        std::vector<MultiIndex> multi_indices, int n_loops)
+        : buffer(buffer),
+          access_type(access_type),
+          multi_indices(multi_indices),
+          loop_accessed_numel(n_loops) {}
+  };
+
+  void Export(std::vector<double>* v, int buffers_per_store) const {
+    int n = sub_features.size();
+    for (int i = 0; i < buffers_per_store; ++i) {
+      if (i < n) {
+        sub_features[i].Export(v);
+      } else {
+        SubFeature::Pad(v);
+      }
+    }
+  }
+
+  explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest,
+                   int64_t cache_line_bytes, IntVec* for_touched_bytes,
+                   ForBufferMap<IntVec>* buffer_touched_under_loop, 
arith::Analyzer* analyzer);
+
+  void Init(const BufferStoreNode* store, int n_loops);
+
+  void SetRegion(const LoopNest& loop_nest,                        //
+                 IntVec* for_touched_bytes,                        //
+                 ForBufferMap<IntVec>* buffer_touched_under_loop,  //
+                 arith::Analyzer* analyzer);
+
+  std::vector<SubFeature> sub_features;
+};
+
+void Feature::Init(const BufferStoreNode* store, int n_loops) {
+  struct Info {
+    AccessType access_type = AccessType::kUnknownRW;
+    std::vector<MultiIndex> multi_indices;
+  };
+  std::unordered_map<const BufferNode*, Info> buffer_info;
+  {
+    Info& info = buffer_info[store->buffer.get()];
+    info.access_type = AccessType::kWrite;
+    info.multi_indices.push_back({store->indices.begin(), 
store->indices.end()});
+  }
+  PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void {
+    if (const BufferLoadNode* load = obj.as<BufferLoadNode>()) {
+      const BufferNode* buffer = load->buffer.get();
+      Info& info = buffer_info[buffer];
+      switch (info.access_type) {
+        case AccessType::kRead:
+          break;
+        case AccessType::kWrite:
+          info.access_type = AccessType::kReadWrite;
+          break;
+        case AccessType::kReadWrite:
+          break;
+        case AccessType::kUnknownRW:
+        default:
+          info.access_type = AccessType::kRead;
+          break;
+      }
+      if (info.access_type != AccessType::kReadWrite) {
+        info.multi_indices.push_back({load->indices.begin(), 
load->indices.end()});
+      }
+    }
+  });
+  this->sub_features.reserve(buffer_info.size());
+  for (const auto& kv : buffer_info) {
+    this->sub_features.emplace_back(kv.first, kv.second.access_type,
+                                    std::move(kv.second.multi_indices), 
n_loops);
+  }
+}
+
+void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes,
+                        ForBufferMap<IntVec>* buffer_touched_under_loop,
+                        arith::Analyzer* analyzer) {
+  int n_loops = loop_nest.loops.size();
+  const std::vector<const ForNode*>& loops = loop_nest.loops;
+  // Step 1. Initialize and bind all the loop variables to a constant
+  *for_touched_bytes = IntVec(n_loops, 0);
+  for (int i = 0; i < n_loops; ++i) {
+    const ForNode* loop = loops[i];
+    analyzer->Bind(loop->loop_var, loop->min, /*allow_override=*/true);
+  }
+  // Step 2. Corner case: no loops
+  if (n_loops == 0) {
+    // In this case, the `access_shape` is not calculated
+    for (SubFeature& feature : sub_features) {
+      feature.access_shape = IntVec(feature.buffer->shape.size(), 1);
+    }
+    return;
+  }
+  // Step 3. Gradually bind the loops from inner to outer,
+  // calculate the area the loops touch on each buffer
+  for (int i = n_loops - 1; i >= 0; --i) {
+    const ForNode* loop = loops[i];
+    analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, 
loop->extent),
+                   /*allow_override=*/true);
+    int64_t& touched_bytes = (*for_touched_bytes)[i] = 0;
+    for (SubFeature& feature : sub_features) {
+      const BufferNode* buffer = feature.buffer;
+      // Note: `feature.access_shape` for `i == 0` is the only one preserved,
+      // while others are discarded
+      int64_t numel;
+      feature.access_shape = utils::RelaxAndUnion(feature.multi_indices, 
&numel, analyzer);
+      feature.loop_accessed_numel[i][buffer] = numel;
+      touched_bytes += numel * buffer->dtype.bytes();
+      (*buffer_touched_under_loop)[loop][buffer].push_back(numel);
+    }
+  }
+}
+
+void Feature::SubFeature::SetStride(const LoopNest& loop_nest, 
arith::Analyzer* analyzer) {
+  int n_loops = loop_nest.loops.size();
+  const std::vector<const ForNode*>& loops = loop_nest.loops;
+  // For each buffer, we find the loop stride on it
+  const BufferNode* buffer = this->buffer;
+  int ndim = this->buffer->shape.size();
+  IntVec buffer_shape = utils::GetBufferShape(GetRef<Buffer>(buffer), 
analyzer);
+  // Calculate the buffer's stride from its shape
+  IntVec buffer_stride(ndim);
+  if (ndim >= 1) {
+    buffer_stride[ndim - 1] = 1;
+    for (int i = ndim - 2; i >= 0; --i) {
+      buffer_stride[i] = buffer_stride[i + 1] * buffer_shape[i + 1];
+    }
+  }
+  // Calculate `num_continuous_bytes`
+  {
+    int64_t& num_continuous_bytes = this->num_continuous_bytes = 1;
+    const IntVec& access_shape = this->access_shape;
+    ICHECK_EQ(access_shape.size(), buffer_shape.size());
+    for (int i = ndim - 1; i >= 0; --i) {
+      if (access_shape[i] == buffer_shape[i]) {
+        num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes();
+        break;
+      }
+    }
+  }
+  // Enumerate loops from inner to outer
+  int i = 0;
+  // Calculate this->min_stride
+  int64_t& stride = this->min_stride = 0;
+  for (i = n_loops - 1; i >= 0; --i) {
+    stride = utils::GetVarStride(this->multi_indices, buffer_stride, 
loops[i]->loop_var);
+    if (stride != 0) {
+      break;
+    }
+  }
+  // Calculate this->innermost_stride
+  this->innermost_stride = (i == n_loops - 1) ? stride : 0;
+  // Calculate this->prod
+  int64_t& prod = this->prod_non_strided_loop_extent = 1;
+  for (int j = n_loops - 1; j > i; --j) {
+    if (const int64_t* extent = GetLoopIntExtent(loops[n_loops - 1])) {
+      prod *= *extent;
+    }
+  }
+}
+
+void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t 
top_loop_touch_bytes,
+                                   const ForBufferMap<IntVec>& 
buffer_touched_under_loop) {
+  const BufferNode* buffer = this->buffer;
+  // Step 0. Collect all `Var`s that appears in the buffer region
+  std::unordered_set<const VarNode*> region_vars;
+  for (const MultiIndex& multi_index : this->multi_indices) {
+    for (const PrimExpr& index : multi_index) {
+      PostOrderVisit(index, [&region_vars](const ObjectRef& obj) -> void {
+        if (const auto* var = obj.as<VarNode>()) {
+          region_vars.insert(var);
+        }
+      });
+    }
+  }
+  // Default case: no reuse
+  ReuseType& reuse_type = this->reuse_type = ReuseType::kNoReuse;
+  double& reuse_dis_iter = this->reuse_dis_iter = 0;
+  double& reuse_dis_bytes = this->reuse_dis_bytes = 0;
+  int64_t& reuse_ct = this->reuse_ct = 0;
+
+  // Step 3.2. Enumerate loops from inner to outer, find the first loop with 
reuse
+  int n_loops = loop_nest.loops.size();
+  const std::vector<const ForNode*>& loops = loop_nest.loops;
+  for (int i = n_loops - 1; i >= 0; --i) {
+    const ForNode* loop = loops[i];
+    // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead
+    if (!region_vars.count(loop->loop_var.get())) {
+      reuse_type = ReuseType::kLoopMultipleRead;
+      if (const int64_t* extent = GetLoopIntExtent(loop)) {
+        reuse_ct = *extent;
+      } else {
+        reuse_ct = 1;
+      }
+      reuse_dis_iter = 1;
+      for (int j = n_loops - 1; j > i; --j) {
+        if (const int64_t* extent = GetLoopIntExtent(loops[j])) {
+          reuse_dis_iter *= *extent;
+        }
+      }
+      reuse_dis_bytes = 0.0;
+      if (i == n_loops - 1) {
+        reuse_dis_bytes = top_loop_touch_bytes;
+      } else {
+        for (const auto& iter : buffer_touched_under_loop.at(loops[i + 1])) {
+          const BufferNode* buffer = iter.first;
+          const IntVec& numels = iter.second;
+          int64_t numel = std::accumulate(numels.begin(), numels.end(), 
int64_t(0));
+          reuse_dis_bytes += numel * buffer->dtype.bytes();
+        }
+      }
+      break;
+    }
+    // Case 2. Find serial reuse, i.e. reuse with kSerialMultipleReadWrite
+    const IntVec& touched = buffer_touched_under_loop.at(loop).at(buffer);
+    if (touched.size() >= 2) {
+      int64_t extent = 1;
+      if (const int64_t* ext = GetLoopIntExtent(loop)) {
+        extent = *ext;
+      }
+      reuse_type = ReuseType::kSerialMultipleReadWrite;
+      reuse_ct = touched.size() - 1;
+      reuse_dis_iter = *std::min_element(touched.begin(), touched.end());
+      reuse_dis_bytes = 0.0;
+      for (const auto& iter : buffer_touched_under_loop.at(loop)) {
+        const BufferNode* buffer = iter.first;
+        const IntVec& numels = iter.second;
+        int64_t numel = std::accumulate(numels.begin(), numels.end(), 
int64_t(0));
+        reuse_dis_bytes += numel * buffer->dtype.bytes();
+      }
+      reuse_dis_iter /= extent;
+      reuse_dis_bytes /= extent;
+      break;
+    }
+  }
+}
+
+void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t 
cache_line_bytes) {
+  int64_t dtype_bytes = this->buffer->dtype.bytes();
+  this->stride = this->innermost_stride;
+  this->bytes = dtype_bytes * loop_nest.prod;
+  if (loop_nest.loops.empty()) {
+    this->unique_bytes = 1;
+    this->lines = 1;
+    this->unique_lines = 1;
+  } else {
+    this->unique_bytes = this->loop_accessed_numel.front().at(buffer) * 
dtype_bytes;
+    this->lines = static_cast<double>(loop_nest.prod) / 
this->prod_non_strided_loop_extent *
+                  std::min(1.0, 1.0 * this->min_stride * dtype_bytes / 
cache_line_bytes);
+    this->lines = std::max(1.0, this->lines);
+    this->unique_lines = static_cast<double>(this->unique_bytes) /
+                         std::min(cache_line_bytes, 
this->num_continuous_bytes);
+    this->unique_lines = std::max(1.0, this->unique_lines);
+  }
+  double proxy_reuse_ct = this->reuse_ct > 0 ? this->reuse_ct : 0.5;
+  this->bytes_d_reuse_ct = this->bytes / proxy_reuse_ct;
+  this->unique_bytes_d_reuse_ct = this->unique_bytes / proxy_reuse_ct;
+  this->lines_d_reuse_ct = this->lines / proxy_reuse_ct;
+  this->unique_lines_d_reuse_ct = this->unique_lines / proxy_reuse_ct;
+}
+
+Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, 
int64_t cache_line_bytes,
+                 IntVec* for_touched_bytes, ForBufferMap<IntVec>* 
buffer_touched_under_loop,
+                 arith::Analyzer* analyzer) {
+  int n_loops = loop_nest.loops.size();
+  // Step 0. Initialize data structures
+  this->Init(store, n_loops);
+  // Step 1. Calculate region-related feature
+  this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop, 
analyzer);
+  // Step 2. Calculate stride-related feature
+  for (auto& feature : sub_features) {
+    feature.SetStride(loop_nest, analyzer);
+  }
+  // Step 3. Calculate reuse-related feature
+  int64_t top_loop_touch_bytes = 0.0;
+  if (n_loops > 0) {
+    for (const SubFeature& feature : sub_features) {
+      int64_t bytes = feature.buffer->dtype.bytes();
+      int64_t n_buffer = feature.loop_accessed_numel[0].size();
+      top_loop_touch_bytes += bytes * n_buffer;
+    }
+  }
+  for (auto& feature : sub_features) {
+    feature.SetReuse(loop_nest, top_loop_touch_bytes, 
*buffer_touched_under_loop);
+  }
+  // Step 4. Calculate rest of the features
+  for (auto& feature : sub_features) {
+    feature.SetFeature(loop_nest, cache_line_bytes);
+  }
+  // Step 5. Sort the features
+  std::sort(sub_features.begin(), sub_features.end(), [](const SubFeature& a, 
const SubFeature& b) {
+    if (a.lines != b.lines) {
+      return a.lines > b.lines;
+    }
+    if (a.bytes != b.bytes) {
+      return a.bytes > b.bytes;
+    }
+    return a.buffer->name < b.buffer->name;
+  });
+}
+
+}  // namespace group2
+
+namespace group3 {
+
+/*! \brief Group 3 feature */
+struct Feature {
+  /*!
+   * \brief See the wiki page [1] for details
+   *
+   * [1] https://en.wikipedia.org/wiki/Roofline_model
+   */
+  std::vector<double> arith_intensity_curve;
+
+  void Export(std::vector<double>* v) const {
+    v->insert(v->end(), arith_intensity_curve.begin(), 
arith_intensity_curve.end());
+  }
+
+  explicit Feature(int n_samples, const LoopNest& loop_nest, const IntVec& 
for_touched_bytes,
+                   const group1::Feature::ArithOps& arith_ops)
+      : arith_intensity_curve(n_samples, 0.0) {
+    const std::vector<const ForNode*>& loops = loop_nest.loops;
+    ICHECK_EQ(loops.size(), for_touched_bytes.size());
+    int n_loops = loops.size();
+    // Calculate `memory_bytes`
+    std::vector<double> memory_bytes;
+    memory_bytes.resize(n_loops);
+    for (int i = 0; i < n_loops; ++i) {
+      memory_bytes[n_loops - 1 - i] = std::log2(for_touched_bytes[i]);
+    }
+    // Calculate `compute_ops` and `cur_compute_ops`
+    std::vector<double> compute_ops;
+    double total_compute_ops = arith_ops.float_mad + arith_ops.float_add_sub + 
arith_ops.float_mul +
+                               arith_ops.float_div_mod + arith_ops.float_cmp +
+                               arith_ops.float_math_func + 
arith_ops.float_other_func;
+    total_compute_ops /= loop_nest.prod;
+    for (int i = n_loops - 1; i >= 0; --i) {
+      if (const int64_t* extent = GetLoopIntExtent(loops[i])) {
+        total_compute_ops *= *extent;
+      }
+      compute_ops.push_back(std::log2(total_compute_ops));
+    }
+    // Fill the feature set
+    if (total_compute_ops <= 0 || compute_ops.empty()) {
+      for (int i = 0; i < n_samples; ++i) {
+        arith_intensity_curve[i] = 0.0;
+      }
+      return;
+    }
+    total_compute_ops = compute_ops.back();  // i.e. total_compute_ops = 
log2(total_compute_ops)
+    int p = 0;
+    for (int i = 0; i < n_samples; ++i) {
+      double& result = arith_intensity_curve[i];
+      double cur_compute_ops = static_cast<double>(i + 1) / n_samples * 
total_compute_ops;
+      // Find the first `p` that `compute[p] >= total * (i + 1) / N`
+      for (; p < n_loops; ++p) {
+        if (compute_ops[p] >= cur_compute_ops - 1e-4) {
+          break;
+        }
+      }
+      CHECK_LT(p, n_loops);
+      if (p == 0) {
+        result = compute_ops[p] / memory_bytes[p];
+      } else {
+        double base = compute_ops[p - 1] / memory_bytes[p - 1];
+        double slope =
+            (compute_ops[p] / memory_bytes[p] - compute_ops[p - 1] / 
memory_bytes[p - 1]) /
+            (compute_ops[p] - compute_ops[p - 1]);
+        result = base + slope * (cur_compute_ops - compute_ops[p - 1]);
+      }
+    }
+  }
+};
+
+}  // namespace group3
+
+namespace group4 {
+
+/*! \brief Group 4 feature */
+struct Feature {
+  int64_t alloc_size = 0;        // The size of allocated buffer in bytes
+  int64_t alloc_prod = 0;        // alloc_outer_prod * alloc_inner_prod
+  int64_t alloc_outer_prod = 1;  // The product of lengths of loops outside 
the scope of the alloc
+
+  static constexpr int64_t kCount = 4;
+
+  void Export(std::vector<double>* v, int64_t outer_prod) const {
+    double vs[] = {
+        slog(alloc_size),
+        slog(alloc_prod),
+        slog(alloc_outer_prod),
+        slog(static_cast<double>(outer_prod) / alloc_outer_prod),
+    };
+    v->insert(v->end(), std::begin(vs), std::end(vs));
+  }
+
+  Feature() = default;
+
+  explicit Feature(const LoopNest& loop_nest, const Buffer& buffer, 
arith::Analyzer* analyzer) {
+    std::vector<int64_t> shape = utils::GetBufferShape(buffer, analyzer);
+    int64_t numel = 1;
+    for (int64_t x : shape) {
+      numel *= x;
+    }
+    alloc_size = numel * buffer->dtype.bytes();
+    alloc_prod = numel * loop_nest.prod;
+    alloc_outer_prod = loop_nest.prod;
+  }
+};
+
+}  // namespace group4
+
+namespace group5 {
+
+/*! \brief Group 5 feature */
+struct Feature {
+  int64_t outer_prod;        // The product of lengths of outer loops
+  int num_loops;             // The number of outer loops
+  int auto_unroll_max_step;  // The value of pragma "auto_unroll_max_step"
+
+  static constexpr int64_t kCount = 3;
+
+  void Export(std::vector<double>* v) const {
+    double vs[] = {
+        slog(outer_prod),
+        slog(num_loops),
+        slog(auto_unroll_max_step),
+    };
+    v->insert(v->end(), std::begin(vs), std::end(vs));
+  }
+
+  explicit Feature(const LoopNest& loop_nest) {
+    this->outer_prod = loop_nest.prod;
+    this->num_loops = loop_nest.loops.size();
+    this->auto_unroll_max_step = loop_nest.auto_unroll.empty() ? 0 : 
loop_nest.auto_unroll.back();
+  }
+};
+
+}  // namespace group5
+
+/*! \brief The feature extracted */
+struct Feature {
+  const BufferNode* buffer = nullptr;
+  int buffer_order = -1;
+  std::unique_ptr<group1::Feature> group1 = nullptr;
+  std::unique_ptr<group2::Feature> group2 = nullptr;
+  std::unique_ptr<group3::Feature> group3 = nullptr;
+  std::unique_ptr<group4::Feature> group4 = nullptr;
+  std::unique_ptr<group5::Feature> group5 = nullptr;
+
+  bool operator<(const Feature& other) const { return buffer_order < 
other.buffer_order; }
+};
+
+/*! \brief The main feature extractor */
+class PerStoreFeatureCollector : private StmtVisitor {
+ public:
+  static std::vector<Feature> Collect(bool is_gpu, int64_t cache_line_bytes,
+                                      int64_t 
arith_intensity_curve_num_samples,
+                                      const IRModule& mod) {
+    PerStoreFeatureCollector collector(is_gpu, cache_line_bytes, 
arith_intensity_curve_num_samples);
+    for (const auto& kv : mod->functions) {
+      if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
+        collector(func->body);
+        for (const auto& it : func->buffer_map) {
+          collector.HandleBufferAlloc(it.second);
+        }
+      }
+    }
+    std::vector<Feature> result;
+    result.reserve(collector.buffer_features_.size());
+    for (auto& it : collector.buffer_features_) {
+      Feature& feature = it.second;
+      if (feature.buffer != nullptr) {
+        ICHECK(feature.group1);
+        ICHECK(feature.group2);
+        ICHECK(feature.group3);
+        ICHECK(feature.group5);
+        if (feature.group4 == nullptr) {
+          feature.group4 = std::make_unique<group4::Feature>();
+        }
+        result.push_back(std::move(feature));
+      }
+    }
+    std::sort(result.begin(), result.end());
+    return result;
+  }
+
+ private:
+  void VisitStmt_(const ForNode* loop) final {
+    int64_t auto_unroll;
+    ForVec* for_vec = loop_nest_.Push(loop, &auto_unroll);
+    StmtVisitor::VisitStmt_(loop);
+    loop_nest_.Pop(loop, for_vec, auto_unroll);
+  }
+
+  void VisitStmt_(const BufferStoreNode* store) final {
+    if (store->value->IsInstance<IntImmNode>() || 
store->value->IsInstance<FloatImmNode>()) {
+      return;
+    }
+    const BufferNode* buffer = store->buffer.get();
+    Feature& feature = buffer_features_[buffer];
+    if (feature.buffer == nullptr) {
+      feature.buffer = buffer;
+      feature.buffer_order = buffer_features_.size();
+    }
+    feature.group1 = std::make_unique<group1::Feature>(store, loop_nest_, 
is_gpu_);
+    feature.group2 =
+        std::make_unique<group2::Feature>(store, loop_nest_, 
cache_line_bytes_, &for_touched_bytes_,
+                                          &buffer_touched_under_loop_, 
&analyzer_);
+    feature.group3 =
+        std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, 
loop_nest_,
+                                          for_touched_bytes_, 
feature.group1->arith_ops);
+    feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    StmtVisitor::VisitStmt_(block);
+    for (const Buffer& buffer : block->alloc_buffers) {
+      HandleBufferAlloc(buffer);
+    }
+  }
+
+  void HandleBufferAlloc(const Buffer& buffer) {
+    Feature& feature = buffer_features_[buffer.get()];
+    feature.group4 = std::make_unique<group4::Feature>(loop_nest_, buffer, 
&analyzer_);
+  }
+
+  explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes,
+                                    int64_t arith_intensity_curve_num_samples)
+      : is_gpu_(is_gpu),
+        cache_line_bytes_(cache_line_bytes),
+        arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) 
{}
+
+  bool is_gpu_;
+  int64_t cache_line_bytes_;
+  int64_t arith_intensity_curve_num_samples_;
+  arith::Analyzer analyzer_;
+  LoopNest loop_nest_ = {};
+  IntVec for_touched_bytes_ = {};
+  ForBufferMap<IntVec> buffer_touched_under_loop_ = {};
+  std::unordered_map<const BufferNode*, Feature> buffer_features_ = {};
+};
+
+}  // namespace tir
+}  // namespace tvm
+
+namespace tvm {
+namespace meta_schedule {
+
+class PerStoreFeatureNode : public FeatureExtractorNode {
+ public:
+  int buffers_per_store;
+  int arith_intensity_curve_num_samples;
+  int cache_line_bytes;
+  int feature_vector_length;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("buffers_per_store", &buffers_per_store);
+    v->Visit("arith_intensity_curve_num_samples", 
&arith_intensity_curve_num_samples);
+    v->Visit("cache_line_bytes", &cache_line_bytes);
+    v->Visit("feature_vector_length", &feature_vector_length);
+  }
+
+  void ExtractSingle(IRModule mod, bool is_gpu, 
std::vector<std::vector<double>>* results) {
+    static transform::Sequential passes = 
tir::transform::PassListForPerStoreFeature();
+    mod = passes(std::move(mod));
+    std::vector<tir::Feature> features = 
tir::PerStoreFeatureCollector::Collect(
+        is_gpu, this->cache_line_bytes, 
this->arith_intensity_curve_num_samples, mod);
+    int n_features = features.size();
+    results->resize(n_features);
+    for (int i = 0; i < n_features; ++i) {
+      const tir::Feature& feature = features[i];
+      std::vector<double>& result = (*results)[i];
+      result.reserve(feature_vector_length);
+      feature.group1->Export(&result);
+      feature.group2->Export(&result, this->buffers_per_store);
+      feature.group3->Export(&result);
+      feature.group4->Export(&result, feature.group5->outer_prod);
+      feature.group5->Export(&result);
+      ICHECK_EQ(static_cast<int>(result.size()), feature_vector_length);
+    }
+  }
+
+  Array<runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+                                      const Array<MeasureCandidate>& 
candidates) {
+    bool is_gpu = tune_context->target.value()->kind->name == "cuda";
+    std::vector<runtime::NDArray> results;
+    results.resize(candidates.size());
+    auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
+      const auto& candidate = candidates[task_id];
+      std::vector<std::vector<double>> features;
+      ExtractSingle(candidate->sch->mod(), is_gpu, &features);
+      results[task_id] = tir::utils::AsNDArray(features);
+    };
+    support::parallel_for_dynamic(0, candidates.size(), 
tune_context->num_threads, f);
+    return results;
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.PerStoreFeature";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode);
+};
+
+FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store,
+                                                   int 
arith_intensity_curve_num_samples,
+                                                   int cache_line_bytes) {
+  ObjectPtr<PerStoreFeatureNode> n = make_object<PerStoreFeatureNode>();
+  n->buffers_per_store = buffers_per_store;
+  n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples;
+  n->cache_line_bytes = cache_line_bytes;
+  n->feature_vector_length = tir::group1::Feature::kCount +                    
              //
+                             tir::group2::Feature::SubFeature::kCount * 
buffers_per_store +  //
+                             arith_intensity_curve_num_samples +               
              //
+                             tir::group4::Feature::kCount +                    
              //
+                             tir::group5::Feature::kCount;
+  return FeatureExtractor(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode);
+TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature")
+    .set_body_typed(FeatureExtractor::PerStoreFeature);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 3e989e4..ef15f49 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -20,6 +20,7 @@
 #define TVM_META_SCHEDULE_UTILS_H_
 
 #include <dmlc/memory_io.h>
+#include <tvm/arith/analyzer.h>
 #include <tvm/meta_schedule/arg_info.h>
 #include <tvm/meta_schedule/builder.h>
 #include <tvm/meta_schedule/cost_model.h>
@@ -43,7 +44,10 @@
 #include "../printer/text_printer.h"
 #include "../support/array.h"
 #include "../support/base64.h"
+#include "../support/nd_int_set.h"
+#include "../support/utils.h"
 #include "../tir/schedule/primitive.h"
+#include "../tir/schedule/utils.h"
 
 namespace tvm {
 namespace meta_schedule {
diff --git a/src/tir/schedule/primitive/sampling.cc 
b/src/tir/schedule/primitive/sampling.cc
index 83ef1e2..6d944b3 100644
--- a/src/tir/schedule/primitive/sampling.cc
+++ b/src/tir/schedule/primitive/sampling.cc
@@ -322,9 +322,9 @@ std::vector<int64_t> SamplePerfectTile(
     const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t 
max_innermost_factor,
     Optional<Array<Integer>>* decision) {
   const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
-  int64_t extent = GetLoopIntExtent(loop);
+  const int64_t* extent = GetLoopIntExtent(loop);
   std::vector<int64_t> result;
-  if (extent == -1) {
+  if (extent == nullptr) {
     // Case 1. Handle loops with non-constant length
     result = std::vector<int64_t>(n_splits, 1);
     result[0] = -1;
@@ -333,7 +333,7 @@ std::vector<int64_t> SamplePerfectTile(
     result = support::AsVector<Integer, int64_t>(decision->value());
     int n = result.size();
     ICHECK_GE(n, 2);
-    int64_t len = extent;
+    int64_t len = *extent;
     for (int i = n - 1; i > 0; --i) {
       int64_t& l = result[i];
       // A previous decision could become invalid because of the change of 
outer tiles
@@ -347,7 +347,7 @@ std::vector<int64_t> SamplePerfectTile(
     result[0] = len;
   } else {
     // Case 3. Use fresh new sampling result
-    result = SamplePerfectTile(rand_state, extent, n_splits, 
max_innermost_factor);
+    result = SamplePerfectTile(rand_state, *extent, n_splits, 
max_innermost_factor);
     ICHECK_LE(result.back(), max_innermost_factor);
   }
   *decision = support::AsArray<int64_t, Integer>(result);
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index c66c2ca..860b3f6 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -215,21 +215,69 @@ inline Map<Var, arith::IntSet> AsIntSet(const Map<Var, 
Range>& var_dom) {
 /*!
  * \brief Get the extents of a loop
  * \param loop The loop to be queried
- * \return The extents of the loop
+ * \return The extent of the loop, nullptr if the extent is not constant
  */
-inline int64_t GetLoopIntExtent(const ForNode* loop) {
-  const auto* int_extent = loop->extent.as<IntImmNode>();
-  return int_extent ? int_extent->value : -1;
-}
+inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return 
as_const_int(loop->extent); }
 
 /*!
  * \brief Get the extents of a loop
  * \param loop_sref The loop to be queried
- * \return The extents of the loop
+ * \return The extent of the loop, nullptr if the extent is not constant
  */
-inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) {
+inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) {
   const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
-  return GetLoopIntExtent(loop);
+  return as_const_int(loop->extent);
+}
+
+/******** Annotation ********/
+
+/*!
+ * \brief Get the annotation on a Block/For
+ * \tparam TObjectRef The type of the annotation value
+ * \param sref The sref to the block or the for loop
+ * \param ann_key The annotation key to be looked up
+ * \return NullOpt if not found; otherwise the annotation value
+ */
+template <class TObjectRef, class TStmtNode>
+inline Optional<TObjectRef> GetAnn(const TStmtNode* stmt, const String& 
ann_key) {
+  const Map<String, ObjectRef>* annotations = &stmt->annotations;
+  for (const auto& ann : *annotations) {
+    if (ann.first == ann_key) {
+      return Downcast<TObjectRef>(ann.second);
+    }
+  }
+  return NullOpt;
+}
+
+/*!
+ * \brief Get the annotation on a Block/For
+ * \tparam TObjectRef The type of the annotation value
+ * \param sref The sref to the block or the for loop
+ * \param ann_key The annotation key to be looked up
+ * \return NullOpt if not found; otherwise the annotation value
+ */
+template <class TObjectRef>
+inline Optional<TObjectRef> GetAnn(const StmtSRef& sref, const String& 
ann_key) {
+  if (const auto* loop = sref->StmtAs<ForNode>()) {
+    return GetAnn<TObjectRef, ForNode>(loop, ann_key);
+  } else if (const auto* block = sref->StmtAs<BlockNode>()) {
+    return GetAnn<TObjectRef, BlockNode>(block, ann_key);
+  } else {
+    LOG(FATAL) << "TypeError: Unknown type of sref: " << 
sref->stmt->GetTypeKey();
+    throw;
+  }
+}
+
+/*!
+ * \brief Check if a Block/For has a specific pair of annotation key and values
+ * \param sref The sref to the block or the for loop
+ * \param ann_key The annotation key to be checked
+ * \param ann_val The annotation value to be checked
+ * \return Whether a Block/For has a specific pair of annotation key and values
+ */
+inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& 
ann_val) {
+  Optional<String> result = GetAnn<String>(sref, ann_key);
+  return result.defined() && result.value() == ann_val;
 }
 
 }  // namespace tir
diff --git a/src/tir/transforms/unify_thread_binding.cc 
b/src/tir/transforms/unify_thread_binding.cc
index aa58684..9c1aab6 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -58,8 +58,15 @@ class ThreadBindingUnifier : public StmtExprMutator {
     if (op->kind != ForKind::kThreadBinding) {
       return StmtExprMutator::VisitStmt_(op);
     }
-    return UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(),
-                                  Range::FromMinExtent(op->min, op->extent));
+    Map<String, ObjectRef> annotations = op->annotations;
+    Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var, 
op->thread_binding.value(),
+                                       Range::FromMinExtent(op->min, 
op->extent));
+    if (annotations.empty()) {
+      return stmt;
+    }
+    For new_loop = Downcast<For>(stmt);
+    new_loop.CopyOnWrite()->annotations = std::move(annotations);
+    return new_loop;
   }
 
   template <typename Node>
@@ -70,7 +77,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
     const String& thread_tag = old_iter_var->thread_tag;
 
     // Step 2: Increase `thread_block_depth_` if the thread tag starts with 
"blockIdx". If the
-    // thread block depth is 0 before the increasement, it means we are 
entering a new kernel, and
+    // thread block depth is 0 before the increment, it means we are entering 
a new kernel, and
     // therefore we need to make `thread_tag2iter_var_map_` empty, as 
different kernels can have
     // thread axes with different extents.
     bool is_kernel_launch_scope = false;
diff --git 
a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
 
b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
new file mode 100644
index 0000000..7b6ef52
--- /dev/null
+++ 
b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
@@ -0,0 +1,1555 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import sys
+from typing import Callable, List
+
+from numpy.testing import assert_allclose
+import pytest
+
+import tvm
+from tvm import meta_schedule as ms, te, tir
+from tvm.script import tir as T
+
+N_FEATURES = 164
+
+
[email protected]_func
+def matmul(
+    A: T.Buffer[(512, 512), "float32"],
+    B: T.Buffer[(512, 512), "float32"],
+    C: T.Buffer[(512, 512), "float32"],
+) -> None:
+    # function attr dict
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    # body
+    # with T.block("root")
+    for i0, i1, i2 in T.grid(512, 512, 512):
+        with T.block("C"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads(C[i, j], A[i, k], B[k, j])
+            T.writes(C[i, j])
+            with T.init():
+                C[i, j] = T.float32(0)
+            C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+
+def _make_context(target) -> ms.TuneContext:
+    return ms.TuneContext(
+        target=target,
+        num_threads=1,
+    )
+
+
+def _make_candidate(f_sch: Callable[[], tir.Schedule]) -> ms.MeasureCandidate:
+    return ms.MeasureCandidate(sch=f_sch(), args_info=[])
+
+
+def _feature_names(  # pylint: disable=invalid-name
+    buffers_per_store: int = 5,
+    arith_intensity_curve_num_samples: int = 10,
+) -> List[str]:
+    result = [
+        "float_mad",
+        "float_addsub",
+        "float_mul",
+        "float_divmod",
+        "float_cmp",
+        "float_mathfunc",
+        "float_otherfunc",
+        "int_mad",
+        "int_addsub",
+        "int_mul",
+        "int_divmod",
+        "int_cmp",
+        "int_mathfunc",
+        "int_otherfunc",
+        "bool_op",
+        "select_op",
+        "vec_num",
+        "vec_prod",
+        "vec_len",
+        "vec_type.kPosNone",
+        "vec_type.kPosInnerSpatial",
+        "vec_type.kPosMiddleSpatial",
+        "vec_type.kPosOuterSpatial",
+        "vec_type.kPosInnerReduce",
+        "vec_type.kPosMiddleReduce",
+        "vec_type.kPosOuterReduce",
+        "vec_type.kPosMixed",
+        "unroll_num",
+        "unroll_prod",
+        "unroll_len",
+        "unroll_type.kPosNone",
+        "unroll_type.kPosInnerSpatial",
+        "unroll_type.kPosMiddleSpatial",
+        "unroll_type.kPosOuterSpatial",
+        "unroll_type.kPosInnerReduce",
+        "unroll_type.kPosMiddleReduce",
+        "unroll_type.kPosOuterReduce",
+        "unroll_type.kPosMixed",
+        "parallel_num",
+        "parallel_prod",
+        "parallel_len",
+        "parallel_type.kPosNone",
+        "parallel_type.kPosInnerSpatial",
+        "parallel_type.kPosMiddleSpatial",
+        "parallel_type.kPosOuterSpatial",
+        "parallel_type.kPosInnerReduce",
+        "parallel_type.kPosMiddleReduce",
+        "parallel_type.kPosOuterReduce",
+        "parallel_type.kPosMixed",
+        "is_gpu",
+        "blockIdx_x_len",
+        "blockIdx_y_len",
+        "blockIdx_z_len",
+        "threadIdx_x_len",
+        "threadIdx_y_len",
+        "threadIdx_z_len",
+        "vthread_len",
+    ]
+    for i in range(buffers_per_store):
+        result.extend(
+            f"B{i}.{s}"
+            for s in [
+                "acc_type.kRead",
+                "acc_type.kWrite",
+                "acc_type.kReadWrite",
+                "bytes",
+                "unique_bytes",
+                "lines",
+                "unique_lines",
+                "reuse_type.kLoopMultipleRead",
+                "reuse_type.kSerialMultipleReadWrite",
+                "reuse_type.kNoReuse",
+                "reuse_dis_iter",
+                "reuse_dis_bytes",
+                "reuse_ct",
+                "bytes_d_reuse_ct",
+                "unique_bytes_d_reuse_ct",
+                "lines_d_reuse_ct",
+                "unique_lines_d_reuse_ct",
+                "stride",
+            ]
+        )
+    result.extend(f"arith_intensity_curve_{i}" for i in 
range(arith_intensity_curve_num_samples))
+    result.extend(
+        [
+            "alloc_size",
+            "alloc_prod",
+            "alloc_outer_prod",
+            "alloc_inner_prod",
+            "outer_prod",
+            "num_loops",
+            "auto_unroll_max_step",
+        ]
+    )
+    # 57 + 18 * 5 + 10 + 4 + 3
+    assert len(result) == N_FEATURES
+    return result
+
+
+def _zip_feature(feature, names):
+    assert feature.ndim == 1
+    assert feature.shape[0] == N_FEATURES
+    assert len(names) == N_FEATURES
+    return list(zip(names, feature))
+
+
+def _print_feature(feature, st, ed):  # pylint: disable=invalid-name
+    named_feature = _zip_feature(feature, _feature_names())
+    for k, v in named_feature[st:ed]:
+        print("\t", k, v)
+
+
+def test_cpu_matmul():
+    def _create_schedule():
+        func = matmul
+        sch = tir.Schedule(func, debug_mask="all")
+        block = sch.get_block("C")
+        i, j, k = sch.get_loops(block)
+        i_o, i_i = sch.split(i, factors=[None, 16])  # outer: 32
+        j_o, j_i = sch.split(j, factors=[None, 8])  # outer: 64
+        sch.reorder(i_o, j_o, k, j_i, i_i)
+        sch.vectorize(j_i)
+        sch.parallel(i_o)
+        sch.parallel(j_o)
+        sch.unroll(k)
+        return sch
+
+    extractor = ms.feature_extractor.PerStoreFeature()
+    (feature,) = extractor.extract_from(
+        _make_context(tvm.target.Target("llvm")),
+        candidates=[_make_candidate(_create_schedule)],
+    )
+    feature = feature.numpy()
+    assert feature.shape == (1, N_FEATURES)
+    f = feature[0]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        # fmt: off
+        desired=[
+            # float math ops
+            0, 27, 27, 0, 0, 0, 0,
+            # int math ops
+            0, 29, 29, 0, 0, 0, 0,
+            # bool/select ops
+            0, 0,
+        ],
+        # fmt: on
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[1.0, 3.169924, 3.169924, 0, 0, 0, 0, 0, 0, 0, 1],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[1.0, 9.002815, 9.002815, 0, 0, 0, 0, 0, 0, 0, 1],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[1.58496, 11.0007, 6.022368, 0, 0, 0, 0, 0, 0, 0, 1],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer A
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            1,
+            0,
+            0,
+            29,
+            20,
+            27,
+            14,
+            1,
+            0,
+            0,
+            4.087463,
+            7.0552826,
+            3.169925,
+            26,
+            17,
+            24,
+            11.0007038,
+            9.002815,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer C
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            0.0,
+            0.0,
+            1.0,
+            29.0,
+            20.000001907348633,
+            27.0,
+            14.00008773803711,
+            1.0,
+            0.0,
+            0.0,
+            7.011227130889893,
+            9.250298500061035,
+            9.002815246582031,
+            20.000001907348633,
+            11.000703811645508,
+            18.0000057220459,
+            5.044394016265869,
+            9.002815246582031,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Buffer B
+    assert_allclose(
+        actual=f[93:111],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            29.0,
+            20.000001907348633,
+            19.000001907348633,
+            14.00008773803711,
+            1.0,
+            0.0,
+            0.0,
+            1.0,
+            3.700439691543579,
+            4.087462902069092,
+            25.0,
+            16.000022888183594,
+            15.000043869018555,
+            10.001408576965332,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[
+            0.7097842693328857,
+            0.7408391237258911,
+            0.8750449419021606,
+            0.9449487924575806,
+            1.0148526430130005,
+            1.0847564935684204,
+            1.113688349723816,
+            1.1394684314727783,
+            1.2119636535644531,
+            1.2971993684768677,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            20.000001907348633,
+            18.0000057220459,
+            1.0,
+            27.0,
+            27.0,
+            2.5849626064300537,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+
+
+def test_cpu_fusion():
+    # pylint: disable=all
+    @T.prim_func
+    def func(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [64, 32], dtype="float32")
+        B = T.match_buffer(b, [64, 32], dtype="float32")
+        C = T.match_buffer(c, [64, 32], dtype="float32")
+        for i, j in T.grid(64, 32):  # type: ignore
+            with T.block():
+                T.reads([A[i, j], B[i, j]])  # type: ignore
+                T.writes([B[i, j], C[i, j]])  # type: ignore
+                with T.block("B"):
+                    T.reads([A[i, j]])  # type: ignore
+                    T.writes([B[i, j]])  # type: ignore
+                    B[i, j] = A[i, j]  # type: ignore
+                with T.block("C"):
+                    T.reads([B[i, j]])  # type: ignore
+                    T.writes([C[i, j]])  # type: ignore
+                    C[i, j] = B[i, j]  # type: ignore
+
+    # pylint: enable=all
+
+    def _create_schedule():
+        return tir.Schedule(func, debug_mask="all")
+
+    extractor = ms.feature_extractor.PerStoreFeature()
+    (feature,) = extractor.extract_from(
+        _make_context(tvm.target.Target("llvm")),
+        candidates=[_make_candidate(_create_schedule)],
+    )
+    feature = feature.numpy()
+    assert feature.shape == (2, N_FEATURES)
+    ## Features for BufferStore(B)
+    f = feature[0]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        # fmt: off
+        desired=[0.0] * 16,
+        # fmt: on
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer A
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            13.000176429748535,
+            13.000176429748535,
+            7.011227130889893,
+            7.011227130889893,
+            0.0,
+            0.0,
+            1.0,
+            0.0,
+            0.0,
+            0.0,
+            14.00008773803711,
+            14.00008773803711,
+            8.005624771118164,
+            8.005624771118164,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer B
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            0.0,
+            1.0,
+            0.0,
+            13.000176429748535,
+            13.000176429748535,
+            7.011227130889893,
+            7.011227130889893,
+            0.0,
+            0.0,
+            1.0,
+            0.0,
+            0.0,
+            0.0,
+            14.00008773803711,
+            14.00008773803711,
+            8.005624771118164,
+            8.005624771118164,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Dummy padding
+    assert_allclose(
+        actual=f[93:111],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            13.000176,
+            11.000703811645508,
+            1.0,
+            11.000703811645508,
+            11.000703811645508,
+            1.5849624872207642,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    ## Features for BufferStore(C)
+    f = feature[1]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        # fmt: off
+        desired=[0.0] * 16,
+        # fmt: on
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer B
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            13.000176429748535,
+            13.000176429748535,
+            7.011227130889893,
+            7.011227130889893,
+            0.0,
+            1.0,
+            0.0,
+            1.0,
+            4.087462902069092,
+            1.0,
+            13.000176429748535,
+            13.000176429748535,
+            7.011227130889893,
+            7.011227130889893,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer C
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            0.0,
+            1.0,
+            0.0,
+            13.000176429748535,
+            13.000176429748535,
+            7.011227130889893,
+            7.011227130889893,
+            0.0,
+            0.0,
+            1.0,
+            0.0,
+            0.0,
+            0.0,
+            14.00008773803711,
+            14.00008773803711,
+            8.005624771118164,
+            8.005624771118164,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Dummy padding
+    assert_allclose(
+        actual=f[93:111],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[0.0] * 18,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            13.000176429748535,
+            11.000703811645508,
+            1.0,
+            11.000703811645508,
+            11.000703811645508,
+            1.5849624872207642,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+
+
+def test_gpu():
+    def _create_schedule():
+        func = matmul
+        sch = tir.Schedule(func, debug_mask="all")
+        c = sch.get_block("C")
+        c_local = sch.cache_write(c, 0, "local")
+        i, j, k = sch.get_loops(c)
+        # pylint: disable=invalid-name
+        i0, i1, i2, i3, i4 = sch.split(i, factors=[None, 1, 16, 32, 1])  # 
outer: 1
+        j0, j1, j2, j3, j4 = sch.split(j, factors=[None, 4, 1, 1, 16])  # 
outer: 8
+        k0, k1, k2 = sch.split(k, factors=[None, 1, 2])  # outer: 256
+        # pylint: enable=invalid-name
+        # fmt: off
+        sch.reorder(
+            i0, j0,  # S
+            i1, j1,  # S
+            i2, j2,  # S
+            k0,      # R
+            k1,      # R
+            i3, j3,  # S
+            k2,      # R
+            i4, j4,  # S
+        )
+        # fmt: on
+        # thread binding
+        i0_j0 = sch.fuse(i0, j0)
+        i1_j1 = sch.fuse(i1, j1)
+        i2_j2 = sch.fuse(i2, j2)
+        sch.bind(i0_j0, "blockIdx.x")
+        sch.bind(i1_j1, "vthread.x")
+        sch.bind(i2_j2, "threadIdx.x")
+        # fusion
+        sch.reverse_compute_at(c_local, i2_j2)
+        # cache read 'A'
+        a_shared = sch.cache_read(c, 1, "shared")
+        sch.compute_at(a_shared, k0)
+        _, _, _, _, a_i, a_j = sch.get_loops(a_shared)
+        a_ij = sch.fuse(a_i, a_j)
+        _, a_j = sch.split(a_ij, factors=[None, 16])  # outer: 64
+        sch.bind(a_j, "threadIdx.x")
+        # cache read 'B'
+        b_shared = sch.cache_read(c, 2, "shared")
+        sch.compute_at(b_shared, k0)
+        _, _, _, _, b_i, b_j = sch.get_loops(b_shared)
+        b_ij = sch.fuse(b_i, b_j)
+        _, b_j = sch.split(b_ij, factors=[None, 16])  # outer: 8
+        sch.bind(b_j, "threadIdx.x")
+        # auto unroll
+        sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32", 
1024))
+        sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1))
+        return sch
+
+    extractor = ms.feature_extractor.PerStoreFeature()
+    (feature,) = extractor.extract_from(
+        _make_context(tvm.target.Target("cuda")),
+        candidates=[_make_candidate(_create_schedule)],
+    )
+    feature = feature.numpy()
+    assert feature.shape == (4, N_FEATURES)
+    ### Check feature[0]: BufferStore(A_shared) <= A[...]
+    f = feature[0]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            24.000000085991324,
+            24.000000085991324,
+            24.000000085991324,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 
1.0, 2.321928094887362],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer A
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            25.000000042995662,
+            20.000001375860553,
+            23.00000017198264,
+            14.000088052430122,
+            1.0,
+            0.0,
+            0.0,
+            18.00000550343433,
+            20.00562591970089,
+            2.321928094887362,
+            23.00000017198264,
+            18.00000550343433,
+            21.000000687930438,
+            12.0003521774803,
+            12.0003521774803,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer A.shared
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            0.0,
+            1.0,
+            0.0,
+            25.000000042995662,
+            12.0003521774803,
+            23.00000017198264,
+            9.002815015607053,
+            1.0,
+            0.0,
+            0.0,
+            6.022367813028454,
+            11.98049663618346,
+            8.005624549193879,
+            17.000011006847668,
+            4.087462841250339,
+            15.000044026886828,
+            1.584962500721156,
+            4.087462841250339,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Dummy padding
+    assert_allclose(
+        actual=f[93:111],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            12.0003521774803,
+            27.000000010748916,
+            17.000011006847668,
+            6.022367813028454,
+            23.00000017198264,
+            2.584962500721156,
+            10.001408,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    ### Check feature[1]: BufferStore(B_shared) <= B[...]
+    f = feature[1]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            22.00000034396526,
+            22.00000034396526,
+            21.000000687930438,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 
1.0, 2.321928094887362],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer B
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            22.00000034396526,
+            20.000001375860553,
+            20.000001375860553,
+            14.000088052430122,
+            1.0,
+            0.0,
+            0.0,
+            15.000044026886828,
+            20.17555076886471,
+            2.321928094887362,
+            20.000001375860553,
+            18.00000550343433,
+            18.00000550343433,
+            12.0003521774803,
+            4.087462841250339,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer B.shared
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            0.0,
+            1.0,
+            0.0,
+            22.00000034396526,
+            9.002815015607053,
+            20.000001375860553,
+            3.169925001442312,
+            1.0,
+            0.0,
+            0.0,
+            3.169925001442312,
+            10.001408194392809,
+            8.005624549193879,
+            14.000088052430122,
+            1.584962500721156,
+            12.0003521774803,
+            0.044394119358453436,
+            4.087462841250339,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Dummy padding
+    assert_allclose(
+        actual=f[93:111],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            9.002815015607053,
+            24.000000085991324,
+            17.000011006847668,
+            3.169925001442312,
+            20.000001375860553,
+            2.584962500721156,
+            10.001408,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    ### Check feature[2]: BufferStore(C_local) <= C_local[...] + A_shared[...] 
* B_shared[...]
+    f = feature[2]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        desired=[
+            0.0,
+            27.000000010748916,
+            27.000000010748916,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            28.000000005374456,
+            28.000000005374456,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 
1.0, 2.321928094887362],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer B.shared
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            29.00000000268723,
+            9.002815015607053,
+            23.00000017198264,
+            3.169925001442312,
+            1.0,
+            0.0,
+            0.0,
+            5.044394119358453,
+            7.651051691178929,
+            5.044394119358453,
+            24.000000085991324,
+            4.087462841250339,
+            18.00000550343433,
+            0.32192809488736235,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer C.local
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            0.0,
+            0.0,
+            1.0,
+            29.00000000268723,
+            11.000704269011246,
+            23.00000017198264,
+            5.044394119358453,
+            1.0,
+            0.0,
+            0.0,
+            4.087462841250339,
+            7.05528243550119,
+            1.584962500721156,
+            28.000000005374456,
+            10.001408194392809,
+            22.00000034396526,
+            4.087462841250339,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Buffer A.shared
+    assert_allclose(
+        actual=f[93:111],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            29.00000000268723,
+            12.0003521774803,
+            19.00000275171979,
+            9.002815015607053,
+            1.0,
+            0.0,
+            0.0,
+            1.0,
+            3.700439718141092,
+            4.087462841250339,
+            25.000000042995662,
+            8.005624549193879,
+            15.000044026886828,
+            5.044394119358453,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[
+            0.7097842504665767,
+            0.7548801745187567,
+            0.8775907547541741,
+            0.9957389916154509,
+            1.2446737395193135,
+            1.493608487423176,
+            1.7093103019954263,
+            1.8031580276850985,
+            1.9841832691827785,
+            2.204648076869754,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            11.000704269011246,
+            18.00000550343433,
+            9.002815015607053,
+            18.00000550343433,
+            27.000000010748916,
+            3.0,
+            10.001408,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    ### Check feature[3]: BufferStore(C) <= C_local[...]
+    f = feature[3]
+    # Group 1.1: arith
+    assert_allclose(
+        actual=f[0:16],
+        desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.2: vectorize
+    assert_allclose(
+        actual=f[16:27],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.3: unroll
+    assert_allclose(
+        actual=f[27:38],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.4: parallel
+    assert_allclose(
+        actual=f[38:49],
+        desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+    assert_allclose(
+        actual=f[49:57],
+        desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 
1.0, 2.321928094887362],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.1: Buffer C
+    assert_allclose(
+        actual=f[57:75],
+        desired=[
+            0.0,
+            1.0,
+            0.0,
+            20.000001375860553,
+            20.000001375860553,
+            14.000088052430122,
+            14.000088052430122,
+            0.0,
+            0.0,
+            1.0,
+            0.0,
+            0.0,
+            0.0,
+            21.000000687930438,
+            21.000000687930438,
+            15.000044026886828,
+            15.000044026886828,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.2: Buffer C.local
+    assert_allclose(
+        actual=f[75:93],
+        desired=[
+            1.0,
+            0.0,
+            0.0,
+            20.000001375860553,
+            11.000704269011246,
+            14.000088052430122,
+            5.044394119358453,
+            1.0,
+            0.0,
+            0.0,
+            9.002815015607053,
+            12.0003521774803,
+            4.087462841250339,
+            16.00002201361136,
+            7.011227255423254,
+            10.001408194392809,
+            1.584962500721156,
+            1.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.3: Dummy padding
+    assert_allclose(
+        actual=f[93:111],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.4: Dummy padding
+    assert_allclose(
+        actual=f[111:129],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 2.5: Dummy padding
+    assert_allclose(
+        actual=f[129:147],
+        desired=[
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+            0.0,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 3: Arithmetic intensity
+    assert_allclose(
+        actual=f[147:157],
+        desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+    # Group 4 & 5
+    assert_allclose(
+        actual=f[157:164],
+        desired=[
+            20.000001375860553,
+            18.00000550343433,
+            1.0,
+            18.00000550343433,
+            18.00000550343433,
+            2.584962500721156,
+            10.001408,
+        ],
+        rtol=1e-5,
+        atol=1e-5,
+    )
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to